use axum::Json; use axum::extract::ws::{Message as WsMessage, WebSocket}; use axum::extract::{ConnectInfo, Query}; use axum::http::HeaderMap; use axum::routing::get; use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse}; use axum_extra::{TypedHeader, headers}; use serde::Deserialize; use std::net::SocketAddr; use std::time::Duration; use tokio::select; use crate::auth::{create_jwt, verify_jwt, verify_jwt_string}; use crate::realtime::Realtime; #[derive(sqlx::FromRow, serde::Serialize, Deserialize)] pub struct WsAuthQuery { pub token: String, } pub fn routes() -> axum::Router { axum::Router::new() .route("/ws/messages/issue-token", get(issue_ws_token)) .route("/ws/messages", get(ws_handler)) } pub async fn issue_ws_token( Extension(db): Extension, headers: HeaderMap, ) -> Result<(StatusCode, Json), (StatusCode, String)> { let claims = verify_jwt(headers)?; tracing::debug!("Recieved token issue request from user {}", claims.sub); let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; sqlx::query( r#" insert into ws_token_ (token, expires_at) values ($1, now() + interval '30 seconds') "#, ) .bind(&token) .execute(&db) .await .map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("failed to provide ws token"), ) })?; Ok((StatusCode::CREATED, Json(WsAuthQuery { token }))) } async fn ws_handler( ws: WebSocketUpgrade, user_agent: Option>, Query(query): Query, ConnectInfo(addr): ConnectInfo, Extension(realtime): Extension, Extension(db): Extension, ) -> Result { // tracing::info!("recieved ws handshake: {}", room_uuid); let claims = verify_jwt_string(&query.token)?; let user_uuid = claims.sub; let result = sqlx::query( r#" delete from ws_token_ where token = $1 and expires_at > now() "#, ) .bind(query.token) .execute(&db) .await .map_err(|e| { tracing::error!("Failed to get WS token from DB: {e}"); (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) })?; if result.rows_affected() == 0 { return Err((StatusCode::UNAUTHORIZED, "Wrong token".into())); } let receiver = realtime.get_sender(user_uuid).subscribe(); let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { user_agent.to_string() } else { String::from("Unknown browser") }; tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected."); Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver))) } async fn handle_socket( mut socket: WebSocket, who: SocketAddr, mut receiver: tokio::sync::broadcast::Receiver, ) { let mut ping_interval = tokio::time::interval(Duration::from_secs(30)); loop { select! { // Receive broadcast messages and send to client (any room) msg = receiver.recv() => { if let Ok(msg) = msg { if let Ok(json) = serde_json::to_string(&msg) { if socket.send(WsMessage::Text(json.into())).await.is_err() { tracing::error!("Failed to send message to {who}, closing connection"); break; } } } else { break; } } // Send Ping _ = ping_interval.tick() => { if socket.send(WsMessage::Ping(vec![].into())).await.is_err() { tracing::error!("Failed to send ping to {who}, closing connection"); break; } // tracing::debug!("Ping sent to {who}"); } // Get incoming messages from client client_msg = socket.recv() => { if let Some(Ok(msg)) = client_msg { match msg { // WsMessage::Pong(_) => { // tracing::debug!("Received Pong from {who}"); // } // WsMessage::Ping(_) => { // tracing::info!("Received Ping from client"); // } // WsMessage::Text(_) => {} WsMessage::Close(_) => { tracing::debug!("Client disconnected"); break; } _ => {} } } else { tracing::debug!("Client {who} abruptly disconnected"); break; } } } } }