diff --git a/db/init.sql b/db/init.sql index a78dfe8..f09b1c5 100644 --- a/db/init.sql +++ b/db/init.sql @@ -1,17 +1,32 @@ CREATE TABLE IF NOT EXISTS user_ ( - id SERIAL PRIMARY KEY, - uuid UUID UNIQUE, - email TEXT UNIQUE, - username TEXT NOT NULL UNIQUE, - password_hash TEXT NOT NULL + id SERIAL PRIMARY KEY, + uuid UUID UNIQUE, + email TEXT UNIQUE, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS friendship_ ( + user_first INT NOT NULL REFERENCES user_(id) ON DELETE CASCADE, + user_second INT NOT NULL REFERENCES user_(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (user_first, user_second), +); + +CREATE TABLE IF NOT EXISTS friend_request_ ( + sender INT NOT NULL REFERENCES user_(id) ON DELETE CASCADE, + receiver INT NOT NULL REFERENCES user_(id) ON DELETE CASCADE, + sent_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (sender, receiver), + CHECK (sender <> receiver) ); CREATE TABLE IF NOT EXISTS room_ ( - id SERIAL PRIMARY KEY, - uuid UUID UNIQUE, - owner INT NOT NULL REFERENCES user_(id) ON DELETE CASCADE, - name TEXT NOT NULL, - global BOOLEAN DEFAULT false + id SERIAL PRIMARY KEY, + uuid UUID UNIQUE, + owner INT NOT NULL REFERENCES user_(id) ON DELETE CASCADE, + name TEXT NOT NULL, + global BOOLEAN DEFAULT false ); CREATE TABLE IF NOT EXISTS membership_ ( @@ -30,9 +45,9 @@ CREATE TABLE IF NOT EXISTS message_ ( ); CREATE TABLE ws_token_ ( - token TEXT PRIMARY KEY, - room_id INT NOT NULL, - expires_at TIMESTAMPTZ NOT NULL + token TEXT PRIMARY KEY, + room_id INT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL ); -- Message timestamp creation diff --git a/src/auth.rs b/src/auth.rs index 125958c..7712d00 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -4,7 +4,10 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode} use password_hash::SaltString; use password_hash::rand_core::OsRng; -use axum::http::{HeaderMap, StatusCode}; +use axum::{ + Json, + http::{HeaderMap, StatusCode}, +}; use uuid::Uuid; const DEFAULT_SECRET_KEY: &str = "43aaf85b92f1ae6fbcef7732c50a0904"; @@ -75,3 +78,10 @@ pub fn verify_jwt(headers: HeaderMap) -> Result { .map(|data| data.claims) .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string())) } + +pub async fn validate_token( + headers: HeaderMap, +) -> Result, (StatusCode, String)> { + let _ = verify_jwt(headers)?; + Ok(Json(serde_json::json!({"valid": true}))) +} diff --git a/src/db.rs b/src/db.rs index cf2269c..1faff3c 100644 --- a/src/db.rs +++ b/src/db.rs @@ -34,3 +34,11 @@ pub async fn username_from_uuid( .await .map_err(|_| (StatusCode::UNAUTHORIZED, String::from("Wrong token"))) } + +pub async fn id_from_username(db: &PgPool, username: String) -> Result { + sqlx::query_scalar("SELECT id FROM user_ WHERE username = $1") + .bind(username) + .fetch_one(db) + .await + .map_err(|_| (StatusCode::NOT_FOUND, "User not found".into())) +} diff --git a/src/main.rs b/src/main.rs index a2ee9d4..199b75b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -48,6 +48,7 @@ async fn main() -> anyhow::Result<()> { .merge(routes::users::routes()) .merge(routes::rooms::routes()) .merge(routes::messages::routes()) + .merge(routes::friends::routes()) .merge(routes::ws::routes()) .layer(Extension(db_pool)) .layer(Extension(realtime)) diff --git a/src/routes/friends.rs b/src/routes/friends.rs new file mode 100644 index 0000000..144d12c --- /dev/null +++ b/src/routes/friends.rs @@ -0,0 +1,174 @@ +use axum::{ + Extension, Json, Router, + http::{HeaderMap, StatusCode}, + routing::{get, post}, +}; +use sqlx::PgPool; +use uuid::Uuid; + +use crate::db::user_id_from_uuid; +use crate::{auth::verify_jwt, db::id_from_username}; + +#[derive(sqlx::FromRow, serde::Serialize)] +pub struct Friend { + pub uuid: Uuid, + pub username: String, +} + +#[derive(sqlx::FromRow, serde::Serialize)] +pub struct FriendRequest { + pub sender_uuid: Uuid, + pub sender_username: String, +} + +#[derive(serde::Deserialize)] +pub struct SendFriendRequestPayload { + pub receiver_username: String, +} + +#[derive(serde::Deserialize)] +pub struct AcceptFriendRequestPayload { + pub sender_uuid: Uuid, +} + +pub fn routes() -> Router { + Router::new() + .route("/friends", get(list_friends)) + .route("/friends/requests", get(list_requests)) + .route("/friends/request", post(send_request)) + .route("/friends/accept", post(accept_request)) +} + +async fn list_friends( + headers: HeaderMap, + Extension(db): Extension, +) -> Result>, (StatusCode, String)> { + let claims = verify_jwt(headers)?; + let user_id = user_id_from_uuid(&db, claims.sub).await?; + + let friends = sqlx::query_as::<_, Friend>( + r#" + SELECT u.uuid, u.username + FROM friendship_ + JOIN user_ u + ON (u.id = friendship_.user_first AND friendship_.user_second = $1) + OR (u.id = friendship_.user_second AND friendship_.user_first = $1) + "#, + ) + .bind(user_id) + .fetch_all(&db) + .await + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Could not list friends".into(), + ) + })?; + + Ok(Json(friends)) +} + +async fn list_requests( + headers: HeaderMap, + Extension(db): Extension, +) -> Result>, (StatusCode, String)> { + let claims = verify_jwt(headers)?; + let user_id = user_id_from_uuid(&db, claims.sub).await?; + + let requests = sqlx::query_as::<_, FriendRequest>( + r#" + SELECT u.uuid AS sender_uuid, u.username AS sender_username + FROM friend_request_ + JOIN user_ u ON u.id = friend_request_.sender + WHERE friend_request_.receiver = $1 + "#, + ) + .bind(user_id) + .fetch_all(&db) + .await + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Could not list friend requests".into(), + ) + })?; + + Ok(Json(requests)) +} + +async fn send_request( + headers: HeaderMap, + Extension(db): Extension, + Json(payload): Json, +) -> Result { + let claims = verify_jwt(headers)?; + + let sender_id = user_id_from_uuid(&db, claims.sub).await?; + let receiver_id = id_from_username(&db, payload.receiver_username).await?; + + if sender_id == receiver_id { + return Err(( + StatusCode::BAD_REQUEST, + "Cannot send a friend request to yourself".into(), + )); + } + + sqlx::query("INSERT INTO friend_request_ (sender, receiver) VALUES ($1, $2)") + .bind(sender_id) + .bind(receiver_id) + .execute(&db) + .await + .map_err(|_| (StatusCode::CONFLICT, "Request already exists".into()))?; + + Ok(StatusCode::CREATED) +} + +async fn accept_request( + headers: HeaderMap, + Extension(db): Extension, + Json(payload): Json, +) -> Result { + let claims = verify_jwt(headers)?; + + let receiver_id = user_id_from_uuid(&db, claims.sub).await?; + let sender_id = user_id_from_uuid(&db, payload.sender_uuid).await?; + + let (first, second) = if sender_id < receiver_id { + (sender_id, receiver_id) + } else { + (receiver_id, sender_id) + }; + + let mut tx = db + .begin() + .await + .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))?; + + let rows = sqlx::query("DELETE FROM friend_request_ WHERE sender = $1 AND receiver = $2") + .bind(sender_id) + .bind(receiver_id) + .execute(&mut *tx) + .await + .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()))? + .rows_affected(); + + if rows == 0 { + return Err((StatusCode::NOT_FOUND, "No such request".into())); + } + + sqlx::query("INSERT INTO friendship_ (user_first, user_second) VALUES ($1, $2)") + .bind(first) + .bind(second) + .execute(&mut *tx) + .await + .map_err(|_| (StatusCode::CONFLICT, "Already friends".into()))?; + + tx.commit().await.map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Could not accept friendship".into(), + ) + })?; + + Ok(StatusCode::CREATED) +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 8283042..1c5b7ca 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,3 +1,4 @@ +pub mod friends; pub mod messages; pub mod rooms; pub mod users; diff --git a/src/routes/rooms.rs b/src/routes/rooms.rs index b1ef12f..97491b3 100644 --- a/src/routes/rooms.rs +++ b/src/routes/rooms.rs @@ -81,7 +81,7 @@ async fn list_rooms( .bind(user_id) .fetch_all(&db) .await - .unwrap_or_else(|_| Vec::new()); + .unwrap_or(Vec::new()); Ok(Json(rooms)) } diff --git a/src/routes/users.rs b/src/routes/users.rs index e304ebc..ca87a80 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -1,7 +1,7 @@ use axum::{ Extension, Json, Router, extract::Request, - http::{HeaderMap, StatusCode}, + http::StatusCode, middleware::Next, response::Response, routing::{get, post}, @@ -11,7 +11,7 @@ use std::env; use uuid::Uuid; use validator::ValidateEmail; -use crate::auth::{create_jwt, hash_password, verify_jwt, verify_password}; +use crate::auth::{create_jwt, hash_password, validate_token, verify_password}; const DUMMY_HASH: &str = "$argon2id$v=19$m=4096,t=3,p=1$YWFhYWFhYWFhYWFhYWFhYQ$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; @@ -149,10 +149,3 @@ pub async fn register_user( }), )) } - -async fn validate_token( - headers: HeaderMap, -) -> Result, (StatusCode, String)> { - let _ = verify_jwt(headers)?; - Ok(Json(serde_json::json!({"valid": true}))) -}