Files
efir-api-server/internal/repository/sqlite/message_repo.go

268 lines
6.9 KiB
Go
Raw Normal View History

package sqlite
import (
"context"
"database/sql"
"fmt"
"messenger/internal/models"
"time"
)
type MessageRepository struct {
db *DB
}
func NewMessageRepository(db *DB) *MessageRepository {
return &MessageRepository{db: db}
}
func (r *MessageRepository) Create(ctx context.Context, message *models.Message) error {
query := `
INSERT INTO messages (chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id
`
err := r.db.QueryRowContext(ctx, query,
message.ChatID, message.SenderID, message.EncryptedBody,
message.AttachmentID, message.IsRead, message.CreatedAt,
).Scan(&message.ID)
if err != nil {
return fmt.Errorf("failed to create message: %w", err)
}
// Если есть attachment, обновляем его message_id
if message.AttachmentID != nil {
updateQuery := `UPDATE attachments SET message_id = ? WHERE id = ?`
_, err = r.db.ExecContext(ctx, updateQuery, message.ID, *message.AttachmentID)
if err != nil {
return fmt.Errorf("failed to link attachment to message: %w", err)
}
}
return nil
}
func (r *MessageRepository) FindByID(ctx context.Context, id int64) (*models.Message, error) {
query := `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE id = ?
`
var message models.Message
var attachmentID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, id).Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find message by id: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
return &message, nil
}
func (r *MessageRepository) GetChatHistory(ctx context.Context, chatID int64, limit int, before time.Time) ([]*models.Message, error) {
var query string
var rows *sql.Rows
var err error
// Если before не zero time, используем пагинацию
if !before.IsZero() {
query = `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ? AND created_at < ?
ORDER BY created_at DESC
LIMIT ?
`
rows, err = r.db.QueryContext(ctx, query, chatID, before, limit)
} else {
query = `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ?
ORDER BY created_at DESC
LIMIT ?
`
rows, err = r.db.QueryContext(ctx, query, chatID, limit)
}
if err != nil {
return nil, fmt.Errorf("failed to get chat history: %w", err)
}
defer rows.Close()
var messages []*models.Message
for rows.Next() {
var message models.Message
var attachmentID sql.NullInt64
err := rows.Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan message: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
messages = append(messages, &message)
}
// Переворачиваем, чтобы получить в хронологическом порядке
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
return messages, nil
}
func (r *MessageRepository) GetLastMessage(ctx context.Context, chatID int64) (*models.Message, error) {
query := `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ?
ORDER BY created_at DESC
LIMIT 1
`
var message models.Message
var attachmentID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, chatID).Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get last message: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
return &message, nil
}
func (r *MessageRepository) MarkAsRead(ctx context.Context, messageID int64) error {
query := `UPDATE messages SET is_read = 1 WHERE id = ? AND is_read = 0`
result, err := r.db.ExecContext(ctx, query, messageID)
if err != nil {
return fmt.Errorf("failed to mark message as read: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
// Сообщение уже прочитано или не найдено - не ошибка
return nil
}
return nil
}
func (r *MessageRepository) Delete(ctx context.Context, messageID int64) error {
// Сначала получаем attachment_id, чтобы потом удалить файл
var attachmentID sql.NullInt64
query := `SELECT attachment_id FROM messages WHERE id = ?`
err := r.db.QueryRowContext(ctx, query, messageID).Scan(&attachmentID)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to get attachment info: %w", err)
}
// Удаляем сообщение (attachment останется, но без связи с сообщением)
deleteQuery := `DELETE FROM messages WHERE id = ?`
result, err := r.db.ExecContext(ctx, deleteQuery, messageID)
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("message not found: %d", messageID)
}
// Если был attachment, обновляем его message_id на NULL
if attachmentID.Valid {
updateQuery := `UPDATE attachments SET message_id = NULL WHERE id = ?`
_, err = r.db.ExecContext(ctx, updateQuery, attachmentID.Int64)
if err != nil {
return fmt.Errorf("failed to unlink attachment: %w", err)
}
}
return nil
}
func (r *MessageRepository) Update(ctx context.Context, message *models.Message) error {
query := `
UPDATE messages
SET encrypted_body = ?, attachment_id = ?, is_read = ?
WHERE id = ? AND sender_id = ?
`
result, err := r.db.ExecContext(ctx, query,
message.EncryptedBody, message.AttachmentID, message.IsRead,
message.ID, message.SenderID,
)
if err != nil {
return fmt.Errorf("failed to update message: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("message not found or not owned by user: %d", message.ID)
}
return nil
}
func (r *MessageRepository) GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error) {
query := `
SELECT COUNT(*)
FROM messages m
WHERE m.chat_id = ?
AND m.sender_id != ?
AND m.is_read = 0
`
var count int
err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to get unread count: %w", err)
}
return count, nil
}