268 lines
6.9 KiB
Go
268 lines
6.9 KiB
Go
|
|
package sqlite
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"database/sql"
|
|||
|
|
"fmt"
|
|||
|
|
"messenger/internal/models"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type MessageRepository struct {
|
|||
|
|
db *DB
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewMessageRepository(db *DB) *MessageRepository {
|
|||
|
|
return &MessageRepository{db: db}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) Create(ctx context.Context, message *models.Message) error {
|
|||
|
|
query := `
|
|||
|
|
INSERT INTO messages (chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at)
|
|||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|||
|
|
RETURNING id
|
|||
|
|
`
|
|||
|
|
|
|||
|
|
err := r.db.QueryRowContext(ctx, query,
|
|||
|
|
message.ChatID, message.SenderID, message.EncryptedBody,
|
|||
|
|
message.AttachmentID, message.IsRead, message.CreatedAt,
|
|||
|
|
).Scan(&message.ID)
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to create message: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Если есть attachment, обновляем его message_id
|
|||
|
|
if message.AttachmentID != nil {
|
|||
|
|
updateQuery := `UPDATE attachments SET message_id = ? WHERE id = ?`
|
|||
|
|
_, err = r.db.ExecContext(ctx, updateQuery, message.ID, *message.AttachmentID)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to link attachment to message: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) FindByID(ctx context.Context, id int64) (*models.Message, error) {
|
|||
|
|
query := `
|
|||
|
|
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
|||
|
|
FROM messages
|
|||
|
|
WHERE id = ?
|
|||
|
|
`
|
|||
|
|
|
|||
|
|
var message models.Message
|
|||
|
|
var attachmentID sql.NullInt64
|
|||
|
|
|
|||
|
|
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
|||
|
|
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
|
|||
|
|
&attachmentID, &message.IsRead, &message.CreatedAt,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if err == sql.ErrNoRows {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to find message by id: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if attachmentID.Valid {
|
|||
|
|
message.AttachmentID = &attachmentID.Int64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &message, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) GetChatHistory(ctx context.Context, chatID int64, limit int, before time.Time) ([]*models.Message, error) {
|
|||
|
|
var query string
|
|||
|
|
var rows *sql.Rows
|
|||
|
|
var err error
|
|||
|
|
|
|||
|
|
// Если before не zero time, используем пагинацию
|
|||
|
|
if !before.IsZero() {
|
|||
|
|
query = `
|
|||
|
|
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
|||
|
|
FROM messages
|
|||
|
|
WHERE chat_id = ? AND created_at < ?
|
|||
|
|
ORDER BY created_at DESC
|
|||
|
|
LIMIT ?
|
|||
|
|
`
|
|||
|
|
rows, err = r.db.QueryContext(ctx, query, chatID, before, limit)
|
|||
|
|
} else {
|
|||
|
|
query = `
|
|||
|
|
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
|||
|
|
FROM messages
|
|||
|
|
WHERE chat_id = ?
|
|||
|
|
ORDER BY created_at DESC
|
|||
|
|
LIMIT ?
|
|||
|
|
`
|
|||
|
|
rows, err = r.db.QueryContext(ctx, query, chatID, limit)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to get chat history: %w", err)
|
|||
|
|
}
|
|||
|
|
defer rows.Close()
|
|||
|
|
|
|||
|
|
var messages []*models.Message
|
|||
|
|
for rows.Next() {
|
|||
|
|
var message models.Message
|
|||
|
|
var attachmentID sql.NullInt64
|
|||
|
|
|
|||
|
|
err := rows.Scan(
|
|||
|
|
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
|
|||
|
|
&attachmentID, &message.IsRead, &message.CreatedAt,
|
|||
|
|
)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to scan message: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if attachmentID.Valid {
|
|||
|
|
message.AttachmentID = &attachmentID.Int64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
messages = append(messages, &message)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Переворачиваем, чтобы получить в хронологическом порядке
|
|||
|
|
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
|||
|
|
messages[i], messages[j] = messages[j], messages[i]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return messages, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) GetLastMessage(ctx context.Context, chatID int64) (*models.Message, error) {
|
|||
|
|
query := `
|
|||
|
|
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
|
|||
|
|
FROM messages
|
|||
|
|
WHERE chat_id = ?
|
|||
|
|
ORDER BY created_at DESC
|
|||
|
|
LIMIT 1
|
|||
|
|
`
|
|||
|
|
|
|||
|
|
var message models.Message
|
|||
|
|
var attachmentID sql.NullInt64
|
|||
|
|
|
|||
|
|
err := r.db.QueryRowContext(ctx, query, chatID).Scan(
|
|||
|
|
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
|
|||
|
|
&attachmentID, &message.IsRead, &message.CreatedAt,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if err == sql.ErrNoRows {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to get last message: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if attachmentID.Valid {
|
|||
|
|
message.AttachmentID = &attachmentID.Int64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &message, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) MarkAsRead(ctx context.Context, messageID int64) error {
|
|||
|
|
query := `UPDATE messages SET is_read = 1 WHERE id = ? AND is_read = 0`
|
|||
|
|
|
|||
|
|
result, err := r.db.ExecContext(ctx, query, messageID)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to mark message as read: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
rows, err := result.RowsAffected()
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if rows == 0 {
|
|||
|
|
// Сообщение уже прочитано или не найдено - не ошибка
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) Delete(ctx context.Context, messageID int64) error {
|
|||
|
|
// Сначала получаем attachment_id, чтобы потом удалить файл
|
|||
|
|
var attachmentID sql.NullInt64
|
|||
|
|
query := `SELECT attachment_id FROM messages WHERE id = ?`
|
|||
|
|
err := r.db.QueryRowContext(ctx, query, messageID).Scan(&attachmentID)
|
|||
|
|
if err != nil && err != sql.ErrNoRows {
|
|||
|
|
return fmt.Errorf("failed to get attachment info: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Удаляем сообщение (attachment останется, но без связи с сообщением)
|
|||
|
|
deleteQuery := `DELETE FROM messages WHERE id = ?`
|
|||
|
|
result, err := r.db.ExecContext(ctx, deleteQuery, messageID)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to delete message: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
rows, err := result.RowsAffected()
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if rows == 0 {
|
|||
|
|
return fmt.Errorf("message not found: %d", messageID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Если был attachment, обновляем его message_id на NULL
|
|||
|
|
if attachmentID.Valid {
|
|||
|
|
updateQuery := `UPDATE attachments SET message_id = NULL WHERE id = ?`
|
|||
|
|
_, err = r.db.ExecContext(ctx, updateQuery, attachmentID.Int64)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to unlink attachment: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) Update(ctx context.Context, message *models.Message) error {
|
|||
|
|
query := `
|
|||
|
|
UPDATE messages
|
|||
|
|
SET encrypted_body = ?, attachment_id = ?, is_read = ?
|
|||
|
|
WHERE id = ? AND sender_id = ?
|
|||
|
|
`
|
|||
|
|
|
|||
|
|
result, err := r.db.ExecContext(ctx, query,
|
|||
|
|
message.EncryptedBody, message.AttachmentID, message.IsRead,
|
|||
|
|
message.ID, message.SenderID,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to update message: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
rows, err := result.RowsAffected()
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if rows == 0 {
|
|||
|
|
return fmt.Errorf("message not found or not owned by user: %d", message.ID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *MessageRepository) GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error) {
|
|||
|
|
query := `
|
|||
|
|
SELECT COUNT(*)
|
|||
|
|
FROM messages m
|
|||
|
|
WHERE m.chat_id = ?
|
|||
|
|
AND m.sender_id != ?
|
|||
|
|
AND m.is_read = 0
|
|||
|
|
`
|
|||
|
|
|
|||
|
|
var count int
|
|||
|
|
err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&count)
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, fmt.Errorf("failed to get unread count: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return count, nil
|
|||
|
|
}
|