diff --git a/Cargo.lock b/Cargo.lock index aeb67f4..8087edc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -271,6 +271,17 @@ dependencies = [ "shlex", ] +[[package]] +name = "cfb" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38f2da7a0a2c4ccf0065be06397cc26a81f4e528be095826eee9d4adbb8c60f" +dependencies = [ + "byteorder", + "fnv", + "uuid", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -629,6 +640,7 @@ dependencies = [ "chrono", "clap", "dashmap", + "infer", "jsonwebtoken", "password-hash", "serde", @@ -1138,6 +1150,15 @@ dependencies = [ "hashbrown 0.16.1", ] +[[package]] +name = "infer" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a588916bfdfd92e71cacef98a63d9b1f0d74d6599980d11894290e7ddefffcf7" +dependencies = [ + "cfb", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" diff --git a/Cargo.toml b/Cargo.toml index 5fe63e6..7b85eda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ bytes = "1.11.0" chrono = { version = "0.4.42", features = ["serde"] } clap = { version = "4.5.53", features = ["derive"] } dashmap = "6.1.0" +infer = "0.19.0" jsonwebtoken = "9.3.1" password-hash = "0.5.0" serde = { version = "1.0.219", features = ["derive"] } diff --git a/src/auth.rs b/src/auth.rs index 150def2..8270597 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -4,12 +4,11 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode} use password_hash::SaltString; use password_hash::rand_core::OsRng; -use axum::{ - Json, - http::{HeaderMap, StatusCode}, -}; +use axum::{Json, http::HeaderMap}; use uuid::Uuid; +use crate::errors::APIError; + const DEFAULT_SECRET_KEY: &str = "43aaf85b92f1ae6fbcef7732c50a0904"; #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -18,13 +17,14 @@ pub struct Claims { pub exp: usize, } -pub fn hash_password(password: &str) -> Result { +pub fn hash_password(password: &str) -> Result { let salt = SaltString::generate(OsRng); let argon2 = Argon2::default(); argon2 .hash_password(password.as_bytes(), &salt) .map_err(|e| e.to_string()) .map(|ph| ph.to_string()) + .map_err(|e| APIError::Internal(e)) } pub fn verify_password(hash: &str, password: &str) -> bool { @@ -38,7 +38,7 @@ pub fn verify_password(hash: &str, password: &str) -> bool { } } -pub fn create_jwt(user_uuid: Uuid) -> Result { +pub fn create_jwt(user_uuid: Uuid) -> Result { let expiration = Utc::now() .checked_add_signed(Duration::days(100)) .expect("valid timestamp") @@ -57,20 +57,20 @@ pub fn create_jwt(user_uuid: Uuid) -> Result { &claims, &EncodingKey::from_secret(secret.as_ref()), ) - .map_err(|_| "Token creation failed".into()) + .map_err(|e| APIError::Internal(e.to_string())) } -pub fn verify_jwt(headers: HeaderMap) -> Result { +pub fn verify_jwt(headers: HeaderMap) -> Result { let token = headers .get("Authorization") .and_then(|v| v.to_str().ok()) .and_then(|s| s.strip_prefix("Bearer ")) - .ok_or((StatusCode::UNAUTHORIZED, "Missing token".to_string()))?; + .ok_or(APIError::MissingToken)?; verify_jwt_string(&token.to_string()) } -pub fn verify_jwt_string(token: &String) -> Result { +pub fn verify_jwt_string(token: &String) -> Result { let secret = std::env::var("FRANGIPANE_JWT_SECRET").unwrap_or_else(|_| DEFAULT_SECRET_KEY.to_string()); @@ -80,12 +80,10 @@ pub fn verify_jwt_string(token: &String) -> Result &Validation::default(), ) .map(|data| data.claims) - .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string())) + .map_err(|_| APIError::InvalidToken) } -pub async fn validate_token( - headers: HeaderMap, -) -> Result, (StatusCode, String)> { +pub async fn validate_token(headers: HeaderMap) -> Result, APIError> { let _ = verify_jwt(headers)?; Ok(Json(serde_json::json!({"valid": true}))) } diff --git a/src/db.rs b/src/db.rs index c84b511..fe039eb 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,65 +1,57 @@ -use axum::http::StatusCode; use sqlx::PgPool; use uuid::Uuid; +use crate::errors::APIError; + pub async fn init_db(url: String) -> Result { let database_url = format!("postgres://frangipane:secret@{url}/frangipane"); PgPool::connect_lazy(database_url.as_str()) } -pub async fn user_id_from_uuid(db: &PgPool, user_uuid: Uuid) -> Result { +pub async fn user_id_from_uuid(db: &PgPool, user_uuid: Uuid) -> Result { sqlx::query_scalar("SELECT id FROM user_ WHERE uuid = $1") .bind(user_uuid) .fetch_one(db) .await - .map_err(|_| (StatusCode::UNAUTHORIZED, String::from("Wrong token"))) + .map_err(|_| APIError::UserNotFound) } -pub async fn room_id_from_uuid(db: &PgPool, room_uuid: Uuid) -> Result { +pub async fn room_id_from_uuid(db: &PgPool, room_uuid: Uuid) -> Result { sqlx::query_scalar("SELECT id FROM room_ WHERE uuid = $1") .bind(room_uuid) .fetch_one(db) .await - .map_err(|e| { - tracing::error!("Failed to convert room uuid to room id: {e}"); - (StatusCode::NOT_FOUND, "Failed to find room".into()) - }) + .map_err(|_| APIError::RoomNotFound) } -pub async fn username_from_uuid( - db: &PgPool, - user_uuid: Uuid, -) -> Result { +pub async fn username_from_uuid(db: &PgPool, user_uuid: Uuid) -> Result { sqlx::query_scalar("SELECT username FROM user_ WHERE uuid = $1") .bind(user_uuid) .fetch_one(db) .await - .map_err(|_| (StatusCode::UNAUTHORIZED, String::from("Wrong token"))) + .map_err(|_| APIError::UserNotFound) } -pub async fn username_from_id(db: &PgPool, user_id: i32) -> Result { +pub async fn username_from_id(db: &PgPool, user_id: i32) -> Result { sqlx::query_scalar("SELECT username FROM user_ WHERE id = $1") .bind(user_id) .fetch_one(db) .await - .map_err(|_| (StatusCode::UNAUTHORIZED, String::from("Wrong token"))) + .map_err(|_| APIError::UserNotFound) } -pub async fn id_from_username(db: &PgPool, username: String) -> Result { +pub async fn id_from_username(db: &PgPool, username: String) -> Result { sqlx::query_scalar("SELECT id FROM user_ WHERE username = $1") .bind(username) .fetch_one(db) .await - .map_err(|_| (StatusCode::NOT_FOUND, "User not found".into())) + .map_err(|_| APIError::UserNotFound) } -pub async fn room_name_from_uuid( - db: &PgPool, - room_uuid: Uuid, -) -> Result { +pub async fn room_name_from_uuid(db: &PgPool, room_uuid: Uuid) -> Result { sqlx::query_scalar("SELECT name FROM room_ WHERE uuid = $1") .bind(room_uuid) .fetch_one(db) .await - .map_err(|_| (StatusCode::NOT_FOUND, "Failed to find room".into())) + .map_err(|_| APIError::RoomNotFound) } diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..e8027f6 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,239 @@ +use std::borrow::Cow; + +use axum::{ + Json, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use serde_json::json; + +use crate::{MAX_ROOM_NAME_LENGTH, MAX_USERNAME_LENGTH}; + +#[derive(Debug)] +pub enum APIError { + // Auth Errors + WrongCredentials, + MissingToken, + InvalidToken, + + // User Errors + UserNotFound, + EmailTaken, + UsernameTaken, + UsernameLength, + InvalidEmail, + PasswordTooShort, + EmptyFields, + AvatarNotFound, + + // Room Errors + RoomNotFound, + NotAMember, + AlreadyMember, + RoomOwnerCannotLeave, + GlobalRoomMemberError, + RoomNameLength, + + // Invite Errors + InviteSelf, + AlreadyInvited, + InviteNotFound, + + // Friend Errors + FriendRequestSelf, + AlreadyFriends, + FriendRequestAlreadySent, + FriendRequestNotFound, + NotFriends, + + // Uploads + WrongFileFormat, + + // Technical/Internal + DatabaseError(sqlx::Error), + Internal(String), +} + +// Allow using `?` with sqlx errors +impl From for APIError { + fn from(err: sqlx::Error) -> Self { + Self::DatabaseError(err) + } +} + +impl IntoResponse for APIError { + fn into_response(self) -> Response { + let (status, code, message): (StatusCode, &str, Cow) = match self { + // Auth + APIError::WrongCredentials => ( + StatusCode::UNAUTHORIZED, + "AUTH_INVALID_CREDENTIALS", + "Invalid email or password".into(), + ), + APIError::MissingToken => ( + StatusCode::UNAUTHORIZED, + "AUTH_MISSING_TOKEN", + "Missing authentication header".into(), + ), + APIError::InvalidToken => ( + StatusCode::UNAUTHORIZED, + "AUTH_INVALID_TOKEN", + "Invalid or expired token".into(), + ), + + // Users + APIError::UserNotFound => ( + StatusCode::NOT_FOUND, + "USER_NOT_FOUND", + "User not found".into(), + ), + APIError::EmailTaken => ( + StatusCode::CONFLICT, + "USER_EMAIL_TAKEN", + "Email already in use".into(), + ), + APIError::UsernameTaken => ( + StatusCode::CONFLICT, + "USER_USERNAME_TAKEN", + "Username already taken".into(), + ), + APIError::UsernameLength => ( + StatusCode::BAD_REQUEST, + "USERNAME_LENGTH", + format!("Username must be 1-{} characters long", MAX_USERNAME_LENGTH).into(), + ), + APIError::InvalidEmail => ( + StatusCode::BAD_REQUEST, + "USER_INVALID_EMAIL", + "Invalid email format".into(), + ), + APIError::PasswordTooShort => ( + StatusCode::BAD_REQUEST, + "USER_PASSWORD_TOO_SHORT", + "Password must be at least 8 characters".into(), + ), + APIError::EmptyFields => ( + StatusCode::BAD_REQUEST, + "USER_EMPTY_FIELDS", + "Required fields are empty".into(), + ), + APIError::AvatarNotFound => ( + StatusCode::NOT_FOUND, + "AVATAR_NOT_FOUND", + "Avatar not found".into(), + ), + + // Rooms + APIError::RoomNotFound => ( + StatusCode::NOT_FOUND, + "ROOM_NOT_FOUND", + "Room not found".into(), + ), + APIError::NotAMember => ( + StatusCode::FORBIDDEN, + "ROOM_NOT_MEMBER", + "You are not a member of this room".into(), + ), + APIError::AlreadyMember => ( + StatusCode::CONFLICT, + "ROOM_ALREADY_MEMBER", + "User is already a member".into(), + ), + APIError::RoomOwnerCannotLeave => ( + StatusCode::FORBIDDEN, + "ROOM_OWNER_CANNOT_LEAVE", + "Owner cannot leave the room without transferring ownership".into(), + ), + APIError::GlobalRoomMemberError => ( + StatusCode::FORBIDDEN, + "ROOM_GLOBAL_NO_MEMBERS", + "Cannot list members for global rooms".into(), + ), + APIError::RoomNameLength => ( + StatusCode::BAD_REQUEST, + "ROOM_NAME_LENGTH", + format!( + "Room name must be 0-{} characters long", + MAX_ROOM_NAME_LENGTH + ) + .into(), + ), + + // Invites + APIError::InviteSelf => ( + StatusCode::BAD_REQUEST, + "INVITE_SELF", + "Cannot invite yourself".into(), + ), + APIError::AlreadyInvited => ( + StatusCode::CONFLICT, + "INVITE_ALREADY_SENT", + "Invite already sent".into(), + ), + APIError::InviteNotFound => ( + StatusCode::NOT_FOUND, + "INVITE_NOT_FOUND", + "Invite not found".into(), + ), + + // Friends + APIError::FriendRequestSelf => ( + StatusCode::BAD_REQUEST, + "FRIEND_REQUEST_SELF", + "Cannot send friend request to yourself".into(), + ), + APIError::AlreadyFriends => ( + StatusCode::CONFLICT, + "FRIEND_ALREADY_EXISTS", + "You are already friends".into(), + ), + APIError::FriendRequestAlreadySent => ( + StatusCode::CONFLICT, + "FRIEND_REQUEST_ALREADY_SENT", + "Request already pending".into(), + ), + APIError::FriendRequestNotFound => ( + StatusCode::NOT_FOUND, + "FRIEND_REQUEST_NOT_FOUND", + "Friend request not found".into(), + ), + APIError::NotFriends => ( + StatusCode::NOT_FOUND, + "FRIEND_NOT_FOUND", + "User is not in your friends list".into(), + ), + + // Uploads + APIError::WrongFileFormat => ( + StatusCode::BAD_REQUEST, + "WRONG_FILE_FORMAT", + "Wrong file format".into(), + ), + + // Internal + APIError::DatabaseError(e) => { + tracing::error!("Database error: {:?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "INTERNAL_DB_ERROR", + "Database error".into(), + ) + } + APIError::Internal(msg) => { + tracing::error!("Internal error: {}", msg); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "INTERNAL_SERVER_ERROR", + "Internal server error".into(), + ) + } + }; + + let body = Json(json!({ + "code": code, + "message": message, + })); + + (status, body).into_response() + } +} diff --git a/src/main.rs b/src/main.rs index 07c050c..1a307f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use axum::{ Extension, Json, Router, + extract::DefaultBodyLimit, http::{ Method, StatusCode, header::{self, CONTENT_TYPE}, @@ -23,9 +24,14 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod auth; mod db; +mod errors; mod realtime; mod routes; +const MAX_USERNAME_LENGTH: usize = 35; +const MAX_ROOM_NAME_LENGTH: usize = 35; +const MAX_UPLOAD_SIZE: usize = 5 * 1024 * 1024; // Not actually used for now + pub struct AppConfig { pub avatar_dir: PathBuf, pub prohibit_registration: bool, @@ -129,6 +135,7 @@ async fn main() -> anyhow::Result<()> { .layer(Extension(config)) .layer(GovernorLayer::new(governor_conf)) .layer(cors) + .layer(DefaultBodyLimit::max(1024 * 5 * 100)) .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::new().level(Level::DEBUG)) diff --git a/src/routes/friends.rs b/src/routes/friends.rs index 1f2a93a..b057123 100644 --- a/src/routes/friends.rs +++ b/src/routes/friends.rs @@ -7,8 +7,11 @@ use axum::{ use sqlx::PgPool; use uuid::Uuid; -use crate::db::{user_id_from_uuid, username_from_id, username_from_uuid}; use crate::{auth::verify_jwt, db::id_from_username}; +use crate::{ + db::{user_id_from_uuid, username_from_id, username_from_uuid}, + errors::APIError, +}; #[derive(sqlx::FromRow, serde::Serialize)] pub struct Friend { @@ -56,7 +59,7 @@ pub fn routes() -> Router { async fn list_friends( headers: HeaderMap, Extension(db): Extension, -) -> Result>, (StatusCode, String)> { +) -> Result>, APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -72,12 +75,7 @@ async fn list_friends( .bind(user_id) .fetch_all(&db) .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not list friends".into(), - ) - })?; + .map_err(|e| APIError::Internal(format!("Could not list friends: {e}")))?; Ok(Json(friends)) } @@ -85,7 +83,7 @@ async fn list_friends( async fn list_requests( headers: HeaderMap, Extension(db): Extension, -) -> Result>, (StatusCode, String)> { +) -> Result>, APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -100,12 +98,7 @@ async fn list_requests( .bind(user_id) .fetch_all(&db) .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not list friend requests".into(), - ) - })?; + .map_err(|e| APIError::Internal(format!("Could not list friend requests: {e}")))?; Ok(Json(requests)) } @@ -114,17 +107,14 @@ async fn send_request( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; let sender_id = user_id_from_uuid(&db, claims.sub).await?; let receiver_id = id_from_username(&db, payload.receiver_username).await?; if sender_id == receiver_id { - return Err(( - StatusCode::BAD_REQUEST, - "Cannot send a friend request to yourself".into(), - )); + return Err(APIError::FriendRequestSelf); } let is_already_friend = sqlx::query_scalar::<_, bool>( @@ -139,14 +129,10 @@ async fn send_request( .bind(sender_id) .bind(receiver_id) .fetch_one(&db) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error".into()))?; + .await?; if is_already_friend { - return Err(( - StatusCode::CONFLICT, - "You are already friends with this user".into(), - )); + return Err(APIError::AlreadyFriends); } sqlx::query("INSERT INTO friend_request_ (sender, receiver) VALUES ($1, $2)") @@ -154,12 +140,7 @@ async fn send_request( .bind(receiver_id) .execute(&db) .await - .map_err(|_| { - ( - StatusCode::CONFLICT, - "You have already send a friend request to this user".into(), - ) - })?; + .map_err(|_| APIError::FriendRequestAlreadySent)?; Ok(( StatusCode::CREATED, @@ -174,7 +155,7 @@ async fn accept_request( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; let receiver_id = user_id_from_uuid(&db, claims.sub).await?; @@ -186,10 +167,7 @@ async fn accept_request( (receiver_id, sender_id) }; - let mut tx = db - .begin() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + let mut tx = db.begin().await?; let rows = sqlx::query( r#" @@ -201,12 +179,11 @@ async fn accept_request( .bind(sender_id) .bind(receiver_id) .execute(&mut *tx) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))? + .await? .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, "No such request".into())); + return Err(APIError::FriendRequestNotFound); } sqlx::query("INSERT INTO friendship_ (user_first, user_second) VALUES ($1, $2)") @@ -214,14 +191,9 @@ async fn accept_request( .bind(second) .execute(&mut *tx) .await - .map_err(|_| (StatusCode::CONFLICT, "Already friends".into()))?; + .map_err(|_| APIError::AlreadyFriends)?; - tx.commit().await.map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not accept friendship".into(), - ) - })?; + tx.commit().await?; Ok(( StatusCode::CREATED, @@ -236,7 +208,7 @@ async fn decline_request( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result { +) -> Result { let claims = verify_jwt(headers)?; let receiver_id = user_id_from_uuid(&db, claims.sub).await?; @@ -252,17 +224,11 @@ async fn decline_request( .bind(sender_id) .bind(receiver_id) .execute(&db) - .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not decline friend request".into(), - ) - })? + .await? .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, "No such request".into())); + return Err(APIError::FriendRequestNotFound); } Ok(StatusCode::CREATED) @@ -272,7 +238,7 @@ async fn is_friend( headers: HeaderMap, Path(target_uuid): Path, Extension(db): Extension, -) -> Result, (StatusCode, String)> { +) -> Result, APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -290,8 +256,7 @@ async fn is_friend( .bind(user_id) .bind(target_id) .fetch_one(&db) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error".into()))?; + .await?; Ok(Json(is_friend)) } @@ -300,7 +265,7 @@ async fn remove_friend( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result { +) -> Result { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -316,20 +281,11 @@ async fn remove_friend( .bind(user_id) .bind(friend_id) .execute(&db) - .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not remove friend".into(), - ) - })? + .await? .rows_affected(); if rows == 0 { - return Err(( - StatusCode::NOT_FOUND, - "User is not in your friends list".into(), - )); + return Err(APIError::NotFriends); } Ok(StatusCode::OK) diff --git a/src/routes/messages.rs b/src/routes/messages.rs index f781810..e96dbfc 100644 --- a/src/routes/messages.rs +++ b/src/routes/messages.rs @@ -18,6 +18,7 @@ use uuid::Uuid; use crate::{ auth::{verify_jwt, verify_jwt_string}, db::room_id_from_uuid, + errors::APIError, routes::{rooms::is_member, ws::WsAuthQuery}, }; use crate::{ @@ -72,25 +73,19 @@ async fn list_messages( Query(query): Query, headers: HeaderMap, Extension(db): Extension, -) -> Result>, (StatusCode, String)> { +) -> Result>, APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; if !is_member(user_id, room_id, &db).await { - return Err(( - StatusCode::FORBIDDEN, - String::from("You are not a member of this room"), - )); + return Err(APIError::NotAMember); } let limit: i32 = query.limit.unwrap_or(30).abs().min(80); - let mut tx = db - .begin() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + let mut tx = db.begin().await?; let messages = sqlx::query_as::<_, MessageRow>( r#" @@ -115,13 +110,7 @@ async fn list_messages( .bind(query.before) .bind(limit) .fetch_all(&mut *tx) - .await - .map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to list messages: {e}"), - ) - })?; + .await?; let mut messages: Vec = messages .into_iter() @@ -149,18 +138,9 @@ async fn list_messages( .bind(user_id) .bind(room_id) .execute(&mut *tx) - .await - .map_err(|e| { - tracing::error!("Error updating membership timestamp: {e}"); - (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to ")) - })?; + .await?; - tx.commit().await.map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not list messages".into(), - ) - })?; + tx.commit().await?; Ok(Json(messages)) } @@ -171,17 +151,14 @@ async fn create_message( Extension(realtime): Extension, headers: HeaderMap, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; if !is_member(user_id, room_id, &db).await { - return Err(( - StatusCode::UNAUTHORIZED, - String::from("You are not a member of this room"), - )); + return Err(APIError::NotAMember); } let uuid = Uuid::now_v7(); @@ -196,8 +173,7 @@ async fn create_message( .bind(&payload.content) .bind(&uuid) .fetch_one(&db) - .await - .map_err(|_| (StatusCode::BAD_REQUEST, "Could not create message".into()))?; + .await?; let sender_name = username_from_uuid(&db, claims.sub).await?; @@ -221,11 +197,7 @@ async fn create_message( ) .bind(room_id) .fetch_all(&db) - .await - .map_err(|e| { - tracing::error!("Error fetching message recipients: {e}"); - (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) - })?; + .await?; let rt = realtime.clone(); let msg_clone = message.clone(); @@ -244,7 +216,7 @@ async fn message_ws_handler( ConnectInfo(addr): ConnectInfo, Extension(realtime): Extension, Extension(db): Extension, -) -> Result { +) -> Result { // tracing::info!("recieved ws handshake: {}", room_uuid); let claims = verify_jwt_string(&query.token)?; @@ -260,13 +232,11 @@ async fn message_ws_handler( .bind(query.token) .execute(&db) .await - .map_err(|e| { - tracing::error!("Failed to get WS token from DB: {e}"); - (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) - })?; + // NOTE: Maybe wrong type of error + .map_err(|e| APIError::Internal(format!("Failed to get WS token from DB: {e}")))?; if result.rows_affected() == 0 { - return Err((StatusCode::UNAUTHORIZED, "Wrong token".into())); + return Err(APIError::InvalidToken); } let receiver = realtime.get_sender(user_uuid).subscribe(); diff --git a/src/routes/rooms.rs b/src/routes/rooms.rs index 0157e97..1e0b6c5 100644 --- a/src/routes/rooms.rs +++ b/src/routes/rooms.rs @@ -7,7 +7,7 @@ use axum::{ use sqlx::{PgPool, Pool, Postgres}; use uuid::Uuid; -use crate::{auth::verify_jwt, db::room_id_from_uuid}; +use crate::{MAX_ROOM_NAME_LENGTH, auth::verify_jwt, db::room_id_from_uuid, errors::APIError}; use crate::{ db::{id_from_username, room_name_from_uuid, user_id_from_uuid, username_from_id}, routes::users::UserProfile, @@ -93,10 +93,10 @@ pub async fn is_member(user_id: i32, room_id: i32, db: &Pool) -> bool async fn list_rooms( headers: HeaderMap, Extension(db): Extension, -) -> Result>, (StatusCode, String)> { +) -> Result>, APIError> { let claims = verify_jwt(headers)?; if claims.sub != claims.sub { - return Err((StatusCode::FORBIDDEN, "Forbidden".to_string())); + return Err(APIError::InvalidToken); } let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -137,9 +137,16 @@ async fn create_room( Extension(db): Extension, headers: HeaderMap, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; + { + let room_name_length = payload.name.len(); + if room_name_length > MAX_ROOM_NAME_LENGTH || room_name_length < 1 { + return Err(APIError::RoomNameLength); + } + } + let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_uuid = uuid::Uuid::now_v7(); @@ -154,8 +161,7 @@ async fn create_room( // .bind(&payload.global) .bind(false) // We do not allow global rooms .execute(&db) - .await - .map_err(|_| (StatusCode::BAD_REQUEST, format!("Could not create room")))?; + .await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; @@ -164,14 +170,9 @@ async fn create_room( .bind(user_id) .bind(room_id) .execute(&db) - .await - .map_err(|_| (StatusCode::BAD_REQUEST, format!("Could not create room")))?; + .await?; - let owner_name = sqlx::query_scalar("SELECT username FROM user_ WHERE id = $1") - .bind(user_id) - .fetch_one(&db) - .await - .map_err(|_| (StatusCode::BAD_REQUEST, format!("Could not create room")))?; + let owner_name = username_from_id(&db, user_id).await?; Ok(( StatusCode::CREATED, @@ -190,7 +191,7 @@ async fn get_room( Path(room_uuid): Path, headers: HeaderMap, Extension(db): Extension, -) -> Result, (StatusCode, String)> { +) -> Result, APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -232,16 +233,10 @@ async fn get_room( .bind(user_id) .fetch_one(&db) .await - .map_err(|e| { - tracing::error!("Failed getting room: {e}"); - (StatusCode::NOT_FOUND, "Room not found".to_string()) - })?; + .map_err(|_| APIError::RoomNotFound)?; if !row.is_member.unwrap_or(false) { - return Err(( - StatusCode::FORBIDDEN, - "You are not a member of this room".to_string(), - )); + return Err(APIError::NotAMember); } Ok(Json(Room { @@ -257,7 +252,7 @@ async fn get_room( async fn list_invites( headers: HeaderMap, Extension(db): Extension, -) -> Result>, (StatusCode, String)> { +) -> Result>, APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -276,14 +271,7 @@ async fn list_invites( ) .bind(user_id) .fetch_all(&db) - .await - .map_err(|e| { - tracing::error!("{e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not list room invites".into(), - ) - })?; + .await?; Ok(Json(requests)) } @@ -292,7 +280,7 @@ async fn send_invite( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; let sender_id = user_id_from_uuid(&db, claims.sub).await?; @@ -300,10 +288,7 @@ async fn send_invite( let room_id = room_id_from_uuid(&db, payload.room_uuid).await?; if sender_id == receiver_id { - return Err(( - StatusCode::BAD_REQUEST, - "Cannot send a room invite to yourself".into(), - )); + return Err(APIError::InviteSelf); } let is_already_member = sqlx::query_scalar::<_, bool>( @@ -318,14 +303,10 @@ async fn send_invite( .bind(receiver_id) .bind(room_id) .fetch_one(&db) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error".into()))?; + .await?; if is_already_member { - return Err(( - StatusCode::CONFLICT, - "This user is already a member of this room".into(), - )); + return Err(APIError::AlreadyMember); } sqlx::query("INSERT INTO room_invite_ (sender, receiver, room) VALUES ($1, $2, $3)") @@ -334,12 +315,7 @@ async fn send_invite( .bind(room_id) .execute(&db) .await - .map_err(|_| { - ( - StatusCode::CONFLICT, - "You have already invited this user".into(), - ) - })?; + .map_err(|_| APIError::AlreadyInvited)?; let room_name = room_name_from_uuid(&db, payload.room_uuid).await?; @@ -358,16 +334,13 @@ async fn accept_request( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; let receiver_id = user_id_from_uuid(&db, claims.sub).await?; let sender_id = user_id_from_uuid(&db, payload.sender_uuid).await?; - let mut tx = db - .begin() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + let mut tx = db.begin().await?; let rows = sqlx::query( r#" @@ -378,12 +351,11 @@ async fn accept_request( .bind(sender_id) .bind(receiver_id) .execute(&mut *tx) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))? + .await? .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, "No such invite".into())); + return Err(APIError::InviteNotFound); } let room_id = room_id_from_uuid(&db, payload.room_uuid).await?; @@ -393,12 +365,7 @@ async fn accept_request( .bind(room_id) .execute(&mut *tx) .await - .map_err(|_| { - ( - StatusCode::CONFLICT, - "Error creating room membership".into(), - ) - })?; + .map_err(|_| APIError::AlreadyMember)?; let room: Room = sqlx::query_as( r#" @@ -418,19 +385,9 @@ async fn accept_request( .bind(sender_id) .fetch_one(&db) .await - .map_err(|_| { - ( - StatusCode::NOT_FOUND, - "Room not found or wrong owner".into(), - ) - })?; + .map_err(|_| APIError::RoomNotFound)?; - tx.commit().await.map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not accept room invite".into(), - ) - })?; + tx.commit().await?; Ok(( StatusCode::CREATED, @@ -449,7 +406,7 @@ async fn decline_request( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result { +) -> Result { let claims = verify_jwt(headers)?; let receiver_id = user_id_from_uuid(&db, claims.sub).await?; @@ -464,17 +421,11 @@ async fn decline_request( .bind(sender_id) .bind(receiver_id) .execute(&db) - .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not decline the room invite".into(), - ) - })? + .await? .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, "No such invite".into())); + return Err(APIError::InviteNotFound); } Ok(StatusCode::CREATED) @@ -484,110 +435,25 @@ async fn leave_room( headers: HeaderMap, Path(room_uuid): Path, Extension(db): Extension, -) -> Result { +) -> Result { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; if !is_member(user_id, room_id, &db).await { - return Err(( - StatusCode::FORBIDDEN, - "You are not a member of this room.".into(), - )); + return Err(APIError::NotAMember); } - let mut tx = db - .begin() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + let mut tx = db.begin().await?; let owner: i32 = sqlx::query_scalar(r#"SELECT owner FROM room_ WHERE id = $1"#) .bind(room_id) .fetch_one(&mut *tx) - .await - .map_err(|e| { - tracing::error!("Failed to get room owner: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to get room owner".into(), - ) - })?; + .await?; if owner == user_id { - return Err(( - StatusCode::FORBIDDEN, - "You cannot leave a room that you own".into(), - )); - // let member_count: i64 = - // sqlx::query_scalar(r#"SELECT count(*) FROM membership_ WHERE room = $1"#) - // .bind(room_id) - // .fetch_one(&mut *tx) - // .await - // .map_err(|e| { - // tracing::error!("Failed to get member count: {e}"); - // ( - // StatusCode::INTERNAL_SERVER_ERROR, - // "Failed to get member count".into(), - // ) - // })?; - // - // if member_count > 0 { - // if let Some(new_owner) = payload.new_owner_uuid { - // let exists: bool = - // sqlx::query_scalar(r#"SELECT EXISTS (SELECT 1 FROM user_ WHERE uuid = $1)"#) - // .bind(new_owner) - // .fetch_one(&mut *tx) - // .await - // .map_err(|e| { - // tracing::error!("Failed to check user existence: {e}"); - // ( - // StatusCode::INTERNAL_SERVER_ERROR, - // "Failed to check user existence".into(), - // ) - // })?; - // - // if !exists { - // tracing::debug!( - // "User {user_id} tried to leave a room without transfering ownership" - // ); - // return Err(( - // StatusCode::FORBIDDEN, - // "Tried to transfer ownership to nonexistant user".into(), - // )); - // } - // - // sqlx::query("UPDATE room_ SET owner = $1 WHERE id = $2") - // .bind(new_owner) - // .bind(room_id) - // .execute(&mut *tx) - // .await - // .map_err(|e| { - // tracing::error!("Failed to set new owner: {e}"); - // ( - // StatusCode::INTERNAL_SERVER_ERROR, - // "Failed to set new owner".into(), - // ) - // })?; - // } else { - // return Err(( - // StatusCode::BAD_REQUEST, - // "Please provide a new owner for a non-empty room".into(), - // )); - // } - // } else { - // sqlx::query("DELETE FROM room_ WHERE id = $1") - // .bind(room_id) - // .execute(&mut *tx) - // .await - // .map_err(|e| { - // tracing::error!("Failed to delete room: {e}"); - // ( - // StatusCode::INTERNAL_SERVER_ERROR, - // "Failed to delete room".into(), - // ) - // })?; - // } + return Err(APIError::RoomOwnerCannotLeave); } sqlx::query( @@ -599,15 +465,9 @@ async fn leave_room( .bind(user_id) .bind(room_id) .execute(&mut *tx) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + .await?; - tx.commit().await.map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not accept room invite".into(), - ) - })?; + tx.commit().await?; Ok(StatusCode::OK) } @@ -616,58 +476,35 @@ async fn transfer_ownership( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result { +) -> Result { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_id = room_id_from_uuid(&db, payload.room_uuid).await?; if !is_member(user_id, room_id, &db).await { - return Err(( - StatusCode::FORBIDDEN, - "You are not a member of this room.".into(), - )); + return Err(APIError::NotAMember); } let owner: i32 = sqlx::query_scalar(r#"SELECT owner FROM room_ WHERE id = $1"#) .bind(room_id) .fetch_one(&db) - .await - .map_err(|e| { - tracing::error!("Failed to get owner for room {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to get room owner".into(), - ) - })?; + .await?; if owner != user_id { - return Err(( - StatusCode::FORBIDDEN, - "You are not a member of this room.".into(), - )); + return Err(APIError::NotAMember); } let exists: bool = sqlx::query_scalar(r#"SELECT EXISTS (SELECT 1 FROM user_ WHERE uuid = $1)"#) .bind(payload.new_owner_uuid) .fetch_one(&db) - .await - .map_err(|e| { - tracing::error!("Failed to check user existence: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to check user existence".into(), - ) - })?; + .await?; if !exists { tracing::debug!( "User {user_id} tried to leave room {room_id} without transfering ownership" ); - return Err(( - StatusCode::FORBIDDEN, - "Tried to transfer ownership to nonexistant user".into(), - )); + return Err(APIError::UserNotFound); } let new_owner_id = user_id_from_uuid(&db, payload.new_owner_uuid).await?; @@ -676,14 +513,7 @@ async fn transfer_ownership( .bind(new_owner_id) .bind(room_id) .execute(&db) - .await - .map_err(|e| { - tracing::error!("Failed to set new owner for room {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to set new owner".into(), - ) - })?; + .await?; Ok(StatusCode::OK) } @@ -692,36 +522,23 @@ async fn list_members( headers: HeaderMap, Path(room_uuid): Path, Extension(db): Extension, -) -> Result<(StatusCode, Json>), (StatusCode, String)> { +) -> Result<(StatusCode, Json>), APIError> { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; if !is_member(user_id, room_id, &db).await { - return Err(( - StatusCode::FORBIDDEN, - "You are not a member of this room.".into(), - )); + return Err(APIError::NotAMember); } let is_global: bool = sqlx::query_scalar("SELECT global FROM room_ WHERE id = $1") .bind(room_id) .fetch_one(&db) - .await - .map_err(|e| { - tracing::error!("Failed to get global boolean {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to fetch room".into(), - ) - })?; + .await?; if is_global { - return Err(( - StatusCode::FORBIDDEN, - "Cannot get member list for global rooms".into(), - )); + return Err(APIError::GlobalRoomMemberError); } let members = sqlx::query_as::<_, UserProfile>( @@ -735,14 +552,7 @@ async fn list_members( ) .bind(room_id) .fetch_all(&db) - .await - .map_err(|e| { - tracing::error!("Failed to get member list for room {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to get member list".into(), - ) - })?; + .await?; Ok((StatusCode::OK, Json(members))) } @@ -751,86 +561,43 @@ async fn delete_room( headers: HeaderMap, Path(room_uuid): Path, Extension(db): Extension, -) -> Result { +) -> Result { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; if !is_member(user_id, room_id, &db).await { - return Err(( - StatusCode::FORBIDDEN, - "You are not a member of this room.".into(), - )); + return Err(APIError::NotAMember); } let owner: i32 = sqlx::query_scalar(r#"SELECT owner FROM room_ WHERE id = $1"#) .bind(room_id) .fetch_one(&db) - .await - .map_err(|e| { - tracing::error!("Failed to get owner for room {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to get room owner".into(), - ) - })?; + .await?; if owner != user_id { - return Err(( - StatusCode::FORBIDDEN, - "You are not a member of this room.".into(), - )); + return Err(APIError::NotAMember); } - let mut tx = db - .begin() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + let mut tx = db.begin().await?; sqlx::query(r#"DELETE FROM message_ WHERE room = $1"#) .bind(room_id) .execute(&mut *tx) - .await - .map_err(|e| { - tracing::error!("Failed to delete messages on room {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to delete messages".into(), - ) - })?; + .await?; sqlx::query(r#"DELETE FROM membership_ WHERE room = $1"#) .bind(room_id) .execute(&mut *tx) - .await - .map_err(|e| { - tracing::error!("Failed to delete room memberships {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to delete room memberships".into(), - ) - })?; + .await?; sqlx::query(r#"DELETE FROM room_ WHERE id = $1"#) .bind(room_id) .execute(&mut *tx) - .await - .map_err(|e| { - tracing::error!("Failed to delete room {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to delete room".into(), - ) - })?; + .await?; - tx.commit().await.map_err(|e| { - tracing::error!("Failed to delete room {room_id}: {e}"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to delete room".into(), - ) - })?; + tx.commit().await?; Ok(StatusCode::OK) } diff --git a/src/routes/users.rs b/src/routes/users.rs index 65c96c9..22f57b4 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -12,13 +12,12 @@ use uuid::Uuid; use validator::ValidateEmail; use crate::{ - AppConfig, + AppConfig, MAX_USERNAME_LENGTH, auth::{create_jwt, hash_password, validate_token, verify_jwt, verify_password}, db::{user_id_from_uuid, username_from_uuid}, + errors::APIError, }; -const DUMMY_HASH: &str = "$argon2id$v=19$m=4096,t=3,p=1$YWFhYWFhYWFhYWFhYWFhYQ$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; - #[derive(sqlx::FromRow, serde::Serialize)] pub struct User { pub uuid: Uuid, @@ -92,14 +91,15 @@ async fn registration_guard( pub async fn login( Extension(db): Extension, Json(payload): Json, -) -> Result, (StatusCode, String)> { +) -> Result, APIError> { + const DUMMY_HASH: &str = "$argon2id$v=19$m=4096,t=3,p=1$YWFhYWFhYWFhYWFhYWFhYQ$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + let user = sqlx::query_as::<_, User>( "SELECT uuid, email, username, password_hash FROM user_ WHERE email = $1", ) .bind(&payload.email) .fetch_optional(&db) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + .await?; let (user_uuid, password_hash) = if let Some(u) = user { (u.uuid, u.password_hash) @@ -109,10 +109,10 @@ pub async fn login( }; if !verify_password(&password_hash, &payload.password) { - return Err((StatusCode::UNAUTHORIZED, "Invalid credentials".into())); + return Err(APIError::WrongCredentials); } - let token = create_jwt(user_uuid).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + let token = create_jwt(user_uuid)?; let username = username_from_uuid(&db, user_uuid).await?; Ok(Json(LoginResponse { @@ -126,31 +126,27 @@ pub async fn login( pub async fn register_user( Extension(db): Extension, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { if payload.email.is_empty() || payload.username.is_empty() || payload.password.is_empty() { - return Err(( - StatusCode::BAD_REQUEST, - "Cannot create a user with empty fields".into(), - )); + return Err(APIError::EmptyFields); } if !ValidateEmail::validate_email(&payload.email) { - return Err((StatusCode::BAD_REQUEST, "Invalid email format".into())); + return Err(APIError::InvalidEmail); + } + + { + let username_length = payload.username.len(); + if username_length > MAX_USERNAME_LENGTH || username_length < 1 { + return Err(APIError::UsernameLength); + } } if payload.password.len() < 8 { - return Err(( - StatusCode::BAD_REQUEST, - "Password must be at least 8 characters long".into(), - )); + return Err(APIError::PasswordTooShort); } - let password_hash = hash_password(&payload.password).map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to hash password".into(), - ) - })?; + let password_hash = hash_password(&payload.password)?; let user_uuid = uuid::Uuid::now_v7(); @@ -167,16 +163,17 @@ pub async fn register_user( .map_err(|e| { if let Some(db_err) = e.as_database_error() { if db_err.code().map(|c| c == "23505").unwrap_or(false) { - return ( - StatusCode::CONFLICT, - "Email or username already taken".into(), - ); + match db_err.constraint() { + Some("user__username_key") => return APIError::UsernameTaken, + Some("user__email_key") => return APIError::EmailTaken, + _ => return APIError::Internal("".to_string()), // TODO: handle this case + } } } - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()) + APIError::DatabaseError(e) })?; - let token = create_jwt(user_uuid).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + let token = create_jwt(user_uuid)?; Ok(( StatusCode::CREATED, @@ -193,53 +190,33 @@ pub async fn update_user( headers: HeaderMap, Extension(db): Extension, Json(payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; if payload.email.is_empty() || payload.username.is_empty() { - return Err(( - StatusCode::BAD_REQUEST, - "Missing username or email fields".into(), - )); + return Err(APIError::EmptyFields); } if !ValidateEmail::validate_email(&payload.email) { - return Err((StatusCode::BAD_REQUEST, "Invalid email format".into())); + return Err(APIError::InvalidEmail); } let user_id = user_id_from_uuid(&db, claims.sub).await?; - let mut tx = db - .begin() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + let mut tx = db.begin().await?; if !payload.password.is_empty() { if payload.password.len() < 8 { - return Err(( - StatusCode::BAD_REQUEST, - "Password must be at least 8 characters long".into(), - )); + return Err(APIError::PasswordTooShort); } - let password_hash = hash_password(&payload.password).map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to hash password".into(), - ) - })?; + let password_hash = hash_password(&payload.password)?; sqlx::query("UPDATE user_ SET password_hash = $1 WHERE id = $2") .bind(password_hash) .bind(user_id) .execute(&mut *tx) - .await - .map_err(|e| { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to update password: {e}"), - ); - })?; + .await?; } sqlx::query("UPDATE user_ SET username = $1, email = $2 WHERE id = $3") @@ -251,21 +228,17 @@ pub async fn update_user( .map_err(|e| { if let Some(db_err) = e.as_database_error() { if db_err.code().map(|c| c == "23505").unwrap_or(false) { - return ( - StatusCode::CONFLICT, - "Email or username already taken".into(), - ); + match db_err.constraint() { + Some("user__username_key") => return APIError::UsernameTaken, + Some("user__email_key") => return APIError::EmailTaken, + _ => return APIError::Internal("".to_string()), // TODO: handle this case + } } } - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()) + APIError::DatabaseError(e) })?; - tx.commit().await.map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not update account".into(), - ) - })?; + tx.commit().await?; Ok(( StatusCode::CREATED, @@ -281,7 +254,7 @@ async fn upload_avatar( Extension(db): Extension, Extension(config): Extension>, body: axum::body::Bytes, -) -> Result { +) -> Result { let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; @@ -292,31 +265,19 @@ async fn upload_avatar( let filename = format!("{}.{}", claims.sub, file_extension); let full_path = std::path::Path::new(&base_dir).join(&filename); - tokio::fs::create_dir_all(&base_dir).await.map_err(|e| { - tracing::error!("Failed to create storage: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to upload file".into(), - ) - })?; + tokio::fs::create_dir_all(&base_dir) + .await + .map_err(|e| APIError::Internal(format!("Failed to create storage: {e}")))?; - tokio::fs::write(&full_path, body).await.map_err(|e| { - tracing::error!("Failed to save file: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to upload file".into(), - ) - })?; + tokio::fs::write(&full_path, body) + .await + .map_err(|e| APIError::Internal(format!("Failed to save file: {e}")))?; sqlx::query("UPDATE user_ SET avatar_url = $1 WHERE id = $2") .bind(filename) .bind(user_id) .execute(&db) - .await - .map_err(|e| { - tracing::error!("DB error: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, "DB erorr".into()) - })?; + .await?; Ok(StatusCode::OK) } @@ -325,22 +286,18 @@ async fn upload_avatar( async fn get_avatar( Path(uuid): Path, Extension(config): Extension>, -) -> Result { +) -> Result { let base_dir = &config.avatar_dir; let filename = format!("{}.png", uuid); let full_path = std::path::Path::new(&base_dir).join(filename); if !full_path.exists() { - return Err((StatusCode::NOT_FOUND, "Avatar not found".into())); + return Err(APIError::AvatarNotFound); } - let file_contents = tokio::fs::read(&full_path).await.map_err(|e| { - tracing::error!("Could not read avatar file: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not read file".into(), - ) - })?; + let file_contents = tokio::fs::read(&full_path) + .await + .map_err(|e| APIError::Internal(format!("Could not read avatar file: {e}")))?; Ok(Response::builder() .header("Content-Type", "image/png") diff --git a/src/routes/voice.rs b/src/routes/voice.rs index 830a572..346f0ad 100644 --- a/src/routes/voice.rs +++ b/src/routes/voice.rs @@ -4,7 +4,6 @@ use axum::{ ConnectInfo, Path, Query, WebSocketUpgrade, ws::{Message, WebSocket}, }, - http::StatusCode, response::IntoResponse, routing::get, }; @@ -17,9 +16,9 @@ use uuid::Uuid; use crate::{ auth::verify_jwt_string, db::{room_id_from_uuid, user_id_from_uuid}, + errors::APIError, realtime::RealTimeVoices, - routes::rooms::is_member, - routes::ws::WsAuthQuery, + routes::{rooms::is_member, ws::WsAuthQuery}, }; pub fn routes() -> Router { @@ -33,7 +32,7 @@ async fn voice_ws_handler( ConnectInfo(addr): ConnectInfo, Extension(voice_manager): Extension, Extension(db): Extension, -) -> Result { +) -> Result { let claims = verify_jwt_string(&query.token)?; let user_uuid = claims.sub; @@ -47,20 +46,18 @@ async fn voice_ws_handler( .bind(&query.token) .execute(&db) .await - .map_err(|e| { - tracing::error!("Failed to get WS token from DB: {e}"); - (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) - })?; + // NOTE: Maybe wrong type of error + .map_err(|e| APIError::Internal(format!("Failed to get WS token from DB: {e}")))?; if result.rows_affected() == 0 { - return Err((StatusCode::UNAUTHORIZED, "Invalid or expired token".into())); + return Err(APIError::InvalidToken); } let user_id = user_id_from_uuid(&db, user_uuid).await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; if !is_member(user_id, room_id, &db).await { - return Err((StatusCode::FORBIDDEN, "Not a member of this room".into())); + return Err(APIError::NotAMember); } tracing::info!("User {} joining voice in room {}", user_uuid, room_uuid); diff --git a/src/routes/ws.rs b/src/routes/ws.rs index 730bb82..a2de3a5 100644 --- a/src/routes/ws.rs +++ b/src/routes/ws.rs @@ -5,6 +5,7 @@ use axum::{Extension, http::StatusCode}; use serde::Deserialize; use crate::auth::{create_jwt, verify_jwt}; +use crate::errors::APIError; #[derive(sqlx::FromRow, serde::Serialize, Deserialize)] pub struct WsAuthQuery { @@ -18,12 +19,12 @@ pub fn routes() -> axum::Router { pub async fn issue_ws_token( Extension(db): Extension, headers: HeaderMap, -) -> Result<(StatusCode, Json), (StatusCode, String)> { +) -> Result<(StatusCode, Json), APIError> { let claims = verify_jwt(headers)?; tracing::debug!("Recieved token issue request from user {}", claims.sub); - let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + let token = create_jwt(claims.sub)?; sqlx::query( r#" @@ -33,13 +34,7 @@ pub async fn issue_ws_token( ) .bind(&token) .execute(&db) - .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("failed to provide ws token"), - ) - })?; + .await?; Ok((StatusCode::CREATED, Json(WsAuthQuery { token }))) }