348 lines
9.8 KiB
Rust
348 lines
9.8 KiB
Rust
use axum::{
|
|
Extension, Json, Router,
|
|
extract::{Path, Request},
|
|
http::{HeaderMap, StatusCode, header},
|
|
middleware::Next,
|
|
response::Response,
|
|
routing::{get, post, put},
|
|
};
|
|
use sqlx::PgPool;
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
use validator::ValidateEmail;
|
|
|
|
use crate::{
|
|
AppConfig, MAX_UPLOAD_SIZE, MAX_USERNAME_LENGTH,
|
|
auth::{create_jwt, hash_password, validate_token, verify_jwt, verify_password},
|
|
db::{user_id_from_uuid, username_from_uuid},
|
|
errors::APIError,
|
|
};
|
|
|
|
#[derive(sqlx::FromRow, serde::Serialize)]
|
|
pub struct User {
|
|
pub uuid: Uuid,
|
|
pub username: String,
|
|
pub password_hash: String,
|
|
pub email: String,
|
|
}
|
|
|
|
#[derive(sqlx::FromRow, serde::Serialize)]
|
|
pub struct UserProfile {
|
|
pub uuid: Uuid,
|
|
pub username: String,
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
pub struct LoginPayload {
|
|
pub email: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(serde::Serialize)]
|
|
pub struct LoginResponse {
|
|
pub uuid: Uuid,
|
|
pub username: String,
|
|
pub email: String,
|
|
pub token: String,
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
pub struct NewUserPayload {
|
|
pub email: String,
|
|
pub username: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
pub struct UpdateUserPayoad {
|
|
pub email: String,
|
|
pub username: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(serde::Serialize)]
|
|
pub struct UpdateUserResponse {
|
|
pub email: String,
|
|
pub username: String,
|
|
}
|
|
|
|
pub fn routes() -> Router {
|
|
Router::new()
|
|
.route("/login", post(login))
|
|
.route("/register", post(register_user))
|
|
.route("/validate-token", get(validate_token))
|
|
.route("/account/settings", put(update_user))
|
|
.route("/account/upload-avatar", post(upload_avatar))
|
|
.route("/account/get-avatar/{uuid}", get(get_avatar))
|
|
.layer(axum::middleware::from_fn(registration_guard))
|
|
}
|
|
|
|
async fn registration_guard(
|
|
Extension(config): Extension<Arc<AppConfig>>,
|
|
req: Request,
|
|
next: Next,
|
|
) -> Result<Response, StatusCode> {
|
|
if req.uri().path() == "/register" && config.prohibit_registration {
|
|
return Err(StatusCode::FORBIDDEN);
|
|
}
|
|
Ok(next.run(req).await)
|
|
}
|
|
|
|
pub async fn login(
|
|
Extension(db): Extension<PgPool>,
|
|
Json(payload): Json<LoginPayload>,
|
|
) -> Result<Json<LoginResponse>, APIError> {
|
|
const DUMMY_HASH: &str = "$argon2id$v=19$m=4096,t=3,p=1$YWFhYWFhYWFhYWFhYWFhYQ$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
|
|
|
|
let user = sqlx::query_as::<_, User>(
|
|
"SELECT uuid, email, username, password_hash FROM user_ WHERE email = $1",
|
|
)
|
|
.bind(&payload.email)
|
|
.fetch_optional(&db)
|
|
.await?;
|
|
|
|
let (user_uuid, password_hash) = if let Some(u) = user {
|
|
(u.uuid, u.password_hash)
|
|
} else {
|
|
// timing shield
|
|
(uuid::Uuid::now_v7(), DUMMY_HASH.to_string())
|
|
};
|
|
|
|
if !verify_password(&password_hash, &payload.password) {
|
|
return Err(APIError::WrongCredentials);
|
|
}
|
|
|
|
let token = create_jwt(user_uuid)?;
|
|
let username = username_from_uuid(&db, user_uuid).await?;
|
|
|
|
Ok(Json(LoginResponse {
|
|
uuid: user_uuid,
|
|
username,
|
|
email: payload.email,
|
|
token,
|
|
}))
|
|
}
|
|
|
|
pub async fn register_user(
|
|
Extension(db): Extension<PgPool>,
|
|
Json(payload): Json<NewUserPayload>,
|
|
) -> Result<(StatusCode, Json<LoginResponse>), APIError> {
|
|
if payload.email.is_empty() || payload.username.is_empty() || payload.password.is_empty() {
|
|
return Err(APIError::EmptyFields);
|
|
}
|
|
|
|
if !ValidateEmail::validate_email(&payload.email) {
|
|
return Err(APIError::InvalidEmail);
|
|
}
|
|
|
|
{
|
|
let username_length = payload.username.len();
|
|
if username_length > MAX_USERNAME_LENGTH || username_length < 1 {
|
|
return Err(APIError::UsernameLength);
|
|
}
|
|
}
|
|
|
|
if payload.password.len() < 8 {
|
|
return Err(APIError::PasswordTooShort);
|
|
}
|
|
|
|
let password_hash = hash_password(&payload.password)?;
|
|
|
|
let user_uuid = uuid::Uuid::now_v7();
|
|
|
|
sqlx::query(
|
|
"INSERT INTO user_ (uuid, username, email, password_hash)
|
|
VALUES ($1, $2, $3, $4)",
|
|
)
|
|
.bind(user_uuid)
|
|
.bind(&payload.username)
|
|
.bind(&payload.email)
|
|
.bind(&password_hash)
|
|
.execute(&db)
|
|
.await
|
|
.map_err(|e| {
|
|
if let Some(db_err) = e.as_database_error() {
|
|
if db_err.code().map(|c| c == "23505").unwrap_or(false) {
|
|
match db_err.constraint() {
|
|
Some("user__username_key") => return APIError::UsernameTaken,
|
|
Some("user__email_key") => return APIError::EmailTaken,
|
|
_ => return APIError::Internal("".to_string()), // TODO: handle this case
|
|
}
|
|
}
|
|
}
|
|
APIError::DatabaseError(e)
|
|
})?;
|
|
|
|
let token = create_jwt(user_uuid)?;
|
|
|
|
Ok((
|
|
StatusCode::CREATED,
|
|
Json(LoginResponse {
|
|
uuid: user_uuid,
|
|
username: payload.username,
|
|
email: payload.email,
|
|
token,
|
|
}),
|
|
))
|
|
}
|
|
|
|
pub async fn update_user(
|
|
headers: HeaderMap,
|
|
Extension(db): Extension<PgPool>,
|
|
Json(payload): Json<UpdateUserPayoad>,
|
|
) -> Result<(StatusCode, Json<UpdateUserResponse>), APIError> {
|
|
let claims = verify_jwt(headers)?;
|
|
|
|
if payload.email.is_empty() || payload.username.is_empty() {
|
|
return Err(APIError::EmptyFields);
|
|
}
|
|
|
|
if !ValidateEmail::validate_email(&payload.email) {
|
|
return Err(APIError::InvalidEmail);
|
|
}
|
|
|
|
let user_id = user_id_from_uuid(&db, claims.sub).await?;
|
|
|
|
let mut tx = db.begin().await?;
|
|
|
|
if !payload.password.is_empty() {
|
|
if payload.password.len() < 8 {
|
|
return Err(APIError::PasswordTooShort);
|
|
}
|
|
|
|
let password_hash = hash_password(&payload.password)?;
|
|
|
|
sqlx::query("UPDATE user_ SET password_hash = $1 WHERE id = $2")
|
|
.bind(password_hash)
|
|
.bind(user_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
}
|
|
|
|
sqlx::query("UPDATE user_ SET username = $1, email = $2 WHERE id = $3")
|
|
.bind(&payload.username)
|
|
.bind(&payload.email)
|
|
.bind(user_id)
|
|
.execute(&mut *tx)
|
|
.await
|
|
.map_err(|e| {
|
|
if let Some(db_err) = e.as_database_error() {
|
|
if db_err.code().map(|c| c == "23505").unwrap_or(false) {
|
|
match db_err.constraint() {
|
|
Some("user__username_key") => return APIError::UsernameTaken,
|
|
Some("user__email_key") => return APIError::EmailTaken,
|
|
_ => return APIError::Internal("".to_string()), // TODO: handle this case
|
|
}
|
|
}
|
|
}
|
|
APIError::DatabaseError(e)
|
|
})?;
|
|
|
|
tx.commit().await?;
|
|
|
|
Ok((
|
|
StatusCode::CREATED,
|
|
Json(UpdateUserResponse {
|
|
username: payload.username,
|
|
email: payload.email,
|
|
}),
|
|
))
|
|
}
|
|
|
|
async fn upload_avatar(
|
|
headers: HeaderMap,
|
|
Extension(db): Extension<PgPool>,
|
|
Extension(config): Extension<Arc<AppConfig>>,
|
|
body: axum::body::Bytes,
|
|
) -> Result<StatusCode, APIError> {
|
|
let claims = verify_jwt(headers)?;
|
|
|
|
if body.len() > MAX_UPLOAD_SIZE {
|
|
// TODO: FileTooLarge error
|
|
return Err(APIError::WrongFileFormat);
|
|
}
|
|
|
|
let kind = infer::get(&body).ok_or(APIError::WrongFileFormat)?;
|
|
|
|
let ("image/png" | "image/jpeg" | "image/webp") = kind.mime_type() else {
|
|
return Err(APIError::WrongFileFormat);
|
|
};
|
|
|
|
let user_id = user_id_from_uuid(&db, claims.sub).await?;
|
|
tracing::debug!(
|
|
"User ID {} is uploading {} bytes ({})",
|
|
user_id,
|
|
body.len(),
|
|
kind.mime_type()
|
|
);
|
|
|
|
let base_dir = &config.avatar_dir;
|
|
|
|
let supported_extensions = ["png", "jpg", "jpeg", "webp"];
|
|
|
|
tokio::fs::create_dir_all(&base_dir)
|
|
.await
|
|
.map_err(|e| APIError::Internal(format!("Failed to create storage: {e}")))?;
|
|
|
|
// Delete all other files for this user first
|
|
for ext in supported_extensions {
|
|
let old_filename = format!("{}.{}", claims.sub, ext);
|
|
let old_path = std::path::Path::new(base_dir).join(&old_filename);
|
|
|
|
match tokio::fs::remove_file(old_path).await {
|
|
Ok(_) => tracing::debug!("Deleted old avatar: {}", old_filename),
|
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
|
Err(e) => tracing::warn!("Failed to delete old avatar {}: {}", old_filename, e),
|
|
}
|
|
}
|
|
|
|
let file_extension = kind.extension();
|
|
let filename = format!("{}.{}", claims.sub, file_extension);
|
|
let full_path = std::path::Path::new(&base_dir).join(&filename);
|
|
|
|
tokio::fs::write(&full_path, body)
|
|
.await
|
|
.map_err(|e| APIError::Internal(format!("Failed to save file: {e}")))?;
|
|
|
|
sqlx::query("UPDATE user_ SET avatar_url = $1 WHERE id = $2")
|
|
.bind(filename)
|
|
.bind(user_id)
|
|
.execute(&db)
|
|
.await?;
|
|
|
|
Ok(StatusCode::OK)
|
|
}
|
|
|
|
// Public route
|
|
async fn get_avatar(
|
|
Path(uuid): Path<Uuid>,
|
|
Extension(config): Extension<Arc<AppConfig>>,
|
|
) -> Result<Response, APIError> {
|
|
let base_dir = &config.avatar_dir;
|
|
|
|
// Helper to try finding the file with allowed extensions
|
|
let mut file_path = None;
|
|
for ext in ["png", "jpg", "jpeg", "webp"] {
|
|
let path = std::path::Path::new(&base_dir).join(format!("{}.{}", uuid, ext));
|
|
if path.exists() {
|
|
file_path = Some(path);
|
|
break;
|
|
}
|
|
}
|
|
|
|
let full_path = file_path.ok_or(APIError::AvatarNotFound)?;
|
|
|
|
let file_contents = tokio::fs::read(&full_path)
|
|
.await
|
|
.map_err(|e| APIError::Internal(format!("Could not read avatar file: {e}")))?;
|
|
|
|
let mime_type = infer::get(&file_contents)
|
|
.map(|k| k.mime_type())
|
|
.unwrap_or("application/octet-stream"); // Fallback
|
|
|
|
Ok(Response::builder()
|
|
.header(header::CONTENT_TYPE, mime_type)
|
|
.body(axum::body::Body::from(file_contents))
|
|
.unwrap())
|
|
}
|