diff --git a/src/routes/users.rs b/src/routes/users.rs index aa3ba41..2ff7324 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -1,10 +1,10 @@ use axum::{ Extension, Json, Router, extract::Request, - http::StatusCode, + http::{HeaderMap, StatusCode}, middleware::Next, response::Response, - routing::{get, post}, + routing::{get, post, put}, }; use sqlx::PgPool; use std::env; @@ -12,8 +12,8 @@ use uuid::Uuid; use validator::ValidateEmail; use crate::{ - auth::{create_jwt, hash_password, validate_token, verify_password}, - db::username_from_uuid, + auth::{create_jwt, hash_password, validate_token, verify_jwt, verify_password}, + db::{user_id_from_uuid, username_from_uuid}, }; const DUMMY_HASH: &str = "$argon2id$v=19$m=4096,t=3,p=1$YWFhYWFhYWFhYWFhYWFhYQ$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; @@ -47,11 +47,25 @@ pub struct NewUserPayload { pub password: String, } +#[derive(serde::Deserialize)] +pub struct UpdateUserPayoad { + pub email: String, + pub username: String, + pub password: String, +} + +#[derive(serde::Serialize)] +pub struct UpdateUserResponse { + pub email: String, + pub username: String, +} + pub fn routes() -> Router { Router::new() .route("/login", post(login)) .route("/register", post(register_user)) .route("/validate-token", get(validate_token)) + .route("/account", put(update_user)) .layer(axum::middleware::from_fn(registration_guard)) } @@ -163,3 +177,90 @@ pub async fn register_user( }), )) } + +pub async fn update_user( + headers: HeaderMap, + Extension(db): Extension, + Json(payload): Json, +) -> Result<(StatusCode, Json), (StatusCode, String)> { + let claims = verify_jwt(headers)?; + + if payload.email.is_empty() || payload.username.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + "Missing username or email fields".into(), + )); + } + + if !ValidateEmail::validate_email(&payload.email) { + return Err((StatusCode::BAD_REQUEST, "Invalid email format".into())); + } + + let user_id = user_id_from_uuid(&db, claims.sub).await?; + + let mut tx = db + .begin() + .await + .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + + if !payload.password.is_empty() { + if payload.password.len() < 8 { + return Err(( + StatusCode::BAD_REQUEST, + "Password must be at least 8 characters long".into(), + )); + } + + let password_hash = hash_password(&payload.password).map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to hash password".into(), + ) + })?; + + sqlx::query("UPDATE user_ SET password_hash = $1 WHERE id = $2") + .bind(password_hash) + .bind(user_id) + .execute(&mut *tx) + .await + .map_err(|e| { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to update password: {e}"), + ); + })?; + } + + sqlx::query("UPDATE user_ SET username = $1, email = $2 WHERE id = $3") + .bind(&payload.username) + .bind(&payload.email) + .bind(user_id) + .execute(&mut *tx) + .await + .map_err(|e| { + if let Some(db_err) = e.as_database_error() { + if db_err.code().map(|c| c == "23505").unwrap_or(false) { + return ( + StatusCode::CONFLICT, + "Email or username already taken".into(), + ); + } + } + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()) + })?; + + tx.commit().await.map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Could not update account".into(), + ) + })?; + + Ok(( + StatusCode::CREATED, + Json(UpdateUserResponse { + username: payload.username, + email: payload.email, + }), + )) +}