diff --git a/src/main.rs b/src/main.rs index 8b29461..d02eab2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,12 @@ use axum::{ Extension, Router, - http::{Method, header}, + http::{ + Method, + header::{self, CONTENT_TYPE}, + }, + middleware, }; +use axum::{body::Body, extract::Request, middleware::Next, response::Response}; use clap::Parser; use std::{net::SocketAddr, time::Duration}; use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder}; @@ -10,6 +15,7 @@ use tower_http::{ trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}, }; use tracing::Level; +use tracing::info; mod auth; mod db; @@ -46,7 +52,7 @@ async fn main() -> anyhow::Result<()> { let governor_conf = GovernorConfigBuilder::default() .burst_size(15) - .per_millisecond(500) + .per_millisecond(250) .finish() .unwrap(); @@ -76,11 +82,13 @@ async fn main() -> anyhow::Result<()> { .layer(cors); if cli.verbose { - app = app.layer( - TraceLayer::new_for_http() - .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) - .on_response(DefaultOnResponse::new().level(Level::INFO)), - ); + app = app + .layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) + .on_response(DefaultOnResponse::new().level(Level::INFO)), + ) + .layer(middleware::from_fn(log_json_body)); } let port = cli.port; @@ -98,3 +106,38 @@ async fn main() -> anyhow::Result<()> { Ok(()) } + +#[cfg(debug_assertions)] +async fn log_json_body(req: Request, next: Next) -> Response { + let (parts, body) = req.into_parts(); + + // Check if the content type is JSON + let is_json = parts + .headers + .get(CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .map_or(false, |v| v.contains("application/json")); + + let bytes = if is_json { + // Read the body bytes + let bytes = axum::body::to_bytes(body, usize::MAX) + .await + .unwrap_or_default(); + + // Log the body (converting to string) + if let Ok(body_str) = std::str::from_utf8(&bytes) { + info!("JSON Request Body: {}", body_str); + } + bytes + } else { + // If not JSON, we still need to collect it or just pass it through + axum::body::to_bytes(body, usize::MAX) + .await + .unwrap_or_default() + }; + + // Reconstruct the request with the bytes we read + let req = Request::from_parts(parts, Body::from(bytes)); + + next.run(req).await +} diff --git a/src/routes/friends.rs b/src/routes/friends.rs index 0b46d7a..4db20af 100644 --- a/src/routes/friends.rs +++ b/src/routes/friends.rs @@ -140,7 +140,12 @@ async fn send_request( .bind(receiver_id) .execute(&db) .await - .map_err(|_| (StatusCode::CONFLICT, "Request already exists".into()))?; + .map_err(|_| { + ( + StatusCode::CONFLICT, + "You have already send a friend request to this user".into(), + ) + })?; Ok(( StatusCode::CREATED, diff --git a/src/routes/messages.rs b/src/routes/messages.rs index 3eb563a..d098ba2 100644 --- a/src/routes/messages.rs +++ b/src/routes/messages.rs @@ -54,7 +54,7 @@ async fn list_messages( if !is_member(user_id, room_id, &db).await { return Err(( - StatusCode::UNAUTHORIZED, + StatusCode::FORBIDDEN, String::from("You are not a member of this room"), )); } diff --git a/src/routes/rooms.rs b/src/routes/rooms.rs index ee95669..6ca22e1 100644 --- a/src/routes/rooms.rs +++ b/src/routes/rooms.rs @@ -13,9 +13,10 @@ use crate::{auth::verify_jwt, db::room_id_from_uuid}; #[derive(sqlx::FromRow, serde::Serialize)] pub struct Room { pub uuid: Uuid, - pub owner_name: String, pub name: String, pub global: bool, + pub owner_name: String, + pub owner_uuid: Uuid, } #[derive(serde::Deserialize)] @@ -89,6 +90,7 @@ async fn list_rooms( r#" SELECT r.uuid, u.username AS owner_name, + u.uuid AS owner_uuid, r.name, r.global FROM room_ r @@ -152,6 +154,7 @@ async fn create_room( Json(Room { uuid: room_uuid, owner_name, + owner_uuid: claims.sub, name: payload.name, global: payload.global, }), @@ -166,24 +169,40 @@ async fn get_room( let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; + let room_id = room_id_from_uuid(&db, room_uuid).await?; + + if !is_member(user_id, room_id, &db).await { + return Err(( + StatusCode::FORBIDDEN, + String::from("You are not a member of this room"), + )); + } let room: Room = sqlx::query_as( r#" - SELECT uuid, u.name AS owner_name, r.name, r.global + SELECT + r.uuid, + u.username AS owner_name, + u.uuid AS owner_uuid, + r.name, + r.global FROM room_ r - JOIN user u ON u.id = r.owner - WHERE uuid = $1 AND owner = $2 + JOIN user_ u ON u.id = r.owner + WHERE r.uuid = $1 "#, ) .bind(room_uuid) - .bind(user_id) .fetch_one(&db) .await - .map_err(|_| (StatusCode::NOT_FOUND, "Room not found".to_string()))?; + .map_err(|e| { + tracing::error!("{e}"); + (StatusCode::NOT_FOUND, "Room not found".to_string()) + })?; Ok(Json(Room { uuid: room_uuid, owner_name: room.owner_name, + owner_uuid: room.owner_uuid, name: room.name, global: room.global, })) @@ -269,7 +288,12 @@ async fn send_invite( .bind(room_id) .execute(&db) .await - .map_err(|_| (StatusCode::CONFLICT, "Request already exists".into()))?; + .map_err(|_| { + ( + StatusCode::CONFLICT, + "You have already invited this user".into(), + ) + })?; tracing::info!("bro"); @@ -334,7 +358,12 @@ async fn accept_request( let room: Room = sqlx::query_as( r#" - SELECT r.uuid, u.username AS owner_name, r.name, r.global + SELECT + r.uuid, + u.username AS owner_name, + u.uuid AS owner_uuid, + r.name, + r.global FROM room_ r JOIN user_ u ON u.id = r.owner WHERE r.id = $1 AND r.owner = $2 @@ -344,7 +373,12 @@ async fn accept_request( .bind(sender_id) .fetch_one(&db) .await - .map_err(|_| (StatusCode::NOT_FOUND, "Room not found".into()))?; + .map_err(|_| { + ( + StatusCode::NOT_FOUND, + "Room not found or wrong owner".into(), + ) + })?; tx.commit().await.map_err(|_| { ( @@ -358,6 +392,7 @@ async fn accept_request( Json(Room { uuid: payload.room_uuid, owner_name: room.owner_name, + owner_uuid: room.owner_uuid, name: room.name, global: room.global, }),