package sqlite import ( "context" "database/sql" "fmt" "messenger/internal/models" "time" ) type MessageRepository struct { db *DB } func NewMessageRepository(db *DB) *MessageRepository { return &MessageRepository{db: db} } func (r *MessageRepository) Create(ctx context.Context, message *models.Message) error { query := ` INSERT INTO messages (chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at) VALUES (?, ?, ?, ?, ?, ?) RETURNING id ` err := r.db.QueryRowContext(ctx, query, message.ChatID, message.SenderID, message.EncryptedBody, message.AttachmentID, message.IsRead, message.CreatedAt, ).Scan(&message.ID) if err != nil { return fmt.Errorf("failed to create message: %w", err) } // Если есть attachment, обновляем его message_id if message.AttachmentID != nil { updateQuery := `UPDATE attachments SET message_id = ? WHERE id = ?` _, err = r.db.ExecContext(ctx, updateQuery, message.ID, *message.AttachmentID) if err != nil { return fmt.Errorf("failed to link attachment to message: %w", err) } } return nil } func (r *MessageRepository) FindByID(ctx context.Context, id int64) (*models.Message, error) { query := ` SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at FROM messages WHERE id = ? ` var message models.Message var attachmentID sql.NullInt64 err := r.db.QueryRowContext(ctx, query, id).Scan( &message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody, &attachmentID, &message.IsRead, &message.CreatedAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to find message by id: %w", err) } if attachmentID.Valid { message.AttachmentID = &attachmentID.Int64 } return &message, nil } func (r *MessageRepository) GetChatHistory(ctx context.Context, chatID int64, limit int, before time.Time) ([]*models.Message, error) { var query string var rows *sql.Rows var err error // Если before не zero time, используем пагинацию if !before.IsZero() { query = ` SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at FROM messages WHERE chat_id = ? AND created_at < ? ORDER BY created_at DESC LIMIT ? ` rows, err = r.db.QueryContext(ctx, query, chatID, before, limit) } else { query = ` SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at FROM messages WHERE chat_id = ? ORDER BY created_at DESC LIMIT ? ` rows, err = r.db.QueryContext(ctx, query, chatID, limit) } if err != nil { return nil, fmt.Errorf("failed to get chat history: %w", err) } defer rows.Close() var messages []*models.Message for rows.Next() { var message models.Message var attachmentID sql.NullInt64 err := rows.Scan( &message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody, &attachmentID, &message.IsRead, &message.CreatedAt, ) if err != nil { return nil, fmt.Errorf("failed to scan message: %w", err) } if attachmentID.Valid { message.AttachmentID = &attachmentID.Int64 } messages = append(messages, &message) } // Переворачиваем, чтобы получить в хронологическом порядке for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { messages[i], messages[j] = messages[j], messages[i] } return messages, nil } func (r *MessageRepository) GetLastMessage(ctx context.Context, chatID int64) (*models.Message, error) { query := ` SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at FROM messages WHERE chat_id = ? ORDER BY created_at DESC LIMIT 1 ` var message models.Message var attachmentID sql.NullInt64 err := r.db.QueryRowContext(ctx, query, chatID).Scan( &message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody, &attachmentID, &message.IsRead, &message.CreatedAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to get last message: %w", err) } if attachmentID.Valid { message.AttachmentID = &attachmentID.Int64 } return &message, nil } func (r *MessageRepository) MarkAsRead(ctx context.Context, messageID int64) error { query := `UPDATE messages SET is_read = 1 WHERE id = ? AND is_read = 0` result, err := r.db.ExecContext(ctx, query, messageID) if err != nil { return fmt.Errorf("failed to mark message as read: %w", err) } rows, err := result.RowsAffected() if err != nil { return err } if rows == 0 { // Сообщение уже прочитано или не найдено - не ошибка return nil } return nil } func (r *MessageRepository) Delete(ctx context.Context, messageID int64) error { // Сначала получаем attachment_id, чтобы потом удалить файл var attachmentID sql.NullInt64 query := `SELECT attachment_id FROM messages WHERE id = ?` err := r.db.QueryRowContext(ctx, query, messageID).Scan(&attachmentID) if err != nil && err != sql.ErrNoRows { return fmt.Errorf("failed to get attachment info: %w", err) } // Удаляем сообщение (attachment останется, но без связи с сообщением) deleteQuery := `DELETE FROM messages WHERE id = ?` result, err := r.db.ExecContext(ctx, deleteQuery, messageID) if err != nil { return fmt.Errorf("failed to delete message: %w", err) } rows, err := result.RowsAffected() if err != nil { return err } if rows == 0 { return fmt.Errorf("message not found: %d", messageID) } // Если был attachment, обновляем его message_id на NULL if attachmentID.Valid { updateQuery := `UPDATE attachments SET message_id = NULL WHERE id = ?` _, err = r.db.ExecContext(ctx, updateQuery, attachmentID.Int64) if err != nil { return fmt.Errorf("failed to unlink attachment: %w", err) } } return nil } func (r *MessageRepository) Update(ctx context.Context, message *models.Message) error { query := ` UPDATE messages SET encrypted_body = ?, attachment_id = ?, is_read = ? WHERE id = ? AND sender_id = ? ` result, err := r.db.ExecContext(ctx, query, message.EncryptedBody, message.AttachmentID, message.IsRead, message.ID, message.SenderID, ) if err != nil { return fmt.Errorf("failed to update message: %w", err) } rows, err := result.RowsAffected() if err != nil { return err } if rows == 0 { return fmt.Errorf("message not found or not owned by user: %d", message.ID) } return nil } func (r *MessageRepository) GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error) { query := ` SELECT COUNT(*) FROM messages m WHERE m.chat_id = ? AND m.sender_id != ? AND m.is_read = 0 ` var count int err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&count) if err != nil { return 0, fmt.Errorf("failed to get unread count: %w", err) } return count, nil }