use axum::{ Extension, Json, Router, extract::{Path, Request}, http::{HeaderMap, StatusCode, header}, middleware::Next, response::Response, routing::{get, post, put}, }; use sqlx::PgPool; use std::sync::Arc; use uuid::Uuid; use validator::ValidateEmail; use crate::{ AppConfig, MAX_UPLOAD_SIZE, MAX_USERNAME_LENGTH, auth::{create_jwt, hash_password, validate_token, verify_jwt, verify_password}, db::{user_id_from_uuid, username_from_uuid}, errors::APIError, }; #[derive(sqlx::FromRow, serde::Serialize)] pub struct User { pub uuid: Uuid, pub username: String, pub password_hash: String, pub email: String, } #[derive(sqlx::FromRow, serde::Serialize)] pub struct UserProfile { pub uuid: Uuid, pub username: String, } #[derive(serde::Deserialize)] pub struct LoginPayload { pub email: String, pub password: String, } #[derive(serde::Serialize)] pub struct LoginResponse { pub uuid: Uuid, pub username: String, pub email: String, pub token: String, } #[derive(serde::Deserialize)] pub struct NewUserPayload { pub email: String, pub username: String, pub password: String, } #[derive(serde::Deserialize)] pub struct UpdateUserPayoad { pub email: String, pub username: String, pub password: String, } #[derive(serde::Serialize)] pub struct UpdateUserResponse { pub email: String, pub username: String, } pub fn routes() -> Router { Router::new() .route("/login", post(login)) .route("/register", post(register_user)) .route("/validate-token", get(validate_token)) .route("/account/settings", put(update_user)) .route("/account/upload-avatar", post(upload_avatar)) .route("/account/get-avatar/{uuid}", get(get_avatar)) .layer(axum::middleware::from_fn(registration_guard)) } async fn registration_guard( Extension(config): Extension>, req: Request, next: Next, ) -> Result { if req.uri().path() == "/register" && config.prohibit_registration { return Err(StatusCode::FORBIDDEN); } Ok(next.run(req).await) } pub async fn login( Extension(db): Extension, Json(payload): Json, ) -> Result, APIError> { const DUMMY_HASH: &str = "$argon2id$v=19$m=4096,t=3,p=1$YWFhYWFhYWFhYWFhYWFhYQ$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; let user = sqlx::query_as::<_, User>( "SELECT uuid, email, username, password_hash FROM user_ WHERE email = $1", ) .bind(&payload.email) .fetch_optional(&db) .await?; let (user_uuid, password_hash) = if let Some(u) = user { (u.uuid, u.password_hash) } else { // timing shield (uuid::Uuid::now_v7(), DUMMY_HASH.to_string()) }; if !verify_password(&password_hash, &payload.password) { return Err(APIError::WrongCredentials); } let token = create_jwt(user_uuid)?; let username = username_from_uuid(&db, user_uuid).await?; Ok(Json(LoginResponse { uuid: user_uuid, username, email: payload.email, token, })) } pub async fn register_user( Extension(db): Extension, Json(payload): Json, ) -> Result<(StatusCode, Json), APIError> { if payload.email.is_empty() || payload.username.is_empty() || payload.password.is_empty() { return Err(APIError::EmptyFields); } if !ValidateEmail::validate_email(&payload.email) { return Err(APIError::InvalidEmail); } { let username_length = payload.username.len(); if username_length > MAX_USERNAME_LENGTH || username_length < 1 { return Err(APIError::UsernameLength); } } if payload.password.len() < 8 { return Err(APIError::PasswordTooShort); } let password_hash = hash_password(&payload.password)?; let user_uuid = uuid::Uuid::now_v7(); sqlx::query( "INSERT INTO user_ (uuid, username, email, password_hash) VALUES ($1, $2, $3, $4)", ) .bind(user_uuid) .bind(&payload.username) .bind(&payload.email) .bind(&password_hash) .execute(&db) .await .map_err(|e| { if let Some(db_err) = e.as_database_error() { if db_err.code().map(|c| c == "23505").unwrap_or(false) { match db_err.constraint() { Some("user__username_key") => return APIError::UsernameTaken, Some("user__email_key") => return APIError::EmailTaken, _ => return APIError::Internal("".to_string()), // TODO: handle this case } } } APIError::DatabaseError(e) })?; let token = create_jwt(user_uuid)?; Ok(( StatusCode::CREATED, Json(LoginResponse { uuid: user_uuid, username: payload.username, email: payload.email, token, }), )) } pub async fn update_user( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, ) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; if payload.email.is_empty() || payload.username.is_empty() { return Err(APIError::EmptyFields); } if !ValidateEmail::validate_email(&payload.email) { return Err(APIError::InvalidEmail); } let user_id = user_id_from_uuid(&db, claims.sub).await?; let mut tx = db.begin().await?; if !payload.password.is_empty() { if payload.password.len() < 8 { return Err(APIError::PasswordTooShort); } let password_hash = hash_password(&payload.password)?; sqlx::query("UPDATE user_ SET password_hash = $1 WHERE id = $2") .bind(password_hash) .bind(user_id) .execute(&mut *tx) .await?; } sqlx::query("UPDATE user_ SET username = $1, email = $2 WHERE id = $3") .bind(&payload.username) .bind(&payload.email) .bind(user_id) .execute(&mut *tx) .await .map_err(|e| { if let Some(db_err) = e.as_database_error() { if db_err.code().map(|c| c == "23505").unwrap_or(false) { match db_err.constraint() { Some("user__username_key") => return APIError::UsernameTaken, Some("user__email_key") => return APIError::EmailTaken, _ => return APIError::Internal("".to_string()), // TODO: handle this case } } } APIError::DatabaseError(e) })?; tx.commit().await?; Ok(( StatusCode::CREATED, Json(UpdateUserResponse { username: payload.username, email: payload.email, }), )) } async fn upload_avatar( headers: HeaderMap, Extension(db): Extension, Extension(config): Extension>, body: axum::body::Bytes, ) -> Result { let claims = verify_jwt(headers)?; if body.len() > MAX_UPLOAD_SIZE { // TODO: FileTooLarge error return Err(APIError::WrongFileFormat); } let kind = infer::get(&body).ok_or(APIError::WrongFileFormat)?; let ("image/png" | "image/jpeg" | "image/webp") = kind.mime_type() else { return Err(APIError::WrongFileFormat); }; let user_id = user_id_from_uuid(&db, claims.sub).await?; tracing::debug!( "User ID {} is uploading {} bytes ({})", user_id, body.len(), kind.mime_type() ); let base_dir = &config.avatar_dir; let supported_extensions = ["png", "jpg", "jpeg", "webp"]; tokio::fs::create_dir_all(&base_dir) .await .map_err(|e| APIError::Internal(format!("Failed to create storage: {e}")))?; // Delete all other files for this user first for ext in supported_extensions { let old_filename = format!("{}.{}", claims.sub, ext); let old_path = std::path::Path::new(base_dir).join(&old_filename); match tokio::fs::remove_file(old_path).await { Ok(_) => tracing::debug!("Deleted old avatar: {}", old_filename), Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} Err(e) => tracing::warn!("Failed to delete old avatar {}: {}", old_filename, e), } } let file_extension = kind.extension(); let filename = format!("{}.{}", claims.sub, file_extension); let full_path = std::path::Path::new(&base_dir).join(&filename); tokio::fs::write(&full_path, body) .await .map_err(|e| APIError::Internal(format!("Failed to save file: {e}")))?; sqlx::query("UPDATE user_ SET avatar_url = $1 WHERE id = $2") .bind(filename) .bind(user_id) .execute(&db) .await?; Ok(StatusCode::OK) } // Public route async fn get_avatar( Path(uuid): Path, Extension(config): Extension>, ) -> Result { let base_dir = &config.avatar_dir; // Helper to try finding the file with allowed extensions let mut file_path = None; for ext in ["png", "jpg", "jpeg", "webp"] { let path = std::path::Path::new(&base_dir).join(format!("{}.{}", uuid, ext)); if path.exists() { file_path = Some(path); break; } } let full_path = file_path.ok_or(APIError::AvatarNotFound)?; let file_contents = tokio::fs::read(&full_path) .await .map_err(|e| APIError::Internal(format!("Could not read avatar file: {e}")))?; let mime_type = infer::get(&file_contents) .map(|k| k.mime_type()) .unwrap_or("application/octet-stream"); // Fallback Ok(Response::builder() .header(header::CONTENT_TYPE, mime_type) .body(axum::body::Body::from(file_contents)) .unwrap()) }