Files
frangipane-backend/src/routes/users.rs

348 lines
9.8 KiB
Rust

use axum::{
Extension, Json, Router,
extract::{Path, Request},
http::{HeaderMap, StatusCode, header},
middleware::Next,
response::Response,
routing::{get, post, put},
};
use sqlx::PgPool;
use std::sync::Arc;
use uuid::Uuid;
use validator::ValidateEmail;
use crate::{
AppConfig, MAX_UPLOAD_SIZE, MAX_USERNAME_LENGTH,
auth::{create_jwt, hash_password, validate_token, verify_jwt, verify_password},
db::{user_id_from_uuid, username_from_uuid},
errors::APIError,
};
#[derive(sqlx::FromRow, serde::Serialize)]
pub struct User {
pub uuid: Uuid,
pub username: String,
pub password_hash: String,
pub email: String,
}
#[derive(sqlx::FromRow, serde::Serialize)]
pub struct UserProfile {
pub uuid: Uuid,
pub username: String,
}
#[derive(serde::Deserialize)]
pub struct LoginPayload {
pub email: String,
pub password: String,
}
#[derive(serde::Serialize)]
pub struct LoginResponse {
pub uuid: Uuid,
pub username: String,
pub email: String,
pub token: String,
}
#[derive(serde::Deserialize)]
pub struct NewUserPayload {
pub email: String,
pub username: String,
pub password: String,
}
#[derive(serde::Deserialize)]
pub struct UpdateUserPayoad {
pub email: String,
pub username: String,
pub password: String,
}
#[derive(serde::Serialize)]
pub struct UpdateUserResponse {
pub email: String,
pub username: String,
}
pub fn routes() -> Router {
Router::new()
.route("/login", post(login))
.route("/register", post(register_user))
.route("/validate-token", get(validate_token))
.route("/account/settings", put(update_user))
.route("/account/upload-avatar", post(upload_avatar))
.route("/account/get-avatar/{uuid}", get(get_avatar))
.layer(axum::middleware::from_fn(registration_guard))
}
async fn registration_guard(
Extension(config): Extension<Arc<AppConfig>>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
if req.uri().path() == "/register" && config.prohibit_registration {
return Err(StatusCode::FORBIDDEN);
}
Ok(next.run(req).await)
}
pub async fn login(
Extension(db): Extension<PgPool>,
Json(payload): Json<LoginPayload>,
) -> Result<Json<LoginResponse>, APIError> {
const DUMMY_HASH: &str = "$argon2id$v=19$m=4096,t=3,p=1$YWFhYWFhYWFhYWFhYWFhYQ$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let user = sqlx::query_as::<_, User>(
"SELECT uuid, email, username, password_hash FROM user_ WHERE email = $1",
)
.bind(&payload.email)
.fetch_optional(&db)
.await?;
let (user_uuid, password_hash) = if let Some(u) = user {
(u.uuid, u.password_hash)
} else {
// timing shield
(uuid::Uuid::now_v7(), DUMMY_HASH.to_string())
};
if !verify_password(&password_hash, &payload.password) {
return Err(APIError::WrongCredentials);
}
let token = create_jwt(user_uuid)?;
let username = username_from_uuid(&db, user_uuid).await?;
Ok(Json(LoginResponse {
uuid: user_uuid,
username,
email: payload.email,
token,
}))
}
pub async fn register_user(
Extension(db): Extension<PgPool>,
Json(payload): Json<NewUserPayload>,
) -> Result<(StatusCode, Json<LoginResponse>), APIError> {
if payload.email.is_empty() || payload.username.is_empty() || payload.password.is_empty() {
return Err(APIError::EmptyFields);
}
if !ValidateEmail::validate_email(&payload.email) {
return Err(APIError::InvalidEmail);
}
{
let username_length = payload.username.len();
if username_length > MAX_USERNAME_LENGTH || username_length < 1 {
return Err(APIError::UsernameLength);
}
}
if payload.password.len() < 8 {
return Err(APIError::PasswordTooShort);
}
let password_hash = hash_password(&payload.password)?;
let user_uuid = uuid::Uuid::now_v7();
sqlx::query(
"INSERT INTO user_ (uuid, username, email, password_hash)
VALUES ($1, $2, $3, $4)",
)
.bind(user_uuid)
.bind(&payload.username)
.bind(&payload.email)
.bind(&password_hash)
.execute(&db)
.await
.map_err(|e| {
if let Some(db_err) = e.as_database_error() {
if db_err.code().map(|c| c == "23505").unwrap_or(false) {
match db_err.constraint() {
Some("user__username_key") => return APIError::UsernameTaken,
Some("user__email_key") => return APIError::EmailTaken,
_ => return APIError::Internal("".to_string()), // TODO: handle this case
}
}
}
APIError::DatabaseError(e)
})?;
let token = create_jwt(user_uuid)?;
Ok((
StatusCode::CREATED,
Json(LoginResponse {
uuid: user_uuid,
username: payload.username,
email: payload.email,
token,
}),
))
}
pub async fn update_user(
headers: HeaderMap,
Extension(db): Extension<PgPool>,
Json(payload): Json<UpdateUserPayoad>,
) -> Result<(StatusCode, Json<UpdateUserResponse>), APIError> {
let claims = verify_jwt(headers)?;
if payload.email.is_empty() || payload.username.is_empty() {
return Err(APIError::EmptyFields);
}
if !ValidateEmail::validate_email(&payload.email) {
return Err(APIError::InvalidEmail);
}
let user_id = user_id_from_uuid(&db, claims.sub).await?;
let mut tx = db.begin().await?;
if !payload.password.is_empty() {
if payload.password.len() < 8 {
return Err(APIError::PasswordTooShort);
}
let password_hash = hash_password(&payload.password)?;
sqlx::query("UPDATE user_ SET password_hash = $1 WHERE id = $2")
.bind(password_hash)
.bind(user_id)
.execute(&mut *tx)
.await?;
}
sqlx::query("UPDATE user_ SET username = $1, email = $2 WHERE id = $3")
.bind(&payload.username)
.bind(&payload.email)
.bind(user_id)
.execute(&mut *tx)
.await
.map_err(|e| {
if let Some(db_err) = e.as_database_error() {
if db_err.code().map(|c| c == "23505").unwrap_or(false) {
match db_err.constraint() {
Some("user__username_key") => return APIError::UsernameTaken,
Some("user__email_key") => return APIError::EmailTaken,
_ => return APIError::Internal("".to_string()), // TODO: handle this case
}
}
}
APIError::DatabaseError(e)
})?;
tx.commit().await?;
Ok((
StatusCode::CREATED,
Json(UpdateUserResponse {
username: payload.username,
email: payload.email,
}),
))
}
async fn upload_avatar(
headers: HeaderMap,
Extension(db): Extension<PgPool>,
Extension(config): Extension<Arc<AppConfig>>,
body: axum::body::Bytes,
) -> Result<StatusCode, APIError> {
let claims = verify_jwt(headers)?;
if body.len() > MAX_UPLOAD_SIZE {
// TODO: FileTooLarge error
return Err(APIError::WrongFileFormat);
}
let kind = infer::get(&body).ok_or(APIError::WrongFileFormat)?;
let ("image/png" | "image/jpeg" | "image/webp") = kind.mime_type() else {
return Err(APIError::WrongFileFormat);
};
let user_id = user_id_from_uuid(&db, claims.sub).await?;
tracing::debug!(
"User ID {} is uploading {} bytes ({})",
user_id,
body.len(),
kind.mime_type()
);
let base_dir = &config.avatar_dir;
let supported_extensions = ["png", "jpg", "jpeg", "webp"];
tokio::fs::create_dir_all(&base_dir)
.await
.map_err(|e| APIError::Internal(format!("Failed to create storage: {e}")))?;
// Delete all other files for this user first
for ext in supported_extensions {
let old_filename = format!("{}.{}", claims.sub, ext);
let old_path = std::path::Path::new(base_dir).join(&old_filename);
match tokio::fs::remove_file(old_path).await {
Ok(_) => tracing::debug!("Deleted old avatar: {}", old_filename),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => tracing::warn!("Failed to delete old avatar {}: {}", old_filename, e),
}
}
let file_extension = kind.extension();
let filename = format!("{}.{}", claims.sub, file_extension);
let full_path = std::path::Path::new(&base_dir).join(&filename);
tokio::fs::write(&full_path, body)
.await
.map_err(|e| APIError::Internal(format!("Failed to save file: {e}")))?;
sqlx::query("UPDATE user_ SET avatar_url = $1 WHERE id = $2")
.bind(filename)
.bind(user_id)
.execute(&db)
.await?;
Ok(StatusCode::OK)
}
// Public route
async fn get_avatar(
Path(uuid): Path<Uuid>,
Extension(config): Extension<Arc<AppConfig>>,
) -> Result<Response, APIError> {
let base_dir = &config.avatar_dir;
// Helper to try finding the file with allowed extensions
let mut file_path = None;
for ext in ["png", "jpg", "jpeg", "webp"] {
let path = std::path::Path::new(&base_dir).join(format!("{}.{}", uuid, ext));
if path.exists() {
file_path = Some(path);
break;
}
}
let full_path = file_path.ok_or(APIError::AvatarNotFound)?;
let file_contents = tokio::fs::read(&full_path)
.await
.map_err(|e| APIError::Internal(format!("Could not read avatar file: {e}")))?;
let mime_type = infer::get(&file_contents)
.map(|k| k.mime_type())
.unwrap_or("application/octet-stream"); // Fallback
Ok(Response::builder()
.header(header::CONTENT_TYPE, mime_type)
.body(axum::body::Body::from(file_contents))
.unwrap())
}