268 lines
6.9 KiB
Go
268 lines
6.9 KiB
Go
package sqlite
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"messenger/internal/models"
|
||
"time"
|
||
)
|
||
|
||
type MessageRepository struct {
|
||
db *DB
|
||
}
|
||
|
||
func NewMessageRepository(db *DB) *MessageRepository {
|
||
return &MessageRepository{db: db}
|
||
}
|
||
|
||
func (r *MessageRepository) Create(ctx context.Context, message *models.Message) error {
|
||
query := `
|
||
INSERT INTO messages (chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
RETURNING id
|
||
`
|
||
|
||
err := r.db.QueryRowContext(ctx, query,
|
||
message.ChatID, message.SenderID, message.EncryptedBody,
|
||
message.AttachmentID, message.IsRead, message.CreatedAt,
|
||
).Scan(&message.ID)
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create message: %w", err)
|
||
}
|
||
|
||
// Если есть attachment, обновляем его message_id
|
||
if message.AttachmentID != nil {
|
||
updateQuery := `UPDATE attachments SET message_id = ? WHERE id = ?`
|
||
_, err = r.db.ExecContext(ctx, updateQuery, message.ID, *message.AttachmentID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to link attachment to message: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *MessageRepository) FindByID(ctx context.Context, id int64) (*models.Message, error) {
|
||
query := `
|
||
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
||
FROM messages
|
||
WHERE id = ?
|
||
`
|
||
|
||
var message models.Message
|
||
var attachmentID sql.NullInt64
|
||
|
||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
||
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
|
||
&attachmentID, &message.IsRead, &message.CreatedAt,
|
||
)
|
||
|
||
if err == sql.ErrNoRows {
|
||
return nil, nil
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to find message by id: %w", err)
|
||
}
|
||
|
||
if attachmentID.Valid {
|
||
message.AttachmentID = &attachmentID.Int64
|
||
}
|
||
|
||
return &message, nil
|
||
}
|
||
|
||
func (r *MessageRepository) GetChatHistory(ctx context.Context, chatID int64, limit int, before time.Time) ([]*models.Message, error) {
|
||
var query string
|
||
var rows *sql.Rows
|
||
var err error
|
||
|
||
// Если before не zero time, используем пагинацию
|
||
if !before.IsZero() {
|
||
query = `
|
||
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
||
FROM messages
|
||
WHERE chat_id = ? AND created_at < ?
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
`
|
||
rows, err = r.db.QueryContext(ctx, query, chatID, before, limit)
|
||
} else {
|
||
query = `
|
||
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
||
FROM messages
|
||
WHERE chat_id = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
`
|
||
rows, err = r.db.QueryContext(ctx, query, chatID, limit)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get chat history: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var messages []*models.Message
|
||
for rows.Next() {
|
||
var message models.Message
|
||
var attachmentID sql.NullInt64
|
||
|
||
err := rows.Scan(
|
||
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
|
||
&attachmentID, &message.IsRead, &message.CreatedAt,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to scan message: %w", err)
|
||
}
|
||
|
||
if attachmentID.Valid {
|
||
message.AttachmentID = &attachmentID.Int64
|
||
}
|
||
|
||
messages = append(messages, &message)
|
||
}
|
||
|
||
// Переворачиваем, чтобы получить в хронологическом порядке
|
||
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
||
messages[i], messages[j] = messages[j], messages[i]
|
||
}
|
||
|
||
return messages, nil
|
||
}
|
||
|
||
func (r *MessageRepository) GetLastMessage(ctx context.Context, chatID int64) (*models.Message, error) {
|
||
query := `
|
||
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
||
FROM messages
|
||
WHERE chat_id = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT 1
|
||
`
|
||
|
||
var message models.Message
|
||
var attachmentID sql.NullInt64
|
||
|
||
err := r.db.QueryRowContext(ctx, query, chatID).Scan(
|
||
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
|
||
&attachmentID, &message.IsRead, &message.CreatedAt,
|
||
)
|
||
|
||
if err == sql.ErrNoRows {
|
||
return nil, nil
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get last message: %w", err)
|
||
}
|
||
|
||
if attachmentID.Valid {
|
||
message.AttachmentID = &attachmentID.Int64
|
||
}
|
||
|
||
return &message, nil
|
||
}
|
||
|
||
func (r *MessageRepository) MarkAsRead(ctx context.Context, messageID int64) error {
|
||
query := `UPDATE messages SET is_read = 1 WHERE id = ? AND is_read = 0`
|
||
|
||
result, err := r.db.ExecContext(ctx, query, messageID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to mark message as read: %w", err)
|
||
}
|
||
|
||
rows, err := result.RowsAffected()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if rows == 0 {
|
||
// Сообщение уже прочитано или не найдено - не ошибка
|
||
return nil
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *MessageRepository) Delete(ctx context.Context, messageID int64) error {
|
||
// Сначала получаем attachment_id, чтобы потом удалить файл
|
||
var attachmentID sql.NullInt64
|
||
query := `SELECT attachment_id FROM messages WHERE id = ?`
|
||
err := r.db.QueryRowContext(ctx, query, messageID).Scan(&attachmentID)
|
||
if err != nil && err != sql.ErrNoRows {
|
||
return fmt.Errorf("failed to get attachment info: %w", err)
|
||
}
|
||
|
||
// Удаляем сообщение (attachment останется, но без связи с сообщением)
|
||
deleteQuery := `DELETE FROM messages WHERE id = ?`
|
||
result, err := r.db.ExecContext(ctx, deleteQuery, messageID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete message: %w", err)
|
||
}
|
||
|
||
rows, err := result.RowsAffected()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if rows == 0 {
|
||
return fmt.Errorf("message not found: %d", messageID)
|
||
}
|
||
|
||
// Если был attachment, обновляем его message_id на NULL
|
||
if attachmentID.Valid {
|
||
updateQuery := `UPDATE attachments SET message_id = NULL WHERE id = ?`
|
||
_, err = r.db.ExecContext(ctx, updateQuery, attachmentID.Int64)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to unlink attachment: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *MessageRepository) Update(ctx context.Context, message *models.Message) error {
|
||
query := `
|
||
UPDATE messages
|
||
SET encrypted_body = ?, attachment_id = ?, is_read = ?
|
||
WHERE id = ? AND sender_id = ?
|
||
`
|
||
|
||
result, err := r.db.ExecContext(ctx, query,
|
||
message.EncryptedBody, message.AttachmentID, message.IsRead,
|
||
message.ID, message.SenderID,
|
||
)
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("failed to update message: %w", err)
|
||
}
|
||
|
||
rows, err := result.RowsAffected()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if rows == 0 {
|
||
return fmt.Errorf("message not found or not owned by user: %d", message.ID)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *MessageRepository) GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error) {
|
||
query := `
|
||
SELECT COUNT(*)
|
||
FROM messages m
|
||
WHERE m.chat_id = ?
|
||
AND m.sender_id != ?
|
||
AND m.is_read = 0
|
||
`
|
||
|
||
var count int
|
||
err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&count)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("failed to get unread count: %w", err)
|
||
}
|
||
|
||
return count, nil
|
||
} |