refactor: clients now have a single websocket that handles all rooms the user is in
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -619,7 +619,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "frangipane"
|
name = "frangipane"
|
||||||
version = "1.0.0"
|
version = "1.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"argon2",
|
"argon2",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "frangipane"
|
name = "frangipane"
|
||||||
version = "1.0.0"
|
version = "1.0.1"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ CREATE TABLE IF NOT EXISTS message_ (
|
|||||||
|
|
||||||
CREATE TABLE ws_token_ (
|
CREATE TABLE ws_token_ (
|
||||||
token TEXT PRIMARY KEY,
|
token TEXT PRIMARY KEY,
|
||||||
room_id INT NOT NULL,
|
|
||||||
expires_at TIMESTAMPTZ NOT NULL
|
expires_at TIMESTAMPTZ NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ INSERT INTO friendship_ (user_first, user_second) VALUES
|
|||||||
INSERT INTO friend_request_ (sender, receiver) VALUES
|
INSERT INTO friend_request_ (sender, receiver) VALUES
|
||||||
(2, 1); -- Bob sent a friend request to Alice
|
(2, 1); -- Bob sent a friend request to Alice
|
||||||
|
|
||||||
INSERT INTO ws_token_ (token, room_id, expires_at) VALUES
|
INSERT INTO ws_token_ (token, expires_at) VALUES
|
||||||
('random_token_1', 1, '2025-12-31T23:59:59Z'),
|
('random_token_1', '2025-12-31T23:59:59Z'),
|
||||||
('random_token_2', 2, '2025-12-31T23:59:59Z');
|
('random_token_2', '2025-12-31T23:59:59Z');
|
||||||
|
|
||||||
INSERT INTO room_invite_ (sender, receiver, room) VALUES
|
INSERT INTO room_invite_ (sender, receiver, room) VALUES
|
||||||
(2, 1, 2);
|
(2, 1, 2);
|
||||||
|
|||||||
@@ -67,11 +67,15 @@ pub fn verify_jwt(headers: HeaderMap) -> Result<Claims, (StatusCode, String)> {
|
|||||||
.and_then(|s| s.strip_prefix("Bearer "))
|
.and_then(|s| s.strip_prefix("Bearer "))
|
||||||
.ok_or((StatusCode::UNAUTHORIZED, "Missing token".to_string()))?;
|
.ok_or((StatusCode::UNAUTHORIZED, "Missing token".to_string()))?;
|
||||||
|
|
||||||
|
verify_jwt_string(&token.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn verify_jwt_string(token: &String) -> Result<Claims, (StatusCode, String)> {
|
||||||
let secret =
|
let secret =
|
||||||
std::env::var("FRANGIPANE_JWT_SECRET").unwrap_or_else(|_| DEFAULT_SECRET_KEY.to_string());
|
std::env::var("FRANGIPANE_JWT_SECRET").unwrap_or_else(|_| DEFAULT_SECRET_KEY.to_string());
|
||||||
|
|
||||||
decode::<Claims>(
|
decode::<Claims>(
|
||||||
token,
|
token.as_str(),
|
||||||
&DecodingKey::from_secret(secret.as_ref()),
|
&DecodingKey::from_secret(secret.as_ref()),
|
||||||
&Validation::default(),
|
&Validation::default(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,27 +1,35 @@
|
|||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::broadcast;
|
use tokio::sync::broadcast;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::routes::messages::Message;
|
use crate::routes::messages::Message;
|
||||||
|
|
||||||
pub type RoomId = i32;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Realtime {
|
pub struct Realtime {
|
||||||
pub rooms: Arc<DashMap<RoomId, broadcast::Sender<Message>>>,
|
pub clients: Arc<DashMap<Uuid, broadcast::Sender<Message>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Realtime {
|
impl Realtime {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
rooms: Arc::new(DashMap::new()),
|
clients: Arc::new(DashMap::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sender_for(&self, room: RoomId) -> broadcast::Sender<Message> {
|
/// Get or create the channel for a specific user
|
||||||
self.rooms
|
pub fn get_sender(&self, user_uuid: Uuid) -> broadcast::Sender<Message> {
|
||||||
.entry(room)
|
self.clients
|
||||||
|
.entry(user_uuid)
|
||||||
.or_insert_with(|| broadcast::channel(100).0)
|
.or_insert_with(|| broadcast::channel(100).0)
|
||||||
.clone()
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn broadcast(&self, recipient_uuids: Vec<Uuid>, message: Message) {
|
||||||
|
for user_uuid in recipient_uuids {
|
||||||
|
if let Some(sender) = self.clients.get(&user_uuid) {
|
||||||
|
let _ = sender.send(message.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ pub struct MessageRow {
|
|||||||
pub uuid: Uuid,
|
pub uuid: Uuid,
|
||||||
pub sender: String,
|
pub sender: String,
|
||||||
pub sender_uuid: Uuid,
|
pub sender_uuid: Uuid,
|
||||||
|
pub room_uuid: Uuid,
|
||||||
pub message_type: String,
|
pub message_type: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub sent_at: chrono::NaiveDateTime,
|
pub sent_at: chrono::NaiveDateTime,
|
||||||
@@ -26,6 +27,7 @@ pub struct MessageRow {
|
|||||||
#[derive(serde::Serialize, Debug, Clone)]
|
#[derive(serde::Serialize, Debug, Clone)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub uuid: Uuid,
|
pub uuid: Uuid,
|
||||||
|
pub room_uuid: Uuid,
|
||||||
pub sender: String,
|
pub sender: String,
|
||||||
pub sender_uuid: Uuid,
|
pub sender_uuid: Uuid,
|
||||||
pub message_type: String,
|
pub message_type: String,
|
||||||
@@ -77,7 +79,7 @@ async fn list_messages(
|
|||||||
m.uuid,
|
m.uuid,
|
||||||
u.username AS sender,
|
u.username AS sender,
|
||||||
u.uuid AS sender_uuid,
|
u.uuid AS sender_uuid,
|
||||||
r.uuid AS room,
|
r.uuid AS room_uuid,
|
||||||
m.message_type,
|
m.message_type,
|
||||||
m.content,
|
m.content,
|
||||||
m.sent_at
|
m.sent_at
|
||||||
@@ -106,6 +108,7 @@ async fn list_messages(
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|m| Message {
|
.map(|m| Message {
|
||||||
uuid: m.uuid,
|
uuid: m.uuid,
|
||||||
|
room_uuid: m.room_uuid,
|
||||||
sender: m.sender,
|
sender: m.sender,
|
||||||
sender_uuid: m.sender_uuid,
|
sender_uuid: m.sender_uuid,
|
||||||
message_type: m.message_type,
|
message_type: m.message_type,
|
||||||
@@ -157,6 +160,7 @@ async fn create_message(
|
|||||||
|
|
||||||
let message = Message {
|
let message = Message {
|
||||||
uuid: uuid,
|
uuid: uuid,
|
||||||
|
room_uuid,
|
||||||
sender: sender_name,
|
sender: sender_name,
|
||||||
sender_uuid: claims.sub,
|
sender_uuid: claims.sub,
|
||||||
message_type: payload.message_type,
|
message_type: payload.message_type,
|
||||||
@@ -164,8 +168,28 @@ async fn create_message(
|
|||||||
sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(),
|
sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let rt_sender = realtime.sender_for(room_id);
|
let recipients: Vec<Uuid> = sqlx::query_scalar(
|
||||||
let _ = rt_sender.send(message.clone());
|
r#"
|
||||||
|
SELECT u.uuid
|
||||||
|
FROM membership_ m
|
||||||
|
JOIN user_ u ON u.id = m.user_id
|
||||||
|
WHERE m.room = $1
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(room_id)
|
||||||
|
.fetch_all(&db)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Error fetching message recipients: {e}");
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let rt = realtime.clone();
|
||||||
|
let msg_clone = message.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
rt.broadcast(recipients, msg_clone);
|
||||||
|
});
|
||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(message)))
|
Ok((StatusCode::CREATED, Json(message)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,23 +3,15 @@ use axum::extract::ws::{Message as WsMessage, WebSocket};
|
|||||||
use axum::extract::{ConnectInfo, Query};
|
use axum::extract::{ConnectInfo, Query};
|
||||||
use axum::http::HeaderMap;
|
use axum::http::HeaderMap;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::{
|
use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
|
||||||
Extension,
|
|
||||||
extract::{Path, WebSocketUpgrade},
|
|
||||||
http::StatusCode,
|
|
||||||
response::IntoResponse,
|
|
||||||
};
|
|
||||||
use axum_extra::{TypedHeader, headers};
|
use axum_extra::{TypedHeader, headers};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::auth::{create_jwt, verify_jwt};
|
use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
|
||||||
use crate::db::user_id_from_uuid;
|
use crate::realtime::Realtime;
|
||||||
use crate::routes::rooms::is_member;
|
|
||||||
use crate::{db::room_id_from_uuid, realtime::Realtime};
|
|
||||||
|
|
||||||
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
|
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
|
||||||
pub struct WsAuthQuery {
|
pub struct WsAuthQuery {
|
||||||
@@ -28,43 +20,27 @@ pub struct WsAuthQuery {
|
|||||||
|
|
||||||
pub fn routes() -> axum::Router {
|
pub fn routes() -> axum::Router {
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/ws/issue-token/rooms/{room_uuid}", get(issue_ws_token))
|
.route("/ws/messages/issue-token", get(issue_ws_token))
|
||||||
.route("/ws/rooms/{room_uuid}", get(ws_handler))
|
.route("/ws/messages", get(ws_handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn issue_ws_token(
|
pub async fn issue_ws_token(
|
||||||
Extension(db): Extension<sqlx::PgPool>,
|
Extension(db): Extension<sqlx::PgPool>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Path(room_uuid): Path<Uuid>,
|
|
||||||
) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> {
|
) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> {
|
||||||
let claims = verify_jwt(headers)?;
|
let claims = verify_jwt(headers)?;
|
||||||
|
|
||||||
let room_id = room_id_from_uuid(&db, room_uuid).await?;
|
tracing::debug!("Recieved token issue request from user {}", claims.sub);
|
||||||
let user_id = user_id_from_uuid(&db, claims.sub).await?;
|
|
||||||
|
|
||||||
if !is_member(user_id, room_id, &db).await {
|
|
||||||
return Err((
|
|
||||||
StatusCode::UNAUTHORIZED,
|
|
||||||
String::from("You are not a member of this room"),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// tracing::info!(
|
|
||||||
// "recieved token issue request from user {} for room {}",
|
|
||||||
// claims.sub,
|
|
||||||
// room_uuid
|
|
||||||
// );
|
|
||||||
|
|
||||||
let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
insert into ws_token_ (token, room_id, expires_at)
|
insert into ws_token_ (token, expires_at)
|
||||||
values ($1, $2, now() + interval '30 seconds')
|
values ($1, now() + interval '30 seconds')
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.bind(&token)
|
.bind(&token)
|
||||||
.bind(room_id)
|
|
||||||
.execute(&db)
|
.execute(&db)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| {
|
.map_err(|_| {
|
||||||
@@ -80,46 +56,43 @@ pub async fn issue_ws_token(
|
|||||||
async fn ws_handler(
|
async fn ws_handler(
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||||
Path(room_uuid): Path<Uuid>,
|
|
||||||
Query(query): Query<WsAuthQuery>,
|
Query(query): Query<WsAuthQuery>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
Extension(realtime): Extension<Realtime>,
|
Extension(realtime): Extension<Realtime>,
|
||||||
Extension(db): Extension<sqlx::PgPool>,
|
Extension(db): Extension<sqlx::PgPool>,
|
||||||
) -> Result<impl IntoResponse, axum::http::StatusCode> {
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||||
// tracing::info!("recieved ws handshake: {}", room_uuid);
|
// tracing::info!("recieved ws handshake: {}", room_uuid);
|
||||||
|
|
||||||
let room_id = room_id_from_uuid(&db, room_uuid)
|
let claims = verify_jwt_string(&query.token)?;
|
||||||
.await
|
let user_uuid = claims.sub;
|
||||||
.map_err(|_| StatusCode::NOT_FOUND)?;
|
|
||||||
|
|
||||||
let valid: Option<i32> = sqlx::query_scalar(
|
let result = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
delete from ws_token_
|
delete from ws_token_
|
||||||
where token = $1
|
where token = $1
|
||||||
and room_id = $2
|
|
||||||
and expires_at > now()
|
and expires_at > now()
|
||||||
returning room_id
|
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.bind(query.token)
|
.bind(query.token)
|
||||||
.bind(room_id)
|
.execute(&db)
|
||||||
.fetch_optional(&db)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to get WS token from DB: {e}");
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
if valid.is_none() {
|
if result.rows_affected() == 0 {
|
||||||
return Err(StatusCode::UNAUTHORIZED);
|
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let sender = realtime.sender_for(room_id);
|
let receiver = realtime.get_sender(user_uuid).subscribe();
|
||||||
let receiver = sender.subscribe();
|
|
||||||
|
|
||||||
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
||||||
user_agent.to_string()
|
user_agent.to_string()
|
||||||
} else {
|
} else {
|
||||||
String::from("Unknown browser")
|
String::from("Unknown browser")
|
||||||
};
|
};
|
||||||
tracing::debug!("`{user_agent}` at {addr} connected.");
|
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
|
||||||
|
|
||||||
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
|
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
|
||||||
}
|
}
|
||||||
@@ -133,7 +106,7 @@ async fn handle_socket(
|
|||||||
|
|
||||||
loop {
|
loop {
|
||||||
select! {
|
select! {
|
||||||
// Receive broadcast messages and send to client
|
// Receive broadcast messages and send to client (any room)
|
||||||
msg = receiver.recv() => {
|
msg = receiver.recv() => {
|
||||||
if let Ok(msg) = msg {
|
if let Ok(msg) = msg {
|
||||||
if let Ok(json) = serde_json::to_string(&msg) {
|
if let Ok(json) = serde_json::to_string(&msg) {
|
||||||
|
|||||||
Reference in New Issue
Block a user