cargo fmt: format all Rust source files
All checks were successful
CI — P1 Route (Rust) / test (push) Successful in 6m35s

This commit is contained in:
2026-03-01 17:53:28 +00:00
parent 00db59ff83
commit a8a8c53917
12 changed files with 305 additions and 167 deletions

View File

@@ -5,13 +5,13 @@
//! Run: cargo bench --bench proxy_latency
//! CI gate: P99 must be < 5ms
use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use std::sync::Arc;
use tokio::sync::mpsc;
use dd0c_route::{
AppConfig, TelemetryEvent, RouterBrain,
proxy::{create_router, ProxyState},
AppConfig, RouterBrain, TelemetryEvent,
};
struct NoOpAuth;
@@ -51,10 +51,13 @@ fn bench_proxy_overhead(c: &mut Criterion) {
let (tx, _rx) = mpsc::channel::<TelemetryEvent>(10000);
let mut providers = std::collections::HashMap::new();
providers.insert("openai".to_string(), dd0c_route::config::ProviderConfig {
providers.insert(
"openai".to_string(),
dd0c_route::config::ProviderConfig {
api_key: "bench-key".to_string(),
base_url: mock_url.clone(),
});
},
);
let config = Arc::new(AppConfig {
proxy_port: 0,
@@ -90,7 +93,8 @@ fn bench_proxy_overhead(c: &mut Criterion) {
let body = serde_json::json!({
"model": "gpt-4o",
"messages": messages,
}).to_string();
})
.to_string();
group.bench_with_input(
BenchmarkId::new("chat_completions", format!("{}_msgs", msg_count)),

View File

@@ -2,7 +2,7 @@ use axum::{
extract::{Path, Query, State},
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::{get, put, delete},
routing::{delete, get, put},
Json, Router,
};
use serde::{Deserialize, Serialize};
@@ -23,7 +23,10 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
Router::new()
// Dashboard analytics
.route("/api/v1/analytics/summary", get(get_analytics_summary))
.route("/api/v1/analytics/timeseries", get(get_analytics_timeseries))
.route(
"/api/v1/analytics/timeseries",
get(get_analytics_timeseries),
)
.route("/api/v1/analytics/models", get(get_model_breakdown))
// Routing rules
.route("/api/v1/rules", get(list_rules).post(create_rule))
@@ -32,7 +35,10 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
.route("/api/v1/keys", get(list_keys).post(create_key))
.route("/api/v1/keys/:id", delete(revoke_key))
// Provider configs
.route("/api/v1/providers", get(list_providers).post(upsert_provider))
.route(
"/api/v1/providers",
get(list_providers).post(upsert_provider),
)
// Org settings
.route("/api/v1/org", get(get_org))
// Health
@@ -40,7 +46,9 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
.with_state(state)
}
async fn health() -> &'static str { "ok" }
async fn health() -> &'static str {
"ok"
}
// --- Analytics Endpoints ---
@@ -76,7 +84,9 @@ async fn get_analytics_summary(
Query(range): Query<TimeRange>,
) -> Result<Json<AnalyticsSummary>, ApiError> {
let auth = state.auth.authenticate(&headers).await?;
let _from = range.from.unwrap_or_else(|| "now() - interval '7 days'".to_string());
let _from = range
.from
.unwrap_or_else(|| "now() - interval '7 days'".to_string());
let row = sqlx::query_as::<_, (i64, f64, f64, f64, i32, i32, i64, i64, i64)>(
"SELECT
@@ -90,14 +100,18 @@ async fn get_analytics_summary(
COUNT(*) FILTER (WHERE strategy = 'cheapest'),
COUNT(*) FILTER (WHERE strategy = 'cascading')
FROM request_events
WHERE org_id = $1 AND time >= now() - interval '7 days'"
WHERE org_id = $1 AND time >= now() - interval '7 days'",
)
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_one(&state.ts_pool)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
let savings_pct = if row.1 > 0.0 { (row.3 / row.1) * 100.0 } else { 0.0 };
let savings_pct = if row.1 > 0.0 {
(row.3 / row.1) * 100.0
} else {
0.0
};
Ok(Json(AnalyticsSummary {
total_requests: row.0,
@@ -131,7 +145,11 @@ async fn get_analytics_timeseries(
let auth = state.auth.authenticate(&headers).await?;
let interval = range.interval.unwrap_or_else(|| "hour".to_string());
let view = if interval == "day" { "request_events_daily" } else { "request_events_hourly" };
let view = if interval == "day" {
"request_events_daily"
} else {
"request_events_hourly"
};
let rows = sqlx::query_as::<_, (chrono::DateTime<chrono::Utc>, i64, f64, i32)>(
&format!(
@@ -147,12 +165,16 @@ async fn get_analytics_timeseries(
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| TimeseriesPoint {
Ok(Json(
rows.iter()
.map(|r| TimeseriesPoint {
bucket: r.0.to_rfc3339(),
request_count: r.1,
cost_saved: r.2,
avg_latency_ms: r.3,
}).collect()))
})
.collect(),
))
}
#[derive(Serialize)]
@@ -181,12 +203,16 @@ async fn get_model_breakdown(
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| ModelBreakdown {
Ok(Json(
rows.iter()
.map(|r| ModelBreakdown {
model: r.0.clone(),
request_count: r.1,
total_tokens: r.2,
total_cost: r.3,
}).collect()))
})
.collect(),
))
}
// --- Routing Rules CRUD ---
@@ -222,7 +248,9 @@ async fn list_rules(
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| RoutingRuleDto {
Ok(Json(
rows.iter()
.map(|r| RoutingRuleDto {
id: Some(r.0),
priority: r.1,
name: r.2.clone(),
@@ -235,7 +263,9 @@ async fn list_rules(
target_provider: r.9.clone(),
fallback_models: r.10.clone(),
enabled: r.11,
}).collect()))
})
.collect(),
))
}
async fn create_rule(
@@ -361,23 +391,37 @@ async fn list_keys(
) -> Result<Json<Vec<ApiKeyDto>>, ApiError> {
let auth = state.auth.authenticate(&headers).await?;
let rows = sqlx::query_as::<_, (Uuid, String, String, Vec<String>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(
let rows = sqlx::query_as::<
_,
(
Uuid,
String,
String,
Vec<String>,
Option<chrono::DateTime<chrono::Utc>>,
chrono::DateTime<chrono::Utc>,
),
>(
"SELECT id, name, key_prefix, scopes, last_used_at, created_at
FROM api_keys WHERE org_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
FROM api_keys WHERE org_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC",
)
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_all(&state.pg_pool)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| ApiKeyDto {
Ok(Json(
rows.iter()
.map(|r| ApiKeyDto {
id: r.0,
name: r.1.clone(),
key_prefix: r.2.clone(),
scopes: r.3.clone(),
last_used_at: r.4,
created_at: r.5,
}).collect()))
})
.collect(),
))
}
async fn create_key(
@@ -409,11 +453,14 @@ async fn create_key(
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
Ok((StatusCode::CREATED, Json(ApiKeyCreated {
Ok((
StatusCode::CREATED,
Json(ApiKeyCreated {
id,
key: raw_key,
name: req.name,
})))
}),
))
}
async fn revoke_key(
@@ -456,19 +503,23 @@ async fn list_providers(
let auth = state.auth.authenticate(&headers).await?;
let rows = sqlx::query_as::<_, (String, Option<String>, bool)>(
"SELECT provider, base_url, is_default FROM provider_configs WHERE org_id = $1"
"SELECT provider, base_url, is_default FROM provider_configs WHERE org_id = $1",
)
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_all(&state.pg_pool)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
Ok(Json(rows.iter().map(|r| ProviderDto {
Ok(Json(
rows.iter()
.map(|r| ProviderDto {
provider: r.0.clone(),
base_url: r.1.clone(),
is_default: r.2,
has_key: true,
}).collect()))
})
.collect(),
))
}
#[derive(Deserialize)]
@@ -488,13 +539,12 @@ async fn upsert_provider(
require_role(&auth, Role::Owner)?;
// Encrypt API key with AES-256-GCM before storing
let encryption_key = std::env::var("PROVIDER_KEY_ENCRYPTION_KEY")
.unwrap_or_else(|_| "0".repeat(64)); // 32-byte hex key
let key_bytes = hex::decode(&encryption_key)
.unwrap_or_else(|_| vec![0u8; 32]);
let encryption_key =
std::env::var("PROVIDER_KEY_ENCRYPTION_KEY").unwrap_or_else(|_| "0".repeat(64)); // 32-byte hex key
let key_bytes = hex::decode(&encryption_key).unwrap_or_else(|_| vec![0u8; 32]);
use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
use aes_gcm::Nonce;
use aes_gcm::{aead::Aead, Aes256Gcm, KeyInit};
let cipher = Aes256Gcm::new_from_slice(&key_bytes)
.map_err(|e| ApiError::Internal(format!("Encryption key error: {}", e)))?;
@@ -502,7 +552,8 @@ async fn upsert_provider(
getrandom::getrandom(&mut nonce_bytes)
.map_err(|e| ApiError::Internal(format!("RNG error: {}", e)))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher.encrypt(nonce, req.api_key.as_bytes())
let ciphertext = cipher
.encrypt(nonce, req.api_key.as_bytes())
.map_err(|e| ApiError::Internal(format!("Encryption error: {}", e)))?;
// Store as nonce || ciphertext
@@ -544,7 +595,7 @@ async fn get_org(
let auth = state.auth.authenticate(&headers).await?;
let row = sqlx::query_as::<_, (Uuid, String, String, String)>(
"SELECT id, name, slug, tier FROM organizations WHERE id = $1"
"SELECT id, name, slug, tier FROM organizations WHERE id = $1",
)
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
.fetch_optional(&state.pg_pool)
@@ -606,7 +657,10 @@ impl IntoResponse for ApiError {
ApiError::AuthError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),
ApiError::Forbidden => (StatusCode::FORBIDDEN, self.to_string()),
ApiError::NotFound => (StatusCode::NOT_FOUND, self.to_string()),
ApiError::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error".to_string()),
ApiError::Internal(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal error".to_string(),
),
};
(status, serde_json::json!({"error": msg}).to_string()).into_response()
}

View File

@@ -1,5 +1,5 @@
use axum::http::HeaderMap;
use async_trait::async_trait;
use axum::http::HeaderMap;
use thiserror::Error;
#[derive(Debug, Clone)]
@@ -40,7 +40,11 @@ pub struct LocalAuthProvider {
impl LocalAuthProvider {
pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self {
Self { pool, _jwt_secret: jwt_secret, redis }
Self {
pool,
_jwt_secret: jwt_secret,
redis,
}
}
}
@@ -72,7 +76,7 @@ impl AuthProvider for LocalAuthProvider {
// 2. Fall back to PostgreSQL
let row = sqlx::query_as::<_, (String, String)>(
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL"
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL",
)
.bind(&key[..8])
.fetch_optional(&self.pool)

View File

@@ -41,18 +41,24 @@ impl AppConfig {
let mut providers = HashMap::new();
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
providers.insert("openai".to_string(), ProviderConfig {
providers.insert(
"openai".to_string(),
ProviderConfig {
api_key: key,
base_url: std::env::var("OPENAI_BASE_URL")
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
});
},
);
}
if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
providers.insert("anthropic".to_string(), ProviderConfig {
providers.insert(
"anthropic".to_string(),
ProviderConfig {
api_key: key,
base_url: std::env::var("ANTHROPIC_BASE_URL")
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
});
},
);
}
Ok(Self {
@@ -66,8 +72,9 @@ impl AppConfig {
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5432/dd0c".to_string()),
redis_url: std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost:6379".to_string()),
timescale_url: std::env::var("TIMESCALE_URL")
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()),
timescale_url: std::env::var("TIMESCALE_URL").unwrap_or_else(|_| {
"postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()
}),
jwt_secret: std::env::var("JWT_SECRET")
.unwrap_or_else(|_| "dev-secret-change-me".to_string()),
auth_mode: if std::env::var("AUTH_MODE").unwrap_or_default() == "oauth" {

View File

@@ -4,8 +4,8 @@ pub mod data;
pub mod proxy;
pub mod router;
pub use auth::{AuthProvider, AuthContext, AuthError, LocalAuthProvider, Role};
pub use auth::{AuthContext, AuthError, AuthProvider, LocalAuthProvider, Role};
pub use config::AppConfig;
pub use data::{EventQueue, ObjectStore, TelemetryEvent};
pub use proxy::{create_router, ProxyState, ProxyError};
pub use router::{RouterBrain, RoutingDecision, Complexity};
pub use proxy::{create_router, ProxyError, ProxyState};
pub use router::{Complexity, RouterBrain, RoutingDecision};

View File

@@ -10,7 +10,6 @@ use axum::{
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::auth::AuthProvider;
use crate::config::AppConfig;
use crate::data::TelemetryEvent;
@@ -61,10 +60,7 @@ async fn proxy_chat_completions(
.to_string();
// 3. Route (pick model + provider)
let decision = state
.router
.route(&auth_ctx.org_id, &request)
.await;
let decision = state.router.route(&auth_ctx.org_id, &request).await;
// Apply routing decision
if let Some(ref routed_model) = decision.model {
@@ -80,7 +76,10 @@ async fn proxy_chat_completions(
let upstream_resp = state
.http_client
.post(format!("{}/v1/chat/completions", upstream_url))
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
.header(
"Authorization",
format!("Bearer {}", state.config.provider_key(&provider)),
)
.header("Content-Type", "application/json")
.body(request.to_string())
.send()
@@ -155,7 +154,10 @@ async fn proxy_embeddings(
let resp = state
.http_client
.post(format!("{}/v1/embeddings", upstream_url))
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
.header(
"Authorization",
format!("Bearer {}", state.config.provider_key(&provider)),
)
.header("Content-Type", "application/json")
.body(body)
.send()
@@ -163,8 +165,14 @@ async fn proxy_embeddings(
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
let status = resp.status();
let body = resp.bytes().await.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
Ok(Response::builder().status(status).body(Body::from(body)).unwrap())
let body = resp
.bytes()
.await
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
Ok(Response::builder()
.status(status)
.body(Body::from(body))
.unwrap())
}
// --- Error types ---

View File

@@ -3,7 +3,7 @@ use tokio::sync::mpsc;
use tracing::info;
use dd0c_route::{
AppConfig, LocalAuthProvider, RouterBrain, ProxyState, TelemetryEvent, create_router,
create_router, AppConfig, LocalAuthProvider, ProxyState, RouterBrain, TelemetryEvent,
};
#[tokio::main]
@@ -27,12 +27,15 @@ async fn main() -> anyhow::Result<()> {
let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?;
// Telemetry channel (bounded, non-blocking)
let (telemetry_tx, mut telemetry_rx) = mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size);
let (telemetry_tx, mut telemetry_rx) =
mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size);
// Spawn telemetry worker (writes to TimescaleDB)
let ts_url = config.timescale_url.clone();
tokio::spawn(async move {
let ts_pool = sqlx::PgPool::connect(&ts_url).await.expect("TimescaleDB connection failed");
let ts_pool = sqlx::PgPool::connect(&ts_url)
.await
.expect("TimescaleDB connection failed");
while let Some(event) = telemetry_rx.recv().await {
if let Err(e) = sqlx::query(
"INSERT INTO request_events (org_id, original_model, routed_model, provider, strategy, latency_ms, status_code, is_streaming, prompt_tokens, completion_tokens, created_at)

View File

@@ -1,4 +1,4 @@
pub mod handler;
pub mod middleware;
pub use handler::{create_router, ProxyState, ProxyError};
pub use handler::{create_router, ProxyError, ProxyState};

View File

@@ -41,13 +41,21 @@ impl RouterBrain {
.get("messages")
.and_then(|m| m.as_array())
.and_then(|msgs| {
msgs.iter().find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
msgs.iter()
.find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
})
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("");
let high_complexity_keywords = ["analyze", "reason", "compare", "evaluate", "synthesize", "debate"];
let high_complexity_keywords = [
"analyze",
"reason",
"compare",
"evaluate",
"synthesize",
"debate",
];
let has_complex_task = high_complexity_keywords
.iter()
.any(|kw| system_prompt.to_lowercase().contains(kw));
@@ -142,7 +150,9 @@ mod tests {
];
for i in 0..12 {
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
messages.push(serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}));
messages.push(
serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}),
);
}
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
let complexity = brain.classify_complexity(&request);

View File

@@ -4,7 +4,7 @@ use uuid::Uuid;
async fn get_org_owner_email(pool: &PgPool, org_id: Uuid) -> Result<String, anyhow::Error> {
let row = sqlx::query_as::<_, (String,)>(
"SELECT email FROM users WHERE org_id = $1 AND role = 'owner' LIMIT 1"
"SELECT email FROM users WHERE org_id = $1 AND role = 'owner' LIMIT 1",
)
.bind(org_id)
.fetch_one(pool)
@@ -18,7 +18,7 @@ async fn get_org_owner_email(pool: &PgPool, org_id: Uuid) -> Result<String, anyh
pub async fn check_anomalies(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow::Result<()> {
// Get orgs with recent activity
let orgs = sqlx::query_as::<_, (Uuid,)>(
"SELECT DISTINCT org_id FROM request_events WHERE time >= now() - interval '1 hour'"
"SELECT DISTINCT org_id FROM request_events WHERE time >= now() - interval '1 hour'",
)
.fetch_all(ts_pool)
.await?;
@@ -35,7 +35,7 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
let current = sqlx::query_as::<_, (f64,)>(
"SELECT COALESCE(SUM(cost_actual), 0)::float8
FROM request_events
WHERE org_id = $1 AND time >= now() - interval '1 hour'"
WHERE org_id = $1 AND time >= now() - interval '1 hour'",
)
.bind(org_id)
.fetch_one(ts_pool)
@@ -88,11 +88,7 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
]
});
let client = reqwest::Client::new();
if let Err(e) = client.post(&slack_url)
.json(&payload)
.send()
.await
{
if let Err(e) = client.post(&slack_url).json(&payload).send().await {
warn!(error = %e, "Failed to send Slack anomaly alert");
}
}
@@ -114,7 +110,8 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
)
});
let client = reqwest::Client::new();
if let Err(e) = client.post("https://api.resend.com/emails")
if let Err(e) = client
.post("https://api.resend.com/emails")
.bearer_auth(&resend_key)
.json(&email_body)
.send()

View File

@@ -6,7 +6,8 @@ use uuid::Uuid;
/// Calculate next Monday 9 AM UTC from a given time
pub fn next_monday_9am(from: DateTime<Utc>) -> DateTime<Utc> {
let days_until_monday = (7 - from.weekday().num_days_from_monday()) % 7;
let days_until_monday = if days_until_monday == 0 && from.time() >= NaiveTime::from_hms_opt(9, 0, 0).unwrap() {
let days_until_monday =
if days_until_monday == 0 && from.time() >= NaiveTime::from_hms_opt(9, 0, 0).unwrap() {
7 // Already past Monday 9 AM, go to next week
} else if days_until_monday == 0 {
0 // It's Monday but before 9 AM
@@ -15,10 +16,7 @@ pub fn next_monday_9am(from: DateTime<Utc>) -> DateTime<Utc> {
};
let target_date = from.date_naive() + chrono::Duration::days(days_until_monday as i64);
target_date
.and_hms_opt(9, 0, 0)
.unwrap()
.and_utc()
target_date.and_hms_opt(9, 0, 0).unwrap().and_utc()
}
#[derive(Debug)]
@@ -58,7 +56,7 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
WHERE o.id IN (
SELECT DISTINCT org_id FROM request_events
WHERE time >= now() - interval '7 days'
)"
)",
)
.fetch_all(pg_pool)
.await?;
@@ -70,13 +68,27 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
Ok(digest) => {
// Send weekly digest via Resend
if let Ok(resend_key) = std::env::var("RESEND_API_KEY") {
let models_html: String = digest.top_models.iter().map(|m| {
format!("<tr><td>{}</td><td>{}</td><td>${:.4}</td></tr>", m.model, m.request_count, m.cost)
}).collect();
let models_html: String = digest
.top_models
.iter()
.map(|m| {
format!(
"<tr><td>{}</td><td>{}</td><td>${:.4}</td></tr>",
m.model, m.request_count, m.cost
)
})
.collect();
let savings_html: String = digest.top_savings.iter().map(|s| {
format!("<tr><td>{}{}</td><td>{}</td><td>${:.4}</td></tr>", s.original_model, s.routed_model, s.requests_routed, s.cost_saved)
}).collect();
let savings_html: String = digest
.top_savings
.iter()
.map(|s| {
format!(
"<tr><td>{}{}</td><td>{}</td><td>${:.4}</td></tr>",
s.original_model, s.routed_model, s.requests_routed, s.cost_saved
)
})
.collect();
let html = format!(
"<h2>Weekly Cost Digest: {}</h2>\
@@ -93,10 +105,14 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
<table style='border-collapse:collapse;width:100%'>\
<tr><th>Route</th><th>Requests</th><th>Saved</th></tr>{}</table>\
<p><a href='https://route.dd0c.dev/dashboard'>View Dashboard →</a></p>",
digest.org_name, digest.total_requests,
digest.total_cost_original, digest.total_cost_actual,
digest.total_cost_saved, digest.savings_pct,
models_html, savings_html
digest.org_name,
digest.total_requests,
digest.total_cost_original,
digest.total_cost_actual,
digest.total_cost_saved,
digest.savings_pct,
models_html,
savings_html
);
let email_body = serde_json::json!({
@@ -107,7 +123,8 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
});
let client = reqwest::Client::new();
match client.post("https://api.resend.com/emails")
match client
.post("https://api.resend.com/emails")
.bearer_auth(&resend_key)
.json(&email_body)
.send()
@@ -134,7 +151,11 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
Ok(())
}
async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyhow::Result<DigestData> {
async fn generate_digest(
ts_pool: &PgPool,
org_id: Uuid,
org_name: &str,
) -> anyhow::Result<DigestData> {
// Summary stats
let summary = sqlx::query_as::<_, (i64, f64, f64, f64)>(
"SELECT COUNT(*),
@@ -142,7 +163,7 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
COALESCE(SUM(cost_actual), 0)::float8,
COALESCE(SUM(cost_saved), 0)::float8
FROM request_events
WHERE org_id = $1 AND time >= now() - interval '7 days'"
WHERE org_id = $1 AND time >= now() - interval '7 days'",
)
.bind(org_id)
.fetch_one(ts_pool)
@@ -155,7 +176,7 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
WHERE org_id = $1 AND time >= now() - interval '7 days'
GROUP BY original_model
ORDER BY SUM(cost_actual) DESC
LIMIT 5"
LIMIT 5",
)
.bind(org_id)
.fetch_all(ts_pool)
@@ -168,13 +189,17 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
WHERE org_id = $1 AND time >= now() - interval '7 days' AND strategy != 'passthrough'
GROUP BY original_model, routed_model
ORDER BY SUM(cost_saved) DESC
LIMIT 5"
LIMIT 5",
)
.bind(org_id)
.fetch_all(ts_pool)
.await?;
let savings_pct = if summary.1 > 0.0 { (summary.3 / summary.1) * 100.0 } else { 0.0 };
let savings_pct = if summary.1 > 0.0 {
(summary.3 / summary.1) * 100.0
} else {
0.0
};
Ok(DigestData {
_org_id: org_id,
@@ -184,17 +209,23 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
total_cost_actual: summary.2,
total_cost_saved: summary.3,
savings_pct,
top_models: top_models.iter().map(|r| ModelUsage {
top_models: top_models
.iter()
.map(|r| ModelUsage {
model: r.0.clone(),
request_count: r.1,
cost: r.2,
}).collect(),
top_savings: top_savings.iter().map(|r| RoutingSaving {
})
.collect(),
top_savings: top_savings
.iter()
.map(|r| RoutingSaving {
original_model: r.0.clone(),
routed_model: r.1.clone(),
requests_routed: r.2,
cost_saved: r.3,
}).collect(),
})
.collect(),
})
}
@@ -205,27 +236,45 @@ mod tests {
#[test]
fn next_monday_from_wednesday() {
let wed = chrono::NaiveDate::from_ymd_opt(2026, 3, 4).unwrap() // Wednesday
.and_hms_opt(14, 0, 0).unwrap().and_utc();
let wed = chrono::NaiveDate::from_ymd_opt(2026, 3, 4)
.unwrap() // Wednesday
.and_hms_opt(14, 0, 0)
.unwrap()
.and_utc();
let next = next_monday_9am(wed);
assert_eq!(next.weekday(), Weekday::Mon);
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap());
assert_eq!(
next.date_naive(),
chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()
);
assert_eq!(next.time(), NaiveTime::from_hms_opt(9, 0, 0).unwrap());
}
#[test]
fn next_monday_from_monday_before_9am() {
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap() // Monday
.and_hms_opt(8, 0, 0).unwrap().and_utc();
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2)
.unwrap() // Monday
.and_hms_opt(8, 0, 0)
.unwrap()
.and_utc();
let next = next_monday_9am(mon);
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap());
assert_eq!(
next.date_naive(),
chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap()
);
}
#[test]
fn next_monday_from_monday_after_9am() {
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap()
.and_hms_opt(10, 0, 0).unwrap().and_utc();
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2)
.unwrap()
.and_hms_opt(10, 0, 0)
.unwrap()
.and_utc();
let next = next_monday_9am(mon);
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap());
assert_eq!(
next.date_naive(),
chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()
);
}
}

View File

@@ -1,11 +1,11 @@
use std::sync::Arc;
use tracing::{info, error};
use chrono::Utc;
use std::sync::Arc;
use tracing::{error, info};
use dd0c_route::AppConfig;
mod digest;
mod anomaly;
mod digest;
/// Refresh model pricing from known provider pricing pages.
/// Falls back to hardcoded defaults if fetch fails.
@@ -83,7 +83,9 @@ async fn main() -> anyhow::Result<()> {
loop {
let now = Utc::now();
let next_monday = digest::next_monday_9am(now);
let sleep_duration = (next_monday - now).to_std().unwrap_or(std::time::Duration::from_secs(3600));
let sleep_duration = (next_monday - now)
.to_std()
.unwrap_or(std::time::Duration::from_secs(3600));
tokio::time::sleep(sleep_duration).await;
info!("Generating weekly digests");