package websocket import ( "context" "messenger/internal/pkg/logger" "messenger/internal/service" "sync" "time" ) type BroadcastMessage struct { Type MessageType ChatID int64 MessageID int64 UserID int64 SenderID int64 Data interface{} Client *Client } type Hub struct { register chan *Client unregister chan *Client broadcast chan *BroadcastMessage clients map[int64]*Client clientsMu sync.RWMutex rooms map[int64]*Room roomsMu sync.RWMutex messageService *service.MessageService chatService *service.ChatService running bool mu sync.RWMutex } func NewHub(messageService *service.MessageService, chatService *service.ChatService) *Hub { return &Hub{ register: make(chan *Client), unregister: make(chan *Client), broadcast: make(chan *BroadcastMessage, 256), clients: make(map[int64]*Client), rooms: make(map[int64]*Room), messageService: messageService, chatService: chatService, running: true, } } func (h *Hub) Run() { logger.Info("WebSocket Hub started") for h.running { select { case client := <-h.register: h.handleRegister(client) case client := <-h.unregister: h.handleUnregister(client) case broadcast := <-h.broadcast: h.handleBroadcast(broadcast) } } logger.Info("WebSocket Hub stopped") } func (h *Hub) Stop() { h.mu.Lock() defer h.mu.Unlock() h.running = false close(h.register) close(h.unregister) close(h.broadcast) h.clientsMu.Lock() for _, client := range h.clients { func() { defer func() { if r := recover(); r != nil { logger.Warn("Recovered while closing client", "error", r) } }() close(client.send) client.conn.Close() }() } h.clients = make(map[int64]*Client) h.clientsMu.Unlock() } func (h *Hub) handleRegister(client *Client) { h.clientsMu.Lock() if oldClient, exists := h.clients[client.userID]; exists { func() { defer func() { if r := recover(); r != nil { logger.Warn("Recovered while closing old client", "error", r) } }() close(oldClient.send) oldClient.conn.Close() }() } h.clients[client.userID] = client h.clientsMu.Unlock() go h.addClientToChats(client) go h.notifyUserOnline(client.userID, true) logger.Info("Client registered", "user_id", client.userID) } func (h *Hub) handleUnregister(client *Client) { h.clientsMu.Lock() delete(h.clients, client.userID) h.clientsMu.Unlock() // Закрываем канал безопасно func() { defer func() { if r := recover(); r != nil { logger.Warn("Recovered while closing send channel", "error", r) } }() close(client.send) }() h.roomsMu.Lock() for _, room := range h.rooms { room.RemoveClient(client) } h.roomsMu.Unlock() go h.notifyUserOnline(client.userID, false) logger.Info("Client unregistered", "user_id", client.userID) } func (h *Hub) handleBroadcast(broadcast *BroadcastMessage) { defer func() { if r := recover(); r != nil { logger.Error("Panic in handleBroadcast", "error", r) } }() switch broadcast.Type { case MsgTypeNewMessage: h.handleNewMessageBroadcast(broadcast) case MsgTypeReadReceipt: h.handleReadReceiptBroadcast(broadcast) case MsgTypeEditMessage: h.handleEditMessageBroadcast(broadcast) case MsgTypeDeleteMessage: h.handleDeleteMessageBroadcast(broadcast) } } func (h *Hub) handleNewMessageBroadcast(broadcast *BroadcastMessage) { req, ok := broadcast.Data.(NewMessageRequest) if !ok { logger.Error("Invalid broadcast data for new_message") return } ctx := context.Background() message, err := h.messageService.SendMessage(ctx, broadcast.Client.userID, req.ChatID, req.Plaintext, req.AttachmentID) if err != nil { logger.Error("Failed to save message", "error", err) broadcast.Client.sendError(500, "Failed to send message") return } response := NewMessageResponse{ ID: message.ID, ChatID: message.ChatID, SenderID: message.SenderID, Plaintext: message.Plaintext, CreatedAt: message.CreatedAt, TempID: req.TempID, } msgBytes, err := CreateMessage(MsgTypeNewMessageResp, response) if err != nil { logger.Error("Failed to create message", "error", err) return } h.roomsMu.RLock() room, exists := h.rooms[req.ChatID] h.roomsMu.RUnlock() if exists { room.BroadcastToAll(msgBytes) } } func (h *Hub) handleReadReceiptBroadcast(broadcast *BroadcastMessage) { ctx := context.Background() err := h.messageService.MarkMessageAsRead(ctx, broadcast.MessageID, broadcast.UserID) if err != nil { logger.Error("Failed to mark message as read", "error", err) return } message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID) if err != nil { logger.Error("Failed to get message for read receipt", "error", err) return } response := ReadReceiptResponse{ MessageID: broadcast.MessageID, UserID: broadcast.UserID, ReadAt: time.Now(), } msgBytes, err := CreateMessage(MsgTypeMessageRead, response) if err != nil { logger.Error("Failed to create read receipt message", "error", err) return } h.roomsMu.RLock() room, exists := h.rooms[message.ChatID] h.roomsMu.RUnlock() if exists { room.BroadcastToAll(msgBytes) } } func (h *Hub) handleEditMessageBroadcast(broadcast *BroadcastMessage) { req, ok := broadcast.Data.(EditMessageRequest) if !ok { logger.Error("Invalid broadcast data for edit_message") return } ctx := context.Background() err := h.messageService.EditMessage(ctx, broadcast.SenderID, req.MessageID, req.Plaintext) if err != nil { logger.Error("Failed to edit message", "error", err) if broadcast.Client != nil { broadcast.Client.sendError(403, "Failed to edit message") } return } message, err := h.messageService.GetMessageByID(ctx, req.MessageID) if err != nil { logger.Error("Failed to get edited message", "error", err) return } response := EditMessageResponse{ MessageID: req.MessageID, NewPlaintext: req.Plaintext, EditedAt: time.Now(), } msgBytes, err := CreateMessage(MsgTypeMessageEdited, response) if err != nil { logger.Error("Failed to create edit message", "error", err) return } h.roomsMu.RLock() room, exists := h.rooms[message.ChatID] h.roomsMu.RUnlock() if exists { room.BroadcastToAll(msgBytes) } } func (h *Hub) handleDeleteMessageBroadcast(broadcast *BroadcastMessage) { ctx := context.Background() message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID) if err != nil { logger.Error("Failed to get message for deletion", "error", err) return } canDelete, err := h.messageService.CanDeleteMessage(ctx, broadcast.SenderID, broadcast.MessageID) if err != nil || !canDelete { logger.Error("User cannot delete message", "user_id", broadcast.SenderID, "message_id", broadcast.MessageID) if broadcast.Client != nil { broadcast.Client.sendError(403, "Cannot delete this message") } return } err = h.messageService.DeleteMessage(ctx, broadcast.MessageID) if err != nil { logger.Error("Failed to delete message", "error", err) return } response := DeleteMessageResponse{ MessageID: broadcast.MessageID, DeletedAt: time.Now(), } msgBytes, err := CreateMessage(MsgTypeMessageDeleted, response) if err != nil { logger.Error("Failed to create delete message", "error", err) return } h.roomsMu.RLock() room, exists := h.rooms[message.ChatID] h.roomsMu.RUnlock() if exists { room.BroadcastToAll(msgBytes) } } func (h *Hub) addClientToChats(client *Client) { ctx := context.Background() chats, err := h.chatService.GetUserChats(ctx, client.userID) if err != nil { logger.Error("Failed to get user chats", "user_id", client.userID, "error", err) return } for _, chat := range chats { h.roomsMu.Lock() room, exists := h.rooms[chat.ID] if !exists { room = NewRoom(chat.ID) h.rooms[chat.ID] = room } h.roomsMu.Unlock() room.AddClient(client) } logger.Info("Client added to rooms", "user_id", client.userID, "room_count", len(chats)) } func (h *Hub) notifyUserOnline(userID int64, isOnline bool) { // Защита от паники defer func() { if r := recover(); r != nil { logger.Warn("Recovered in notifyUserOnline", "error", r) } }() response := UserOnlineResponse{ UserID: userID, IsOnline: isOnline, } msgBytes, err := CreateMessage(MsgTypeUserOnline, response) if err != nil { logger.Error("Failed to create user online message", "error", err) return } h.roomsMu.RLock() defer h.roomsMu.RUnlock() for _, room := range h.rooms { if room.HasClientByUserID(userID) { room.BroadcastToAll(msgBytes) } } } func (h *Hub) GetRoom(chatID int64) (*Room, bool) { h.roomsMu.RLock() defer h.roomsMu.RUnlock() room, exists := h.rooms[chatID] return room, exists } func (h *Hub) GetClient(userID int64) (*Client, bool) { h.clientsMu.RLock() defer h.clientsMu.RUnlock() client, exists := h.clients[userID] return client, exists } func (h *Hub) SendToUser(userID int64, message []byte) bool { h.clientsMu.RLock() client, exists := h.clients[userID] h.clientsMu.RUnlock() if exists { client.SendMessage(message) return true } return false } // GetRegisterChan возвращает канал для регистрации клиентов func (h *Hub) GetRegisterChan() chan<- *Client { return h.register } // GetUnregisterChan возвращает канал для отмены регистрации клиентов func (h *Hub) GetUnregisterChan() chan<- *Client { return h.unregister } // GetBroadcastChan возвращает канал для широковещательных сообщений func (h *Hub) GetBroadcastChan() chan<- *BroadcastMessage { return h.broadcast }