added voice chat websocket route
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -619,12 +619,13 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "frangipane"
|
name = "frangipane"
|
||||||
version = "1.0.4"
|
version = "1.0.5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"argon2",
|
"argon2",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-extra",
|
"axum-extra",
|
||||||
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "frangipane"
|
name = "frangipane"
|
||||||
version = "1.0.4"
|
version = "1.0.5"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@@ -8,6 +8,7 @@ anyhow = "1.0.99"
|
|||||||
argon2 = "0.5.3"
|
argon2 = "0.5.3"
|
||||||
axum = { version = "0.8.4", features = ["multipart", "ws"] }
|
axum = { version = "0.8.4", features = ["multipart", "ws"] }
|
||||||
axum-extra = { version = "0.12.5", features = ["typed-header"] }
|
axum-extra = { version = "0.12.5", features = ["typed-header"] }
|
||||||
|
bytes = "1.11.0"
|
||||||
chrono = { version = "0.4.42", features = ["serde"] }
|
chrono = { version = "0.4.42", features = ["serde"] }
|
||||||
clap = { version = "4.5.53", features = ["derive"] }
|
clap = { version = "4.5.53", features = ["derive"] }
|
||||||
dashmap = "6.1.0"
|
dashmap = "6.1.0"
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
{
|
{
|
||||||
packages.default = pkgs.rustPlatform.buildRustPackage {
|
packages.default = pkgs.rustPlatform.buildRustPackage {
|
||||||
pname = "frangipane";
|
pname = "frangipane";
|
||||||
version = "1.0.0";
|
version = "1.0.5";
|
||||||
|
|
||||||
src = ./.;
|
src = ./.;
|
||||||
cargoLock = {
|
cargoLock = {
|
||||||
|
|||||||
@@ -20,7 +20,10 @@ pub async fn room_id_from_uuid(db: &PgPool, room_uuid: Uuid) -> Result<i32, (Sta
|
|||||||
.bind(room_uuid)
|
.bind(room_uuid)
|
||||||
.fetch_one(db)
|
.fetch_one(db)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| (StatusCode::NOT_FOUND, "Failed to find room".into()))
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to convert room uuid to room id: {e}");
|
||||||
|
(StatusCode::NOT_FOUND, "Failed to find room".into())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn username_from_uuid(
|
pub async fn username_from_uuid(
|
||||||
|
|||||||
13
src/main.rs
13
src/main.rs
@@ -92,18 +92,23 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let governor_limiter = governor_conf.limiter().clone();
|
let governor_limiter = governor_conf.limiter().clone();
|
||||||
|
|
||||||
// a separate background task to clean up
|
let realtime = realtime::RealtimeMessages::new();
|
||||||
|
let voice_manager = realtime::RealTimeVoices::new();
|
||||||
|
let vm_clone = voice_manager.clone();
|
||||||
|
|
||||||
|
// A separate background task to clean up
|
||||||
let interval = Duration::from_secs(60);
|
let interval = Duration::from_secs(60);
|
||||||
std::thread::spawn(move || {
|
std::thread::spawn(move || {
|
||||||
loop {
|
loop {
|
||||||
std::thread::sleep(interval);
|
std::thread::sleep(interval);
|
||||||
|
|
||||||
// tracing::info!("rate limiting storage size: {}", governor_limiter.len());
|
// tracing::info!("rate limiting storage size: {}", governor_limiter.len());
|
||||||
governor_limiter.retain_recent();
|
governor_limiter.retain_recent();
|
||||||
|
|
||||||
|
vm_clone.retain_active();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let realtime = realtime::Realtime::new();
|
|
||||||
|
|
||||||
let data_dir = PathBuf::from(cli.data_dir);
|
let data_dir = PathBuf::from(cli.data_dir);
|
||||||
let config = Arc::new(AppConfig {
|
let config = Arc::new(AppConfig {
|
||||||
avatar_dir: data_dir.join("avatars"),
|
avatar_dir: data_dir.join("avatars"),
|
||||||
@@ -115,10 +120,12 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
.merge(routes::users::routes())
|
.merge(routes::users::routes())
|
||||||
.merge(routes::rooms::routes())
|
.merge(routes::rooms::routes())
|
||||||
.merge(routes::messages::routes())
|
.merge(routes::messages::routes())
|
||||||
|
.merge(routes::voice::routes())
|
||||||
.merge(routes::friends::routes())
|
.merge(routes::friends::routes())
|
||||||
.merge(routes::ws::routes())
|
.merge(routes::ws::routes())
|
||||||
.layer(Extension(db_pool))
|
.layer(Extension(db_pool))
|
||||||
.layer(Extension(realtime))
|
.layer(Extension(realtime))
|
||||||
|
.layer(Extension(voice_manager))
|
||||||
.layer(Extension(config))
|
.layer(Extension(config))
|
||||||
.layer(GovernorLayer::new(governor_conf))
|
.layer(GovernorLayer::new(governor_conf))
|
||||||
.layer(cors)
|
.layer(cors)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use axum::body::Bytes;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::broadcast;
|
use tokio::sync::broadcast;
|
||||||
@@ -6,18 +7,25 @@ use uuid::Uuid;
|
|||||||
use crate::routes::messages::Message;
|
use crate::routes::messages::Message;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Realtime {
|
pub struct RealtimeMessages {
|
||||||
pub clients: Arc<DashMap<Uuid, broadcast::Sender<Message>>>,
|
pub clients: Arc<DashMap<Uuid, broadcast::Sender<Message>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Realtime {
|
type VoicePacket = (Uuid, Bytes);
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct RealTimeVoices {
|
||||||
|
pub rooms: Arc<DashMap<Uuid, broadcast::Sender<VoicePacket>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RealtimeMessages {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
clients: Arc::new(DashMap::new()),
|
clients: Arc::new(DashMap::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create the channel for a specific user
|
/// Get or create the sender for a given user
|
||||||
pub fn get_sender(&self, user_uuid: Uuid) -> broadcast::Sender<Message> {
|
pub fn get_sender(&self, user_uuid: Uuid) -> broadcast::Sender<Message> {
|
||||||
self.clients
|
self.clients
|
||||||
.entry(user_uuid)
|
.entry(user_uuid)
|
||||||
@@ -25,6 +33,7 @@ impl Realtime {
|
|||||||
.clone()
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Send a message to all the recipients
|
||||||
pub fn broadcast(&self, recipient_uuids: Vec<Uuid>, message: Message) {
|
pub fn broadcast(&self, recipient_uuids: Vec<Uuid>, message: Message) {
|
||||||
for user_uuid in recipient_uuids {
|
for user_uuid in recipient_uuids {
|
||||||
if let Some(sender) = self.clients.get(&user_uuid) {
|
if let Some(sender) = self.clients.get(&user_uuid) {
|
||||||
@@ -33,3 +42,24 @@ impl Realtime {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl RealTimeVoices {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
rooms: Arc::new(DashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get or create the broadcast sender for a given room
|
||||||
|
pub fn get_or_create_room(&self, room_uuid: Uuid) -> broadcast::Sender<VoicePacket> {
|
||||||
|
self.rooms
|
||||||
|
.entry(room_uuid)
|
||||||
|
.or_insert_with(|| broadcast::channel(500).0)
|
||||||
|
.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean up empty rooms
|
||||||
|
pub fn retain_active(&self) {
|
||||||
|
self.rooms.retain(|_, sender| sender.receiver_count() > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,16 +1,28 @@
|
|||||||
|
use std::{net::SocketAddr, time::Duration};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
Extension, Json, Router,
|
Extension, Json, Router,
|
||||||
extract::{Path, Query},
|
extract::{
|
||||||
|
ConnectInfo, Path, Query, WebSocketUpgrade,
|
||||||
|
ws::{Message as WsMessage, WebSocket},
|
||||||
|
},
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
|
response::IntoResponse,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
|
use axum_extra::{TypedHeader, headers};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
use tokio::select;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{auth::verify_jwt, db::room_id_from_uuid, routes::rooms::is_member};
|
use crate::{
|
||||||
|
auth::{verify_jwt, verify_jwt_string},
|
||||||
|
db::room_id_from_uuid,
|
||||||
|
routes::{rooms::is_member, ws::WsAuthQuery},
|
||||||
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
db::{user_id_from_uuid, username_from_uuid},
|
db::{user_id_from_uuid, username_from_uuid},
|
||||||
realtime::Realtime,
|
realtime::RealtimeMessages,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(sqlx::FromRow, serde::Serialize, Debug)]
|
#[derive(sqlx::FromRow, serde::Serialize, Debug)]
|
||||||
@@ -51,6 +63,7 @@ pub fn routes() -> Router {
|
|||||||
Router::new()
|
Router::new()
|
||||||
.route("/messages/{room_uuid}", get(list_messages))
|
.route("/messages/{room_uuid}", get(list_messages))
|
||||||
.route("/messages/{room_uuid}", post(create_message))
|
.route("/messages/{room_uuid}", post(create_message))
|
||||||
|
.route("/ws/messages", get(message_ws_handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Also resets `last_read_at`
|
/// Also resets `last_read_at`
|
||||||
@@ -155,7 +168,7 @@ async fn list_messages(
|
|||||||
async fn create_message(
|
async fn create_message(
|
||||||
Path(room_uuid): Path<Uuid>,
|
Path(room_uuid): Path<Uuid>,
|
||||||
Extension(db): Extension<PgPool>,
|
Extension(db): Extension<PgPool>,
|
||||||
Extension(realtime): Extension<Realtime>,
|
Extension(realtime): Extension<RealtimeMessages>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(payload): Json<NewMessagePayload>,
|
Json(payload): Json<NewMessagePayload>,
|
||||||
) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> {
|
) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> {
|
||||||
@@ -223,3 +236,106 @@ async fn create_message(
|
|||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(message)))
|
Ok((StatusCode::CREATED, Json(message)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn message_ws_handler(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||||
|
Query(query): Query<WsAuthQuery>,
|
||||||
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
|
Extension(realtime): Extension<RealtimeMessages>,
|
||||||
|
Extension(db): Extension<sqlx::PgPool>,
|
||||||
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||||
|
// tracing::info!("recieved ws handshake: {}", room_uuid);
|
||||||
|
|
||||||
|
let claims = verify_jwt_string(&query.token)?;
|
||||||
|
let user_uuid = claims.sub;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
delete from ws_token_
|
||||||
|
where token = $1
|
||||||
|
and expires_at > now()
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(query.token)
|
||||||
|
.execute(&db)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to get WS token from DB: {e}");
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if result.rows_affected() == 0 {
|
||||||
|
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let receiver = realtime.get_sender(user_uuid).subscribe();
|
||||||
|
|
||||||
|
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
||||||
|
user_agent.to_string()
|
||||||
|
} else {
|
||||||
|
String::from("Unknown browser")
|
||||||
|
};
|
||||||
|
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
|
||||||
|
|
||||||
|
Ok(ws.on_upgrade(move |socket| handle_message_socket(socket, addr, receiver)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_message_socket(
|
||||||
|
mut socket: WebSocket,
|
||||||
|
who: SocketAddr,
|
||||||
|
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
|
||||||
|
) {
|
||||||
|
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
// Receive broadcast messages and send to client (any room)
|
||||||
|
msg = receiver.recv() => {
|
||||||
|
if let Ok(msg) = msg {
|
||||||
|
if let Ok(json) = serde_json::to_string(&msg) {
|
||||||
|
if socket.send(WsMessage::Text(json.into())).await.is_err() {
|
||||||
|
tracing::error!("Failed to send message to {who}, closing connection");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send Ping
|
||||||
|
_ = ping_interval.tick() => {
|
||||||
|
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
|
||||||
|
tracing::error!("Failed to send ping to {who}, closing connection");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// tracing::debug!("Ping sent to {who}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get incoming messages from client
|
||||||
|
client_msg = socket.recv() => {
|
||||||
|
if let Some(Ok(msg)) = client_msg {
|
||||||
|
match msg {
|
||||||
|
// WsMessage::Pong(_) => {
|
||||||
|
// tracing::debug!("Received Pong from {who}");
|
||||||
|
// }
|
||||||
|
// WsMessage::Ping(_) => {
|
||||||
|
// tracing::info!("Received Ping from client");
|
||||||
|
// }
|
||||||
|
// WsMessage::Text(_) => {}
|
||||||
|
WsMessage::Close(_) => {
|
||||||
|
tracing::debug!("Client disconnected");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::debug!("Client {who} abruptly disconnected");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ pub mod friends;
|
|||||||
pub mod messages;
|
pub mod messages;
|
||||||
pub mod rooms;
|
pub mod rooms;
|
||||||
pub mod users;
|
pub mod users;
|
||||||
|
pub mod voice;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|||||||
126
src/routes/voice.rs
Normal file
126
src/routes/voice.rs
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
use axum::{
|
||||||
|
Extension, Router,
|
||||||
|
extract::{
|
||||||
|
ConnectInfo, Path, Query, WebSocketUpgrade,
|
||||||
|
ws::{Message, WebSocket},
|
||||||
|
},
|
||||||
|
http::StatusCode,
|
||||||
|
response::IntoResponse,
|
||||||
|
routing::get,
|
||||||
|
};
|
||||||
|
use bytes::{BufMut, Bytes, BytesMut};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use std::{net::SocketAddr, time::Duration};
|
||||||
|
use tokio::select;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::verify_jwt_string,
|
||||||
|
db::{room_id_from_uuid, user_id_from_uuid},
|
||||||
|
realtime::RealTimeVoices,
|
||||||
|
routes::rooms::is_member,
|
||||||
|
routes::ws::WsAuthQuery,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn routes() -> Router {
|
||||||
|
Router::new().route("/ws/voice/{room_uuid}", get(voice_ws_handler))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn voice_ws_handler(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
Path(room_uuid): Path<Uuid>,
|
||||||
|
Query(query): Query<WsAuthQuery>,
|
||||||
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
|
Extension(voice_manager): Extension<RealTimeVoices>,
|
||||||
|
Extension(db): Extension<PgPool>,
|
||||||
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||||
|
let claims = verify_jwt_string(&query.token)?;
|
||||||
|
let user_uuid = claims.sub;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
DELETE FROM ws_token_
|
||||||
|
WHERE token = $1
|
||||||
|
AND expires_at > now()
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&query.token)
|
||||||
|
.execute(&db)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to get WS token from DB: {e}");
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if result.rows_affected() == 0 {
|
||||||
|
return Err((StatusCode::UNAUTHORIZED, "Invalid or expired token".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let user_id = user_id_from_uuid(&db, user_uuid).await?;
|
||||||
|
let room_id = room_id_from_uuid(&db, room_uuid).await?;
|
||||||
|
|
||||||
|
if !is_member(user_id, room_id, &db).await {
|
||||||
|
return Err((StatusCode::FORBIDDEN, "Not a member of this room".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("User {} joining voice in room {}", user_uuid, room_uuid);
|
||||||
|
|
||||||
|
let tx = voice_manager.get_or_create_room(room_uuid);
|
||||||
|
let rx = tx.subscribe();
|
||||||
|
|
||||||
|
Ok(ws.on_upgrade(move |socket| handle_voice_socket(socket, addr, user_uuid, tx, rx)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_voice_socket(
|
||||||
|
mut socket: WebSocket,
|
||||||
|
who: SocketAddr,
|
||||||
|
my_uuid: Uuid,
|
||||||
|
tx: tokio::sync::broadcast::Sender<(Uuid, Bytes)>,
|
||||||
|
mut rx: tokio::sync::broadcast::Receiver<(Uuid, Bytes)>,
|
||||||
|
) {
|
||||||
|
let mut ping_interval = tokio::time::interval(Duration::from_secs(15));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
// Receive audio from other users and send to client
|
||||||
|
voice_packet = rx.recv() => {
|
||||||
|
if let Ok((speaker_uuid, audio_data)) = voice_packet {
|
||||||
|
if speaker_uuid != my_uuid {
|
||||||
|
let mut msg = BytesMut::with_capacity(16 + audio_data.len());
|
||||||
|
msg.put(speaker_uuid.as_bytes().as_slice());
|
||||||
|
msg.put(audio_data);
|
||||||
|
|
||||||
|
if socket.send(Message::Binary(msg.freeze().into())).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive audio from alient and broadcast to room
|
||||||
|
client_msg = socket.recv() => {
|
||||||
|
if let Some(Ok(msg)) = client_msg {
|
||||||
|
match msg {
|
||||||
|
Message::Binary(data) => {
|
||||||
|
let _ = tx.send((my_uuid, Bytes::from(data)));
|
||||||
|
}
|
||||||
|
Message::Close(_) => {
|
||||||
|
tracing::debug!("Voice client {} disconnected", who);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keepalive
|
||||||
|
_ = ping_interval.tick() => {
|
||||||
|
if socket.send(Message::Ping(vec![].into())).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
118
src/routes/ws.rs
118
src/routes/ws.rs
@@ -1,17 +1,10 @@
|
|||||||
use axum::Json;
|
use axum::Json;
|
||||||
use axum::extract::ws::{Message as WsMessage, WebSocket};
|
|
||||||
use axum::extract::{ConnectInfo, Query};
|
|
||||||
use axum::http::HeaderMap;
|
use axum::http::HeaderMap;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
|
use axum::{Extension, http::StatusCode};
|
||||||
use axum_extra::{TypedHeader, headers};
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::net::SocketAddr;
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::select;
|
|
||||||
|
|
||||||
use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
|
use crate::auth::{create_jwt, verify_jwt};
|
||||||
use crate::realtime::Realtime;
|
|
||||||
|
|
||||||
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
|
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
|
||||||
pub struct WsAuthQuery {
|
pub struct WsAuthQuery {
|
||||||
@@ -19,9 +12,7 @@ pub struct WsAuthQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn routes() -> axum::Router {
|
pub fn routes() -> axum::Router {
|
||||||
axum::Router::new()
|
axum::Router::new().route("/ws/issue-token", get(issue_ws_token))
|
||||||
.route("/ws/messages/issue-token", get(issue_ws_token))
|
|
||||||
.route("/ws/messages", get(ws_handler))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn issue_ws_token(
|
pub async fn issue_ws_token(
|
||||||
@@ -52,106 +43,3 @@ pub async fn issue_ws_token(
|
|||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
|
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ws_handler(
|
|
||||||
ws: WebSocketUpgrade,
|
|
||||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
|
||||||
Query(query): Query<WsAuthQuery>,
|
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
||||||
Extension(realtime): Extension<Realtime>,
|
|
||||||
Extension(db): Extension<sqlx::PgPool>,
|
|
||||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
|
||||||
// tracing::info!("recieved ws handshake: {}", room_uuid);
|
|
||||||
|
|
||||||
let claims = verify_jwt_string(&query.token)?;
|
|
||||||
let user_uuid = claims.sub;
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
delete from ws_token_
|
|
||||||
where token = $1
|
|
||||||
and expires_at > now()
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(query.token)
|
|
||||||
.execute(&db)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::error!("Failed to get WS token from DB: {e}");
|
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
if result.rows_affected() == 0 {
|
|
||||||
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let receiver = realtime.get_sender(user_uuid).subscribe();
|
|
||||||
|
|
||||||
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
|
||||||
user_agent.to_string()
|
|
||||||
} else {
|
|
||||||
String::from("Unknown browser")
|
|
||||||
};
|
|
||||||
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
|
|
||||||
|
|
||||||
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_socket(
|
|
||||||
mut socket: WebSocket,
|
|
||||||
who: SocketAddr,
|
|
||||||
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
|
|
||||||
) {
|
|
||||||
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
|
|
||||||
|
|
||||||
loop {
|
|
||||||
select! {
|
|
||||||
// Receive broadcast messages and send to client (any room)
|
|
||||||
msg = receiver.recv() => {
|
|
||||||
if let Ok(msg) = msg {
|
|
||||||
if let Ok(json) = serde_json::to_string(&msg) {
|
|
||||||
if socket.send(WsMessage::Text(json.into())).await.is_err() {
|
|
||||||
tracing::error!("Failed to send message to {who}, closing connection");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send Ping
|
|
||||||
_ = ping_interval.tick() => {
|
|
||||||
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
|
|
||||||
tracing::error!("Failed to send ping to {who}, closing connection");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// tracing::debug!("Ping sent to {who}");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get incoming messages from client
|
|
||||||
client_msg = socket.recv() => {
|
|
||||||
if let Some(Ok(msg)) = client_msg {
|
|
||||||
match msg {
|
|
||||||
// WsMessage::Pong(_) => {
|
|
||||||
// tracing::debug!("Received Pong from {who}");
|
|
||||||
// }
|
|
||||||
// WsMessage::Ping(_) => {
|
|
||||||
// tracing::info!("Received Ping from client");
|
|
||||||
// }
|
|
||||||
// WsMessage::Text(_) => {}
|
|
||||||
WsMessage::Close(_) => {
|
|
||||||
tracing::debug!("Client disconnected");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::debug!("Client {who} abruptly disconnected");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user