Files
efir-api-server/internal/websocket/hub.go

414 lines
9.6 KiB
Go

package websocket
import (
"context"
"messenger/internal/pkg/logger"
"messenger/internal/service"
"sync"
"time"
)
type BroadcastMessage struct {
Type MessageType
ChatID int64
MessageID int64
UserID int64
SenderID int64
Data interface{}
Client *Client
}
type Hub struct {
register chan *Client
unregister chan *Client
broadcast chan *BroadcastMessage
clients map[int64]*Client
clientsMu sync.RWMutex
rooms map[int64]*Room
roomsMu sync.RWMutex
messageService *service.MessageService
chatService *service.ChatService
running bool
mu sync.RWMutex
}
func NewHub(messageService *service.MessageService, chatService *service.ChatService) *Hub {
return &Hub{
register: make(chan *Client),
unregister: make(chan *Client),
broadcast: make(chan *BroadcastMessage, 256),
clients: make(map[int64]*Client),
rooms: make(map[int64]*Room),
messageService: messageService,
chatService: chatService,
running: true,
}
}
func (h *Hub) Run() {
logger.Info("WebSocket Hub started")
for h.running {
select {
case client := <-h.register:
h.handleRegister(client)
case client := <-h.unregister:
h.handleUnregister(client)
case broadcast := <-h.broadcast:
h.handleBroadcast(broadcast)
}
}
logger.Info("WebSocket Hub stopped")
}
func (h *Hub) Stop() {
h.mu.Lock()
defer h.mu.Unlock()
h.running = false
close(h.register)
close(h.unregister)
close(h.broadcast)
h.clientsMu.Lock()
for _, client := range h.clients {
func() {
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered while closing client", "error", r)
}
}()
close(client.send)
client.conn.Close()
}()
}
h.clients = make(map[int64]*Client)
h.clientsMu.Unlock()
}
func (h *Hub) handleRegister(client *Client) {
h.clientsMu.Lock()
if oldClient, exists := h.clients[client.userID]; exists {
func() {
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered while closing old client", "error", r)
}
}()
close(oldClient.send)
oldClient.conn.Close()
}()
}
h.clients[client.userID] = client
h.clientsMu.Unlock()
go h.addClientToChats(client)
go h.notifyUserOnline(client.userID, true)
logger.Info("Client registered", "user_id", client.userID)
}
func (h *Hub) handleUnregister(client *Client) {
h.clientsMu.Lock()
delete(h.clients, client.userID)
h.clientsMu.Unlock()
// Закрываем канал безопасно
func() {
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered while closing send channel", "error", r)
}
}()
close(client.send)
}()
h.roomsMu.Lock()
for _, room := range h.rooms {
room.RemoveClient(client)
}
h.roomsMu.Unlock()
go h.notifyUserOnline(client.userID, false)
logger.Info("Client unregistered", "user_id", client.userID)
}
func (h *Hub) handleBroadcast(broadcast *BroadcastMessage) {
defer func() {
if r := recover(); r != nil {
logger.Error("Panic in handleBroadcast", "error", r)
}
}()
switch broadcast.Type {
case MsgTypeNewMessage:
h.handleNewMessageBroadcast(broadcast)
case MsgTypeReadReceipt:
h.handleReadReceiptBroadcast(broadcast)
case MsgTypeEditMessage:
h.handleEditMessageBroadcast(broadcast)
case MsgTypeDeleteMessage:
h.handleDeleteMessageBroadcast(broadcast)
}
}
func (h *Hub) handleNewMessageBroadcast(broadcast *BroadcastMessage) {
req, ok := broadcast.Data.(NewMessageRequest)
if !ok {
logger.Error("Invalid broadcast data for new_message")
return
}
ctx := context.Background()
message, err := h.messageService.SendMessage(ctx, broadcast.Client.userID, req.ChatID, req.Plaintext, req.AttachmentID)
if err != nil {
logger.Error("Failed to save message", "error", err)
broadcast.Client.sendError(500, "Failed to send message")
return
}
response := NewMessageResponse{
ID: message.ID,
ChatID: message.ChatID,
SenderID: message.SenderID,
Plaintext: message.Plaintext,
CreatedAt: message.CreatedAt,
TempID: req.TempID,
}
msgBytes, err := CreateMessage(MsgTypeNewMessageResp, response)
if err != nil {
logger.Error("Failed to create message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[req.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) handleReadReceiptBroadcast(broadcast *BroadcastMessage) {
ctx := context.Background()
err := h.messageService.MarkMessageAsRead(ctx, broadcast.MessageID, broadcast.UserID)
if err != nil {
logger.Error("Failed to mark message as read", "error", err)
return
}
message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID)
if err != nil {
logger.Error("Failed to get message for read receipt", "error", err)
return
}
response := ReadReceiptResponse{
MessageID: broadcast.MessageID,
UserID: broadcast.UserID,
ReadAt: time.Now(),
}
msgBytes, err := CreateMessage(MsgTypeMessageRead, response)
if err != nil {
logger.Error("Failed to create read receipt message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[message.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) handleEditMessageBroadcast(broadcast *BroadcastMessage) {
req, ok := broadcast.Data.(EditMessageRequest)
if !ok {
logger.Error("Invalid broadcast data for edit_message")
return
}
ctx := context.Background()
err := h.messageService.EditMessage(ctx, broadcast.SenderID, req.MessageID, req.Plaintext)
if err != nil {
logger.Error("Failed to edit message", "error", err)
if broadcast.Client != nil {
broadcast.Client.sendError(403, "Failed to edit message")
}
return
}
message, err := h.messageService.GetMessageByID(ctx, req.MessageID)
if err != nil {
logger.Error("Failed to get edited message", "error", err)
return
}
response := EditMessageResponse{
MessageID: req.MessageID,
NewPlaintext: req.Plaintext,
EditedAt: time.Now(),
}
msgBytes, err := CreateMessage(MsgTypeMessageEdited, response)
if err != nil {
logger.Error("Failed to create edit message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[message.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) handleDeleteMessageBroadcast(broadcast *BroadcastMessage) {
ctx := context.Background()
message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID)
if err != nil {
logger.Error("Failed to get message for deletion", "error", err)
return
}
canDelete, err := h.messageService.CanDeleteMessage(ctx, broadcast.SenderID, broadcast.MessageID)
if err != nil || !canDelete {
logger.Error("User cannot delete message", "user_id", broadcast.SenderID, "message_id", broadcast.MessageID)
if broadcast.Client != nil {
broadcast.Client.sendError(403, "Cannot delete this message")
}
return
}
err = h.messageService.DeleteMessage(ctx, broadcast.MessageID)
if err != nil {
logger.Error("Failed to delete message", "error", err)
return
}
response := DeleteMessageResponse{
MessageID: broadcast.MessageID,
DeletedAt: time.Now(),
}
msgBytes, err := CreateMessage(MsgTypeMessageDeleted, response)
if err != nil {
logger.Error("Failed to create delete message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[message.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) addClientToChats(client *Client) {
ctx := context.Background()
chats, err := h.chatService.GetUserChats(ctx, client.userID)
if err != nil {
logger.Error("Failed to get user chats", "user_id", client.userID, "error", err)
return
}
for _, chat := range chats {
h.roomsMu.Lock()
room, exists := h.rooms[chat.ID]
if !exists {
room = NewRoom(chat.ID)
h.rooms[chat.ID] = room
}
h.roomsMu.Unlock()
room.AddClient(client)
}
logger.Info("Client added to rooms", "user_id", client.userID, "room_count", len(chats))
}
func (h *Hub) notifyUserOnline(userID int64, isOnline bool) {
// Защита от паники
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered in notifyUserOnline", "error", r)
}
}()
response := UserOnlineResponse{
UserID: userID,
IsOnline: isOnline,
}
msgBytes, err := CreateMessage(MsgTypeUserOnline, response)
if err != nil {
logger.Error("Failed to create user online message", "error", err)
return
}
h.roomsMu.RLock()
defer h.roomsMu.RUnlock()
for _, room := range h.rooms {
if room.HasClientByUserID(userID) {
room.BroadcastToAll(msgBytes)
}
}
}
func (h *Hub) GetRoom(chatID int64) (*Room, bool) {
h.roomsMu.RLock()
defer h.roomsMu.RUnlock()
room, exists := h.rooms[chatID]
return room, exists
}
func (h *Hub) GetClient(userID int64) (*Client, bool) {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
client, exists := h.clients[userID]
return client, exists
}
func (h *Hub) SendToUser(userID int64, message []byte) bool {
h.clientsMu.RLock()
client, exists := h.clients[userID]
h.clientsMu.RUnlock()
if exists {
client.SendMessage(message)
return true
}
return false
}
// GetRegisterChan возвращает канал для регистрации клиентов
func (h *Hub) GetRegisterChan() chan<- *Client {
return h.register
}
// GetUnregisterChan возвращает канал для отмены регистрации клиентов
func (h *Hub) GetUnregisterChan() chan<- *Client {
return h.unregister
}
// GetBroadcastChan возвращает канал для широковещательных сообщений
func (h *Hub) GetBroadcastChan() chan<- *BroadcastMessage {
return h.broadcast
}