From d20962101e894903872ea90ecc9cc8dd04b1355f Mon Sep 17 00:00:00 2001 From: eiiko6 Date: Wed, 17 Dec 2025 17:14:39 +0100 Subject: [PATCH] added global rooms in backend and fixed room membership checks --- db/init.sql | 3 +- db/mock_data.sql | 9 ++-- src/routes/messages.rs | 20 ++++----- src/routes/rooms.rs | 98 ++++++++++++++++++++++++++---------------- src/routes/ws.rs | 16 ++----- 5 files changed, 81 insertions(+), 65 deletions(-) diff --git a/db/init.sql b/db/init.sql index c3a97f4..a78dfe8 100644 --- a/db/init.sql +++ b/db/init.sql @@ -10,7 +10,8 @@ CREATE TABLE IF NOT EXISTS room_ ( id SERIAL PRIMARY KEY, uuid UUID UNIQUE, owner INT NOT NULL REFERENCES user_(id) ON DELETE CASCADE, - name TEXT NOT NULL + name TEXT NOT NULL, + global BOOLEAN DEFAULT false ); CREATE TABLE IF NOT EXISTS membership_ ( diff --git a/db/mock_data.sql b/db/mock_data.sql index f1b597f..1e08fb7 100644 --- a/db/mock_data.sql +++ b/db/mock_data.sql @@ -3,14 +3,13 @@ INSERT INTO user_ (username, email, uuid, password_hash) VALUES ('bob', 'bob@example.com', '019b1e36-3b8c-7f82-b845-6bfeb72466ce', '$argon2id$v=19$m=19456,t=2,p=1$mzO6Qx8ZH4/wrj14ZgKiuA$7bxNWCgsIVEfPgtueFbjbi8mDjbAHMYAHOGpxTJnEpQ'), ('carol', 'carol@example.com', '019b1e36-7706-76e2-b9ce-b37916ddfc99', '$argon2id$v=19$m=19456,t=2,p=1$5rw/7uIJIKMnyqNrYQt92Q$DJVEfgbaZtkflsmDEkSoR3uDQmujI4T73cWq9hOBgVI'); -INSERT INTO room_ (owner, name, uuid) VALUES -(1, 'General Discussion', '5dc599ee-1f5c-40c2-a22a-e40780d2d960'), -(2, 'Tech Talk', '6b14fe7b-2171-4464-95af-4888062b1b6d'), -(1, 'Random Memes', 'fb794f59-6b2d-4daa-8980-dc5255862657'); +INSERT INTO room_ (owner, name, global, uuid) VALUES +(1, 'General Discussion', true, '5dc599ee-1f5c-40c2-a22a-e40780d2d960'), +(2, 'Tech Talk', false, '6b14fe7b-2171-4464-95af-4888062b1b6d'), +(1, 'Random Memes', false, 'fb794f59-6b2d-4daa-8980-dc5255862657'); INSERT INTO membership_ (user_id, room) VALUES (1, 1), -- Alice in General Discussion -(2, 1), -- Bob in General Discussion (2, 2), -- Bob in Tech Talk (3, 1), -- Carol in General Discussion (1, 3); -- Alice in Random Memes diff --git a/src/routes/messages.rs b/src/routes/messages.rs index f2ba177..613d366 100644 --- a/src/routes/messages.rs +++ b/src/routes/messages.rs @@ -7,7 +7,7 @@ use axum::{ use sqlx::PgPool; use uuid::Uuid; -use crate::{auth::verify_jwt, db::room_id_from_uuid}; +use crate::{auth::verify_jwt, db::room_id_from_uuid, routes::rooms::is_member}; use crate::{ db::{user_id_from_uuid, username_from_uuid}, realtime::Realtime, @@ -51,15 +51,7 @@ async fn list_messages( let user_id = user_id_from_uuid(&db, claims.sub).await?; let room_id = room_id_from_uuid(&db, room_uuid).await?; - let membership: Vec = - sqlx::query_scalar("SELECT user_id FROM membership_ WHERE user_id = $1 AND room = $2") - .bind(user_id) - .bind(room_id) - .fetch_all(&db) - .await - .unwrap_or_else(|_| Vec::new()); - - if membership.is_empty() { + if !is_member(user_id, room_id, &db).await { return Err(( StatusCode::UNAUTHORIZED, String::from("You are not a member of this room"), @@ -115,9 +107,15 @@ async fn create_message( let claims = verify_jwt(headers)?; let user_id = user_id_from_uuid(&db, claims.sub).await?; - let room_id = room_id_from_uuid(&db, room_uuid).await?; + if !is_member(user_id, room_id, &db).await { + return Err(( + StatusCode::UNAUTHORIZED, + String::from("You are not a member of this room"), + )); + } + let sent_at: chrono::NaiveDateTime = sqlx::query_scalar( "INSERT INTO message_ (sender, room, message_type, content) VALUES ($1, $2, $3, $4) RETURNING sent_at", diff --git a/src/routes/rooms.rs b/src/routes/rooms.rs index 0eef802..c798418 100644 --- a/src/routes/rooms.rs +++ b/src/routes/rooms.rs @@ -4,7 +4,7 @@ use axum::{ http::{HeaderMap, StatusCode}, routing::{get, post}, }; -use sqlx::PgPool; +use sqlx::{PgPool, Pool, Postgres}; use uuid::Uuid; use crate::db::user_id_from_uuid; @@ -20,13 +20,35 @@ pub struct Room { #[derive(serde::Deserialize)] pub struct NewRoomPayload { pub name: String, + pub global: bool, } pub fn routes() -> Router { Router::new() .route("/rooms/{user_uuid}", get(list_rooms)) .route("/rooms", post(create_room)) - .route("/rooms/{user_uuid}/{room_id}", get(get_room)) + // .route("/rooms/{user_uuid}/{room_id}", get(get_room)) +} + +pub async fn is_member(user_id: i32, room_id: i32, db: &Pool) -> bool { + sqlx::query_scalar( + r#" + SELECT r.global + OR EXISTS ( + SELECT 1 + FROM membership_ m + WHERE m.user_id = $1 + AND m.room = r.id + ) + FROM room_ r + WHERE r.id = $2 + "#, + ) + .bind(user_id) + .bind(room_id) + .fetch_one(db) + .await + .unwrap_or(false) } async fn list_rooms( @@ -43,17 +65,19 @@ async fn list_rooms( let rooms = sqlx::query_as::<_, Room>( r#" - SELECT uuid, owner, name FROM room_ r - JOIN membership_ m ON m.user_id = $1 AND m.room = r.id + SELECT r.uuid, r.owner, r.name + FROM room_ r + WHERE r.global OR EXISTS ( + SELECT 1 + FROM membership_ m + WHERE m.user_id = $1 AND m.room = r.id + ) "#, ) .bind(user_id) .fetch_all(&db) .await - .unwrap_or_else(|e| { - tracing::error!("faied to list rooms: {e}"); - Vec::new() - }); + .unwrap_or_else(|_| Vec::new()); Ok(Json(rooms)) } @@ -70,18 +94,20 @@ async fn create_room( let room_uuid = uuid::Uuid::now_v7(); sqlx::query( - "INSERT INTO room_ (uuid, owner, name) - VALUES ($1, $2, $3)", + "INSERT INTO room_ (uuid, owner, name, global) + VALUES ($1, $2, $3, $4)", ) .bind(room_uuid) .bind(user_id) .bind(&payload.name) + .bind(&payload.global) .execute(&db) .await .map_err(|_| (StatusCode::BAD_REQUEST, format!("Could not create room")))?; let room_id = room_id_from_uuid(&db, room_uuid).await?; + // We do this even for the owner sqlx::query("INSERT INTO membership_ (user_id, room) VALUES ($1, $2)") .bind(user_id) .bind(room_id) @@ -99,29 +125,29 @@ async fn create_room( )) } -async fn get_room( - Path((user_uuid, room_uuid)): Path<(Uuid, Uuid)>, - headers: HeaderMap, - Extension(db): Extension, -) -> Result, (StatusCode, String)> { - let claims = verify_jwt(headers)?; - if claims.sub != user_uuid { - return Err((StatusCode::FORBIDDEN, "Forbidden".to_string())); - } - - let user_id = user_id_from_uuid(&db, user_uuid).await?; - - let room: Room = - sqlx::query_as("SELECT uuid, owner, name FROM room_ WHERE uuid = $1 AND owner = $2") - .bind(room_uuid) - .bind(user_id) - .fetch_one(&db) - .await - .map_err(|_| (StatusCode::NOT_FOUND, "Room not found".to_string()))?; - - Ok(Json(Room { - uuid: room_uuid, - owner: room.owner, - name: room.name, - })) -} +// async fn get_room( +// Path((user_uuid, room_uuid)): Path<(Uuid, Uuid)>, +// headers: HeaderMap, +// Extension(db): Extension, +// ) -> Result, (StatusCode, String)> { +// let claims = verify_jwt(headers)?; +// if claims.sub != user_uuid { +// return Err((StatusCode::FORBIDDEN, "Forbidden".to_string())); +// } +// +// let user_id = user_id_from_uuid(&db, user_uuid).await?; +// +// let room: Room = +// sqlx::query_as("SELECT uuid, owner, name FROM room_ WHERE uuid = $1 AND owner = $2") +// .bind(room_uuid) +// .bind(user_id) +// .fetch_one(&db) +// .await +// .map_err(|_| (StatusCode::NOT_FOUND, "Room not found".to_string()))?; +// +// Ok(Json(Room { +// uuid: room_uuid, +// owner: room.owner, +// name: room.name, +// })) +// } diff --git a/src/routes/ws.rs b/src/routes/ws.rs index 422c100..dbcc75f 100644 --- a/src/routes/ws.rs +++ b/src/routes/ws.rs @@ -14,6 +14,7 @@ use uuid::Uuid; use crate::auth::{create_jwt, verify_jwt}; use crate::db::user_id_from_uuid; +use crate::routes::rooms::is_member; use crate::{db::room_id_from_uuid, realtime::Realtime}; #[derive(sqlx::FromRow, serde::Serialize, Deserialize)] @@ -37,15 +38,7 @@ pub async fn issue_ws_token( let room_id = room_id_from_uuid(&db, room_uuid).await?; let user_id = user_id_from_uuid(&db, claims.sub).await?; - let membership: Vec = - sqlx::query_scalar("SELECT user_id FROM membership_ WHERE user_id = $1 AND room = $2") - .bind(user_id) - .bind(room_id) - .fetch_all(&db) - .await - .unwrap_or_else(|_| Vec::new()); - - if membership.is_empty() { + if !is_member(user_id, room_id, &db).await { return Err(( StatusCode::UNAUTHORIZED, String::from("You are not a member of this room"), @@ -70,11 +63,10 @@ pub async fn issue_ws_token( .bind(room_id) .execute(&db) .await - .map_err(|e| { - tracing::error!("failed to insert ws token: {e}"); + .map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, - format!("failed to insert ws token: {e}"), + format!("failed to provide ws token"), ) })?;