package websocket import ( "encoding/json" "log" "messenger/internal/models" "sync" "time" "github.com/gorilla/websocket" ) const ( writeWait = 10 * time.Second pongWait = 60 * time.Second pingPeriod = (pongWait * 9) / 10 maxMessageSize = 512 * 1024 ) type Client struct { hub *Hub conn *websocket.Conn send chan []byte user *models.User userID int64 rooms map[int64]bool mu sync.RWMutex lastPing time.Time } func NewClient(hub *Hub, conn *websocket.Conn, user *models.User) *Client { return &Client{ hub: hub, conn: conn, send: make(chan []byte, 256), user: user, userID: user.ID, rooms: make(map[int64]bool), lastPing: time.Now(), } } func (c *Client) ReadPump() { defer func() { if r := recover(); r != nil { log.Printf("Recovered in ReadPump: %v", r) } c.hub.GetUnregisterChan() <- c c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.mu.Lock() c.lastPing = time.Now() c.mu.Unlock() return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } break } log.Printf("Received message from client %d: %s", c.userID, string(message)) c.handleMessage(message) } } func (c *Client) WritePump() { ticker := time.NewTicker(pingPeriod) defer func() { if r := recover(); r != nil { log.Printf("Recovered in WritePump: %v", r) } ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } func (c *Client) handleMessage(raw []byte) { defer func() { if r := recover(); r != nil { log.Printf("Recovered in handleMessage: %v", r) } }() msg, err := ParseMessage(raw) if err != nil { c.sendError(400, "Invalid message format") return } switch msg.Type { case MsgTypeNewMessage: c.handleNewMessage(msg.Data) case MsgTypeTyping: c.handleTyping(msg.Data) case MsgTypeReadReceipt: c.handleReadReceipt(msg.Data) case MsgTypeEditMessage: c.handleEditMessage(msg.Data) case MsgTypeDeleteMessage: c.handleDeleteMessage(msg.Data) case MsgTypePing: c.handlePing() default: c.sendError(400, "Unknown message type") } } func (c *Client) handleNewMessage(data []byte) { var req NewMessageRequest if err := json.Unmarshal(data, &req); err != nil { log.Printf("Failed to parse new_message: %v", err) c.sendError(400, "Invalid new_message data") return } log.Printf("New message from user %d to chat %d: %s", c.userID, req.ChatID, req.Plaintext) if !c.isMemberOfChat(req.ChatID) { log.Printf("User %d is not a member of chat %d", c.userID, req.ChatID) c.sendError(403, "Not a member of this chat") return } c.hub.GetBroadcastChan() <- &BroadcastMessage{ Type: MsgTypeNewMessage, ChatID: req.ChatID, SenderID: c.userID, Data: req, Client: c, } } func (c *Client) handleTyping(data []byte) { var req TypingRequest if err := json.Unmarshal(data, &req); err != nil { return } resp := TypingResponse{ ChatID: req.ChatID, UserID: c.userID, IsTyping: req.IsTyping, } msgBytes, _ := CreateMessage(MsgTypeUserTyping, resp) c.hub.roomsMu.RLock() room, exists := c.hub.rooms[req.ChatID] c.hub.roomsMu.RUnlock() if exists { room.Broadcast(msgBytes, c) } } func (c *Client) handleReadReceipt(data []byte) { var req ReadReceiptRequest if err := json.Unmarshal(data, &req); err != nil { return } c.hub.broadcast <- &BroadcastMessage{ Type: MsgTypeReadReceipt, MessageID: req.MessageID, UserID: c.userID, } } func (c *Client) handleEditMessage(data []byte) { var req EditMessageRequest if err := json.Unmarshal(data, &req); err != nil { c.sendError(400, "Invalid edit_message data") return } c.hub.broadcast <- &BroadcastMessage{ Type: MsgTypeEditMessage, MessageID: req.MessageID, SenderID: c.userID, Data: req, } } func (c *Client) handleDeleteMessage(data []byte) { var req DeleteMessageRequest if err := json.Unmarshal(data, &req); err != nil { c.sendError(400, "Invalid delete_message data") return } c.hub.broadcast <- &BroadcastMessage{ Type: MsgTypeDeleteMessage, MessageID: req.MessageID, SenderID: c.userID, } } func (c *Client) handlePing() { c.mu.Lock() c.lastPing = time.Now() c.mu.Unlock() pongMsg, _ := CreateMessage(MsgTypePong, nil) select { case c.send <- pongMsg: default: } } func (c *Client) sendError(code int, message string) { errResp := ErrorResponse{ Code: code, Message: message, } msgBytes, _ := CreateMessage(MsgTypeError, errResp) select { case c.send <- msgBytes: default: } } func (c *Client) isMemberOfChat(chatID int64) bool { c.mu.RLock() defer c.mu.RUnlock() return c.rooms[chatID] } func (c *Client) JoinRoom(chatID int64) { c.mu.Lock() defer c.mu.Unlock() c.rooms[chatID] = true } func (c *Client) LeaveRoom(chatID int64) { c.mu.Lock() defer c.mu.Unlock() delete(c.rooms, chatID) } func (c *Client) GetRooms() []int64 { c.mu.RLock() defer c.mu.RUnlock() rooms := make([]int64, 0, len(c.rooms)) for chatID := range c.rooms { rooms = append(rooms, chatID) } return rooms } func (c *Client) SendMessage(msg []byte) { select { case c.send <- msg: default: log.Printf("Client %d send channel full", c.userID) c.hub.GetUnregisterChan() <- c c.conn.Close() } } func (c *Client) GetUserID() int64 { return c.userID } func (c *Client) GetUser() *models.User { return c.user } func (c *Client) IsAlive() bool { c.mu.RLock() defer c.mu.RUnlock() return time.Since(c.lastPing) < pongWait }