added an authenticated room websocket for messages

This commit is contained in:
2025-12-15 19:51:31 +01:00
parent 391a0d3f2e
commit ffc2e99cc7
8 changed files with 240 additions and 14 deletions

45
Cargo.lock generated
View File

@@ -83,6 +83,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425" checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"base64",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
@@ -102,8 +103,10 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@@ -210,6 +213,7 @@ dependencies = [
"argon2", "argon2",
"axum", "axum",
"chrono", "chrono",
"dashmap",
"jsonwebtoken", "jsonwebtoken",
"password-hash", "password-hash",
"serde", "serde",
@@ -332,6 +336,12 @@ dependencies = [
"parking_lot_core", "parking_lot_core",
] ]
[[package]]
name = "data-encoding"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]] [[package]]
name = "der" name = "der"
version = "0.7.10" version = "0.7.10"
@@ -2266,6 +2276,18 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.17" version = "0.7.17"
@@ -2436,6 +2458,23 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442"
dependencies = [
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand 0.9.2",
"sha1",
"thiserror 2.0.17",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.19.0" version = "1.19.0"
@@ -2487,6 +2526,12 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "utf8_iter" name = "utf8_iter"
version = "1.0.4" version = "1.0.4"

View File

@@ -6,8 +6,9 @@ edition = "2024"
[dependencies] [dependencies]
anyhow = "1.0.99" anyhow = "1.0.99"
argon2 = "0.5.3" argon2 = "0.5.3"
axum = { version = "0.8.4", features = ["multipart"] } axum = { version = "0.8.4", features = ["multipart", "ws"] }
chrono = { version = "0.4.42", features = ["serde"] } chrono = { version = "0.4.42", features = ["serde"] }
dashmap = "6.1.0"
jsonwebtoken = "9.3.1" jsonwebtoken = "9.3.1"
password-hash = "0.5.0" password-hash = "0.5.0"
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }

View File

@@ -28,6 +28,12 @@ CREATE TABLE IF NOT EXISTS message_ (
sent_at TIMESTAMP sent_at TIMESTAMP
); );
CREATE TABLE ws_token_ (
token TEXT PRIMARY KEY,
room_id INT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL
);
-- Message timestamp creation -- Message timestamp creation
CREATE OR REPLACE FUNCTION create_message_timestamp() CREATE OR REPLACE FUNCTION create_message_timestamp()
RETURNS trigger RETURNS trigger

View File

@@ -8,6 +8,7 @@ use tower_http::cors::{Any, CorsLayer};
mod auth; mod auth;
mod db; mod db;
mod realtime;
mod routes; mod routes;
#[tokio::main] #[tokio::main]
@@ -24,8 +25,8 @@ async fn main() -> anyhow::Result<()> {
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]); .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]);
let governor_conf = GovernorConfigBuilder::default() let governor_conf = GovernorConfigBuilder::default()
.per_second(25) .per_second(50)
.burst_size(50) .burst_size(200)
.finish() .finish()
.unwrap(); .unwrap();
@@ -41,11 +42,15 @@ async fn main() -> anyhow::Result<()> {
} }
}); });
let realtime = realtime::Realtime::new();
let app = Router::new() let app = Router::new()
.merge(routes::users::routes()) .merge(routes::users::routes())
.merge(routes::rooms::routes()) .merge(routes::rooms::routes())
.merge(routes::messages::routes()) .merge(routes::messages::routes())
.merge(routes::ws::routes())
.layer(Extension(db_pool)) .layer(Extension(db_pool))
.layer(Extension(realtime))
.layer(cors) .layer(cors)
.layer(GovernorLayer::new(governor_conf)); .layer(GovernorLayer::new(governor_conf));

27
src/realtime.rs Normal file
View File

@@ -0,0 +1,27 @@
use dashmap::DashMap;
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::routes::messages::Message;
pub type RoomId = i32;
#[derive(Clone)]
pub struct Realtime {
pub rooms: Arc<DashMap<RoomId, broadcast::Sender<Message>>>,
}
impl Realtime {
pub fn new() -> Self {
Self {
rooms: Arc::new(DashMap::new()),
}
}
pub fn sender_for(&self, room: RoomId) -> broadcast::Sender<Message> {
self.rooms
.entry(room)
.or_insert_with(|| broadcast::channel(100).0)
.clone()
}
}

View File

@@ -7,8 +7,11 @@ use axum::{
use sqlx::PgPool; use sqlx::PgPool;
use uuid::Uuid; use uuid::Uuid;
use crate::db::{user_id_from_uuid, username_from_uuid};
use crate::{auth::verify_jwt, db::room_id_from_uuid}; use crate::{auth::verify_jwt, db::room_id_from_uuid};
use crate::{
db::{user_id_from_uuid, username_from_uuid},
realtime::Realtime,
};
#[derive(sqlx::FromRow, serde::Serialize, Debug)] #[derive(sqlx::FromRow, serde::Serialize, Debug)]
pub struct MessageRow { pub struct MessageRow {
@@ -18,7 +21,7 @@ pub struct MessageRow {
pub sent_at: chrono::NaiveDateTime, pub sent_at: chrono::NaiveDateTime,
} }
#[derive(sqlx::FromRow, serde::Serialize, Debug)] #[derive(sqlx::FromRow, serde::Serialize, Debug, Clone)]
pub struct Message { pub struct Message {
pub sender: String, pub sender: String,
pub message_type: String, pub message_type: String,
@@ -105,6 +108,7 @@ async fn list_messages(
async fn create_message( async fn create_message(
Path(room_uuid): Path<Uuid>, Path(room_uuid): Path<Uuid>,
Extension(db): Extension<PgPool>, Extension(db): Extension<PgPool>,
Extension(realtime): Extension<Realtime>,
headers: HeaderMap, headers: HeaderMap,
Json(payload): Json<NewMessagePayload>, Json(payload): Json<NewMessagePayload>,
) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> { ) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> {
@@ -133,13 +137,15 @@ async fn create_message(
let sender_name = username_from_uuid(&db, claims.sub).await?; let sender_name = username_from_uuid(&db, claims.sub).await?;
Ok(( let message = Message {
StatusCode::CREATED, sender: sender_name,
Json(Message { message_type: payload.message_type,
sender: sender_name, content: payload.content,
message_type: payload.message_type, sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(),
content: payload.content, };
sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(),
}), let rt_sender = realtime.sender_for(room_id);
)) let _ = rt_sender.send(message.clone());
Ok((StatusCode::CREATED, Json(message)))
} }

View File

@@ -1,3 +1,4 @@
pub mod messages; pub mod messages;
pub mod rooms; pub mod rooms;
pub mod users; pub mod users;
pub mod ws;

135
src/routes/ws.rs Normal file
View File

@@ -0,0 +1,135 @@
use axum::Json;
use axum::extract::Query;
use axum::extract::ws::{Message as WsMessage, WebSocket};
use axum::http::HeaderMap;
use axum::routing::get;
use axum::{
Extension,
extract::{Path, WebSocketUpgrade},
http::StatusCode,
response::IntoResponse,
};
use serde::Deserialize;
use uuid::Uuid;
use crate::auth::{create_jwt, verify_jwt};
use crate::db::user_id_from_uuid;
use crate::{db::room_id_from_uuid, realtime::Realtime};
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
pub struct WsAuthQuery {
pub token: String,
}
pub fn routes() -> axum::Router {
axum::Router::new()
.route("/ws/issue-token/rooms/{room_uuid}", get(issue_ws_token))
.route("/ws/rooms/{room_uuid}", get(ws_handler))
}
pub async fn issue_ws_token(
Extension(db): Extension<sqlx::PgPool>,
headers: HeaderMap,
Path(room_uuid): Path<Uuid>,
) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> {
let claims = verify_jwt(headers)?;
let room_id = room_id_from_uuid(&db, room_uuid).await?;
let user_id = user_id_from_uuid(&db, claims.sub).await?;
let membership: Vec<i32> =
sqlx::query_scalar("SELECT user_id FROM membership_ WHERE user_id = $1 AND room = $2")
.bind(user_id)
.bind(room_id)
.fetch_all(&db)
.await
.unwrap_or_else(|_| Vec::new());
if membership.is_empty() {
return Err((
StatusCode::UNAUTHORIZED,
String::from("You are not a member of this room"),
));
}
// tracing::info!(
// "recieved token issue request from user {} for room {}",
// claims.sub,
// room_uuid
// );
let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
sqlx::query(
r#"
insert into ws_token_ (token, room_id, expires_at)
values ($1, $2, now() + interval '30 seconds')
"#,
)
.bind(&token)
.bind(room_id)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("failed to insert ws token: {e}");
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to insert ws token: {e}"),
)
})?;
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
}
async fn ws_handler(
ws: WebSocketUpgrade,
Path(room_uuid): Path<Uuid>,
Query(query): Query<WsAuthQuery>,
Extension(realtime): Extension<Realtime>,
Extension(db): Extension<sqlx::PgPool>,
) -> Result<impl IntoResponse, axum::http::StatusCode> {
tracing::info!("recieved ws handshake: {}", room_uuid);
let room_id = room_id_from_uuid(&db, room_uuid)
.await
.map_err(|_| StatusCode::NOT_FOUND)?;
let valid: Option<i32> = sqlx::query_scalar(
r#"
delete from ws_token_
where token = $1
and room_id = $2
and expires_at > now()
returning room_id
"#,
)
.bind(query.token)
.bind(room_id)
.fetch_optional(&db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if valid.is_none() {
return Err(StatusCode::UNAUTHORIZED);
}
let sender = realtime.sender_for(room_id);
let receiver = sender.subscribe();
Ok(ws.on_upgrade(move |socket| handle_socket(socket, receiver)))
}
async fn handle_socket(
mut socket: WebSocket,
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
) {
while let Ok(msg) = receiver.recv().await {
if socket
.send(WsMessage::Text(serde_json::to_string(&msg).unwrap().into()))
.await
.is_err()
{
break;
}
}
}