added an authenticated room websocket for messages
This commit is contained in:
@@ -8,6 +8,7 @@ use tower_http::cors::{Any, CorsLayer};
|
||||
|
||||
mod auth;
|
||||
mod db;
|
||||
mod realtime;
|
||||
mod routes;
|
||||
|
||||
#[tokio::main]
|
||||
@@ -24,8 +25,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]);
|
||||
|
||||
let governor_conf = GovernorConfigBuilder::default()
|
||||
.per_second(25)
|
||||
.burst_size(50)
|
||||
.per_second(50)
|
||||
.burst_size(200)
|
||||
.finish()
|
||||
.unwrap();
|
||||
|
||||
@@ -41,11 +42,15 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
let realtime = realtime::Realtime::new();
|
||||
|
||||
let app = Router::new()
|
||||
.merge(routes::users::routes())
|
||||
.merge(routes::rooms::routes())
|
||||
.merge(routes::messages::routes())
|
||||
.merge(routes::ws::routes())
|
||||
.layer(Extension(db_pool))
|
||||
.layer(Extension(realtime))
|
||||
.layer(cors)
|
||||
.layer(GovernorLayer::new(governor_conf));
|
||||
|
||||
|
||||
27
src/realtime.rs
Normal file
27
src/realtime.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::routes::messages::Message;
|
||||
|
||||
pub type RoomId = i32;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Realtime {
|
||||
pub rooms: Arc<DashMap<RoomId, broadcast::Sender<Message>>>,
|
||||
}
|
||||
|
||||
impl Realtime {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
rooms: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sender_for(&self, room: RoomId) -> broadcast::Sender<Message> {
|
||||
self.rooms
|
||||
.entry(room)
|
||||
.or_insert_with(|| broadcast::channel(100).0)
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
@@ -7,8 +7,11 @@ use axum::{
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::db::{user_id_from_uuid, username_from_uuid};
|
||||
use crate::{auth::verify_jwt, db::room_id_from_uuid};
|
||||
use crate::{
|
||||
db::{user_id_from_uuid, username_from_uuid},
|
||||
realtime::Realtime,
|
||||
};
|
||||
|
||||
#[derive(sqlx::FromRow, serde::Serialize, Debug)]
|
||||
pub struct MessageRow {
|
||||
@@ -18,7 +21,7 @@ pub struct MessageRow {
|
||||
pub sent_at: chrono::NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow, serde::Serialize, Debug)]
|
||||
#[derive(sqlx::FromRow, serde::Serialize, Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub sender: String,
|
||||
pub message_type: String,
|
||||
@@ -105,6 +108,7 @@ async fn list_messages(
|
||||
async fn create_message(
|
||||
Path(room_uuid): Path<Uuid>,
|
||||
Extension(db): Extension<PgPool>,
|
||||
Extension(realtime): Extension<Realtime>,
|
||||
headers: HeaderMap,
|
||||
Json(payload): Json<NewMessagePayload>,
|
||||
) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> {
|
||||
@@ -133,13 +137,15 @@ async fn create_message(
|
||||
|
||||
let sender_name = username_from_uuid(&db, claims.sub).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(Message {
|
||||
sender: sender_name,
|
||||
message_type: payload.message_type,
|
||||
content: payload.content,
|
||||
sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||
}),
|
||||
))
|
||||
let message = Message {
|
||||
sender: sender_name,
|
||||
message_type: payload.message_type,
|
||||
content: payload.content,
|
||||
sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||
};
|
||||
|
||||
let rt_sender = realtime.sender_for(room_id);
|
||||
let _ = rt_sender.send(message.clone());
|
||||
|
||||
Ok((StatusCode::CREATED, Json(message)))
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod messages;
|
||||
pub mod rooms;
|
||||
pub mod users;
|
||||
pub mod ws;
|
||||
|
||||
135
src/routes/ws.rs
Normal file
135
src/routes/ws.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use axum::Json;
|
||||
use axum::extract::Query;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket};
|
||||
use axum::http::HeaderMap;
|
||||
use axum::routing::get;
|
||||
use axum::{
|
||||
Extension,
|
||||
extract::{Path, WebSocketUpgrade},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::{create_jwt, verify_jwt};
|
||||
use crate::db::user_id_from_uuid;
|
||||
use crate::{db::room_id_from_uuid, realtime::Realtime};
|
||||
|
||||
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
|
||||
pub struct WsAuthQuery {
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/ws/issue-token/rooms/{room_uuid}", get(issue_ws_token))
|
||||
.route("/ws/rooms/{room_uuid}", get(ws_handler))
|
||||
}
|
||||
|
||||
pub async fn issue_ws_token(
|
||||
Extension(db): Extension<sqlx::PgPool>,
|
||||
headers: HeaderMap,
|
||||
Path(room_uuid): Path<Uuid>,
|
||||
) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> {
|
||||
let claims = verify_jwt(headers)?;
|
||||
|
||||
let room_id = room_id_from_uuid(&db, room_uuid).await?;
|
||||
let user_id = user_id_from_uuid(&db, claims.sub).await?;
|
||||
|
||||
let membership: Vec<i32> =
|
||||
sqlx::query_scalar("SELECT user_id FROM membership_ WHERE user_id = $1 AND room = $2")
|
||||
.bind(user_id)
|
||||
.bind(room_id)
|
||||
.fetch_all(&db)
|
||||
.await
|
||||
.unwrap_or_else(|_| Vec::new());
|
||||
|
||||
if membership.is_empty() {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
String::from("You are not a member of this room"),
|
||||
));
|
||||
}
|
||||
|
||||
// tracing::info!(
|
||||
// "recieved token issue request from user {} for room {}",
|
||||
// claims.sub,
|
||||
// room_uuid
|
||||
// );
|
||||
|
||||
let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
insert into ws_token_ (token, room_id, expires_at)
|
||||
values ($1, $2, now() + interval '30 seconds')
|
||||
"#,
|
||||
)
|
||||
.bind(&token)
|
||||
.bind(room_id)
|
||||
.execute(&db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("failed to insert ws token: {e}");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("failed to insert ws token: {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Path(room_uuid): Path<Uuid>,
|
||||
Query(query): Query<WsAuthQuery>,
|
||||
Extension(realtime): Extension<Realtime>,
|
||||
Extension(db): Extension<sqlx::PgPool>,
|
||||
) -> Result<impl IntoResponse, axum::http::StatusCode> {
|
||||
tracing::info!("recieved ws handshake: {}", room_uuid);
|
||||
|
||||
let room_id = room_id_from_uuid(&db, room_uuid)
|
||||
.await
|
||||
.map_err(|_| StatusCode::NOT_FOUND)?;
|
||||
|
||||
let valid: Option<i32> = sqlx::query_scalar(
|
||||
r#"
|
||||
delete from ws_token_
|
||||
where token = $1
|
||||
and room_id = $2
|
||||
and expires_at > now()
|
||||
returning room_id
|
||||
"#,
|
||||
)
|
||||
.bind(query.token)
|
||||
.bind(room_id)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if valid.is_none() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let sender = realtime.sender_for(room_id);
|
||||
let receiver = sender.subscribe();
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_socket(socket, receiver)))
|
||||
}
|
||||
|
||||
async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
|
||||
) {
|
||||
while let Ok(msg) = receiver.recv().await {
|
||||
if socket
|
||||
.send(WsMessage::Text(serde_json::to_string(&msg).unwrap().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user