cargo fmt: format all Rust source files
All checks were successful
CI — P1 Route (Rust) / test (push) Successful in 6m35s

This commit is contained in:
2026-03-01 17:53:28 +00:00
parent 00db59ff83
commit a8a8c53917
12 changed files with 305 additions and 167 deletions

View File

@@ -5,13 +5,13 @@
//! Run: cargo bench --bench proxy_latency //! Run: cargo bench --bench proxy_latency
//! CI gate: P99 must be < 5ms //! CI gate: P99 must be < 5ms
use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use dd0c_route::{ use dd0c_route::{
AppConfig, TelemetryEvent, RouterBrain,
proxy::{create_router, ProxyState}, proxy::{create_router, ProxyState},
AppConfig, RouterBrain, TelemetryEvent,
}; };
struct NoOpAuth; struct NoOpAuth;
@@ -51,10 +51,13 @@ fn bench_proxy_overhead(c: &mut Criterion) {
let (tx, _rx) = mpsc::channel::<TelemetryEvent>(10000); let (tx, _rx) = mpsc::channel::<TelemetryEvent>(10000);
let mut providers = std::collections::HashMap::new(); let mut providers = std::collections::HashMap::new();
providers.insert("openai".to_string(), dd0c_route::config::ProviderConfig { providers.insert(
api_key: "bench-key".to_string(), "openai".to_string(),
base_url: mock_url.clone(), dd0c_route::config::ProviderConfig {
}); api_key: "bench-key".to_string(),
base_url: mock_url.clone(),
},
);
let config = Arc::new(AppConfig { let config = Arc::new(AppConfig {
proxy_port: 0, proxy_port: 0,
@@ -90,7 +93,8 @@ fn bench_proxy_overhead(c: &mut Criterion) {
let body = serde_json::json!({ let body = serde_json::json!({
"model": "gpt-4o", "model": "gpt-4o",
"messages": messages, "messages": messages,
}).to_string(); })
.to_string();
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("chat_completions", format!("{}_msgs", msg_count)), BenchmarkId::new("chat_completions", format!("{}_msgs", msg_count)),

View File

@@ -2,7 +2,7 @@ use axum::{
extract::{Path, Query, State}, extract::{Path, Query, State},
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::IntoResponse, response::IntoResponse,
routing::{get, put, delete}, routing::{delete, get, put},
Json, Router, Json, Router,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -23,7 +23,10 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
Router::new() Router::new()
// Dashboard analytics // Dashboard analytics
.route("/api/v1/analytics/summary", get(get_analytics_summary)) .route("/api/v1/analytics/summary", get(get_analytics_summary))
.route("/api/v1/analytics/timeseries", get(get_analytics_timeseries)) .route(
"/api/v1/analytics/timeseries",
get(get_analytics_timeseries),
)
.route("/api/v1/analytics/models", get(get_model_breakdown)) .route("/api/v1/analytics/models", get(get_model_breakdown))
// Routing rules // Routing rules
.route("/api/v1/rules", get(list_rules).post(create_rule)) .route("/api/v1/rules", get(list_rules).post(create_rule))
@@ -32,7 +35,10 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
.route("/api/v1/keys", get(list_keys).post(create_key)) .route("/api/v1/keys", get(list_keys).post(create_key))
.route("/api/v1/keys/:id", delete(revoke_key)) .route("/api/v1/keys/:id", delete(revoke_key))
// Provider configs // Provider configs
.route("/api/v1/providers", get(list_providers).post(upsert_provider)) .route(
"/api/v1/providers",
get(list_providers).post(upsert_provider),
)
// Org settings // Org settings
.route("/api/v1/org", get(get_org)) .route("/api/v1/org", get(get_org))
// Health // Health
@@ -40,7 +46,9 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
.with_state(state) .with_state(state)
} }
async fn health() -> &'static str { "ok" } async fn health() -> &'static str {
"ok"
}
// --- Analytics Endpoints --- // --- Analytics Endpoints ---
@@ -76,7 +84,9 @@ async fn get_analytics_summary(
Query(range): Query<TimeRange>, Query(range): Query<TimeRange>,
) -> Result<Json<AnalyticsSummary>, ApiError> { ) -> Result<Json<AnalyticsSummary>, ApiError> {
let auth = state.auth.authenticate(&headers).await?; let auth = state.auth.authenticate(&headers).await?;
let _from = range.from.unwrap_or_else(|| "now() - interval '7 days'".to_string()); let _from = range
.from
.unwrap_or_else(|| "now() - interval '7 days'".to_string());
let row = sqlx::query_as::<_, (i64, f64, f64, f64, i32, i32, i64, i64, i64)>( let row = sqlx::query_as::<_, (i64, f64, f64, f64, i32, i32, i64, i64, i64)>(
"SELECT "SELECT
@@ -90,14 +100,18 @@ async fn get_analytics_summary(
COUNT(*) FILTER (WHERE strategy = 'cheapest'), COUNT(*) FILTER (WHERE strategy = 'cheapest'),
COUNT(*) FILTER (WHERE strategy = 'cascading') COUNT(*) FILTER (WHERE strategy = 'cascading')
FROM request_events FROM request_events
WHERE org_id = $1 AND time >= now() - interval '7 days'" WHERE org_id = $1 AND time >= now() - interval '7 days'",
) )
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default()) .bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_one(&state.ts_pool) .fetch_one(&state.ts_pool)
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
let savings_pct = if row.1 > 0.0 { (row.3 / row.1) * 100.0 } else { 0.0 }; let savings_pct = if row.1 > 0.0 {
(row.3 / row.1) * 100.0
} else {
0.0
};
Ok(Json(AnalyticsSummary { Ok(Json(AnalyticsSummary {
total_requests: row.0, total_requests: row.0,
@@ -131,7 +145,11 @@ async fn get_analytics_timeseries(
let auth = state.auth.authenticate(&headers).await?; let auth = state.auth.authenticate(&headers).await?;
let interval = range.interval.unwrap_or_else(|| "hour".to_string()); let interval = range.interval.unwrap_or_else(|| "hour".to_string());
let view = if interval == "day" { "request_events_daily" } else { "request_events_hourly" }; let view = if interval == "day" {
"request_events_daily"
} else {
"request_events_hourly"
};
let rows = sqlx::query_as::<_, (chrono::DateTime<chrono::Utc>, i64, f64, i32)>( let rows = sqlx::query_as::<_, (chrono::DateTime<chrono::Utc>, i64, f64, i32)>(
&format!( &format!(
@@ -147,12 +165,16 @@ async fn get_analytics_timeseries(
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| TimeseriesPoint { Ok(Json(
bucket: r.0.to_rfc3339(), rows.iter()
request_count: r.1, .map(|r| TimeseriesPoint {
cost_saved: r.2, bucket: r.0.to_rfc3339(),
avg_latency_ms: r.3, request_count: r.1,
}).collect())) cost_saved: r.2,
avg_latency_ms: r.3,
})
.collect(),
))
} }
#[derive(Serialize)] #[derive(Serialize)]
@@ -181,12 +203,16 @@ async fn get_model_breakdown(
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| ModelBreakdown { Ok(Json(
model: r.0.clone(), rows.iter()
request_count: r.1, .map(|r| ModelBreakdown {
total_tokens: r.2, model: r.0.clone(),
total_cost: r.3, request_count: r.1,
}).collect())) total_tokens: r.2,
total_cost: r.3,
})
.collect(),
))
} }
// --- Routing Rules CRUD --- // --- Routing Rules CRUD ---
@@ -222,20 +248,24 @@ async fn list_rules(
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| RoutingRuleDto { Ok(Json(
id: Some(r.0), rows.iter()
priority: r.1, .map(|r| RoutingRuleDto {
name: r.2.clone(), id: Some(r.0),
match_model: r.3.clone(), priority: r.1,
match_feature: r.4.clone(), name: r.2.clone(),
match_team: r.5.clone(), match_model: r.3.clone(),
match_complexity: r.6.clone(), match_feature: r.4.clone(),
strategy: r.7.clone(), match_team: r.5.clone(),
target_model: r.8.clone(), match_complexity: r.6.clone(),
target_provider: r.9.clone(), strategy: r.7.clone(),
fallback_models: r.10.clone(), target_model: r.8.clone(),
enabled: r.11, target_provider: r.9.clone(),
}).collect())) fallback_models: r.10.clone(),
enabled: r.11,
})
.collect(),
))
} }
async fn create_rule( async fn create_rule(
@@ -361,23 +391,37 @@ async fn list_keys(
) -> Result<Json<Vec<ApiKeyDto>>, ApiError> { ) -> Result<Json<Vec<ApiKeyDto>>, ApiError> {
let auth = state.auth.authenticate(&headers).await?; let auth = state.auth.authenticate(&headers).await?;
let rows = sqlx::query_as::<_, (Uuid, String, String, Vec<String>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>( let rows = sqlx::query_as::<
_,
(
Uuid,
String,
String,
Vec<String>,
Option<chrono::DateTime<chrono::Utc>>,
chrono::DateTime<chrono::Utc>,
),
>(
"SELECT id, name, key_prefix, scopes, last_used_at, created_at "SELECT id, name, key_prefix, scopes, last_used_at, created_at
FROM api_keys WHERE org_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC" FROM api_keys WHERE org_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC",
) )
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default()) .bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_all(&state.pg_pool) .fetch_all(&state.pg_pool)
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| ApiKeyDto { Ok(Json(
id: r.0, rows.iter()
name: r.1.clone(), .map(|r| ApiKeyDto {
key_prefix: r.2.clone(), id: r.0,
scopes: r.3.clone(), name: r.1.clone(),
last_used_at: r.4, key_prefix: r.2.clone(),
created_at: r.5, scopes: r.3.clone(),
}).collect())) last_used_at: r.4,
created_at: r.5,
})
.collect(),
))
} }
async fn create_key( async fn create_key(
@@ -409,11 +453,14 @@ async fn create_key(
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
Ok((StatusCode::CREATED, Json(ApiKeyCreated { Ok((
id, StatusCode::CREATED,
key: raw_key, Json(ApiKeyCreated {
name: req.name, id,
}))) key: raw_key,
name: req.name,
}),
))
} }
async fn revoke_key( async fn revoke_key(
@@ -456,19 +503,23 @@ async fn list_providers(
let auth = state.auth.authenticate(&headers).await?; let auth = state.auth.authenticate(&headers).await?;
let rows = sqlx::query_as::<_, (String, Option<String>, bool)>( let rows = sqlx::query_as::<_, (String, Option<String>, bool)>(
"SELECT provider, base_url, is_default FROM provider_configs WHERE org_id = $1" "SELECT provider, base_url, is_default FROM provider_configs WHERE org_id = $1",
) )
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default()) .bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_all(&state.pg_pool) .fetch_all(&state.pg_pool)
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| ProviderDto { Ok(Json(
provider: r.0.clone(), rows.iter()
base_url: r.1.clone(), .map(|r| ProviderDto {
is_default: r.2, provider: r.0.clone(),
has_key: true, base_url: r.1.clone(),
}).collect())) is_default: r.2,
has_key: true,
})
.collect(),
))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -488,13 +539,12 @@ async fn upsert_provider(
require_role(&auth, Role::Owner)?; require_role(&auth, Role::Owner)?;
// Encrypt API key with AES-256-GCM before storing // Encrypt API key with AES-256-GCM before storing
let encryption_key = std::env::var("PROVIDER_KEY_ENCRYPTION_KEY") let encryption_key =
.unwrap_or_else(|_| "0".repeat(64)); // 32-byte hex key std::env::var("PROVIDER_KEY_ENCRYPTION_KEY").unwrap_or_else(|_| "0".repeat(64)); // 32-byte hex key
let key_bytes = hex::decode(&encryption_key) let key_bytes = hex::decode(&encryption_key).unwrap_or_else(|_| vec![0u8; 32]);
.unwrap_or_else(|_| vec![0u8; 32]);
use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
use aes_gcm::Nonce; use aes_gcm::Nonce;
use aes_gcm::{aead::Aead, Aes256Gcm, KeyInit};
let cipher = Aes256Gcm::new_from_slice(&key_bytes) let cipher = Aes256Gcm::new_from_slice(&key_bytes)
.map_err(|e| ApiError::Internal(format!("Encryption key error: {}", e)))?; .map_err(|e| ApiError::Internal(format!("Encryption key error: {}", e)))?;
@@ -502,7 +552,8 @@ async fn upsert_provider(
getrandom::getrandom(&mut nonce_bytes) getrandom::getrandom(&mut nonce_bytes)
.map_err(|e| ApiError::Internal(format!("RNG error: {}", e)))?; .map_err(|e| ApiError::Internal(format!("RNG error: {}", e)))?;
let nonce = Nonce::from_slice(&nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher.encrypt(nonce, req.api_key.as_bytes()) let ciphertext = cipher
.encrypt(nonce, req.api_key.as_bytes())
.map_err(|e| ApiError::Internal(format!("Encryption error: {}", e)))?; .map_err(|e| ApiError::Internal(format!("Encryption error: {}", e)))?;
// Store as nonce || ciphertext // Store as nonce || ciphertext
@@ -544,7 +595,7 @@ async fn get_org(
let auth = state.auth.authenticate(&headers).await?; let auth = state.auth.authenticate(&headers).await?;
let row = sqlx::query_as::<_, (Uuid, String, String, String)>( let row = sqlx::query_as::<_, (Uuid, String, String, String)>(
"SELECT id, name, slug, tier FROM organizations WHERE id = $1" "SELECT id, name, slug, tier FROM organizations WHERE id = $1",
) )
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default()) .bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_optional(&state.pg_pool) .fetch_optional(&state.pg_pool)
@@ -606,7 +657,10 @@ impl IntoResponse for ApiError {
ApiError::AuthError(_) => (StatusCode::UNAUTHORIZED, self.to_string()), ApiError::AuthError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),
ApiError::Forbidden => (StatusCode::FORBIDDEN, self.to_string()), ApiError::Forbidden => (StatusCode::FORBIDDEN, self.to_string()),
ApiError::NotFound => (StatusCode::NOT_FOUND, self.to_string()), ApiError::NotFound => (StatusCode::NOT_FOUND, self.to_string()),
ApiError::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error".to_string()), ApiError::Internal(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal error".to_string(),
),
}; };
(status, serde_json::json!({"error": msg}).to_string()).into_response() (status, serde_json::json!({"error": msg}).to_string()).into_response()
} }

View File

@@ -1,5 +1,5 @@
use axum::http::HeaderMap;
use async_trait::async_trait; use async_trait::async_trait;
use axum::http::HeaderMap;
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -40,7 +40,11 @@ pub struct LocalAuthProvider {
impl LocalAuthProvider { impl LocalAuthProvider {
pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self { pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self {
Self { pool, _jwt_secret: jwt_secret, redis } Self {
pool,
_jwt_secret: jwt_secret,
redis,
}
} }
} }
@@ -72,7 +76,7 @@ impl AuthProvider for LocalAuthProvider {
// 2. Fall back to PostgreSQL // 2. Fall back to PostgreSQL
let row = sqlx::query_as::<_, (String, String)>( let row = sqlx::query_as::<_, (String, String)>(
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL" "SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL",
) )
.bind(&key[..8]) .bind(&key[..8])
.fetch_optional(&self.pool) .fetch_optional(&self.pool)

View File

@@ -41,18 +41,24 @@ impl AppConfig {
let mut providers = HashMap::new(); let mut providers = HashMap::new();
if let Ok(key) = std::env::var("OPENAI_API_KEY") { if let Ok(key) = std::env::var("OPENAI_API_KEY") {
providers.insert("openai".to_string(), ProviderConfig { providers.insert(
api_key: key, "openai".to_string(),
base_url: std::env::var("OPENAI_BASE_URL") ProviderConfig {
.unwrap_or_else(|_| "https://api.openai.com".to_string()), api_key: key,
}); base_url: std::env::var("OPENAI_BASE_URL")
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
},
);
} }
if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") { if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
providers.insert("anthropic".to_string(), ProviderConfig { providers.insert(
api_key: key, "anthropic".to_string(),
base_url: std::env::var("ANTHROPIC_BASE_URL") ProviderConfig {
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()), api_key: key,
}); base_url: std::env::var("ANTHROPIC_BASE_URL")
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
},
);
} }
Ok(Self { Ok(Self {
@@ -66,8 +72,9 @@ impl AppConfig {
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5432/dd0c".to_string()), .unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5432/dd0c".to_string()),
redis_url: std::env::var("REDIS_URL") redis_url: std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost:6379".to_string()), .unwrap_or_else(|_| "redis://localhost:6379".to_string()),
timescale_url: std::env::var("TIMESCALE_URL") timescale_url: std::env::var("TIMESCALE_URL").unwrap_or_else(|_| {
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()), "postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()
}),
jwt_secret: std::env::var("JWT_SECRET") jwt_secret: std::env::var("JWT_SECRET")
.unwrap_or_else(|_| "dev-secret-change-me".to_string()), .unwrap_or_else(|_| "dev-secret-change-me".to_string()),
auth_mode: if std::env::var("AUTH_MODE").unwrap_or_default() == "oauth" { auth_mode: if std::env::var("AUTH_MODE").unwrap_or_default() == "oauth" {

View File

@@ -4,8 +4,8 @@ pub mod data;
pub mod proxy; pub mod proxy;
pub mod router; pub mod router;
pub use auth::{AuthProvider, AuthContext, AuthError, LocalAuthProvider, Role}; pub use auth::{AuthContext, AuthError, AuthProvider, LocalAuthProvider, Role};
pub use config::AppConfig; pub use config::AppConfig;
pub use data::{EventQueue, ObjectStore, TelemetryEvent}; pub use data::{EventQueue, ObjectStore, TelemetryEvent};
pub use proxy::{create_router, ProxyState, ProxyError}; pub use proxy::{create_router, ProxyError, ProxyState};
pub use router::{RouterBrain, RoutingDecision, Complexity}; pub use router::{Complexity, RouterBrain, RoutingDecision};

View File

@@ -10,7 +10,6 @@ use axum::{
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::auth::AuthProvider; use crate::auth::AuthProvider;
use crate::config::AppConfig; use crate::config::AppConfig;
use crate::data::TelemetryEvent; use crate::data::TelemetryEvent;
@@ -61,10 +60,7 @@ async fn proxy_chat_completions(
.to_string(); .to_string();
// 3. Route (pick model + provider) // 3. Route (pick model + provider)
let decision = state let decision = state.router.route(&auth_ctx.org_id, &request).await;
.router
.route(&auth_ctx.org_id, &request)
.await;
// Apply routing decision // Apply routing decision
if let Some(ref routed_model) = decision.model { if let Some(ref routed_model) = decision.model {
@@ -80,7 +76,10 @@ async fn proxy_chat_completions(
let upstream_resp = state let upstream_resp = state
.http_client .http_client
.post(format!("{}/v1/chat/completions", upstream_url)) .post(format!("{}/v1/chat/completions", upstream_url))
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider))) .header(
"Authorization",
format!("Bearer {}", state.config.provider_key(&provider)),
)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.body(request.to_string()) .body(request.to_string())
.send() .send()
@@ -100,7 +99,7 @@ async fn proxy_chat_completions(
latency_ms: latency.as_millis() as u32, latency_ms: latency.as_millis() as u32,
status_code: status.as_u16(), status_code: status.as_u16(),
is_streaming, is_streaming,
prompt_tokens: 0, // Filled by worker from response prompt_tokens: 0, // Filled by worker from response
completion_tokens: 0, completion_tokens: 0,
timestamp: chrono::Utc::now(), timestamp: chrono::Utc::now(),
}); });
@@ -155,7 +154,10 @@ async fn proxy_embeddings(
let resp = state let resp = state
.http_client .http_client
.post(format!("{}/v1/embeddings", upstream_url)) .post(format!("{}/v1/embeddings", upstream_url))
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider))) .header(
"Authorization",
format!("Bearer {}", state.config.provider_key(&provider)),
)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.body(body) .body(body)
.send() .send()
@@ -163,8 +165,14 @@ async fn proxy_embeddings(
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?; .map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
let status = resp.status(); let status = resp.status();
let body = resp.bytes().await.map_err(|e| ProxyError::UpstreamError(e.to_string()))?; let body = resp
Ok(Response::builder().status(status).body(Body::from(body)).unwrap()) .bytes()
.await
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
Ok(Response::builder()
.status(status)
.body(Body::from(body))
.unwrap())
} }
// --- Error types --- // --- Error types ---

View File

@@ -3,7 +3,7 @@ use tokio::sync::mpsc;
use tracing::info; use tracing::info;
use dd0c_route::{ use dd0c_route::{
AppConfig, LocalAuthProvider, RouterBrain, ProxyState, TelemetryEvent, create_router, create_router, AppConfig, LocalAuthProvider, ProxyState, RouterBrain, TelemetryEvent,
}; };
#[tokio::main] #[tokio::main]
@@ -27,12 +27,15 @@ async fn main() -> anyhow::Result<()> {
let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?; let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?;
// Telemetry channel (bounded, non-blocking) // Telemetry channel (bounded, non-blocking)
let (telemetry_tx, mut telemetry_rx) = mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size); let (telemetry_tx, mut telemetry_rx) =
mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size);
// Spawn telemetry worker (writes to TimescaleDB) // Spawn telemetry worker (writes to TimescaleDB)
let ts_url = config.timescale_url.clone(); let ts_url = config.timescale_url.clone();
tokio::spawn(async move { tokio::spawn(async move {
let ts_pool = sqlx::PgPool::connect(&ts_url).await.expect("TimescaleDB connection failed"); let ts_pool = sqlx::PgPool::connect(&ts_url)
.await
.expect("TimescaleDB connection failed");
while let Some(event) = telemetry_rx.recv().await { while let Some(event) = telemetry_rx.recv().await {
if let Err(e) = sqlx::query( if let Err(e) = sqlx::query(
"INSERT INTO request_events (org_id, original_model, routed_model, provider, strategy, latency_ms, status_code, is_streaming, prompt_tokens, completion_tokens, created_at) "INSERT INTO request_events (org_id, original_model, routed_model, provider, strategy, latency_ms, status_code, is_streaming, prompt_tokens, completion_tokens, created_at)

View File

@@ -1,4 +1,4 @@
pub mod handler; pub mod handler;
pub mod middleware; pub mod middleware;
pub use handler::{create_router, ProxyState, ProxyError}; pub use handler::{create_router, ProxyError, ProxyState};

View File

@@ -41,13 +41,21 @@ impl RouterBrain {
.get("messages") .get("messages")
.and_then(|m| m.as_array()) .and_then(|m| m.as_array())
.and_then(|msgs| { .and_then(|msgs| {
msgs.iter().find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system")) msgs.iter()
.find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
}) })
.and_then(|m| m.get("content")) .and_then(|m| m.get("content"))
.and_then(|c| c.as_str()) .and_then(|c| c.as_str())
.unwrap_or(""); .unwrap_or("");
let high_complexity_keywords = ["analyze", "reason", "compare", "evaluate", "synthesize", "debate"]; let high_complexity_keywords = [
"analyze",
"reason",
"compare",
"evaluate",
"synthesize",
"debate",
];
let has_complex_task = high_complexity_keywords let has_complex_task = high_complexity_keywords
.iter() .iter()
.any(|kw| system_prompt.to_lowercase().contains(kw)); .any(|kw| system_prompt.to_lowercase().contains(kw));
@@ -142,7 +150,9 @@ mod tests {
]; ];
for i in 0..12 { for i in 0..12 {
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)})); messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
messages.push(serde_json::json!({"role": "assistant", "content": format!("Response {}", i)})); messages.push(
serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}),
);
} }
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages }); let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
let complexity = brain.classify_complexity(&request); let complexity = brain.classify_complexity(&request);

View File

@@ -4,7 +4,7 @@ use uuid::Uuid;
async fn get_org_owner_email(pool: &PgPool, org_id: Uuid) -> Result<String, anyhow::Error> { async fn get_org_owner_email(pool: &PgPool, org_id: Uuid) -> Result<String, anyhow::Error> {
let row = sqlx::query_as::<_, (String,)>( let row = sqlx::query_as::<_, (String,)>(
"SELECT email FROM users WHERE org_id = $1 AND role = 'owner' LIMIT 1" "SELECT email FROM users WHERE org_id = $1 AND role = 'owner' LIMIT 1",
) )
.bind(org_id) .bind(org_id)
.fetch_one(pool) .fetch_one(pool)
@@ -18,7 +18,7 @@ async fn get_org_owner_email(pool: &PgPool, org_id: Uuid) -> Result<String, anyh
pub async fn check_anomalies(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow::Result<()> { pub async fn check_anomalies(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow::Result<()> {
// Get orgs with recent activity // Get orgs with recent activity
let orgs = sqlx::query_as::<_, (Uuid,)>( let orgs = sqlx::query_as::<_, (Uuid,)>(
"SELECT DISTINCT org_id FROM request_events WHERE time >= now() - interval '1 hour'" "SELECT DISTINCT org_id FROM request_events WHERE time >= now() - interval '1 hour'",
) )
.fetch_all(ts_pool) .fetch_all(ts_pool)
.await?; .await?;
@@ -35,7 +35,7 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
let current = sqlx::query_as::<_, (f64,)>( let current = sqlx::query_as::<_, (f64,)>(
"SELECT COALESCE(SUM(cost_actual), 0)::float8 "SELECT COALESCE(SUM(cost_actual), 0)::float8
FROM request_events FROM request_events
WHERE org_id = $1 AND time >= now() - interval '1 hour'" WHERE org_id = $1 AND time >= now() - interval '1 hour'",
) )
.bind(org_id) .bind(org_id)
.fetch_one(ts_pool) .fetch_one(ts_pool)
@@ -88,11 +88,7 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
] ]
}); });
let client = reqwest::Client::new(); let client = reqwest::Client::new();
if let Err(e) = client.post(&slack_url) if let Err(e) = client.post(&slack_url).json(&payload).send().await {
.json(&payload)
.send()
.await
{
warn!(error = %e, "Failed to send Slack anomaly alert"); warn!(error = %e, "Failed to send Slack anomaly alert");
} }
} }
@@ -114,7 +110,8 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
) )
}); });
let client = reqwest::Client::new(); let client = reqwest::Client::new();
if let Err(e) = client.post("https://api.resend.com/emails") if let Err(e) = client
.post("https://api.resend.com/emails")
.bearer_auth(&resend_key) .bearer_auth(&resend_key)
.json(&email_body) .json(&email_body)
.send() .send()

View File

@@ -6,19 +6,17 @@ use uuid::Uuid;
/// Calculate next Monday 9 AM UTC from a given time /// Calculate next Monday 9 AM UTC from a given time
pub fn next_monday_9am(from: DateTime<Utc>) -> DateTime<Utc> { pub fn next_monday_9am(from: DateTime<Utc>) -> DateTime<Utc> {
let days_until_monday = (7 - from.weekday().num_days_from_monday()) % 7; let days_until_monday = (7 - from.weekday().num_days_from_monday()) % 7;
let days_until_monday = if days_until_monday == 0 && from.time() >= NaiveTime::from_hms_opt(9, 0, 0).unwrap() { let days_until_monday =
7 // Already past Monday 9 AM, go to next week if days_until_monday == 0 && from.time() >= NaiveTime::from_hms_opt(9, 0, 0).unwrap() {
} else if days_until_monday == 0 { 7 // Already past Monday 9 AM, go to next week
0 // It's Monday but before 9 AM } else if days_until_monday == 0 {
} else { 0 // It's Monday but before 9 AM
days_until_monday } else {
}; days_until_monday
};
let target_date = from.date_naive() + chrono::Duration::days(days_until_monday as i64); let target_date = from.date_naive() + chrono::Duration::days(days_until_monday as i64);
target_date target_date.and_hms_opt(9, 0, 0).unwrap().and_utc()
.and_hms_opt(9, 0, 0)
.unwrap()
.and_utc()
} }
#[derive(Debug)] #[derive(Debug)]
@@ -58,7 +56,7 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
WHERE o.id IN ( WHERE o.id IN (
SELECT DISTINCT org_id FROM request_events SELECT DISTINCT org_id FROM request_events
WHERE time >= now() - interval '7 days' WHERE time >= now() - interval '7 days'
)" )",
) )
.fetch_all(pg_pool) .fetch_all(pg_pool)
.await?; .await?;
@@ -70,13 +68,27 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
Ok(digest) => { Ok(digest) => {
// Send weekly digest via Resend // Send weekly digest via Resend
if let Ok(resend_key) = std::env::var("RESEND_API_KEY") { if let Ok(resend_key) = std::env::var("RESEND_API_KEY") {
let models_html: String = digest.top_models.iter().map(|m| { let models_html: String = digest
format!("<tr><td>{}</td><td>{}</td><td>${:.4}</td></tr>", m.model, m.request_count, m.cost) .top_models
}).collect(); .iter()
.map(|m| {
format!(
"<tr><td>{}</td><td>{}</td><td>${:.4}</td></tr>",
m.model, m.request_count, m.cost
)
})
.collect();
let savings_html: String = digest.top_savings.iter().map(|s| { let savings_html: String = digest
format!("<tr><td>{}{}</td><td>{}</td><td>${:.4}</td></tr>", s.original_model, s.routed_model, s.requests_routed, s.cost_saved) .top_savings
}).collect(); .iter()
.map(|s| {
format!(
"<tr><td>{}{}</td><td>{}</td><td>${:.4}</td></tr>",
s.original_model, s.routed_model, s.requests_routed, s.cost_saved
)
})
.collect();
let html = format!( let html = format!(
"<h2>Weekly Cost Digest: {}</h2>\ "<h2>Weekly Cost Digest: {}</h2>\
@@ -93,10 +105,14 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
<table style='border-collapse:collapse;width:100%'>\ <table style='border-collapse:collapse;width:100%'>\
<tr><th>Route</th><th>Requests</th><th>Saved</th></tr>{}</table>\ <tr><th>Route</th><th>Requests</th><th>Saved</th></tr>{}</table>\
<p><a href='https://route.dd0c.dev/dashboard'>View Dashboard →</a></p>", <p><a href='https://route.dd0c.dev/dashboard'>View Dashboard →</a></p>",
digest.org_name, digest.total_requests, digest.org_name,
digest.total_cost_original, digest.total_cost_actual, digest.total_requests,
digest.total_cost_saved, digest.savings_pct, digest.total_cost_original,
models_html, savings_html digest.total_cost_actual,
digest.total_cost_saved,
digest.savings_pct,
models_html,
savings_html
); );
let email_body = serde_json::json!({ let email_body = serde_json::json!({
@@ -107,7 +123,8 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
}); });
let client = reqwest::Client::new(); let client = reqwest::Client::new();
match client.post("https://api.resend.com/emails") match client
.post("https://api.resend.com/emails")
.bearer_auth(&resend_key) .bearer_auth(&resend_key)
.json(&email_body) .json(&email_body)
.send() .send()
@@ -134,7 +151,11 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
Ok(()) Ok(())
} }
async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyhow::Result<DigestData> { async fn generate_digest(
ts_pool: &PgPool,
org_id: Uuid,
org_name: &str,
) -> anyhow::Result<DigestData> {
// Summary stats // Summary stats
let summary = sqlx::query_as::<_, (i64, f64, f64, f64)>( let summary = sqlx::query_as::<_, (i64, f64, f64, f64)>(
"SELECT COUNT(*), "SELECT COUNT(*),
@@ -142,7 +163,7 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
COALESCE(SUM(cost_actual), 0)::float8, COALESCE(SUM(cost_actual), 0)::float8,
COALESCE(SUM(cost_saved), 0)::float8 COALESCE(SUM(cost_saved), 0)::float8
FROM request_events FROM request_events
WHERE org_id = $1 AND time >= now() - interval '7 days'" WHERE org_id = $1 AND time >= now() - interval '7 days'",
) )
.bind(org_id) .bind(org_id)
.fetch_one(ts_pool) .fetch_one(ts_pool)
@@ -155,7 +176,7 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
WHERE org_id = $1 AND time >= now() - interval '7 days' WHERE org_id = $1 AND time >= now() - interval '7 days'
GROUP BY original_model GROUP BY original_model
ORDER BY SUM(cost_actual) DESC ORDER BY SUM(cost_actual) DESC
LIMIT 5" LIMIT 5",
) )
.bind(org_id) .bind(org_id)
.fetch_all(ts_pool) .fetch_all(ts_pool)
@@ -168,13 +189,17 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
WHERE org_id = $1 AND time >= now() - interval '7 days' AND strategy != 'passthrough' WHERE org_id = $1 AND time >= now() - interval '7 days' AND strategy != 'passthrough'
GROUP BY original_model, routed_model GROUP BY original_model, routed_model
ORDER BY SUM(cost_saved) DESC ORDER BY SUM(cost_saved) DESC
LIMIT 5" LIMIT 5",
) )
.bind(org_id) .bind(org_id)
.fetch_all(ts_pool) .fetch_all(ts_pool)
.await?; .await?;
let savings_pct = if summary.1 > 0.0 { (summary.3 / summary.1) * 100.0 } else { 0.0 }; let savings_pct = if summary.1 > 0.0 {
(summary.3 / summary.1) * 100.0
} else {
0.0
};
Ok(DigestData { Ok(DigestData {
_org_id: org_id, _org_id: org_id,
@@ -184,17 +209,23 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
total_cost_actual: summary.2, total_cost_actual: summary.2,
total_cost_saved: summary.3, total_cost_saved: summary.3,
savings_pct, savings_pct,
top_models: top_models.iter().map(|r| ModelUsage { top_models: top_models
model: r.0.clone(), .iter()
request_count: r.1, .map(|r| ModelUsage {
cost: r.2, model: r.0.clone(),
}).collect(), request_count: r.1,
top_savings: top_savings.iter().map(|r| RoutingSaving { cost: r.2,
original_model: r.0.clone(), })
routed_model: r.1.clone(), .collect(),
requests_routed: r.2, top_savings: top_savings
cost_saved: r.3, .iter()
}).collect(), .map(|r| RoutingSaving {
original_model: r.0.clone(),
routed_model: r.1.clone(),
requests_routed: r.2,
cost_saved: r.3,
})
.collect(),
}) })
} }
@@ -205,27 +236,45 @@ mod tests {
#[test] #[test]
fn next_monday_from_wednesday() { fn next_monday_from_wednesday() {
let wed = chrono::NaiveDate::from_ymd_opt(2026, 3, 4).unwrap() // Wednesday let wed = chrono::NaiveDate::from_ymd_opt(2026, 3, 4)
.and_hms_opt(14, 0, 0).unwrap().and_utc(); .unwrap() // Wednesday
.and_hms_opt(14, 0, 0)
.unwrap()
.and_utc();
let next = next_monday_9am(wed); let next = next_monday_9am(wed);
assert_eq!(next.weekday(), Weekday::Mon); assert_eq!(next.weekday(), Weekday::Mon);
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()); assert_eq!(
next.date_naive(),
chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()
);
assert_eq!(next.time(), NaiveTime::from_hms_opt(9, 0, 0).unwrap()); assert_eq!(next.time(), NaiveTime::from_hms_opt(9, 0, 0).unwrap());
} }
#[test] #[test]
fn next_monday_from_monday_before_9am() { fn next_monday_from_monday_before_9am() {
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap() // Monday let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2)
.and_hms_opt(8, 0, 0).unwrap().and_utc(); .unwrap() // Monday
.and_hms_opt(8, 0, 0)
.unwrap()
.and_utc();
let next = next_monday_9am(mon); let next = next_monday_9am(mon);
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap()); assert_eq!(
next.date_naive(),
chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap()
);
} }
#[test] #[test]
fn next_monday_from_monday_after_9am() { fn next_monday_from_monday_after_9am() {
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap() let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2)
.and_hms_opt(10, 0, 0).unwrap().and_utc(); .unwrap()
.and_hms_opt(10, 0, 0)
.unwrap()
.and_utc();
let next = next_monday_9am(mon); let next = next_monday_9am(mon);
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()); assert_eq!(
next.date_naive(),
chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()
);
} }
} }

View File

@@ -1,11 +1,11 @@
use std::sync::Arc;
use tracing::{info, error};
use chrono::Utc; use chrono::Utc;
use std::sync::Arc;
use tracing::{error, info};
use dd0c_route::AppConfig; use dd0c_route::AppConfig;
mod digest;
mod anomaly; mod anomaly;
mod digest;
/// Refresh model pricing from known provider pricing pages. /// Refresh model pricing from known provider pricing pages.
/// Falls back to hardcoded defaults if fetch fails. /// Falls back to hardcoded defaults if fetch fails.
@@ -83,7 +83,9 @@ async fn main() -> anyhow::Result<()> {
loop { loop {
let now = Utc::now(); let now = Utc::now();
let next_monday = digest::next_monday_9am(now); let next_monday = digest::next_monday_9am(now);
let sleep_duration = (next_monday - now).to_std().unwrap_or(std::time::Duration::from_secs(3600)); let sleep_duration = (next_monday - now)
.to_std()
.unwrap_or(std::time::Duration::from_secs(3600));
tokio::time::sleep(sleep_duration).await; tokio::time::sleep(sleep_duration).await;
info!("Generating weekly digests"); info!("Generating weekly digests");