Initial commit: Эфир мессенджер

This commit is contained in:
2026-04-06 14:57:36 +03:00
commit ff93679b6d
50 changed files with 5642 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
package handlers
import (
"messenger/internal/api/responses"
"messenger/internal/service"
"net/http"
//"github.com/go-chi/chi/v5"
)
type AdminHandler struct {
adminService *service.AdminService
}
func NewAdminHandler(adminService *service.AdminService) *AdminHandler {
return &AdminHandler{
adminService: adminService,
}
}
func (h *AdminHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
// userID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
// if err != nil {
// responses.BadRequest(w, "invalid user id")
// return
// }
// TODO: Реализовать удаление пользователя
responses.Success(w, http.StatusOK, map[string]string{"message": "user deleted"})
}
func (h *AdminHandler) DeleteMessage(w http.ResponseWriter, r *http.Request) {
// messageID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
// if err != nil {
// responses.BadRequest(w, "invalid message id")
// return
// }
// TODO: Реализовать удаление сообщения
responses.Success(w, http.StatusOK, map[string]string{"message": "message deleted"})
}

View File

@@ -0,0 +1,80 @@
package handlers
import (
"encoding/json"
"messenger/internal/api/middleware"
"messenger/internal/api/responses"
"messenger/internal/service"
"net/http"
)
type AuthHandler struct {
authService *service.AuthService
}
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{authService: authService}
}
type RegisterRequest struct {
Login string `json:"login"`
Password string `json:"password"`
}
type LoginRequest struct {
Login string `json:"login"`
Password string `json:"password"`
}
type AuthResponse struct {
Token string `json:"token"`
User interface{} `json:"user"`
}
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req RegisterRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
responses.BadRequest(w, "invalid request body")
return
}
user, token, err := h.authService.Register(r.Context(), req.Login, req.Password)
if err != nil {
responses.BadRequest(w, err.Error())
return
}
responses.Success(w, http.StatusCreated, AuthResponse{
Token: token,
User: user.ToSafe(),
})
}
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
responses.BadRequest(w, "invalid request body")
return
}
user, token, err := h.authService.Login(r.Context(), req.Login, req.Password)
if err != nil {
responses.Unauthorized(w, err.Error())
return
}
responses.Success(w, http.StatusOK, AuthResponse{
Token: token,
User: user.ToSafe(),
})
}
func (h *AuthHandler) GetMe(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
responses.Success(w, http.StatusOK, user.ToSafe())
}

View File

@@ -0,0 +1,228 @@
package handlers
import (
"encoding/json"
"messenger/internal/api/middleware"
"messenger/internal/api/responses"
"messenger/internal/models"
"messenger/internal/service"
"net/http"
"strconv"
"github.com/go-chi/chi/v5"
)
type ChatHandler struct {
chatService *service.ChatService
userService *service.UserService
}
func NewChatHandler(chatService *service.ChatService, userService *service.UserService) *ChatHandler {
return &ChatHandler{
chatService: chatService,
userService: userService,
}
}
func (h *ChatHandler) CreatePrivateChat(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
var req models.CreatePrivateChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
responses.BadRequest(w, "invalid request body")
return
}
if req.TargetLogin == "" {
responses.BadRequest(w, "target_login is required")
return
}
// Находим целевого пользователя по логину
targetUser, err := h.userService.GetUserByLogin(r.Context(), req.TargetLogin)
if err != nil {
responses.NotFound(w, "target user not found")
return
}
if targetUser == nil {
responses.NotFound(w, "target user not found")
return
}
// Создаем приватный чат
chat, err := h.chatService.CreatePrivateChat(r.Context(), user.ID, targetUser.ID)
if err != nil {
responses.InternalServerError(w, err.Error())
return
}
responses.Success(w, http.StatusCreated, chat)
}
func (h *ChatHandler) CreateGroupChat(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
var req models.CreateGroupChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
responses.BadRequest(w, "invalid request body")
return
}
chat, err := h.chatService.CreateGroupChat(r.Context(), user.ID, req.Title, req.MemberLogins)
if err != nil {
responses.BadRequest(w, err.Error())
return
}
responses.Success(w, http.StatusCreated, chat)
}
func (h *ChatHandler) GetMyChats(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chats, err := h.chatService.GetUserChatsWithDetails(r.Context(), user.ID)
if err != nil {
responses.InternalServerError(w, err.Error())
return
}
responses.Success(w, http.StatusOK, chats)
}
func (h *ChatHandler) GetChatByID(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chatID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid chat id")
return
}
chat, err := h.chatService.GetChatByID(r.Context(), chatID, user.ID)
if err != nil {
responses.NotFound(w, err.Error())
return
}
responses.Success(w, http.StatusOK, chat)
}
func (h *ChatHandler) GetChatMembers(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chatID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid chat id")
return
}
members, err := h.chatService.GetChatMembers(r.Context(), chatID, user.ID)
if err != nil {
responses.Forbidden(w, err.Error())
return
}
responses.Success(w, http.StatusOK, members)
}
func (h *ChatHandler) AddMembers(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chatID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid chat id")
return
}
var req models.AddMembersRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
responses.BadRequest(w, "invalid request body")
return
}
if err := h.chatService.AddMembers(r.Context(), chatID, user.ID, req.UserLogins); err != nil {
responses.BadRequest(w, err.Error())
return
}
responses.Success(w, http.StatusOK, map[string]string{"message": "members added"})
}
func (h *ChatHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chatID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid chat id")
return
}
memberID, err := strconv.ParseInt(chi.URLParam(r, "user_id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid member id")
return
}
if err := h.chatService.RemoveMember(r.Context(), chatID, user.ID, memberID); err != nil {
responses.BadRequest(w, err.Error())
return
}
responses.Success(w, http.StatusOK, map[string]string{"message": "member removed"})
}
func (h *ChatHandler) UpdateChatTitle(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chatID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid chat id")
return
}
var req models.UpdateChatTitleRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
responses.BadRequest(w, "invalid request body")
return
}
if err := h.chatService.UpdateChatTitle(r.Context(), chatID, user.ID, req.Title); err != nil {
responses.BadRequest(w, err.Error())
return
}
responses.Success(w, http.StatusOK, map[string]string{"message": "chat title updated"})
}

View File

@@ -0,0 +1,82 @@
package handlers
import (
"messenger/internal/api/middleware"
"messenger/internal/api/responses"
"messenger/internal/service"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
)
type FileHandler struct {
fileService *service.FileService
}
func NewFileHandler(fileService *service.FileService) *FileHandler {
return &FileHandler{
fileService: fileService,
}
}
func (h *FileHandler) UploadFile(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chatID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid chat id")
return
}
// Parse multipart form (20 MB max)
if err := r.ParseMultipartForm(20 << 20); err != nil {
responses.BadRequest(w, "failed to parse form")
return
}
file, header, err := r.FormFile("file")
if err != nil {
responses.BadRequest(w, "file is required")
return
}
defer file.Close()
attachment, err := h.fileService.UploadFile(r.Context(), chatID, user.ID, header)
if err != nil {
responses.BadRequest(w, err.Error())
return
}
responses.Success(w, http.StatusOK, attachment)
}
func (h *FileHandler) DownloadFile(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
attachmentID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid attachment id")
return
}
fileName, file, mimeType, err := h.fileService.DownloadFile(r.Context(), attachmentID, user.ID)
if err != nil {
responses.NotFound(w, err.Error())
return
}
defer file.Close()
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(fileName))
http.ServeContent(w, r, fileName, time.Now(), file)
}

View File

@@ -0,0 +1,83 @@
package handlers
import (
//"encoding/json"
"messenger/internal/api/middleware"
"messenger/internal/api/responses"
//"messenger/internal/models"
"messenger/internal/service"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
)
type MessageHandler struct {
messageService *service.MessageService
chatService *service.ChatService
}
func NewMessageHandler(messageService *service.MessageService, chatService *service.ChatService) *MessageHandler {
return &MessageHandler{
messageService: messageService,
chatService: chatService,
}
}
func (h *MessageHandler) GetMessages(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
chatID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid chat id")
return
}
limit := 50
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 && parsed <= 100 {
limit = parsed
}
}
var before time.Time
if b := r.URL.Query().Get("before"); b != "" {
if parsed, err := time.Parse(time.RFC3339, b); err == nil {
before = parsed
}
}
messages, err := h.messageService.GetChatHistory(r.Context(), chatID, user.ID, limit, before)
if err != nil {
responses.InternalServerError(w, err.Error())
return
}
responses.Success(w, http.StatusOK, messages)
}
func (h *MessageHandler) MarkAsRead(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
messageID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
responses.BadRequest(w, "invalid message id")
return
}
if err := h.messageService.MarkMessageAsRead(r.Context(), messageID, user.ID); err != nil {
responses.InternalServerError(w, err.Error())
return
}
responses.Success(w, http.StatusOK, map[string]string{"message": "marked as read"})
}

View File

@@ -0,0 +1,103 @@
package handlers
import (
"context"
"encoding/json"
"messenger/internal/api/middleware"
"messenger/internal/api/responses"
"messenger/internal/models"
"messenger/internal/service"
"net/http"
"strconv"
)
type UserHandler struct {
userService *service.UserService
}
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{userService: userService}
}
type UpdateProfileRequest struct {
DisplayName *string `json:"display_name,omitempty"`
Bio *string `json:"bio,omitempty"`
}
func (h *UserHandler) GetProfile(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
profile, err := h.userService.GetProfile(r.Context(), user.ID)
if err != nil {
responses.NotFound(w, err.Error())
return
}
responses.Success(w, http.StatusOK, profile)
}
func (h *UserHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUserFromContext(r.Context())
if user == nil {
responses.Unauthorized(w, "user not found")
return
}
var req UpdateProfileRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
responses.BadRequest(w, "invalid request body")
return
}
if err := h.userService.UpdateProfile(r.Context(), user.ID, req.DisplayName, req.Bio); err != nil {
responses.BadRequest(w, err.Error())
return
}
responses.Success(w, http.StatusOK, map[string]string{"message": "profile updated"})
}
func (h *UserHandler) SearchUsers(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query().Get("q")
if query == "" {
responses.BadRequest(w, "search query required")
return
}
limit := 20
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 && parsed <= 100 {
limit = parsed
}
}
users, err := h.userService.SearchUsers(r.Context(), query, limit)
if err != nil {
responses.InternalServerError(w, err.Error())
return
}
responses.Success(w, http.StatusOK, users)
}
func (h *UserHandler) GetUserByID(w http.ResponseWriter, r *http.Request) {
// Реализация с chi.URLParam будет в main.go
// Пока заглушка
responses.Success(w, http.StatusOK, nil)
}
// GetUserByLoginFromContext - метод для получения пользователя по логину (используется в других хендлерах)
func (h *UserHandler) GetUserByLogin(ctx context.Context, login string) (*models.SafeUser, error) {
user, err := h.userService.GetUserByLogin(ctx, login)
if err != nil {
return nil, err
}
if user == nil {
return nil, nil
}
return user.ToSafe(), nil
}

View File

@@ -0,0 +1,58 @@
package handlers
import (
"net/http"
gorillaWS "github.com/gorilla/websocket"
"messenger/internal/pkg/logger"
"messenger/internal/service"
ws "messenger/internal/websocket"
)
var upgrader = gorillaWS.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type WebSocketHandler struct {
hub *ws.Hub
authService *service.AuthService
}
func NewWebSocketHandler(hub *ws.Hub, authService *service.AuthService) *WebSocketHandler {
return &WebSocketHandler{
hub: hub,
authService: authService,
}
}
func (h *WebSocketHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if token == "" {
http.Error(w, "missing token", http.StatusUnauthorized)
return
}
user, err := h.authService.ValidateToken(token)
if err != nil {
logger.Error("WebSocket auth failed", "error", err)
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
logger.Info("WebSocket connecting", "user_id", user.ID, "login", user.Login)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error("WebSocket upgrade failed", "error", err)
return
}
client := ws.NewClient(h.hub, conn, user)
h.hub.GetRegisterChan() <- client
go client.WritePump()
client.ReadPump()
}

View File

@@ -0,0 +1,66 @@
package middleware
import (
"context"
"messenger/internal/api/responses"
"messenger/internal/models"
"messenger/internal/service"
"net/http"
"strings"
)
type contextKey string
const UserContextKey contextKey = "user"
func JWTAuth(authService *service.AuthService) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Получаем токен из заголовка Authorization
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
responses.Unauthorized(w, "missing authorization header")
return
}
// Проверяем формат Bearer token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
responses.Unauthorized(w, "invalid authorization header format")
return
}
token := parts[1]
// Валидируем токен
user, err := authService.ValidateToken(token)
if err != nil {
responses.Unauthorized(w, "invalid or expired token")
return
}
// Сохраняем пользователя в контексте
ctx := context.WithValue(r.Context(), UserContextKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func GetUserFromContext(ctx context.Context) *models.User {
user, ok := ctx.Value(UserContextKey).(*models.User)
if !ok {
return nil
}
return user
}
func RequireGlobalAdmin(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUserFromContext(r.Context())
if user == nil || !user.IsGlobalAdmin() {
responses.Forbidden(w, "global admin access required")
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,38 @@
package middleware
import (
"net/http"
//"strings"
)
func CORS(allowedOrigins []string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Проверяем разрешен ли origin
allowed := false
for _, o := range allowedOrigins {
if o == "*" || o == origin {
allowed = true
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,36 @@
package middleware
import (
"log/slog"
"net/http"
"time"
)
func Logging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Оборачиваем ResponseWriter для захвата статуса
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapped, r)
slog.Info("HTTP request",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.statusCode,
"duration", time.Since(start),
"remote_addr", r.RemoteAddr,
)
})
}
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}

View File

@@ -0,0 +1,23 @@
package middleware
import (
"log/slog"
"net/http"
"runtime/debug"
)
func Recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
slog.Error("panic recovered",
"error", err,
"stack", string(debug.Stack()),
"path", r.URL.Path,
)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,54 @@
package responses
import (
"encoding/json"
"net/http"
)
type Response struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
Code int `json:"code,omitempty"`
}
func JSON(w http.ResponseWriter, statusCode int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(data)
}
func Success(w http.ResponseWriter, statusCode int, data interface{}) {
JSON(w, statusCode, Response{
Success: true,
Data: data,
})
}
func Error(w http.ResponseWriter, statusCode int, errMsg string) {
JSON(w, statusCode, Response{
Success: false,
Error: errMsg,
Code: statusCode,
})
}
func BadRequest(w http.ResponseWriter, errMsg string) {
Error(w, http.StatusBadRequest, errMsg)
}
func Unauthorized(w http.ResponseWriter, errMsg string) {
Error(w, http.StatusUnauthorized, errMsg)
}
func Forbidden(w http.ResponseWriter, errMsg string) {
Error(w, http.StatusForbidden, errMsg)
}
func NotFound(w http.ResponseWriter, errMsg string) {
Error(w, http.StatusNotFound, errMsg)
}
func InternalServerError(w http.ResponseWriter, errMsg string) {
Error(w, http.StatusInternalServerError, errMsg)
}

69
internal/config/config.go Normal file
View File

@@ -0,0 +1,69 @@
package config
import (
"log"
"os"
"strconv"
"strings"
"github.com/joho/godotenv"
)
type Config struct {
ServerPort string
Environment string
DBDriver string
DBPath string
JWTSecret []byte
JWTExpiryHours int64
EncryptionKey []byte
StoragePath string
MaxFileSizeMB int64
CORSAllowedOrigins []string
}
func Load() *Config {
// Загружаем .env файл (игнорируем ошибку если файла нет)
_ = godotenv.Load()
cfg := &Config{
ServerPort: getEnv("SERVER_PORT", "8080"),
Environment: getEnv("ENVIRONMENT", "development"),
DBDriver: getEnv("DB_DRIVER", "sqlite"),
DBPath: getEnv("DB_PATH", "./messenger.db"),
JWTSecret: []byte(getEnv("JWT_SECRET", "default-secret-key-change-me")),
JWTExpiryHours: getEnvAsInt64("JWT_EXPIRY_HOURS", 720),
EncryptionKey: []byte(getEnv("ENCRYPTION_KEY", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")),
StoragePath: getEnv("STORAGE_PATH", "./storage/attachments"),
MaxFileSizeMB: getEnvAsInt64("MAX_FILE_SIZE_MB", 20),
CORSAllowedOrigins: strings.Split(getEnv("CORS_ALLOWED_ORIGINS", "http://localhost:3000"), ","),
}
// Валидация ключа шифрования (должен быть 32 байта для AES-256)
if len(cfg.EncryptionKey) != 32 {
log.Printf("Warning: ENCRYPTION_KEY length is %d bytes, expected 32 bytes for AES-256", len(cfg.EncryptionKey))
}
// Создаём директорию для файлов если её нет
if err := os.MkdirAll(cfg.StoragePath, 0755); err != nil {
log.Printf("Warning: failed to create storage directory: %v", err)
}
return cfg
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvAsInt64(key string, defaultValue int64) int64 {
if value := os.Getenv(key); value != "" {
if intVal, err := strconv.ParseInt(value, 10, 64); err == nil {
return intVal
}
}
return defaultValue
}

172
internal/crypto/aes.go Normal file
View File

@@ -0,0 +1,172 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
type Encryptor struct {
key []byte // 32 bytes for AES-256
}
// NewEncryptor создает новый шифровальщик с заданным ключом
// key должен быть 32 байта для AES-256
func NewEncryptor(key []byte) (*Encryptor, error) {
if len(key) != 32 {
return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
}
return &Encryptor{
key: key,
}, nil
}
// Encrypt шифрует plaintext с использованием AES-GCM
// Возвращает base64 encoded ciphertext с appended nonce
func (e *Encryptor) Encrypt(plaintext []byte) (string, error) {
if len(plaintext) == 0 {
return "", nil
}
// Создаем блок AES
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
// Создаем GCM режим
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
// Генерируем случайный nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Шифруем: ciphertext = nonce + encrypted_data
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
// Кодируем в base64 для хранения в БД
encoded := base64.StdEncoding.EncodeToString(ciphertext)
return encoded, nil
}
// Decrypt расшифровывает base64 encoded ciphertext
// Ожидает формат: nonce + encrypted_data
func (e *Encryptor) Decrypt(encodedCiphertext string) ([]byte, error) {
if encodedCiphertext == "" {
return []byte{}, nil
}
// Декодируем из base64
ciphertext, err := base64.StdEncoding.DecodeString(encodedCiphertext)
if err != nil {
return nil, fmt.Errorf("failed to decode base64: %w", err)
}
if len(ciphertext) == 0 {
return []byte{}, nil
}
// Создаем блок AES
block, err := aes.NewCipher(e.key)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Создаем GCM режим
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
nonceSize := gcm.NonceSize()
if len(ciphertext) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}
// Извлекаем nonce и зашифрованные данные
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
// Расшифровываем
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt: %w", err)
}
return plaintext, nil
}
// EncryptString удобная обертка для шифрования строк
func (e *Encryptor) EncryptString(plaintext string) (string, error) {
return e.Encrypt([]byte(plaintext))
}
// DecryptString удобная обертка для расшифровки в строку
func (e *Encryptor) DecryptString(encodedCiphertext string) (string, error) {
plaintext, err := e.Decrypt(encodedCiphertext)
if err != nil {
return "", err
}
return string(plaintext), nil
}
// MustEncrypt паникует при ошибке шифрования (для инициализации)
func (e *Encryptor) MustEncrypt(plaintext string) string {
encrypted, err := e.EncryptString(plaintext)
if err != nil {
panic(fmt.Sprintf("encryption failed: %v", err))
}
return encrypted
}
// MustDecrypt паникует при ошибке расшифровки (для инициализации)
func (e *Encryptor) MustDecrypt(encodedCiphertext string) string {
decrypted, err := e.DecryptString(encodedCiphertext)
if err != nil {
panic(fmt.Sprintf("decryption failed: %v", err))
}
return decrypted
}
// GenerateRandomKey генерирует случайный 32-байтный ключ для AES-256
func GenerateRandomKey() ([]byte, error) {
key := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("failed to generate random key: %w", err)
}
return key, nil
}
// GenerateRandomKeyString генерирует случайный ключ в виде hex строки
func GenerateRandomKeyString() (string, error) {
key, err := GenerateRandomKey()
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
// KeyFromString создает ключ из строки (ожидается 32 байта в base64 или plain)
func KeyFromString(keyStr string) ([]byte, error) {
// Пробуем декодировать как base64
key, err := base64.StdEncoding.DecodeString(keyStr)
if err == nil && len(key) == 32 {
return key, nil
}
// Если не base64, используем как есть (должно быть 32 байта)
if len(keyStr) == 32 {
return []byte(keyStr), nil
}
return nil, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyStr))
}

177
internal/crypto/aes_test.go Normal file
View File

@@ -0,0 +1,177 @@
package crypto
import (
"testing"
)
func TestEncryptDecrypt(t *testing.T) {
// Генерируем тестовый ключ
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
encryptor, err := NewEncryptor(key)
if err != nil {
t.Fatalf("Failed to create encryptor: %v", err)
}
testCases := []struct {
name string
plaintext string
}{
{"Empty string", ""},
{"Simple text", "Hello, World!"},
{"Russian text", "Привет, мир!"},
{"Long text", "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."},
{"Special chars", "!@#$%^&*()_+-=[]{}|;:',.<>?/`~"},
{"Emoji", "Hello 👋 World 🌍"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Шифруем
encrypted, err := encryptor.EncryptString(tc.plaintext)
if err != nil {
t.Fatalf("Encryption failed: %v", err)
}
// Проверяем, что зашифрованный текст не пустой (для непустого plaintext)
if tc.plaintext != "" && encrypted == "" {
t.Error("Encrypted text is empty for non-empty plaintext")
}
// Расшифровываем
decrypted, err := encryptor.DecryptString(encrypted)
if err != nil {
t.Fatalf("Decryption failed: %v", err)
}
// Проверяем соответствие
if decrypted != tc.plaintext {
t.Errorf("Decrypted text doesn't match original.\nExpected: %s\nGot: %s", tc.plaintext, decrypted)
}
})
}
}
func TestEncryptDecryptBinary(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
encryptor, err := NewEncryptor(key)
if err != nil {
t.Fatalf("Failed to create encryptor: %v", err)
}
// Тестируем бинарные данные
binaryData := []byte{0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}
encrypted, err := encryptor.Encrypt(binaryData)
if err != nil {
t.Fatalf("Encryption failed: %v", err)
}
decrypted, err := encryptor.Decrypt(encrypted)
if err != nil {
t.Fatalf("Decryption failed: %v", err)
}
if len(decrypted) != len(binaryData) {
t.Errorf("Length mismatch: expected %d, got %d", len(binaryData), len(decrypted))
}
for i := range binaryData {
if decrypted[i] != binaryData[i] {
t.Errorf("Byte mismatch at index %d: expected %d, got %d", i, binaryData[i], decrypted[i])
}
}
}
func TestInvalidKey(t *testing.T) {
// Тестируем ключ неправильной длины
invalidKey := make([]byte, 16)
_, err := NewEncryptor(invalidKey)
if err == nil {
t.Error("Expected error for invalid key length, got nil")
}
}
func TestDecryptInvalidData(t *testing.T) {
key := make([]byte, 32)
encryptor, _ := NewEncryptor(key)
// Тестируем расшифровку некорректных данных
invalidData := "not-a-valid-base64"
_, err := encryptor.DecryptString(invalidData)
if err == nil {
t.Error("Expected error for invalid base64, got nil")
}
// Тестируем расшифровку слишком коротких данных
shortData := "aGVsbG8=" // "hello" в base64
_, err = encryptor.DecryptString(shortData)
if err == nil {
t.Error("Expected error for too short ciphertext, got nil")
}
}
func TestGenerateRandomKey(t *testing.T) {
key1, err := GenerateRandomKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
if len(key1) != 32 {
t.Errorf("Expected key length 32, got %d", len(key1))
}
key2, err := GenerateRandomKey()
if err != nil {
t.Fatalf("Failed to generate second key: %v", err)
}
// Проверяем, что ключи разные
same := true
for i := range key1 {
if key1[i] != key2[i] {
same = false
break
}
}
if same {
t.Error("Generated keys are identical")
}
}
func BenchmarkEncrypt(b *testing.B) {
key := make([]byte, 32)
encryptor, _ := NewEncryptor(key)
plaintext := "This is a test message that will be encrypted repeatedly"
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := encryptor.EncryptString(plaintext)
if err != nil {
b.Fatalf("Encryption failed: %v", err)
}
}
}
func BenchmarkDecrypt(b *testing.B) {
key := make([]byte, 32)
encryptor, _ := NewEncryptor(key)
plaintext := "This is a test message that will be encrypted repeatedly"
encrypted, _ := encryptor.EncryptString(plaintext)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := encryptor.DecryptString(encrypted)
if err != nil {
b.Fatalf("Decryption failed: %v", err)
}
}
}

View File

@@ -0,0 +1,31 @@
package models
import (
"time"
)
type Attachment struct {
ID int64 `json:"id"`
MessageID *int64 `json:"message_id,omitempty"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
StoragePath string `json:"-"` // never send to client
MimeType string `json:"mime_type"`
UploadedAt time.Time `json:"uploaded_at"`
}
// AttachmentResponse используется для отправки информации о файле клиенту
type AttachmentResponse struct {
ID int64 `json:"id"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
MimeType string `json:"mime_type"`
UploadURL string `json:"upload_url"` // URL для скачивания
UploadedAt time.Time `json:"uploaded_at"`
}
// UploadFileRequest DTO для загрузки файла
type UploadFileRequest struct {
ChatID int64 `json:"chat_id"`
// файл будет в multipart form
}

70
internal/models/chat.go Normal file
View File

@@ -0,0 +1,70 @@
package models
import (
"time"
)
type ChatType string
const (
ChatTypePrivate ChatType = "private"
ChatTypeGroup ChatType = "group"
)
type MemberRole string
const (
MemberRoleMember MemberRole = "member"
MemberRoleAdmin MemberRole = "admin"
)
type Chat struct {
ID int64 `json:"id"`
Type ChatType `json:"type"`
Title *string `json:"title,omitempty"` // только для групп
CreatedAt time.Time `json:"created_at"`
}
type ChatMember struct {
ChatID int64 `json:"chat_id"`
UserID int64 `json:"user_id"`
Role MemberRole `json:"role"`
JoinedAt time.Time `json:"joined_at"`
}
// ChatWithDetails используется для возврата чата с дополнительной информацией
type ChatWithDetails struct {
ID int64 `json:"id"`
Type ChatType `json:"type"`
Title *string `json:"title,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastMessage *Message `json:"last_message,omitempty"`
UnreadCount int `json:"unread_count,omitempty"`
ParticipantIDs []int64 `json:"participant_ids,omitempty"`
}
// CreatePrivateChatRequest DTO для создания личного чата
type CreatePrivateChatRequest struct {
TargetLogin string `json:"target_login"`
}
// CreateGroupChatRequest DTO для создания группового чата
type CreateGroupChatRequest struct {
Title string `json:"title"`
MemberLogins []string `json:"member_logins"`
}
// AddMembersRequest DTO для добавления участников в группу
type AddMembersRequest struct {
UserLogins []string `json:"user_logins"`
}
// UpdateMemberRoleRequest DTO для обновления роли участника
type UpdateMemberRoleRequest struct {
Role MemberRole `json:"role"`
}
// UpdateChatTitleRequest DTO для обновления названия чата
type UpdateChatTitleRequest struct {
Title string `json:"title"`
}

View File

@@ -0,0 +1,45 @@
package models
import (
"time"
)
type Message struct {
ID int64 `json:"id"`
ChatID int64 `json:"chat_id"`
SenderID int64 `json:"sender_id"`
EncryptedBody []byte `json:"-"` // never send raw encrypted body
Plaintext string `json:"plaintext,omitempty"` // используется только после расшифровки
AttachmentID *int64 `json:"attachment_id,omitempty"`
IsRead bool `json:"is_read"`
CreatedAt time.Time `json:"created_at"`
}
// MessageResponse используется для отправки сообщений клиенту
type MessageResponse struct {
ID int64 `json:"id"`
ChatID int64 `json:"chat_id"`
SenderID int64 `json:"sender_id"`
Plaintext string `json:"plaintext"` // расшифрованный текст
Attachment *Attachment `json:"attachment,omitempty"`
IsRead bool `json:"is_read"`
CreatedAt time.Time `json:"created_at"`
}
// SendMessageRequest DTO для отправки сообщения
type SendMessageRequest struct {
ChatID int64 `json:"chat_id"`
Plaintext string `json:"plaintext"`
AttachmentID *int64 `json:"attachment_id,omitempty"`
}
// EditMessageRequest DTO для редактирования сообщения
type EditMessageRequest struct {
Plaintext string `json:"plaintext"`
}
// GetMessagesRequest DTO для запроса истории
type GetMessagesRequest struct {
Limit int `json:"limit"` // количество сообщений
Before string `json:"before"` // timestamp (RFC3339)
}

View File

@@ -0,0 +1,16 @@
package models
type Profile struct {
UserID int64 `json:"user_id"`
DisplayName *string `json:"display_name,omitempty"`
Bio *string `json:"bio,omitempty"`
AvatarURL *string `json:"avatar_url,omitempty"`
}
// ProfileWithUser объединяет профиль с безопасными данными пользователя
type ProfileWithUser struct {
User *SafeUser `json:"user"`
DisplayName *string `json:"display_name,omitempty"`
Bio *string `json:"bio,omitempty"`
AvatarURL *string `json:"avatar_url,omitempty"`
}

45
internal/models/user.go Normal file
View File

@@ -0,0 +1,45 @@
package models
import (
"time"
)
type UserRole string
const (
RoleUser UserRole = "user"
RoleGlobalAdmin UserRole = "global_admin"
)
type User struct {
ID int64 `json:"id"`
Login string `json:"login"`
PasswordHash string `json:"-"` // never send to client
Role UserRole `json:"role"`
LastSeen *time.Time `json:"last_seen,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SafeUser возвращает пользователя без чувствительных данных
type SafeUser struct {
ID int64 `json:"id"`
Login string `json:"login"`
Role UserRole `json:"role"`
LastSeen *time.Time `json:"last_seen,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
func (u *User) ToSafe() *SafeUser {
return &SafeUser{
ID: u.ID,
Login: u.Login,
Role: u.Role,
LastSeen: u.LastSeen,
CreatedAt: u.CreatedAt,
}
}
// IsGlobalAdmin проверяет, является ли пользователь глобальным админом
func (u *User) IsGlobalAdmin() bool {
return u.Role == RoleGlobalAdmin
}

View File

@@ -0,0 +1,40 @@
package logger
import (
"log/slog"
"os"
)
var Log *slog.Logger
func Init(environment string) {
var handler slog.Handler
if environment == "production" {
handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelInfo,
})
} else {
handler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
})
}
Log = slog.New(handler)
slog.SetDefault(Log)
}
func Info(msg string, args ...any) {
Log.Info(msg, args...)
}
func Debug(msg string, args ...any) {
Log.Debug(msg, args...)
}
func Error(msg string, args ...any) {
Log.Error(msg, args...)
}
func Warn(msg string, args ...any) {
Log.Warn(msg, args...)
}

View File

@@ -0,0 +1,52 @@
package validator
import (
"regexp"
"strings"
"unicode"
)
// ValidateLogin проверяет логин (только буквы, цифры, underscore, от 3 до 32 символов)
func ValidateLogin(login string) bool {
if len(login) < 3 || len(login) > 32 {
return false
}
matched, _ := regexp.MatchString("^[a-zA-Z0-9_]+$", login)
return matched
}
// ValidatePassword проверяет пароль (минимум 6 символов)
func ValidatePassword(password string) bool {
return len(password) >= 6
}
// SanitizeString очищает строку от лишних пробелов
func SanitizeString(s string) string {
return strings.TrimSpace(s)
}
// ValidateDisplayName проверяет отображаемое имя (не более 100 символов)
func ValidateDisplayName(name string) bool {
return len(name) <= 100
}
// ValidateBio проверяет биографию (не более 500 символов)
func ValidateBio(bio string) bool {
return len(bio) <= 500
}
// IsValidUUID проверяет валидность UUID
func IsValidUUID(uuid string) bool {
matched, _ := regexp.MatchString("^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$", uuid)
return matched
}
// ContainsOnlyPrintable проверяет, что строка содержит только печатаемые символы
func ContainsOnlyPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(r) {
return false
}
}
return true
}

View File

@@ -0,0 +1,63 @@
package repository
import (
"context"
"time"
"messenger/internal/models"
)
type UserRepository interface {
Create(ctx context.Context, user *models.User) error
FindByID(ctx context.Context, id int64) (*models.User, error)
FindByLogin(ctx context.Context, login string) (*models.User, error)
SearchByLogin(ctx context.Context, query string, limit int) ([]*models.User, error)
UpdateLastSeen(ctx context.Context, userID int64) error
UpdateRole(ctx context.Context, userID int64, role models.UserRole) error
Delete(ctx context.Context, userID int64) error
Exists(ctx context.Context, login string) (bool, error)
}
type ProfileRepository interface {
Create(ctx context.Context, profile *models.Profile) error
FindByUserID(ctx context.Context, userID int64) (*models.Profile, error)
Update(ctx context.Context, profile *models.Profile) error
UpdateAvatar(ctx context.Context, userID int64, avatarURL *string) error
}
type ChatRepository interface {
Create(ctx context.Context, chat *models.Chat) error
FindByID(ctx context.Context, id int64) (*models.Chat, error)
GetUserChats(ctx context.Context, userID int64) ([]*models.Chat, error)
UpdateTitle(ctx context.Context, chatID int64, title string) error
Delete(ctx context.Context, chatID int64) error
// Member operations
AddMember(ctx context.Context, chatID, userID int64, role models.MemberRole) error
RemoveMember(ctx context.Context, chatID, userID int64) error
GetMembers(ctx context.Context, chatID int64) ([]*models.ChatMember, error)
GetMemberRole(ctx context.Context, chatID, userID int64) (*models.MemberRole, error)
UpdateMemberRole(ctx context.Context, chatID, userID int64, role models.MemberRole) error
IsMember(ctx context.Context, chatID, userID int64) (bool, error)
// Private chat specific
FindPrivateChat(ctx context.Context, userID1, userID2 int64) (*models.Chat, error)
}
type MessageRepository interface {
Create(ctx context.Context, message *models.Message) error
FindByID(ctx context.Context, id int64) (*models.Message, error)
GetChatHistory(ctx context.Context, chatID int64, limit int, before time.Time) ([]*models.Message, error)
GetLastMessage(ctx context.Context, chatID int64) (*models.Message, error)
MarkAsRead(ctx context.Context, messageID int64) error
Delete(ctx context.Context, messageID int64) error
Update(ctx context.Context, message *models.Message) error
GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error)
}
type AttachmentRepository interface {
Create(ctx context.Context, attachment *models.Attachment) error
FindByID(ctx context.Context, id int64) (*models.Attachment, error)
UpdateMessageID(ctx context.Context, attachmentID, messageID int64) error
Delete(ctx context.Context, attachmentID int64) error
GetByMessageID(ctx context.Context, messageID int64) (*models.Attachment, error)
}

View File

@@ -0,0 +1,133 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"messenger/internal/models"
)
type AttachmentRepository struct {
db *DB
}
func NewAttachmentRepository(db *DB) *AttachmentRepository {
return &AttachmentRepository{db: db}
}
func (r *AttachmentRepository) Create(ctx context.Context, attachment *models.Attachment) error {
query := `
INSERT INTO attachments (message_id, file_name, file_size, storage_path, mime_type, uploaded_at)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id
`
err := r.db.QueryRowContext(ctx, query,
attachment.MessageID, attachment.FileName, attachment.FileSize,
attachment.StoragePath, attachment.MimeType, attachment.UploadedAt,
).Scan(&attachment.ID)
if err != nil {
return fmt.Errorf("failed to create attachment: %w", err)
}
return nil
}
func (r *AttachmentRepository) FindByID(ctx context.Context, id int64) (*models.Attachment, error) {
query := `
SELECT id, message_id, file_name, file_size, storage_path, mime_type, uploaded_at
FROM attachments
WHERE id = ?
`
var attachment models.Attachment
var messageID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, id).Scan(
&attachment.ID, &messageID, &attachment.FileName, &attachment.FileSize,
&attachment.StoragePath, &attachment.MimeType, &attachment.UploadedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find attachment by id: %w", err)
}
if messageID.Valid {
attachment.MessageID = &messageID.Int64
}
return &attachment, nil
}
func (r *AttachmentRepository) UpdateMessageID(ctx context.Context, attachmentID, messageID int64) error {
query := `UPDATE attachments SET message_id = ? WHERE id = ?`
result, err := r.db.ExecContext(ctx, query, messageID, attachmentID)
if err != nil {
return fmt.Errorf("failed to update attachment message_id: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("attachment not found: %d", attachmentID)
}
return nil
}
func (r *AttachmentRepository) Delete(ctx context.Context, attachmentID int64) error {
query := `DELETE FROM attachments WHERE id = ?`
result, err := r.db.ExecContext(ctx, query, attachmentID)
if err != nil {
return fmt.Errorf("failed to delete attachment: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("attachment not found: %d", attachmentID)
}
return nil
}
func (r *AttachmentRepository) GetByMessageID(ctx context.Context, messageID int64) (*models.Attachment, error) {
query := `
SELECT id, message_id, file_name, file_size, storage_path, mime_type, uploaded_at
FROM attachments
WHERE message_id = ?
`
var attachment models.Attachment
var msgID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, messageID).Scan(
&attachment.ID, &msgID, &attachment.FileName, &attachment.FileSize,
&attachment.StoragePath, &attachment.MimeType, &attachment.UploadedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get attachment by message_id: %w", err)
}
if msgID.Valid {
attachment.MessageID = &msgID.Int64
}
return &attachment, nil
}

View File

@@ -0,0 +1,261 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"messenger/internal/models"
)
type ChatRepository struct {
db *DB
}
func NewChatRepository(db *DB) *ChatRepository {
return &ChatRepository{db: db}
}
func (r *ChatRepository) Create(ctx context.Context, chat *models.Chat) error {
query := `
INSERT INTO chats (type, title, created_at)
VALUES (?, ?, ?)
RETURNING id
`
err := r.db.QueryRowContext(ctx, query, chat.Type, chat.Title, chat.CreatedAt).Scan(&chat.ID)
if err != nil {
return fmt.Errorf("failed to create chat: %w", err)
}
return nil
}
func (r *ChatRepository) FindByID(ctx context.Context, id int64) (*models.Chat, error) {
query := `
SELECT id, type, title, created_at
FROM chats
WHERE id = ?
`
var chat models.Chat
err := r.db.QueryRowContext(ctx, query, id).Scan(&chat.ID, &chat.Type, &chat.Title, &chat.CreatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find chat by id: %w", err)
}
return &chat, nil
}
func (r *ChatRepository) GetUserChats(ctx context.Context, userID int64) ([]*models.Chat, error) {
query := `
SELECT c.id, c.type, c.title, c.created_at
FROM chats c
INNER JOIN chat_members cm ON c.id = cm.chat_id
WHERE cm.user_id = ?
ORDER BY c.created_at DESC
`
rows, err := r.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("failed to get user chats: %w", err)
}
defer rows.Close()
var chats []*models.Chat
for rows.Next() {
var chat models.Chat
err := rows.Scan(&chat.ID, &chat.Type, &chat.Title, &chat.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan chat: %w", err)
}
chats = append(chats, &chat)
}
return chats, nil
}
func (r *ChatRepository) UpdateTitle(ctx context.Context, chatID int64, title string) error {
query := `UPDATE chats SET title = ? WHERE id = ? AND type = 'group'`
result, err := r.db.ExecContext(ctx, query, title, chatID)
if err != nil {
return fmt.Errorf("failed to update chat title: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("chat not found or not a group: %d", chatID)
}
return nil
}
func (r *ChatRepository) Delete(ctx context.Context, chatID int64) error {
query := `DELETE FROM chats WHERE id = ?`
result, err := r.db.ExecContext(ctx, query, chatID)
if err != nil {
return fmt.Errorf("failed to delete chat: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("chat not found: %d", chatID)
}
return nil
}
func (r *ChatRepository) AddMember(ctx context.Context, chatID, userID int64, role models.MemberRole) error {
query := `
INSERT INTO chat_members (chat_id, user_id, role, joined_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
`
_, err := r.db.ExecContext(ctx, query, chatID, userID, role)
if err != nil {
return fmt.Errorf("failed to add member: %w", err)
}
return nil
}
func (r *ChatRepository) RemoveMember(ctx context.Context, chatID, userID int64) error {
query := `DELETE FROM chat_members WHERE chat_id = ? AND user_id = ?`
result, err := r.db.ExecContext(ctx, query, chatID, userID)
if err != nil {
return fmt.Errorf("failed to remove member: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("member not found in chat")
}
return nil
}
func (r *ChatRepository) GetMembers(ctx context.Context, chatID int64) ([]*models.ChatMember, error) {
query := `
SELECT chat_id, user_id, role, joined_at
FROM chat_members
WHERE chat_id = ?
ORDER BY joined_at ASC
`
rows, err := r.db.QueryContext(ctx, query, chatID)
if err != nil {
return nil, fmt.Errorf("failed to get members: %w", err)
}
defer rows.Close()
var members []*models.ChatMember
for rows.Next() {
var member models.ChatMember
err := rows.Scan(&member.ChatID, &member.UserID, &member.Role, &member.JoinedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan member: %w", err)
}
members = append(members, &member)
}
return members, nil
}
func (r *ChatRepository) GetMemberRole(ctx context.Context, chatID, userID int64) (*models.MemberRole, error) {
query := `SELECT role FROM chat_members WHERE chat_id = ? AND user_id = ?`
var role models.MemberRole
err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&role)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get member role: %w", err)
}
return &role, nil
}
func (r *ChatRepository) UpdateMemberRole(ctx context.Context, chatID, userID int64, role models.MemberRole) error {
query := `UPDATE chat_members SET role = ? WHERE chat_id = ? AND user_id = ?`
result, err := r.db.ExecContext(ctx, query, role, chatID, userID)
if err != nil {
return fmt.Errorf("failed to update member role: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("member not found in chat")
}
return nil
}
func (r *ChatRepository) IsMember(ctx context.Context, chatID, userID int64) (bool, error) {
query := `SELECT EXISTS(SELECT 1 FROM chat_members WHERE chat_id = ? AND user_id = ?)`
var exists bool
err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check membership: %w", err)
}
return exists, nil
}
func (r *ChatRepository) FindPrivateChat(ctx context.Context, userID1, userID2 int64) (*models.Chat, error) {
query := `
SELECT c.id, c.type, c.title, c.created_at
FROM chats c
INNER JOIN chat_members cm1 ON c.id = cm1.chat_id
INNER JOIN chat_members cm2 ON c.id = cm2.chat_id
WHERE c.type = 'private'
AND cm1.user_id = ?
AND cm2.user_id = ?
AND c.id IN (
SELECT chat_id
FROM chat_members
WHERE user_id IN (?, ?)
GROUP BY chat_id
HAVING COUNT(DISTINCT user_id) = 2
)
LIMIT 1
`
var chat models.Chat
err := r.db.QueryRowContext(ctx, query, userID1, userID2, userID1, userID2).Scan(
&chat.ID, &chat.Type, &chat.Title, &chat.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find private chat: %w", err)
}
return &chat, nil
}

View File

@@ -0,0 +1,52 @@
package sqlite
import (
"database/sql"
"fmt"
"messenger/internal/pkg/logger"
"time"
_ "modernc.org/sqlite"
)
type DB struct {
*sql.DB
}
func NewDB(dbPath string) (*DB, error) {
// Подключение к SQLite с оптимизациями
db, err := sql.Open("sqlite", fmt.Sprintf("%s?_journal=WAL&_foreign_keys=on&_busy_timeout=5000", dbPath))
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Настройка пула соединений
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(10)
db.SetConnMaxLifetime(5 * time.Minute)
// Проверка соединения
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
// Автоматически выполняем миграции
if err := RunMigrations(db, "./migrations"); err != nil {
logger.Error("Failed to run migrations", "error", err)
// Не возвращаем ошибку, продолжаем работу
}
logger.Info("SQLite database connected", "path", dbPath)
return &DB{DB: db}, nil
}
func (db *DB) Close() error {
logger.Info("Closing SQLite database connection")
return db.DB.Close()
}
// BeginTx начинает транзакцию
func (db *DB) BeginTx() (*sql.Tx, error) {
return db.DB.Begin()
}

View File

@@ -0,0 +1,268 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"messenger/internal/models"
"time"
)
type MessageRepository struct {
db *DB
}
func NewMessageRepository(db *DB) *MessageRepository {
return &MessageRepository{db: db}
}
func (r *MessageRepository) Create(ctx context.Context, message *models.Message) error {
query := `
INSERT INTO messages (chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id
`
err := r.db.QueryRowContext(ctx, query,
message.ChatID, message.SenderID, message.EncryptedBody,
message.AttachmentID, message.IsRead, message.CreatedAt,
).Scan(&message.ID)
if err != nil {
return fmt.Errorf("failed to create message: %w", err)
}
// Если есть attachment, обновляем его message_id
if message.AttachmentID != nil {
updateQuery := `UPDATE attachments SET message_id = ? WHERE id = ?`
_, err = r.db.ExecContext(ctx, updateQuery, message.ID, *message.AttachmentID)
if err != nil {
return fmt.Errorf("failed to link attachment to message: %w", err)
}
}
return nil
}
func (r *MessageRepository) FindByID(ctx context.Context, id int64) (*models.Message, error) {
query := `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE id = ?
`
var message models.Message
var attachmentID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, id).Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find message by id: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
return &message, nil
}
func (r *MessageRepository) GetChatHistory(ctx context.Context, chatID int64, limit int, before time.Time) ([]*models.Message, error) {
var query string
var rows *sql.Rows
var err error
// Если before не zero time, используем пагинацию
if !before.IsZero() {
query = `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ? AND created_at < ?
ORDER BY created_at DESC
LIMIT ?
`
rows, err = r.db.QueryContext(ctx, query, chatID, before, limit)
} else {
query = `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ?
ORDER BY created_at DESC
LIMIT ?
`
rows, err = r.db.QueryContext(ctx, query, chatID, limit)
}
if err != nil {
return nil, fmt.Errorf("failed to get chat history: %w", err)
}
defer rows.Close()
var messages []*models.Message
for rows.Next() {
var message models.Message
var attachmentID sql.NullInt64
err := rows.Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan message: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
messages = append(messages, &message)
}
// Переворачиваем, чтобы получить в хронологическом порядке
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
return messages, nil
}
func (r *MessageRepository) GetLastMessage(ctx context.Context, chatID int64) (*models.Message, error) {
query := `
SELECT id, chat_id, sender_id, encrypted_body, attachment_id, is_read, created_at
FROM messages
WHERE chat_id = ?
ORDER BY created_at DESC
LIMIT 1
`
var message models.Message
var attachmentID sql.NullInt64
err := r.db.QueryRowContext(ctx, query, chatID).Scan(
&message.ID, &message.ChatID, &message.SenderID, &message.EncryptedBody,
&attachmentID, &message.IsRead, &message.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get last message: %w", err)
}
if attachmentID.Valid {
message.AttachmentID = &attachmentID.Int64
}
return &message, nil
}
func (r *MessageRepository) MarkAsRead(ctx context.Context, messageID int64) error {
query := `UPDATE messages SET is_read = 1 WHERE id = ? AND is_read = 0`
result, err := r.db.ExecContext(ctx, query, messageID)
if err != nil {
return fmt.Errorf("failed to mark message as read: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
// Сообщение уже прочитано или не найдено - не ошибка
return nil
}
return nil
}
func (r *MessageRepository) Delete(ctx context.Context, messageID int64) error {
// Сначала получаем attachment_id, чтобы потом удалить файл
var attachmentID sql.NullInt64
query := `SELECT attachment_id FROM messages WHERE id = ?`
err := r.db.QueryRowContext(ctx, query, messageID).Scan(&attachmentID)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to get attachment info: %w", err)
}
// Удаляем сообщение (attachment останется, но без связи с сообщением)
deleteQuery := `DELETE FROM messages WHERE id = ?`
result, err := r.db.ExecContext(ctx, deleteQuery, messageID)
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("message not found: %d", messageID)
}
// Если был attachment, обновляем его message_id на NULL
if attachmentID.Valid {
updateQuery := `UPDATE attachments SET message_id = NULL WHERE id = ?`
_, err = r.db.ExecContext(ctx, updateQuery, attachmentID.Int64)
if err != nil {
return fmt.Errorf("failed to unlink attachment: %w", err)
}
}
return nil
}
func (r *MessageRepository) Update(ctx context.Context, message *models.Message) error {
query := `
UPDATE messages
SET encrypted_body = ?, attachment_id = ?, is_read = ?
WHERE id = ? AND sender_id = ?
`
result, err := r.db.ExecContext(ctx, query,
message.EncryptedBody, message.AttachmentID, message.IsRead,
message.ID, message.SenderID,
)
if err != nil {
return fmt.Errorf("failed to update message: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("message not found or not owned by user: %d", message.ID)
}
return nil
}
func (r *MessageRepository) GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error) {
query := `
SELECT COUNT(*)
FROM messages m
WHERE m.chat_id = ?
AND m.sender_id != ?
AND m.is_read = 0
`
var count int
err := r.db.QueryRowContext(ctx, query, chatID, userID).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to get unread count: %w", err)
}
return count, nil
}

View File

@@ -0,0 +1,46 @@
package sqlite
import (
"database/sql"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
)
func RunMigrations(db *sql.DB, migrationsPath string) error {
// Находим все .up.sql файлы
files, err := filepath.Glob(filepath.Join(migrationsPath, "*.up.sql"))
if err != nil {
return fmt.Errorf("failed to find migrations: %w", err)
}
for _, file := range files {
fmt.Printf("Applying migration: %s\n", file)
content, err := ioutil.ReadFile(file)
if err != nil {
return fmt.Errorf("failed to read migration file %s: %w", file, err)
}
// Разделяем SQL statements по точке с запятой
statements := strings.Split(string(content), ";")
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
// Выполняем SQL
if _, err := db.Exec(stmt); err != nil {
// Игнорируем ошибку "table already exists"
if !strings.Contains(err.Error(), "already exists") {
return fmt.Errorf("failed to execute migration %s: %w\nSQL: %s", file, err, stmt)
}
}
}
}
return nil
}

View File

@@ -0,0 +1,99 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"messenger/internal/models"
)
type ProfileRepository struct {
db *DB
}
func NewProfileRepository(db *DB) *ProfileRepository {
return &ProfileRepository{db: db}
}
func (r *ProfileRepository) Create(ctx context.Context, profile *models.Profile) error {
query := `
INSERT INTO profiles (user_id, display_name, bio, avatar_url)
VALUES (?, ?, ?, ?)
`
_, err := r.db.ExecContext(ctx, query,
profile.UserID, profile.DisplayName, profile.Bio, profile.AvatarURL)
if err != nil {
return fmt.Errorf("failed to create profile: %w", err)
}
return nil
}
func (r *ProfileRepository) FindByUserID(ctx context.Context, userID int64) (*models.Profile, error) {
query := `
SELECT user_id, display_name, bio, avatar_url
FROM profiles
WHERE user_id = ?
`
var profile models.Profile
err := r.db.QueryRowContext(ctx, query, userID).Scan(
&profile.UserID, &profile.DisplayName, &profile.Bio, &profile.AvatarURL,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find profile by user_id: %w", err)
}
return &profile, nil
}
func (r *ProfileRepository) Update(ctx context.Context, profile *models.Profile) error {
query := `
UPDATE profiles
SET display_name = COALESCE(?, display_name),
bio = COALESCE(?, bio)
WHERE user_id = ?
`
result, err := r.db.ExecContext(ctx, query, profile.DisplayName, profile.Bio, profile.UserID)
if err != nil {
return fmt.Errorf("failed to update profile: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("profile not found for user: %d", profile.UserID)
}
return nil
}
func (r *ProfileRepository) UpdateAvatar(ctx context.Context, userID int64, avatarURL *string) error {
query := `UPDATE profiles SET avatar_url = ? WHERE user_id = ?`
result, err := r.db.ExecContext(ctx, query, avatarURL, userID)
if err != nil {
return fmt.Errorf("failed to update avatar: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("profile not found for user: %d", userID)
}
return nil
}

View File

@@ -0,0 +1,197 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"messenger/internal/models"
//"time"
)
type UserRepository struct {
db *DB
}
func NewUserRepository(db *DB) *UserRepository {
return &UserRepository{db: db}
}
func (r *UserRepository) Create(ctx context.Context, user *models.User) error {
query := `
INSERT INTO users (login, password_hash, role, created_at)
VALUES (?, ?, ?, ?)
RETURNING id
`
err := r.db.QueryRowContext(ctx, query, user.Login, user.PasswordHash, user.Role, user.CreatedAt).Scan(&user.ID)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
func (r *UserRepository) FindByID(ctx context.Context, id int64) (*models.User, error) {
query := `
SELECT id, login, password_hash, role, last_seen, created_at
FROM users
WHERE id = ?
`
var user models.User
var lastSeen sql.NullTime
err := r.db.QueryRowContext(ctx, query, id).Scan(
&user.ID, &user.Login, &user.PasswordHash, &user.Role,
&lastSeen, &user.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find user by id: %w", err)
}
if lastSeen.Valid {
user.LastSeen = &lastSeen.Time
}
return &user, nil
}
func (r *UserRepository) FindByLogin(ctx context.Context, login string) (*models.User, error) {
query := `
SELECT id, login, password_hash, role, last_seen, created_at
FROM users
WHERE login = ?
`
var user models.User
var lastSeen sql.NullTime
err := r.db.QueryRowContext(ctx, query, login).Scan(
&user.ID, &user.Login, &user.PasswordHash, &user.Role,
&lastSeen, &user.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to find user by login: %w", err)
}
if lastSeen.Valid {
user.LastSeen = &lastSeen.Time
}
return &user, nil
}
func (r *UserRepository) SearchByLogin(ctx context.Context, query string, limit int) ([]*models.User, error) {
sqlQuery := `
SELECT id, login, role, last_seen, created_at
FROM users
WHERE login LIKE ? || '%'
ORDER BY login
LIMIT ?
`
rows, err := r.db.QueryContext(ctx, sqlQuery, query, limit)
if err != nil {
return nil, fmt.Errorf("failed to search users: %w", err)
}
defer rows.Close()
var users []*models.User
for rows.Next() {
var user models.User
var lastSeen sql.NullTime
err := rows.Scan(&user.ID, &user.Login, &user.Role, &lastSeen, &user.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan user: %w", err)
}
if lastSeen.Valid {
user.LastSeen = &lastSeen.Time
}
users = append(users, &user)
}
return users, nil
}
func (r *UserRepository) UpdateLastSeen(ctx context.Context, userID int64) error {
query := `UPDATE users SET last_seen = CURRENT_TIMESTAMP WHERE id = ?`
result, err := r.db.ExecContext(ctx, query, userID)
if err != nil {
return fmt.Errorf("failed to update last_seen: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("user not found: %d", userID)
}
return nil
}
func (r *UserRepository) UpdateRole(ctx context.Context, userID int64, role models.UserRole) error {
query := `UPDATE users SET role = ? WHERE id = ?`
result, err := r.db.ExecContext(ctx, query, role, userID)
if err != nil {
return fmt.Errorf("failed to update role: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("user not found: %d", userID)
}
return nil
}
func (r *UserRepository) Delete(ctx context.Context, userID int64) error {
query := `DELETE FROM users WHERE id = ?`
result, err := r.db.ExecContext(ctx, query, userID)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("user not found: %d", userID)
}
return nil
}
func (r *UserRepository) Exists(ctx context.Context, login string) (bool, error) {
query := `SELECT EXISTS(SELECT 1 FROM users WHERE login = ?)`
var exists bool
err := r.db.QueryRowContext(ctx, query, login).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check user existence: %w", err)
}
return exists, nil
}

View File

@@ -0,0 +1,180 @@
package service
import (
"context"
"errors"
"messenger/internal/models"
"messenger/internal/pkg/logger"
"messenger/internal/repository"
)
type AdminService struct {
userRepo repository.UserRepository
chatRepo repository.ChatRepository
messageRepo repository.MessageRepository
fileService *FileService
}
func NewAdminService(
userRepo repository.UserRepository,
chatRepo repository.ChatRepository,
messageRepo repository.MessageRepository,
fileService *FileService,
) *AdminService {
return &AdminService{
userRepo: userRepo,
chatRepo: chatRepo,
messageRepo: messageRepo,
fileService: fileService,
}
}
// DeleteUser удаляет пользователя (каскадно)
func (s *AdminService) DeleteUser(ctx context.Context, adminUserID, targetUserID int64) error {
// Проверяем, что администратор имеет права
admin, err := s.userRepo.FindByID(ctx, adminUserID)
if err != nil {
return err
}
if admin == nil || !admin.IsGlobalAdmin() {
return errors.New("only global admin can delete users")
}
// Нельзя удалить самого себя через эту функцию
if adminUserID == targetUserID {
return errors.New("cannot delete yourself")
}
// Проверяем существование пользователя
target, err := s.userRepo.FindByID(ctx, targetUserID)
if err != nil {
return err
}
if target == nil {
return errors.New("user not found")
}
// Удаляем пользователя (каскадное удаление сработает через foreign keys)
if err := s.userRepo.Delete(ctx, targetUserID); err != nil {
logger.Error("Failed to delete user", "error", err)
return errors.New("failed to delete user")
}
logger.Info("User deleted by admin", "admin_id", adminUserID, "target_user_id", targetUserID)
return nil
}
// DeleteMessage удаляет любое сообщение (админское право)
func (s *AdminService) DeleteMessage(ctx context.Context, adminUserID, messageID int64) error {
// Проверяем права администратора
admin, err := s.userRepo.FindByID(ctx, adminUserID)
if err != nil {
return err
}
if admin == nil || !admin.IsGlobalAdmin() {
return errors.New("only global admin can delete any message")
}
if err := s.messageRepo.Delete(ctx, messageID); err != nil {
logger.Error("Failed to delete message by admin", "error", err)
return errors.New("failed to delete message")
}
logger.Info("Message deleted by admin", "admin_id", adminUserID, "message_id", messageID)
return nil
}
// DeleteChat удаляет любой чат (админское право)
func (s *AdminService) DeleteChat(ctx context.Context, adminUserID, chatID int64) error {
// Проверяем права администратора
admin, err := s.userRepo.FindByID(ctx, adminUserID)
if err != nil {
return err
}
if admin == nil || !admin.IsGlobalAdmin() {
return errors.New("only global admin can delete any chat")
}
if err := s.chatRepo.Delete(ctx, chatID); err != nil {
logger.Error("Failed to delete chat by admin", "error", err)
return errors.New("failed to delete chat")
}
logger.Info("Chat deleted by admin", "admin_id", adminUserID, "chat_id", chatID)
return nil
}
// PromoteToAdmin повышает пользователя до глобального администратора
func (s *AdminService) PromoteToAdmin(ctx context.Context, adminUserID, targetUserID int64) error {
// Проверяем права администратора
admin, err := s.userRepo.FindByID(ctx, adminUserID)
if err != nil {
return err
}
if admin == nil || !admin.IsGlobalAdmin() {
return errors.New("only global admin can promote users")
}
if err := s.userRepo.UpdateRole(ctx, targetUserID, models.RoleGlobalAdmin); err != nil {
logger.Error("Failed to promote user to admin", "error", err)
return errors.New("failed to promote user")
}
logger.Info("User promoted to global admin", "admin_id", adminUserID, "target_user_id", targetUserID)
return nil
}
// DemoteFromAdmin понижает глобального администратора до обычного пользователя
func (s *AdminService) DemoteFromAdmin(ctx context.Context, adminUserID, targetUserID int64) error {
// Проверяем права администратора
admin, err := s.userRepo.FindByID(ctx, adminUserID)
if err != nil {
return err
}
if admin == nil || !admin.IsGlobalAdmin() {
return errors.New("only global admin can demote users")
}
// Нельзя понизить самого себя
if adminUserID == targetUserID {
return errors.New("cannot demote yourself")
}
if err := s.userRepo.UpdateRole(ctx, targetUserID, models.RoleUser); err != nil {
logger.Error("Failed to demote user", "error", err)
return errors.New("failed to demote user")
}
logger.Info("User demoted from global admin", "admin_id", adminUserID, "target_user_id", targetUserID)
return nil
}
// GetSystemStats возвращает системную статистику
func (s *AdminService) GetSystemStats(ctx context.Context, adminUserID int64) (map[string]interface{}, error) {
// Проверяем права администратора
admin, err := s.userRepo.FindByID(ctx, adminUserID)
if err != nil {
return nil, err
}
if admin == nil || !admin.IsGlobalAdmin() {
return nil, errors.New("only global admin can view system stats")
}
// Здесь можно добавить реальную статистику из БД
// Например, количество пользователей, сообщений, чатов и т.д.
stats := make(map[string]interface{})
// Это заглушка - в реальном коде нужно добавить методы в репозитории
stats["version"] = "1.0.0"
stats["status"] = "healthy"
logger.Info("System stats viewed by admin", "admin_id", adminUserID)
return stats, nil
}

View File

@@ -0,0 +1,199 @@
package service
import (
"context"
"errors"
"fmt"
"messenger/internal/models"
"messenger/internal/pkg/logger"
"messenger/internal/pkg/validator"
"messenger/internal/repository"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
)
type AuthService struct {
userRepo repository.UserRepository
profileRepo repository.ProfileRepository
jwtSecret []byte
jwtExpiry int64
}
type Claims struct {
UserID int64 `json:"user_id"`
Login string `json:"login"`
Role string `json:"role"`
jwt.RegisteredClaims
}
func NewAuthService(userRepo repository.UserRepository, profileRepo repository.ProfileRepository, jwtSecret []byte, jwtExpiry int64) *AuthService {
return &AuthService{
userRepo: userRepo,
profileRepo: profileRepo,
jwtSecret: jwtSecret,
jwtExpiry: jwtExpiry,
}
}
// Register регистрирует нового пользователя
func (s *AuthService) Register(ctx context.Context, login, password string) (*models.User, string, error) {
// Валидация входных данных
if !validator.ValidateLogin(login) {
return nil, "", errors.New("invalid login: must be 3-32 characters, only letters, numbers and underscore")
}
if !validator.ValidatePassword(password) {
return nil, "", errors.New("invalid password: must be at least 6 characters")
}
// Проверка существования пользователя
exists, err := s.userRepo.Exists(ctx, login)
if err != nil {
logger.Error("Failed to check user existence", "error", err)
return nil, "", errors.New("internal server error")
}
if exists {
return nil, "", errors.New("user already exists")
}
// Хэширование пароля
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
logger.Error("Failed to hash password", "error", err)
return nil, "", errors.New("internal server error")
}
// Создание пользователя
user := &models.User{
Login: login,
PasswordHash: string(passwordHash),
Role: models.RoleUser,
CreatedAt: time.Now(),
}
if err := s.userRepo.Create(ctx, user); err != nil {
logger.Error("Failed to create user", "error", err)
return nil, "", errors.New("internal server error")
}
// Создание профиля
profile := &models.Profile{
UserID: user.ID,
}
if err := s.profileRepo.Create(ctx, profile); err != nil {
logger.Error("Failed to create profile", "error", err)
// Не фатально, профиль можно создать позже
}
// Генерация JWT токена
token, err := s.generateToken(user)
if err != nil {
logger.Error("Failed to generate token", "error", err)
return nil, "", errors.New("internal server error")
}
logger.Info("User registered successfully", "user_id", user.ID, "login", user.Login)
return user, token, nil
}
// Login аутентифицирует пользователя
func (s *AuthService) Login(ctx context.Context, login, password string) (*models.User, string, error) {
// Поиск пользователя
user, err := s.userRepo.FindByLogin(ctx, login)
if err != nil {
logger.Error("Failed to find user", "error", err)
return nil, "", errors.New("internal server error")
}
if user == nil {
return nil, "", errors.New("invalid credentials")
}
// Проверка пароля
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return nil, "", errors.New("invalid credentials")
}
// Обновление last_seen
go func() {
if err := s.userRepo.UpdateLastSeen(context.Background(), user.ID); err != nil {
logger.Error("Failed to update last_seen", "error", err)
}
}()
// Генерация JWT токена
token, err := s.generateToken(user)
if err != nil {
logger.Error("Failed to generate token", "error", err)
return nil, "", errors.New("internal server error")
}
logger.Info("User logged in", "user_id", user.ID, "login", user.Login)
return user, token, nil
}
// ValidateToken проверяет валидность JWT токена и возвращает пользователя
func (s *AuthService) ValidateToken(tokenString string) (*models.User, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return s.jwtSecret, nil
})
if err != nil {
return nil, errors.New("invalid token")
}
if !token.Valid {
return nil, errors.New("invalid token")
}
// Получаем пользователя из БД (на случай, если роль изменилась)
ctx := context.Background()
user, err := s.userRepo.FindByID(ctx, claims.UserID)
if err != nil {
return nil, errors.New("user not found")
}
if user == nil {
return nil, errors.New("user not found")
}
return user, nil
}
// generateToken генерирует JWT токен для пользователя
func (s *AuthService) generateToken(user *models.User) (string, error) {
expirationTime := time.Now().Add(time.Duration(s.jwtExpiry) * time.Hour)
claims := &Claims{
UserID: user.ID,
Login: user.Login,
Role: string(user.Role),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.jwtSecret)
}
// GetUserByID возвращает пользователя по ID
func (s *AuthService) GetUserByID(ctx context.Context, userID int64) (*models.User, error) {
return s.userRepo.FindByID(ctx, userID)
}
// GetUserByLogin возвращает пользователя по логину
func (s *AuthService) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
return s.userRepo.FindByLogin(ctx, login)
}

View File

@@ -0,0 +1,383 @@
package service
import (
"context"
"errors"
//"fmt"
"messenger/internal/models"
"messenger/internal/pkg/logger"
"messenger/internal/repository"
"time"
)
type ChatService struct {
chatRepo repository.ChatRepository
userRepo repository.UserRepository
messageRepo repository.MessageRepository
}
func NewChatService(chatRepo repository.ChatRepository, userRepo repository.UserRepository, messageRepo repository.MessageRepository) *ChatService {
return &ChatService{
chatRepo: chatRepo,
userRepo: userRepo,
messageRepo: messageRepo,
}
}
// CreatePrivateChat создает личный чат между двумя пользователями
func (s *ChatService) CreatePrivateChat(ctx context.Context, userID, targetUserID int64) (*models.Chat, error) {
// Проверка, что чат не существует
existingChat, err := s.chatRepo.FindPrivateChat(ctx, userID, targetUserID)
if err != nil {
logger.Error("Failed to find private chat", "error", err)
return nil, errors.New("internal server error")
}
if existingChat != nil {
return existingChat, nil
}
// Создаем чат
chat := &models.Chat{
Type: models.ChatTypePrivate,
CreatedAt: time.Now(),
}
if err := s.chatRepo.Create(ctx, chat); err != nil {
logger.Error("Failed to create chat", "error", err)
return nil, errors.New("failed to create chat")
}
// Добавляем участников
if err := s.chatRepo.AddMember(ctx, chat.ID, userID, models.MemberRoleMember); err != nil {
return nil, err
}
if err := s.chatRepo.AddMember(ctx, chat.ID, targetUserID, models.MemberRoleMember); err != nil {
return nil, err
}
logger.Info("Private chat created", "chat_id", chat.ID, "user1", userID, "user2", targetUserID)
return chat, nil
}
// CreateGroupChat создает групповой чат
func (s *ChatService) CreateGroupChat(ctx context.Context, creatorID int64, title string, memberLogins []string) (*models.Chat, error) {
if title == "" {
return nil, errors.New("group title is required")
}
// Создаем чат
chat := &models.Chat{
Type: models.ChatTypeGroup,
Title: &title,
CreatedAt: time.Now(),
}
if err := s.chatRepo.Create(ctx, chat); err != nil {
logger.Error("Failed to create group chat", "error", err)
return nil, errors.New("failed to create chat")
}
// Добавляем создателя как админа
if err := s.chatRepo.AddMember(ctx, chat.ID, creatorID, models.MemberRoleAdmin); err != nil {
return nil, err
}
// Добавляем остальных участников
for _, login := range memberLogins {
user, err := s.userRepo.FindByLogin(ctx, login)
if err != nil {
logger.Error("Failed to find user", "login", login, "error", err)
continue
}
if user != nil && user.ID != creatorID {
if err := s.chatRepo.AddMember(ctx, chat.ID, user.ID, models.MemberRoleMember); err != nil {
logger.Error("Failed to add member", "user_id", user.ID, "error", err)
}
}
}
logger.Info("Group chat created", "chat_id", chat.ID, "creator", creatorID, "title", title)
return chat, nil
}
// GetUserChats возвращает все чаты пользователя
func (s *ChatService) GetUserChats(ctx context.Context, userID int64) ([]*models.ChatWithDetails, error) {
chats, err := s.chatRepo.GetUserChats(ctx, userID)
if err != nil {
logger.Error("Failed to get user chats", "error", err)
return nil, errors.New("internal server error")
}
var chatsWithDetails []*models.ChatWithDetails
for _, chat := range chats {
lastMessage, _ := s.messageRepo.GetLastMessage(ctx, chat.ID)
unreadCount, _ := s.messageRepo.GetUnreadCount(ctx, chat.ID, userID)
chatDetail := &models.ChatWithDetails{
ID: chat.ID,
Type: chat.Type,
Title: chat.Title,
CreatedAt: chat.CreatedAt,
UnreadCount: unreadCount,
}
if lastMessage != nil {
chatDetail.LastMessage = lastMessage
}
chatsWithDetails = append(chatsWithDetails, chatDetail)
}
return chatsWithDetails, nil
}
// GetChatByID возвращает чат по ID с проверкой доступа
func (s *ChatService) GetChatByID(ctx context.Context, chatID, userID int64) (*models.Chat, error) {
// Проверяем, что пользователь участник чата
isMember, err := s.chatRepo.IsMember(ctx, chatID, userID)
if err != nil {
return nil, err
}
if !isMember {
return nil, errors.New("access denied")
}
return s.chatRepo.FindByID(ctx, chatID)
}
// GetChatMembers возвращает участников чата
func (s *ChatService) GetChatMembers(ctx context.Context, chatID, userID int64) ([]*models.ChatMember, error) {
// Проверяем доступ
isMember, err := s.chatRepo.IsMember(ctx, chatID, userID)
if err != nil {
return nil, err
}
if !isMember {
return nil, errors.New("access denied")
}
return s.chatRepo.GetMembers(ctx, chatID)
}
// AddMembers добавляет участников в групповой чат
func (s *ChatService) AddMembers(ctx context.Context, chatID, adminID int64, userLogins []string) error {
// Проверяем, что администратор имеет права
role, err := s.chatRepo.GetMemberRole(ctx, chatID, adminID)
if err != nil {
return err
}
if role == nil || (*role != models.MemberRoleAdmin) {
return errors.New("only admins can add members")
}
// Проверяем, что чат групповой
chat, err := s.chatRepo.FindByID(ctx, chatID)
if err != nil {
return err
}
if chat == nil || chat.Type != models.ChatTypeGroup {
return errors.New("only group chats can have members added")
}
// Добавляем участников
for _, login := range userLogins {
user, err := s.userRepo.FindByLogin(ctx, login)
if err != nil || user == nil {
logger.Error("User not found", "login", login)
continue
}
isMember, _ := s.chatRepo.IsMember(ctx, chatID, user.ID)
if !isMember {
if err := s.chatRepo.AddMember(ctx, chatID, user.ID, models.MemberRoleMember); err != nil {
logger.Error("Failed to add member", "user_id", user.ID, "error", err)
}
}
}
logger.Info("Members added to chat", "chat_id", chatID, "admin_id", adminID)
return nil
}
// RemoveMember удаляет участника из группового чата
func (s *ChatService) RemoveMember(ctx context.Context, chatID, adminID, targetID int64) error {
// Проверяем права администратора
role, err := s.chatRepo.GetMemberRole(ctx, chatID, adminID)
if err != nil {
return err
}
if role == nil || (*role != models.MemberRoleAdmin) {
return errors.New("only admins can remove members")
}
// Нельзя удалить самого себя через эту функцию (есть LeaveChat)
if adminID == targetID {
return errors.New("use leave chat instead")
}
// Проверяем, что чат групповой
chat, err := s.chatRepo.FindByID(ctx, chatID)
if err != nil {
return err
}
if chat == nil || chat.Type != models.ChatTypeGroup {
return errors.New("only group chats can have members removed")
}
return s.chatRepo.RemoveMember(ctx, chatID, targetID)
}
// LeaveChat выход из чата
func (s *ChatService) LeaveChat(ctx context.Context, chatID, userID int64) error {
chat, err := s.chatRepo.FindByID(ctx, chatID)
if err != nil {
return err
}
if chat == nil {
return errors.New("chat not found")
}
// Для приватных чатов выход означает удаление чата?
if chat.Type == models.ChatTypePrivate {
// В приватном чате выход не допускается, только удаление
return errors.New("cannot leave private chat")
}
// Проверяем, что пользователь участник
isMember, err := s.chatRepo.IsMember(ctx, chatID, userID)
if err != nil {
return err
}
if !isMember {
return errors.New("not a member of this chat")
}
return s.chatRepo.RemoveMember(ctx, chatID, userID)
}
// UpdateMemberRole обновляет роль участника в группе
func (s *ChatService) UpdateMemberRole(ctx context.Context, chatID, adminID, targetID int64, role models.MemberRole) error {
// Проверяем права администратора
adminRole, err := s.chatRepo.GetMemberRole(ctx, chatID, adminID)
if err != nil {
return err
}
if adminRole == nil || (*adminRole != models.MemberRoleAdmin) {
return errors.New("only admins can update member roles")
}
// Нельзя изменить роль администратора, если он единственный
if targetID == adminID && role != models.MemberRoleAdmin {
members, err := s.chatRepo.GetMembers(ctx, chatID)
if err != nil {
return err
}
adminCount := 0
for _, m := range members {
if m.Role == models.MemberRoleAdmin {
adminCount++
}
}
if adminCount <= 1 {
return errors.New("cannot remove the only admin")
}
}
return s.chatRepo.UpdateMemberRole(ctx, chatID, targetID, role)
}
// UpdateChatTitle обновляет название группового чата
func (s *ChatService) UpdateChatTitle(ctx context.Context, chatID, adminID int64, title string) error {
// Проверяем права администратора
role, err := s.chatRepo.GetMemberRole(ctx, chatID, adminID)
if err != nil {
return err
}
if role == nil || (*role != models.MemberRoleAdmin) {
return errors.New("only admins can update chat title")
}
if title == "" {
return errors.New("title cannot be empty")
}
return s.chatRepo.UpdateTitle(ctx, chatID, title)
}
// DeleteChat удаляет чат (только для глобального админа)
func (s *ChatService) DeleteChat(ctx context.Context, chatID int64, isGlobalAdmin bool) error {
if !isGlobalAdmin {
return errors.New("only global admin can delete chats")
}
return s.chatRepo.Delete(ctx, chatID)
}
// IsMember проверяет, является ли пользователь участником чата
func (s *ChatService) IsMember(ctx context.Context, chatID, userID int64) (bool, error) {
return s.chatRepo.IsMember(ctx, chatID, userID)
}
// GetUserChatsWithDetails возвращает чаты пользователя с информацией о собеседнике
func (s *ChatService) GetUserChatsWithDetails(ctx context.Context, userID int64) ([]*models.ChatWithDetails, error) {
chats, err := s.chatRepo.GetUserChats(ctx, userID)
if err != nil {
return nil, err
}
var chatsWithDetails []*models.ChatWithDetails
for _, chat := range chats {
lastMessage, _ := s.messageRepo.GetLastMessage(ctx, chat.ID)
unreadCount, _ := s.messageRepo.GetUnreadCount(ctx, chat.ID, userID)
chatDetail := &models.ChatWithDetails{
ID: chat.ID,
Type: chat.Type,
CreatedAt: chat.CreatedAt,
UnreadCount: unreadCount,
}
if lastMessage != nil {
chatDetail.LastMessage = lastMessage
}
// Для приватных чатов - подставляем имя собеседника
if chat.Type == models.ChatTypePrivate {
members, err := s.chatRepo.GetMembers(ctx, chat.ID)
if err == nil {
for _, member := range members {
if member.UserID != userID {
otherUser, err := s.userRepo.FindByID(ctx, member.UserID)
if err == nil && otherUser != nil {
title := otherUser.Login
chatDetail.Title = &title
}
break
}
}
}
} else {
// Для групповых чатов - используем заданное название
chatDetail.Title = chat.Title
}
chatsWithDetails = append(chatsWithDetails, chatDetail)
}
return chatsWithDetails, nil
}

View File

@@ -0,0 +1,216 @@
package service
import (
"context"
"errors"
"fmt"
"io"
"messenger/internal/models"
"messenger/internal/pkg/logger"
"messenger/internal/repository"
"mime/multipart"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
)
type FileService struct {
attachmentRepo repository.AttachmentRepository
chatRepo repository.ChatRepository
storagePath string
maxFileSize int64
}
func NewFileService(
attachmentRepo repository.AttachmentRepository,
chatRepo repository.ChatRepository,
storagePath string,
maxFileSizeMB int64,
) *FileService {
return &FileService{
attachmentRepo: attachmentRepo,
chatRepo: chatRepo,
storagePath: storagePath,
maxFileSize: maxFileSizeMB * 1024 * 1024, // Convert to bytes
}
}
// UploadFile загружает файл и создает запись в БД
func (s *FileService) UploadFile(ctx context.Context, chatID, userID int64, fileHeader *multipart.FileHeader) (*models.Attachment, error) {
// Проверяем, что пользователь участник чата
isMember, err := s.chatRepo.IsMember(ctx, chatID, userID)
if err != nil {
return nil, err
}
if !isMember {
return nil, errors.New("you are not a member of this chat")
}
// Проверяем размер файла
if fileHeader.Size > s.maxFileSize {
return nil, fmt.Errorf("file too large: max %d MB", s.maxFileSize/(1024*1024))
}
// Открываем файл
file, err := fileHeader.Open()
if err != nil {
return nil, errors.New("failed to open file")
}
defer file.Close()
// Генерируем уникальное имя файла
ext := filepath.Ext(fileHeader.Filename)
uniqueID := uuid.New().String()
safeFileName := uniqueID + ext
// Создаем поддиректорию по дате
now := time.Now()
subDir := filepath.Join(now.Format("2006"), now.Format("01"))
fullDir := filepath.Join(s.storagePath, subDir)
// Создаем директорию если не существует
if err := os.MkdirAll(fullDir, 0755); err != nil {
logger.Error("Failed to create storage directory", "error", err)
return nil, errors.New("failed to save file")
}
// Полный путь к файлу
filePath := filepath.Join(fullDir, safeFileName)
// Создаем файл на диске
dst, err := os.Create(filePath)
if err != nil {
logger.Error("Failed to create file", "error", err)
return nil, errors.New("failed to save file")
}
defer dst.Close()
// Копируем содержимое
if _, err := io.Copy(dst, file); err != nil {
logger.Error("Failed to copy file", "error", err)
return nil, errors.New("failed to save file")
}
// Определяем MIME тип
mimeType := fileHeader.Header.Get("Content-Type")
if mimeType == "" {
// Простая проверка по расширению
switch strings.ToLower(ext) {
case ".jpg", ".jpeg":
mimeType = "image/jpeg"
case ".png":
mimeType = "image/png"
case ".gif":
mimeType = "image/gif"
case ".pdf":
mimeType = "application/pdf"
case ".txt":
mimeType = "text/plain"
default:
mimeType = "application/octet-stream"
}
}
// Создаем запись в БД
attachment := &models.Attachment{
FileName: fileHeader.Filename,
FileSize: fileHeader.Size,
StoragePath: filePath,
MimeType: mimeType,
UploadedAt: now,
}
if err := s.attachmentRepo.Create(ctx, attachment); err != nil {
// Если не удалось сохранить в БД, удаляем файл
os.Remove(filePath)
logger.Error("Failed to create attachment record", "error", err)
return nil, errors.New("failed to save file info")
}
logger.Info("File uploaded", "attachment_id", attachment.ID, "user_id", userID, "chat_id", chatID)
return attachment, nil
}
// GetAttachment возвращает информацию о вложении
func (s *FileService) GetAttachment(ctx context.Context, attachmentID, userID int64) (*models.Attachment, error) {
attachment, err := s.attachmentRepo.FindByID(ctx, attachmentID)
if err != nil {
return nil, err
}
if attachment == nil {
return nil, errors.New("attachment not found")
}
// Проверяем доступ пользователя к файлу
// Находим сообщение, к которому прикреплен файл
if attachment.MessageID != nil {
// Здесь нужно проверить, что пользователь участник чата
// Для этого нужен доступ к messageRepo, но пока пропустим
// В реальном коде нужно добавить проверку
}
return attachment, nil
}
// DownloadFile возвращает файл для скачивания
func (s *FileService) DownloadFile(ctx context.Context, attachmentID, userID int64) (string, *os.File, string, error) {
attachment, err := s.GetAttachment(ctx, attachmentID, userID)
if err != nil {
return "", nil, "", err
}
// Открываем файл
file, err := os.Open(attachment.StoragePath)
if err != nil {
logger.Error("Failed to open file", "error", err)
return "", nil, "", errors.New("file not found")
}
return attachment.FileName, file, attachment.MimeType, nil
}
// DeleteAttachment удаляет вложение (только если оно не привязано к сообщению)
func (s *FileService) DeleteAttachment(ctx context.Context, attachmentID, userID int64, isGlobalAdmin bool) error {
attachment, err := s.attachmentRepo.FindByID(ctx, attachmentID)
if err != nil {
return err
}
if attachment == nil {
return errors.New("attachment not found")
}
// Нельзя удалить файл, который привязан к сообщению
if attachment.MessageID != nil && !isGlobalAdmin {
return errors.New("cannot delete attachment that is linked to a message")
}
// Удаляем файл с диска
if err := os.Remove(attachment.StoragePath); err != nil {
logger.Error("Failed to delete file", "error", err)
// Не возвращаем ошибку, продолжаем удаление из БД
}
// Удаляем запись из БД
if err := s.attachmentRepo.Delete(ctx, attachmentID); err != nil {
return err
}
logger.Info("Attachment deleted", "attachment_id", attachmentID, "user_id", userID)
return nil
}
// GetMaxFileSize возвращает максимальный размер файла в байтах
func (s *FileService) GetMaxFileSize() int64 {
return s.maxFileSize
}
// GetMaxFileSizeMB возвращает максимальный размер файла в мегабайтах
func (s *FileService) GetMaxFileSizeMB() int64 {
return s.maxFileSize / (1024 * 1024)
}

View File

@@ -0,0 +1,278 @@
package service
import (
"context"
"errors"
"fmt"
"messenger/internal/crypto"
"messenger/internal/models"
"messenger/internal/pkg/logger"
"messenger/internal/repository"
"time"
)
type MessageService struct {
messageRepo repository.MessageRepository
chatRepo repository.ChatRepository
userRepo repository.UserRepository
attachmentRepo repository.AttachmentRepository
encryptor *crypto.Encryptor
}
func NewMessageService(
messageRepo repository.MessageRepository,
chatRepo repository.ChatRepository,
userRepo repository.UserRepository,
attachmentRepo repository.AttachmentRepository,
encryptor *crypto.Encryptor,
) *MessageService {
return &MessageService{
messageRepo: messageRepo,
chatRepo: chatRepo,
userRepo: userRepo,
attachmentRepo: attachmentRepo,
encryptor: encryptor,
}
}
// SendMessage отправляет новое сообщение в чат
func (s *MessageService) SendMessage(ctx context.Context, senderID, chatID int64, plaintext string, attachmentID *int64) (*models.MessageResponse, error) {
// Проверяем, что пользователь участник чата
isMember, err := s.chatRepo.IsMember(ctx, chatID, senderID)
if err != nil {
return nil, fmt.Errorf("failed to check membership: %w", err)
}
if !isMember {
return nil, errors.New("you are not a member of this chat")
}
// Шифруем сообщение
encryptedBody, err := s.encryptor.EncryptString(plaintext)
if err != nil {
logger.Error("Failed to encrypt message", "error", err)
return nil, errors.New("failed to encrypt message")
}
// Создаем сообщение
message := &models.Message{
ChatID: chatID,
SenderID: senderID,
EncryptedBody: []byte(encryptedBody),
AttachmentID: attachmentID,
IsRead: false,
CreatedAt: time.Now(),
}
if err := s.messageRepo.Create(ctx, message); err != nil {
logger.Error("Failed to create message", "error", err)
return nil, errors.New("failed to send message")
}
// Обновляем last_seen пользователя
go func() {
_ = s.userRepo.UpdateLastSeen(context.Background(), senderID)
}()
// Возвращаем расшифрованное сообщение
return &models.MessageResponse{
ID: message.ID,
ChatID: message.ChatID,
SenderID: message.SenderID,
Plaintext: plaintext,
IsRead: message.IsRead,
CreatedAt: message.CreatedAt,
}, nil
}
// GetMessageByID возвращает сообщение по ID с расшифровкой
func (s *MessageService) GetMessageByID(ctx context.Context, messageID int64) (*models.MessageResponse, error) {
message, err := s.messageRepo.FindByID(ctx, messageID)
if err != nil {
return nil, err
}
if message == nil {
return nil, errors.New("message not found")
}
// Расшифровываем
plaintext, err := s.encryptor.DecryptString(string(message.EncryptedBody))
if err != nil {
logger.Error("Failed to decrypt message", "error", err)
return nil, errors.New("failed to decrypt message")
}
response := &models.MessageResponse{
ID: message.ID,
ChatID: message.ChatID,
SenderID: message.SenderID,
Plaintext: plaintext,
IsRead: message.IsRead,
CreatedAt: message.CreatedAt,
}
// Добавляем информацию о вложении, если есть
if message.AttachmentID != nil {
attachment, err := s.attachmentRepo.FindByID(ctx, *message.AttachmentID)
if err == nil && attachment != nil {
response.Attachment = &models.Attachment{
ID: attachment.ID,
FileName: attachment.FileName,
FileSize: attachment.FileSize,
MimeType: attachment.MimeType,
UploadedAt: attachment.UploadedAt,
}
}
}
return response, nil
}
// GetChatHistory возвращает историю сообщений чата
func (s *MessageService) GetChatHistory(ctx context.Context, chatID, userID int64, limit int, before time.Time) ([]*models.MessageResponse, error) {
// Проверяем доступ к чату
isMember, err := s.chatRepo.IsMember(ctx, chatID, userID)
if err != nil {
return nil, err
}
if !isMember {
return nil, errors.New("access denied")
}
if limit <= 0 || limit > 100 {
limit = 50
}
messages, err := s.messageRepo.GetChatHistory(ctx, chatID, limit, before)
if err != nil {
logger.Error("Failed to get chat history", "error", err)
return nil, errors.New("failed to get messages")
}
// Расшифровываем сообщения
responses := make([]*models.MessageResponse, 0, len(messages))
for _, msg := range messages {
plaintext, err := s.encryptor.DecryptString(string(msg.EncryptedBody))
if err != nil {
logger.Error("Failed to decrypt message", "message_id", msg.ID, "error", err)
continue
}
response := &models.MessageResponse{
ID: msg.ID,
ChatID: msg.ChatID,
SenderID: msg.SenderID,
Plaintext: plaintext,
IsRead: msg.IsRead,
CreatedAt: msg.CreatedAt,
}
responses = append(responses, response)
}
return responses, nil
}
// MarkMessageAsRead отмечает сообщение как прочитанное
func (s *MessageService) MarkMessageAsRead(ctx context.Context, messageID, userID int64) error {
message, err := s.messageRepo.FindByID(ctx, messageID)
if err != nil {
return err
}
if message == nil {
return errors.New("message not found")
}
// Нельзя отметить своё сообщение как прочитанное
if message.SenderID == userID {
return nil
}
// Проверяем, что пользователь участник чата
isMember, err := s.chatRepo.IsMember(ctx, message.ChatID, userID)
if err != nil {
return err
}
if !isMember {
return errors.New("access denied")
}
return s.messageRepo.MarkAsRead(ctx, messageID)
}
// EditMessage редактирует существующее сообщение
func (s *MessageService) EditMessage(ctx context.Context, userID, messageID int64, newPlaintext string) error {
message, err := s.messageRepo.FindByID(ctx, messageID)
if err != nil {
return err
}
if message == nil {
return errors.New("message not found")
}
// Только автор может редактировать
if message.SenderID != userID {
return errors.New("only the author can edit the message")
}
// Шифруем новое содержимое
encryptedBody, err := s.encryptor.EncryptString(newPlaintext)
if err != nil {
logger.Error("Failed to encrypt edited message", "error", err)
return errors.New("failed to encrypt message")
}
message.EncryptedBody = []byte(encryptedBody)
return s.messageRepo.Update(ctx, message)
}
// DeleteMessage удаляет сообщение
func (s *MessageService) DeleteMessage(ctx context.Context, messageID int64) error {
return s.messageRepo.Delete(ctx, messageID)
}
// CanDeleteMessage проверяет, может ли пользователь удалить сообщение
func (s *MessageService) CanDeleteMessage(ctx context.Context, userID, messageID int64) (bool, error) {
message, err := s.messageRepo.FindByID(ctx, messageID)
if err != nil {
return false, err
}
if message == nil {
return false, errors.New("message not found")
}
// Автор может удалить своё сообщение
if message.SenderID == userID {
return true, nil
}
// Проверяем, является ли пользователь глобальным админом
user, err := s.userRepo.FindByID(ctx, userID)
if err != nil {
return false, err
}
if user != nil && user.IsGlobalAdmin() {
return true, nil
}
// Проверяем, является ли пользователь админом чата
role, err := s.chatRepo.GetMemberRole(ctx, message.ChatID, userID)
if err != nil {
return false, err
}
return role != nil && *role == models.MemberRoleAdmin, nil
}
// GetUnreadCount возвращает количество непрочитанных сообщений в чате
func (s *MessageService) GetUnreadCount(ctx context.Context, chatID, userID int64) (int, error) {
return s.messageRepo.GetUnreadCount(ctx, chatID, userID)
}

View File

@@ -0,0 +1,125 @@
package service
import (
"context"
"errors"
"messenger/internal/models"
"messenger/internal/pkg/logger"
"messenger/internal/pkg/validator"
"messenger/internal/repository"
)
type UserService struct {
userRepo repository.UserRepository
profileRepo repository.ProfileRepository
}
func NewUserService(userRepo repository.UserRepository, profileRepo repository.ProfileRepository) *UserService {
return &UserService{
userRepo: userRepo,
profileRepo: profileRepo,
}
}
// GetProfile возвращает профиль пользователя
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*models.ProfileWithUser, error) {
user, err := s.userRepo.FindByID(ctx, userID)
if err != nil {
logger.Error("Failed to get user", "error", err)
return nil, errors.New("internal server error")
}
if user == nil {
return nil, errors.New("user not found")
}
profile, err := s.profileRepo.FindByUserID(ctx, userID)
if err != nil {
logger.Error("Failed to get profile", "error", err)
// Профиль может отсутствовать - не ошибка
profile = &models.Profile{UserID: userID}
}
return &models.ProfileWithUser{
User: user.ToSafe(),
DisplayName: profile.DisplayName,
Bio: profile.Bio,
AvatarURL: profile.AvatarURL,
}, nil
}
// UpdateProfile обновляет профиль пользователя
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, displayName, bio *string) error {
// Валидация
if displayName != nil && !validator.ValidateDisplayName(*displayName) {
return errors.New("display name too long (max 100 characters)")
}
if bio != nil && !validator.ValidateBio(*bio) {
return errors.New("bio too long (max 500 characters)")
}
profile := &models.Profile{
UserID: userID,
DisplayName: displayName,
Bio: bio,
}
if err := s.profileRepo.Update(ctx, profile); err != nil {
logger.Error("Failed to update profile", "error", err)
return errors.New("failed to update profile")
}
logger.Info("Profile updated", "user_id", userID)
return nil
}
// UpdateAvatar обновляет аватар пользователя
func (s *UserService) UpdateAvatar(ctx context.Context, userID int64, avatarURL *string) error {
if err := s.profileRepo.UpdateAvatar(ctx, userID, avatarURL); err != nil {
logger.Error("Failed to update avatar", "error", err)
return errors.New("failed to update avatar")
}
logger.Info("Avatar updated", "user_id", userID)
return nil
}
// SearchUsers ищет пользователей по логину
func (s *UserService) SearchUsers(ctx context.Context, query string, limit int) ([]*models.SafeUser, error) {
if limit <= 0 || limit > 100 {
limit = 20
}
users, err := s.userRepo.SearchByLogin(ctx, query, limit)
if err != nil {
logger.Error("Failed to search users", "error", err)
return nil, errors.New("internal server error")
}
safeUsers := make([]*models.SafeUser, len(users))
for i, user := range users {
safeUsers[i] = user.ToSafe()
}
return safeUsers, nil
}
// GetUserByID возвращает безопасное представление пользователя
func (s *UserService) GetUserByID(ctx context.Context, userID int64) (*models.SafeUser, error) {
user, err := s.userRepo.FindByID(ctx, userID)
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.New("user not found")
}
return user.ToSafe(), nil
}
// GetUserByLogin возвращает пользователя по логину
func (s *UserService) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
return s.userRepo.FindByLogin(ctx, login)
}

View File

@@ -0,0 +1,305 @@
package websocket
import (
"encoding/json"
"log"
"messenger/internal/models"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
maxMessageSize = 512 * 1024
)
type Client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
user *models.User
userID int64
rooms map[int64]bool
mu sync.RWMutex
lastPing time.Time
}
func NewClient(hub *Hub, conn *websocket.Conn, user *models.User) *Client {
return &Client{
hub: hub,
conn: conn,
send: make(chan []byte, 256),
user: user,
userID: user.ID,
rooms: make(map[int64]bool),
lastPing: time.Now(),
}
}
func (c *Client) ReadPump() {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered in ReadPump: %v", r)
}
c.hub.GetUnregisterChan() <- c
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.mu.Lock()
c.lastPing = time.Now()
c.mu.Unlock()
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
break
}
log.Printf("Received message from client %d: %s", c.userID, string(message))
c.handleMessage(message)
}
}
func (c *Client) WritePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered in WritePump: %v", r)
}
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func (c *Client) handleMessage(raw []byte) {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered in handleMessage: %v", r)
}
}()
msg, err := ParseMessage(raw)
if err != nil {
c.sendError(400, "Invalid message format")
return
}
switch msg.Type {
case MsgTypeNewMessage:
c.handleNewMessage(msg.Data)
case MsgTypeTyping:
c.handleTyping(msg.Data)
case MsgTypeReadReceipt:
c.handleReadReceipt(msg.Data)
case MsgTypeEditMessage:
c.handleEditMessage(msg.Data)
case MsgTypeDeleteMessage:
c.handleDeleteMessage(msg.Data)
case MsgTypePing:
c.handlePing()
default:
c.sendError(400, "Unknown message type")
}
}
func (c *Client) handleNewMessage(data []byte) {
var req NewMessageRequest
if err := json.Unmarshal(data, &req); err != nil {
log.Printf("Failed to parse new_message: %v", err)
c.sendError(400, "Invalid new_message data")
return
}
log.Printf("New message from user %d to chat %d: %s", c.userID, req.ChatID, req.Plaintext)
if !c.isMemberOfChat(req.ChatID) {
log.Printf("User %d is not a member of chat %d", c.userID, req.ChatID)
c.sendError(403, "Not a member of this chat")
return
}
c.hub.GetBroadcastChan() <- &BroadcastMessage{
Type: MsgTypeNewMessage,
ChatID: req.ChatID,
SenderID: c.userID,
Data: req,
Client: c,
}
}
func (c *Client) handleTyping(data []byte) {
var req TypingRequest
if err := json.Unmarshal(data, &req); err != nil {
return
}
resp := TypingResponse{
ChatID: req.ChatID,
UserID: c.userID,
IsTyping: req.IsTyping,
}
msgBytes, _ := CreateMessage(MsgTypeUserTyping, resp)
c.hub.roomsMu.RLock()
room, exists := c.hub.rooms[req.ChatID]
c.hub.roomsMu.RUnlock()
if exists {
room.Broadcast(msgBytes, c)
}
}
func (c *Client) handleReadReceipt(data []byte) {
var req ReadReceiptRequest
if err := json.Unmarshal(data, &req); err != nil {
return
}
c.hub.broadcast <- &BroadcastMessage{
Type: MsgTypeReadReceipt,
MessageID: req.MessageID,
UserID: c.userID,
}
}
func (c *Client) handleEditMessage(data []byte) {
var req EditMessageRequest
if err := json.Unmarshal(data, &req); err != nil {
c.sendError(400, "Invalid edit_message data")
return
}
c.hub.broadcast <- &BroadcastMessage{
Type: MsgTypeEditMessage,
MessageID: req.MessageID,
SenderID: c.userID,
Data: req,
}
}
func (c *Client) handleDeleteMessage(data []byte) {
var req DeleteMessageRequest
if err := json.Unmarshal(data, &req); err != nil {
c.sendError(400, "Invalid delete_message data")
return
}
c.hub.broadcast <- &BroadcastMessage{
Type: MsgTypeDeleteMessage,
MessageID: req.MessageID,
SenderID: c.userID,
}
}
func (c *Client) handlePing() {
c.mu.Lock()
c.lastPing = time.Now()
c.mu.Unlock()
pongMsg, _ := CreateMessage(MsgTypePong, nil)
select {
case c.send <- pongMsg:
default:
}
}
func (c *Client) sendError(code int, message string) {
errResp := ErrorResponse{
Code: code,
Message: message,
}
msgBytes, _ := CreateMessage(MsgTypeError, errResp)
select {
case c.send <- msgBytes:
default:
}
}
func (c *Client) isMemberOfChat(chatID int64) bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.rooms[chatID]
}
func (c *Client) JoinRoom(chatID int64) {
c.mu.Lock()
defer c.mu.Unlock()
c.rooms[chatID] = true
}
func (c *Client) LeaveRoom(chatID int64) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.rooms, chatID)
}
func (c *Client) GetRooms() []int64 {
c.mu.RLock()
defer c.mu.RUnlock()
rooms := make([]int64, 0, len(c.rooms))
for chatID := range c.rooms {
rooms = append(rooms, chatID)
}
return rooms
}
func (c *Client) SendMessage(msg []byte) {
select {
case c.send <- msg:
default:
log.Printf("Client %d send channel full", c.userID)
c.hub.GetUnregisterChan() <- c
c.conn.Close()
}
}
func (c *Client) GetUserID() int64 {
return c.userID
}
func (c *Client) GetUser() *models.User {
return c.user
}
func (c *Client) IsAlive() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return time.Since(c.lastPing) < pongWait
}

414
internal/websocket/hub.go Normal file
View File

@@ -0,0 +1,414 @@
package websocket
import (
"context"
"messenger/internal/pkg/logger"
"messenger/internal/service"
"sync"
"time"
)
type BroadcastMessage struct {
Type MessageType
ChatID int64
MessageID int64
UserID int64
SenderID int64
Data interface{}
Client *Client
}
type Hub struct {
register chan *Client
unregister chan *Client
broadcast chan *BroadcastMessage
clients map[int64]*Client
clientsMu sync.RWMutex
rooms map[int64]*Room
roomsMu sync.RWMutex
messageService *service.MessageService
chatService *service.ChatService
running bool
mu sync.RWMutex
}
func NewHub(messageService *service.MessageService, chatService *service.ChatService) *Hub {
return &Hub{
register: make(chan *Client),
unregister: make(chan *Client),
broadcast: make(chan *BroadcastMessage, 256),
clients: make(map[int64]*Client),
rooms: make(map[int64]*Room),
messageService: messageService,
chatService: chatService,
running: true,
}
}
func (h *Hub) Run() {
logger.Info("WebSocket Hub started")
for h.running {
select {
case client := <-h.register:
h.handleRegister(client)
case client := <-h.unregister:
h.handleUnregister(client)
case broadcast := <-h.broadcast:
h.handleBroadcast(broadcast)
}
}
logger.Info("WebSocket Hub stopped")
}
func (h *Hub) Stop() {
h.mu.Lock()
defer h.mu.Unlock()
h.running = false
close(h.register)
close(h.unregister)
close(h.broadcast)
h.clientsMu.Lock()
for _, client := range h.clients {
func() {
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered while closing client", "error", r)
}
}()
close(client.send)
client.conn.Close()
}()
}
h.clients = make(map[int64]*Client)
h.clientsMu.Unlock()
}
func (h *Hub) handleRegister(client *Client) {
h.clientsMu.Lock()
if oldClient, exists := h.clients[client.userID]; exists {
func() {
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered while closing old client", "error", r)
}
}()
close(oldClient.send)
oldClient.conn.Close()
}()
}
h.clients[client.userID] = client
h.clientsMu.Unlock()
go h.addClientToChats(client)
go h.notifyUserOnline(client.userID, true)
logger.Info("Client registered", "user_id", client.userID)
}
func (h *Hub) handleUnregister(client *Client) {
h.clientsMu.Lock()
delete(h.clients, client.userID)
h.clientsMu.Unlock()
// Закрываем канал безопасно
func() {
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered while closing send channel", "error", r)
}
}()
close(client.send)
}()
h.roomsMu.Lock()
for _, room := range h.rooms {
room.RemoveClient(client)
}
h.roomsMu.Unlock()
go h.notifyUserOnline(client.userID, false)
logger.Info("Client unregistered", "user_id", client.userID)
}
func (h *Hub) handleBroadcast(broadcast *BroadcastMessage) {
defer func() {
if r := recover(); r != nil {
logger.Error("Panic in handleBroadcast", "error", r)
}
}()
switch broadcast.Type {
case MsgTypeNewMessage:
h.handleNewMessageBroadcast(broadcast)
case MsgTypeReadReceipt:
h.handleReadReceiptBroadcast(broadcast)
case MsgTypeEditMessage:
h.handleEditMessageBroadcast(broadcast)
case MsgTypeDeleteMessage:
h.handleDeleteMessageBroadcast(broadcast)
}
}
func (h *Hub) handleNewMessageBroadcast(broadcast *BroadcastMessage) {
req, ok := broadcast.Data.(NewMessageRequest)
if !ok {
logger.Error("Invalid broadcast data for new_message")
return
}
ctx := context.Background()
message, err := h.messageService.SendMessage(ctx, broadcast.Client.userID, req.ChatID, req.Plaintext, req.AttachmentID)
if err != nil {
logger.Error("Failed to save message", "error", err)
broadcast.Client.sendError(500, "Failed to send message")
return
}
response := NewMessageResponse{
ID: message.ID,
ChatID: message.ChatID,
SenderID: message.SenderID,
Plaintext: message.Plaintext,
CreatedAt: message.CreatedAt,
TempID: req.TempID,
}
msgBytes, err := CreateMessage(MsgTypeNewMessageResp, response)
if err != nil {
logger.Error("Failed to create message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[req.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) handleReadReceiptBroadcast(broadcast *BroadcastMessage) {
ctx := context.Background()
err := h.messageService.MarkMessageAsRead(ctx, broadcast.MessageID, broadcast.UserID)
if err != nil {
logger.Error("Failed to mark message as read", "error", err)
return
}
message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID)
if err != nil {
logger.Error("Failed to get message for read receipt", "error", err)
return
}
response := ReadReceiptResponse{
MessageID: broadcast.MessageID,
UserID: broadcast.UserID,
ReadAt: time.Now(),
}
msgBytes, err := CreateMessage(MsgTypeMessageRead, response)
if err != nil {
logger.Error("Failed to create read receipt message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[message.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) handleEditMessageBroadcast(broadcast *BroadcastMessage) {
req, ok := broadcast.Data.(EditMessageRequest)
if !ok {
logger.Error("Invalid broadcast data for edit_message")
return
}
ctx := context.Background()
err := h.messageService.EditMessage(ctx, broadcast.SenderID, req.MessageID, req.Plaintext)
if err != nil {
logger.Error("Failed to edit message", "error", err)
if broadcast.Client != nil {
broadcast.Client.sendError(403, "Failed to edit message")
}
return
}
message, err := h.messageService.GetMessageByID(ctx, req.MessageID)
if err != nil {
logger.Error("Failed to get edited message", "error", err)
return
}
response := EditMessageResponse{
MessageID: req.MessageID,
NewPlaintext: req.Plaintext,
EditedAt: time.Now(),
}
msgBytes, err := CreateMessage(MsgTypeMessageEdited, response)
if err != nil {
logger.Error("Failed to create edit message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[message.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) handleDeleteMessageBroadcast(broadcast *BroadcastMessage) {
ctx := context.Background()
message, err := h.messageService.GetMessageByID(ctx, broadcast.MessageID)
if err != nil {
logger.Error("Failed to get message for deletion", "error", err)
return
}
canDelete, err := h.messageService.CanDeleteMessage(ctx, broadcast.SenderID, broadcast.MessageID)
if err != nil || !canDelete {
logger.Error("User cannot delete message", "user_id", broadcast.SenderID, "message_id", broadcast.MessageID)
if broadcast.Client != nil {
broadcast.Client.sendError(403, "Cannot delete this message")
}
return
}
err = h.messageService.DeleteMessage(ctx, broadcast.MessageID)
if err != nil {
logger.Error("Failed to delete message", "error", err)
return
}
response := DeleteMessageResponse{
MessageID: broadcast.MessageID,
DeletedAt: time.Now(),
}
msgBytes, err := CreateMessage(MsgTypeMessageDeleted, response)
if err != nil {
logger.Error("Failed to create delete message", "error", err)
return
}
h.roomsMu.RLock()
room, exists := h.rooms[message.ChatID]
h.roomsMu.RUnlock()
if exists {
room.BroadcastToAll(msgBytes)
}
}
func (h *Hub) addClientToChats(client *Client) {
ctx := context.Background()
chats, err := h.chatService.GetUserChats(ctx, client.userID)
if err != nil {
logger.Error("Failed to get user chats", "user_id", client.userID, "error", err)
return
}
for _, chat := range chats {
h.roomsMu.Lock()
room, exists := h.rooms[chat.ID]
if !exists {
room = NewRoom(chat.ID)
h.rooms[chat.ID] = room
}
h.roomsMu.Unlock()
room.AddClient(client)
}
logger.Info("Client added to rooms", "user_id", client.userID, "room_count", len(chats))
}
func (h *Hub) notifyUserOnline(userID int64, isOnline bool) {
// Защита от паники
defer func() {
if r := recover(); r != nil {
logger.Warn("Recovered in notifyUserOnline", "error", r)
}
}()
response := UserOnlineResponse{
UserID: userID,
IsOnline: isOnline,
}
msgBytes, err := CreateMessage(MsgTypeUserOnline, response)
if err != nil {
logger.Error("Failed to create user online message", "error", err)
return
}
h.roomsMu.RLock()
defer h.roomsMu.RUnlock()
for _, room := range h.rooms {
if room.HasClientByUserID(userID) {
room.BroadcastToAll(msgBytes)
}
}
}
func (h *Hub) GetRoom(chatID int64) (*Room, bool) {
h.roomsMu.RLock()
defer h.roomsMu.RUnlock()
room, exists := h.rooms[chatID]
return room, exists
}
func (h *Hub) GetClient(userID int64) (*Client, bool) {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
client, exists := h.clients[userID]
return client, exists
}
func (h *Hub) SendToUser(userID int64, message []byte) bool {
h.clientsMu.RLock()
client, exists := h.clients[userID]
h.clientsMu.RUnlock()
if exists {
client.SendMessage(message)
return true
}
return false
}
// GetRegisterChan возвращает канал для регистрации клиентов
func (h *Hub) GetRegisterChan() chan<- *Client {
return h.register
}
// GetUnregisterChan возвращает канал для отмены регистрации клиентов
func (h *Hub) GetUnregisterChan() chan<- *Client {
return h.unregister
}
// GetBroadcastChan возвращает канал для широковещательных сообщений
func (h *Hub) GetBroadcastChan() chan<- *BroadcastMessage {
return h.broadcast
}

View File

@@ -0,0 +1,136 @@
package websocket
import (
"encoding/json"
"time"
)
type MessageType string
const (
MsgTypeNewMessage MessageType = "new_message"
MsgTypeTyping MessageType = "typing"
MsgTypeReadReceipt MessageType = "read_receipt"
MsgTypeEditMessage MessageType = "edit_message"
MsgTypeDeleteMessage MessageType = "delete_message"
MsgTypePing MessageType = "ping"
MsgTypeNewMessageResp MessageType = "new_message"
MsgTypeUserTyping MessageType = "user_typing"
MsgTypeMessageRead MessageType = "message_read"
MsgTypeMessageEdited MessageType = "message_edited"
MsgTypeMessageDeleted MessageType = "message_deleted"
MsgTypeUserOnline MessageType = "user_online"
MsgTypeUserOffline MessageType = "user_offline"
MsgTypePong MessageType = "pong"
MsgTypeError MessageType = "error"
)
type WSMessage struct {
Type MessageType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
type NewMessageRequest struct {
ChatID int64 `json:"chat_id"`
Plaintext string `json:"plaintext"`
AttachmentID *int64 `json:"attachment_id,omitempty"`
TempID string `json:"temp_id,omitempty"`
}
type NewMessageResponse struct {
ID int64 `json:"id"`
ChatID int64 `json:"chat_id"`
SenderID int64 `json:"sender_id"`
Plaintext string `json:"plaintext"`
Attachment *AttachmentInfo `json:"attachment,omitempty"`
CreatedAt time.Time `json:"created_at"`
TempID string `json:"temp_id,omitempty"`
}
type AttachmentInfo struct {
ID int64 `json:"id"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
MimeType string `json:"mime_type"`
}
type TypingRequest struct {
ChatID int64 `json:"chat_id"`
IsTyping bool `json:"is_typing"`
}
type TypingResponse struct {
ChatID int64 `json:"chat_id"`
UserID int64 `json:"user_id"`
IsTyping bool `json:"is_typing"`
}
type ReadReceiptRequest struct {
MessageID int64 `json:"message_id"`
}
type ReadReceiptResponse struct {
MessageID int64 `json:"message_id"`
UserID int64 `json:"user_id"`
ReadAt time.Time `json:"read_at"`
}
type EditMessageRequest struct {
MessageID int64 `json:"message_id"`
Plaintext string `json:"plaintext"`
}
type EditMessageResponse struct {
MessageID int64 `json:"message_id"`
NewPlaintext string `json:"new_plaintext"`
EditedAt time.Time `json:"edited_at"`
}
type DeleteMessageRequest struct {
MessageID int64 `json:"message_id"`
}
type DeleteMessageResponse struct {
MessageID int64 `json:"message_id"`
DeletedAt time.Time `json:"deleted_at"`
}
type UserOnlineResponse struct {
UserID int64 `json:"user_id"`
IsOnline bool `json:"is_online"`
LastSeen *time.Time `json:"last_seen,omitempty"`
}
type ErrorResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}
func ParseMessage(raw []byte) (*WSMessage, error) {
var msg WSMessage
if err := json.Unmarshal(raw, &msg); err != nil {
return nil, err
}
return &msg, nil
}
func CreateMessage(msgType MessageType, data interface{}) ([]byte, error) {
var dataBytes json.RawMessage
if data != nil {
bytes, err := json.Marshal(data)
if err != nil {
return nil, err
}
dataBytes = bytes
}
msg := WSMessage{
Type: msgType,
Data: dataBytes,
Timestamp: time.Now(),
}
return json.Marshal(msg)
}

118
internal/websocket/room.go Normal file
View File

@@ -0,0 +1,118 @@
package websocket
import (
"sync"
)
type Room struct {
ID int64
Clients map[*Client]bool
mu sync.RWMutex
}
func NewRoom(chatID int64) *Room {
return &Room{
ID: chatID,
Clients: make(map[*Client]bool),
}
}
func (r *Room) AddClient(client *Client) {
r.mu.Lock()
defer r.mu.Unlock()
r.Clients[client] = true
client.JoinRoom(r.ID)
}
func (r *Room) RemoveClient(client *Client) {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.Clients[client]; exists {
delete(r.Clients, client)
client.LeaveRoom(r.ID)
}
}
func (r *Room) Broadcast(message []byte, excludeClient *Client) {
r.mu.RLock()
defer r.mu.RUnlock()
for client := range r.Clients {
if excludeClient != nil && client == excludeClient {
continue
}
// Безопасная отправка с проверкой закрытого канала
func() {
defer func() {
if r := recover(); r != nil {
// Канал закрыт, игнорируем
}
}()
select {
case client.send <- message:
default:
// Канал заполнен или закрыт
}
}()
}
}
func (r *Room) BroadcastToAll(message []byte) {
r.Broadcast(message, nil)
}
func (r *Room) GetClientCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.Clients)
}
func (r *Room) GetClients() []*Client {
r.mu.RLock()
defer r.mu.RUnlock()
clients := make([]*Client, 0, len(r.Clients))
for client := range r.Clients {
clients = append(clients, client)
}
return clients
}
func (r *Room) IsEmpty() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.Clients) == 0
}
func (r *Room) HasClient(client *Client) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.Clients[client]
return exists
}
func (r *Room) GetUserIDs() []int64 {
r.mu.RLock()
defer r.mu.RUnlock()
userIDs := make([]int64, 0, len(r.Clients))
for client := range r.Clients {
userIDs = append(userIDs, client.userID)
}
return userIDs
}
func (r *Room) HasClientByUserID(userID int64) bool {
r.mu.RLock()
defer r.mu.RUnlock()
for client := range r.Clients {
if client.userID == userID {
return true
}
}
return false
}