414 lines
9.6 KiB
Go
414 lines
9.6 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"messenger/internal/pkg/logger"
|
|
"messenger/internal/service"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type BroadcastMessage struct {
|
|
Type MessageType
|
|
ChatID int64
|
|
MessageID int64
|
|
UserID int64
|
|
SenderID int64
|
|
Data interface{}
|
|
Client *Client
|
|
}
|
|
|
|
type Hub struct {
|
|
register chan *Client
|
|
unregister chan *Client
|
|
broadcast chan *BroadcastMessage
|
|
clients map[int64]*Client
|
|
clientsMu sync.RWMutex
|
|
rooms map[int64]*Room
|
|
roomsMu sync.RWMutex
|
|
messageService *service.MessageService
|
|
chatService *service.ChatService
|
|
running bool
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewHub(messageService *service.MessageService, chatService *service.ChatService) *Hub {
|
|
return &Hub{
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
broadcast: make(chan *BroadcastMessage, 256),
|
|
clients: make(map[int64]*Client),
|
|
rooms: make(map[int64]*Room),
|
|
messageService: messageService,
|
|
chatService: chatService,
|
|
running: true,
|
|
}
|
|
}
|
|
|
|
func (h *Hub) Run() {
|
|
logger.Info("WebSocket Hub started")
|
|
|
|
for h.running {
|
|
select {
|
|
case client := <-h.register:
|
|
h.handleRegister(client)
|
|
case client := <-h.unregister:
|
|
h.handleUnregister(client)
|
|
case broadcast := <-h.broadcast:
|
|
h.handleBroadcast(broadcast)
|
|
}
|
|
}
|
|
|
|
logger.Info("WebSocket Hub stopped")
|
|
}
|
|
|
|
func (h *Hub) Stop() {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
h.running = false
|
|
close(h.register)
|
|
close(h.unregister)
|
|
close(h.broadcast)
|
|
|
|
h.clientsMu.Lock()
|
|
for _, client := range h.clients {
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
logger.Warn("Recovered while closing client", "error", r)
|
|
}
|
|
}()
|
|
close(client.send)
|
|
client.conn.Close()
|
|
}()
|
|
}
|
|
h.clients = make(map[int64]*Client)
|
|
h.clientsMu.Unlock()
|
|
}
|
|
|
|
func (h *Hub) handleRegister(client *Client) {
|
|
h.clientsMu.Lock()
|
|
if oldClient, exists := h.clients[client.userID]; exists {
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
logger.Warn("Recovered while closing old client", "error", r)
|
|
}
|
|
}()
|
|
close(oldClient.send)
|
|
oldClient.conn.Close()
|
|
}()
|
|
}
|
|
h.clients[client.userID] = client
|
|
h.clientsMu.Unlock()
|
|
|
|
go h.addClientToChats(client)
|
|
go h.notifyUserOnline(client.userID, true)
|
|
|
|
logger.Info("Client registered", "user_id", client.userID)
|
|
}
|
|
|
|
func (h *Hub) handleUnregister(client *Client) {
|
|
h.clientsMu.Lock()
|
|
delete(h.clients, client.userID)
|
|
h.clientsMu.Unlock()
|
|
|
|
// Закрываем канал безопасно
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
logger.Warn("Recovered while closing send channel", "error", r)
|
|
}
|
|
}()
|
|
close(client.send)
|
|
}()
|
|
|
|
h.roomsMu.Lock()
|
|
for _, room := range h.rooms {
|
|
room.RemoveClient(client)
|
|
}
|
|
h.roomsMu.Unlock()
|
|
|
|
go h.notifyUserOnline(client.userID, false)
|
|
|
|
logger.Info("Client unregistered", "user_id", client.userID)
|
|
}
|
|
|
|
func (h *Hub) handleBroadcast(broadcast *BroadcastMessage) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
logger.Error("Panic in handleBroadcast", "error", r)
|
|
}
|
|
}()
|
|
|
|
switch broadcast.Type {
|
|
case MsgTypeNewMessage:
|
|
h.handleNewMessageBroadcast(broadcast)
|
|
case MsgTypeReadReceipt:
|
|
h.handleReadReceiptBroadcast(broadcast)
|
|
case MsgTypeEditMessage:
|
|
h.handleEditMessageBroadcast(broadcast)
|
|
case MsgTypeDeleteMessage:
|
|
h.handleDeleteMessageBroadcast(broadcast)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) handleNewMessageBroadcast(broadcast *BroadcastMessage) {
|
|
req, ok := broadcast.Data.(NewMessageRequest)
|
|
if !ok {
|
|
logger.Error("Invalid broadcast data for new_message")
|
|
return
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
message, err := h.messageService.SendMessage(ctx, broadcast.Client.userID, req.ChatID, req.Plaintext, req.AttachmentID)
|
|
if err != nil {
|
|
logger.Error("Failed to save message", "error", err)
|
|
broadcast.Client.sendError(500, "Failed to send message")
|
|
return
|
|
}
|
|
|
|
response := NewMessageResponse{
|
|
ID: message.ID,
|
|
ChatID: message.ChatID,
|
|
SenderID: message.SenderID,
|
|
Plaintext: message.Plaintext,
|
|
CreatedAt: message.CreatedAt,
|
|
TempID: req.TempID,
|
|
}
|
|
|
|
msgBytes, err := CreateMessage(MsgTypeNewMessageResp, response)
|
|
if err != nil {
|
|
logger.Error("Failed to create message", "error", err)
|
|
return
|
|
}
|
|
|
|
h.roomsMu.RLock()
|
|
room, exists := h.rooms[req.ChatID]
|
|
h.roomsMu.RUnlock()
|
|
|
|
if exists {
|
|
room.BroadcastToAll(msgBytes)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) handleReadReceiptBroadcast(broadcast *BroadcastMessage) {
|
|
ctx := context.Background()
|
|
|
|
err := h.messageService.MarkMessageAsRead(ctx, broadcast.MessageID, broadcast.UserID)
|
|
if err != nil {
|
|
logger.Error("Failed to mark message as read", "error", err)
|
|
return
|
|
}
|
|
|
|
message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID)
|
|
if err != nil {
|
|
logger.Error("Failed to get message for read receipt", "error", err)
|
|
return
|
|
}
|
|
|
|
response := ReadReceiptResponse{
|
|
MessageID: broadcast.MessageID,
|
|
UserID: broadcast.UserID,
|
|
ReadAt: time.Now(),
|
|
}
|
|
|
|
msgBytes, err := CreateMessage(MsgTypeMessageRead, response)
|
|
if err != nil {
|
|
logger.Error("Failed to create read receipt message", "error", err)
|
|
return
|
|
}
|
|
|
|
h.roomsMu.RLock()
|
|
room, exists := h.rooms[message.ChatID]
|
|
h.roomsMu.RUnlock()
|
|
|
|
if exists {
|
|
room.BroadcastToAll(msgBytes)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) handleEditMessageBroadcast(broadcast *BroadcastMessage) {
|
|
req, ok := broadcast.Data.(EditMessageRequest)
|
|
if !ok {
|
|
logger.Error("Invalid broadcast data for edit_message")
|
|
return
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
err := h.messageService.EditMessage(ctx, broadcast.SenderID, req.MessageID, req.Plaintext)
|
|
if err != nil {
|
|
logger.Error("Failed to edit message", "error", err)
|
|
if broadcast.Client != nil {
|
|
broadcast.Client.sendError(403, "Failed to edit message")
|
|
}
|
|
return
|
|
}
|
|
|
|
message, err := h.messageService.GetMessageByID(ctx, req.MessageID)
|
|
if err != nil {
|
|
logger.Error("Failed to get edited message", "error", err)
|
|
return
|
|
}
|
|
|
|
response := EditMessageResponse{
|
|
MessageID: req.MessageID,
|
|
NewPlaintext: req.Plaintext,
|
|
EditedAt: time.Now(),
|
|
}
|
|
|
|
msgBytes, err := CreateMessage(MsgTypeMessageEdited, response)
|
|
if err != nil {
|
|
logger.Error("Failed to create edit message", "error", err)
|
|
return
|
|
}
|
|
|
|
h.roomsMu.RLock()
|
|
room, exists := h.rooms[message.ChatID]
|
|
h.roomsMu.RUnlock()
|
|
|
|
if exists {
|
|
room.BroadcastToAll(msgBytes)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) handleDeleteMessageBroadcast(broadcast *BroadcastMessage) {
|
|
ctx := context.Background()
|
|
|
|
message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID)
|
|
if err != nil {
|
|
logger.Error("Failed to get message for deletion", "error", err)
|
|
return
|
|
}
|
|
|
|
canDelete, err := h.messageService.CanDeleteMessage(ctx, broadcast.SenderID, broadcast.MessageID)
|
|
if err != nil || !canDelete {
|
|
logger.Error("User cannot delete message", "user_id", broadcast.SenderID, "message_id", broadcast.MessageID)
|
|
if broadcast.Client != nil {
|
|
broadcast.Client.sendError(403, "Cannot delete this message")
|
|
}
|
|
return
|
|
}
|
|
|
|
err = h.messageService.DeleteMessage(ctx, broadcast.MessageID)
|
|
if err != nil {
|
|
logger.Error("Failed to delete message", "error", err)
|
|
return
|
|
}
|
|
|
|
response := DeleteMessageResponse{
|
|
MessageID: broadcast.MessageID,
|
|
DeletedAt: time.Now(),
|
|
}
|
|
|
|
msgBytes, err := CreateMessage(MsgTypeMessageDeleted, response)
|
|
if err != nil {
|
|
logger.Error("Failed to create delete message", "error", err)
|
|
return
|
|
}
|
|
|
|
h.roomsMu.RLock()
|
|
room, exists := h.rooms[message.ChatID]
|
|
h.roomsMu.RUnlock()
|
|
|
|
if exists {
|
|
room.BroadcastToAll(msgBytes)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) addClientToChats(client *Client) {
|
|
ctx := context.Background()
|
|
|
|
chats, err := h.chatService.GetUserChats(ctx, client.userID)
|
|
if err != nil {
|
|
logger.Error("Failed to get user chats", "user_id", client.userID, "error", err)
|
|
return
|
|
}
|
|
|
|
for _, chat := range chats {
|
|
h.roomsMu.Lock()
|
|
room, exists := h.rooms[chat.ID]
|
|
if !exists {
|
|
room = NewRoom(chat.ID)
|
|
h.rooms[chat.ID] = room
|
|
}
|
|
h.roomsMu.Unlock()
|
|
|
|
room.AddClient(client)
|
|
}
|
|
|
|
logger.Info("Client added to rooms", "user_id", client.userID, "room_count", len(chats))
|
|
}
|
|
|
|
func (h *Hub) notifyUserOnline(userID int64, isOnline bool) {
|
|
// Защита от паники
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
logger.Warn("Recovered in notifyUserOnline", "error", r)
|
|
}
|
|
}()
|
|
|
|
response := UserOnlineResponse{
|
|
UserID: userID,
|
|
IsOnline: isOnline,
|
|
}
|
|
|
|
msgBytes, err := CreateMessage(MsgTypeUserOnline, response)
|
|
if err != nil {
|
|
logger.Error("Failed to create user online message", "error", err)
|
|
return
|
|
}
|
|
|
|
h.roomsMu.RLock()
|
|
defer h.roomsMu.RUnlock()
|
|
|
|
for _, room := range h.rooms {
|
|
if room.HasClientByUserID(userID) {
|
|
room.BroadcastToAll(msgBytes)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) GetRoom(chatID int64) (*Room, bool) {
|
|
h.roomsMu.RLock()
|
|
defer h.roomsMu.RUnlock()
|
|
room, exists := h.rooms[chatID]
|
|
return room, exists
|
|
}
|
|
|
|
func (h *Hub) GetClient(userID int64) (*Client, bool) {
|
|
h.clientsMu.RLock()
|
|
defer h.clientsMu.RUnlock()
|
|
client, exists := h.clients[userID]
|
|
return client, exists
|
|
}
|
|
|
|
func (h *Hub) SendToUser(userID int64, message []byte) bool {
|
|
h.clientsMu.RLock()
|
|
client, exists := h.clients[userID]
|
|
h.clientsMu.RUnlock()
|
|
|
|
if exists {
|
|
client.SendMessage(message)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetRegisterChan возвращает канал для регистрации клиентов
|
|
func (h *Hub) GetRegisterChan() chan<- *Client {
|
|
return h.register
|
|
}
|
|
|
|
// GetUnregisterChan возвращает канал для отмены регистрации клиентов
|
|
func (h *Hub) GetUnregisterChan() chan<- *Client {
|
|
return h.unregister
|
|
}
|
|
|
|
// GetBroadcastChan возвращает канал для широковещательных сообщений
|
|
func (h *Hub) GetBroadcastChan() chan<- *BroadcastMessage {
|
|
return h.broadcast
|
|
} |