Files
slimes/server/src/main.rs
2026-05-07 23:44:03 +02:00

229 lines
6.2 KiB
Rust

use std::{net::SocketAddr, sync::Arc, time::Duration};
use axum::{
Json, Router,
extract::{Path, Query, State},
http::{Method, StatusCode, header},
routing::{get, post},
};
use clap::Parser;
use serde::{Deserialize, Serialize};
use sqlx::{PgPool, postgres::PgPoolOptions};
use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
use tower_http::cors::{Any, CorsLayer};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long)]
database_url: String,
/// Port to listen on
#[arg(short, long, default_value_t = 8081)]
port: u16,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BenchmarkResults {
pub duration: Duration,
pub primes_found: u64,
pub score: u64,
pub batch_count: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BenchmarkReport {
pub prime_limit: u64,
pub logical_cores: usize,
pub single_thread: BenchmarkResults,
pub multi_thread: BenchmarkResults,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FullReport {
#[serde(skip_deserializing)]
pub id: Option<i32>,
pub mac_address: String,
pub timestamp: String,
pub slimes: Option<serde_json::Value>,
pub benchmark: Option<BenchmarkReport>,
pub client_version: String,
pub signature: String,
#[serde(default)]
pub gpu_score: i32,
}
#[derive(Deserialize)]
pub struct Pagination {
pub limit: Option<i64>,
pub offset: Option<i64>,
}
pub struct AppState {
db: PgPool,
}
async fn submit(
State(state): State<Arc<AppState>>,
Json(payload): Json<FullReport>,
) -> Result<(StatusCode, Json<i32>), (StatusCode, String)> {
let score = payload
.benchmark
.as_ref()
.map(|b| (b.multi_thread.score + b.single_thread.score) / 2)
.unwrap_or(0);
let raw_json = serde_json::to_value(&payload)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let row: (i32,) = sqlx::query_as(
"INSERT INTO reports (mac_address, score, timestamp, client_version, signature, data)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id",
)
.bind(&payload.mac_address)
.bind(score as i64)
.bind(&payload.timestamp)
.bind(&payload.client_version)
.bind(&payload.signature)
.bind(raw_json)
.fetch_one(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok((StatusCode::CREATED, Json(row.0)))
}
fn parse_report_row(id: i32, gpu_score: i32, data: serde_json::Value) -> Option<FullReport> {
let mut report: FullReport = serde_json::from_value(data).ok()?;
report.id = Some(id);
report.gpu_score = gpu_score;
Some(report)
}
async fn get_leaderboard(
State(state): State<Arc<AppState>>,
Query(pagination): Query<Pagination>,
) -> Result<Json<Vec<FullReport>>, (StatusCode, String)> {
let limit = pagination.limit.unwrap_or(10);
let offset = pagination.offset.unwrap_or(0);
let rows: Vec<(i32, i32, serde_json::Value)> = sqlx::query_as(
r#"
SELECT id, gpu_score, data FROM (
SELECT DISTINCT ON (mac_address) id, data, score
FROM reports
ORDER BY mac_address, score DESC
) sub
ORDER BY score DESC
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let results = rows
.into_iter()
.filter_map(|(id, gpu_score, data)| parse_report_row(id, gpu_score, data))
.collect();
Ok(Json(results))
}
async fn get_report_by_id(
State(state): State<Arc<AppState>>,
Path(id): Path<i32>,
) -> Result<Json<FullReport>, (StatusCode, String)> {
let row: (i32, i32, serde_json::Value) =
sqlx::query_as("SELECT id, gpu_score, data FROM reports WHERE id = $1")
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Report not found".to_string()))?;
let report = parse_report_row(row.0, row.1, row.2).ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to parse stored data".to_string(),
))?;
Ok(Json(report))
}
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/", post(submit).get(get_leaderboard))
.route("/{id}", get(get_report_by_id))
.with_state(state)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
let subscriber = tracing_subscriber::FmtSubscriber::new();
tracing::subscriber::set_global_default(subscriber).unwrap();
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&args.database_url)
.await?;
sqlx::query(
"CREATE TABLE IF NOT EXISTS reports (
id SERIAL PRIMARY KEY,
mac_address TEXT NOT NULL,
score BIGINT NOT NULL,
gpu_score INTEGER NOT NULL DEFAULT 0,
timestamp TEXT NOT NULL,
client_version TEXT NOT NULL,
signature TEXT,
data JSONB NOT NULL
);",
)
.execute(&pool)
.await?;
let shared_state = Arc::new(AppState { db: pool });
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST])
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]);
let governor_conf = GovernorConfigBuilder::default()
.per_second(5)
.burst_size(10)
.finish()
.unwrap();
let governor_limiter = governor_conf.limiter().clone();
std::thread::spawn(move || {
loop {
std::thread::sleep(Duration::from_secs(60));
governor_limiter.retain_recent();
}
});
let app = Router::new()
.layer(cors)
.layer(GovernorLayer::new(governor_conf))
.merge(routes(shared_state));
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
tracing::info!("Listening on {}", addr);
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
Ok(())
}