Files
efir-api-server/internal/repository/sqlite/message_repo.go

268 lines
6.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sqlite
import (
"context"
"database/sql"
"fmt"
"messenger/internal/models"
"time"
)
type MessageRepository struct {
db *DB
}
func NewMessageRepository(db *DB) *MessageRepository {
return &MessageRepository{db: db}
}
func (r *MessageRepository) Create(ctx context.Context, message *models.Message) error {
query := `
INSERT INTO messages (chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id
`
err := r.db.QueryRowContext(ctx, query,
message.ChatID, message.SenderID, message.EncryptedBody,
message.AttachmentID, message.IsRead, message.CreatedAt,
).Scan(&message.ID)
if err != nil {
return fmt.Errorf("failed to create message: %w", err)
}
// Если есть attachment, обновляем его message_id
if message.AttachmentID != nil {
updateQuery := `UPDATE attachments SET message_id = ? WHERE id = ?`
_, err = r.db.ExecContext(ctx, updateQuery, message.ID, *message.AttachmentID)
if err != nil {
return fmt.Errorf("failed to link attachment to message: %w", err)
}
}
return nil
}
func (r *MessageRepository) FindByID(ctx context.Context, id int64) (*models.Message, error) {
query := `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE id = ?
`
var message models.Message
var attachmentID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, id).Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find message by id: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
return &message, nil
}
func (r *MessageRepository) GetChatHistory(ctx context.Context, chatID int64, limit int, before time.Time) ([]*models.Message, error) {
var query string
var rows *sql.Rows
var err error
// Если before не zero time, используем пагинацию
if !before.IsZero() {
query = `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ? AND created_at < ?
ORDER BY created_at DESC
LIMIT ?
`
rows, err = r.db.QueryContext(ctx, query, chatID, before, limit)
} else {
query = `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ?
ORDER BY created_at DESC
LIMIT ?
`
rows, err = r.db.QueryContext(ctx, query, chatID, limit)
}
if err != nil {
return nil, fmt.Errorf("failed to get chat history: %w", err)
}
defer rows.Close()
var messages []*models.Message
for rows.Next() {
var message models.Message
var attachmentID sql.NullInt64
err := rows.Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan message: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
messages = append(messages, &message)
}
// Переворачиваем, чтобы получить в хронологическом порядке
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
return messages, nil
}
func (r *MessageRepository) GetLastMessage(ctx context.Context, chatID int64) (*models.Message, error) {
query := `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ?
ORDER BY created_at DESC
LIMIT 1
`
var message models.Message
var attachmentID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, chatID).Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get last message: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
return &message, nil
}
func (r *MessageRepository) MarkAsRead(ctx context.Context, messageID int64) error {
query := `UPDATE messages SET is_read = 1 WHERE id = ? AND is_read = 0`
result, err := r.db.ExecContext(ctx, query, messageID)
if err != nil {
return fmt.Errorf("failed to mark message as read: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
// Сообщение уже прочитано или не найдено - не ошибка
return nil
}
return nil
}
func (r *MessageRepository) Delete(ctx context.Context, messageID int64) error {
// Сначала получаем attachment_id, чтобы потом удалить файл
var attachmentID sql.NullInt64
query := `SELECT attachment_id FROM messages WHERE id = ?`
err := r.db.QueryRowContext(ctx, query, messageID).Scan(&attachmentID)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to get attachment info: %w", err)
}
// Удаляем сообщение (attachment останется, но без связи с сообщением)
deleteQuery := `DELETE FROM messages WHERE id = ?`
result, err := r.db.ExecContext(ctx, deleteQuery, messageID)
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("message not found: %d", messageID)
}
// Если был attachment, обновляем его message_id на NULL
if attachmentID.Valid {
updateQuery := `UPDATE attachments SET message_id = NULL WHERE id = ?`
_, err = r.db.ExecContext(ctx, updateQuery, attachmentID.Int64)
if err != nil {
return fmt.Errorf("failed to unlink attachment: %w", err)
}
}
return nil
}
func (r *MessageRepository) Update(ctx context.Context, message *models.Message) error {
query := `
UPDATE messages
SET encrypted_body = ?, attachment_id = ?, is_read = ?
WHERE id = ? AND sender_id = ?
`
result, err := r.db.ExecContext(ctx, query,
message.EncryptedBody, message.AttachmentID, message.IsRead,
message.ID, message.SenderID,
)
if err != nil {
return fmt.Errorf("failed to update message: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("message not found or not owned by user: %d", message.ID)
}
return nil
}
func (r *MessageRepository) GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error) {
query := `
SELECT COUNT(*)
FROM messages m
WHERE m.chat_id = ?
AND m.sender_id != ?
AND m.is_read = 0
`
var count int
err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to get unread count: %w", err)
}
return count, nil
}