cargo fmt: format all Rust source files
All checks were successful
CI — P1 Route (Rust) / test (push) Successful in 6m35s
All checks were successful
CI — P1 Route (Rust) / test (push) Successful in 6m35s
This commit is contained in:
@@ -5,13 +5,13 @@
|
||||
//! Run: cargo bench --bench proxy_latency
|
||||
//! CI gate: P99 must be < 5ms
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId};
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use dd0c_route::{
|
||||
AppConfig, TelemetryEvent, RouterBrain,
|
||||
proxy::{create_router, ProxyState},
|
||||
AppConfig, RouterBrain, TelemetryEvent,
|
||||
};
|
||||
|
||||
struct NoOpAuth;
|
||||
@@ -51,10 +51,13 @@ fn bench_proxy_overhead(c: &mut Criterion) {
|
||||
let (tx, _rx) = mpsc::channel::<TelemetryEvent>(10000);
|
||||
|
||||
let mut providers = std::collections::HashMap::new();
|
||||
providers.insert("openai".to_string(), dd0c_route::config::ProviderConfig {
|
||||
api_key: "bench-key".to_string(),
|
||||
base_url: mock_url.clone(),
|
||||
});
|
||||
providers.insert(
|
||||
"openai".to_string(),
|
||||
dd0c_route::config::ProviderConfig {
|
||||
api_key: "bench-key".to_string(),
|
||||
base_url: mock_url.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
let config = Arc::new(AppConfig {
|
||||
proxy_port: 0,
|
||||
@@ -90,7 +93,8 @@ fn bench_proxy_overhead(c: &mut Criterion) {
|
||||
let body = serde_json::json!({
|
||||
"model": "gpt-4o",
|
||||
"messages": messages,
|
||||
}).to_string();
|
||||
})
|
||||
.to_string();
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("chat_completions", format!("{}_msgs", msg_count)),
|
||||
|
||||
@@ -2,7 +2,7 @@ use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::{get, put, delete},
|
||||
routing::{delete, get, put},
|
||||
Json, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -23,7 +23,10 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
|
||||
Router::new()
|
||||
// Dashboard analytics
|
||||
.route("/api/v1/analytics/summary", get(get_analytics_summary))
|
||||
.route("/api/v1/analytics/timeseries", get(get_analytics_timeseries))
|
||||
.route(
|
||||
"/api/v1/analytics/timeseries",
|
||||
get(get_analytics_timeseries),
|
||||
)
|
||||
.route("/api/v1/analytics/models", get(get_model_breakdown))
|
||||
// Routing rules
|
||||
.route("/api/v1/rules", get(list_rules).post(create_rule))
|
||||
@@ -32,7 +35,10 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
|
||||
.route("/api/v1/keys", get(list_keys).post(create_key))
|
||||
.route("/api/v1/keys/:id", delete(revoke_key))
|
||||
// Provider configs
|
||||
.route("/api/v1/providers", get(list_providers).post(upsert_provider))
|
||||
.route(
|
||||
"/api/v1/providers",
|
||||
get(list_providers).post(upsert_provider),
|
||||
)
|
||||
// Org settings
|
||||
.route("/api/v1/org", get(get_org))
|
||||
// Health
|
||||
@@ -40,7 +46,9 @@ pub fn create_api_router(state: Arc<ApiState>) -> Router {
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn health() -> &'static str { "ok" }
|
||||
async fn health() -> &'static str {
|
||||
"ok"
|
||||
}
|
||||
|
||||
// --- Analytics Endpoints ---
|
||||
|
||||
@@ -76,7 +84,9 @@ async fn get_analytics_summary(
|
||||
Query(range): Query<TimeRange>,
|
||||
) -> Result<Json<AnalyticsSummary>, ApiError> {
|
||||
let auth = state.auth.authenticate(&headers).await?;
|
||||
let _from = range.from.unwrap_or_else(|| "now() - interval '7 days'".to_string());
|
||||
let _from = range
|
||||
.from
|
||||
.unwrap_or_else(|| "now() - interval '7 days'".to_string());
|
||||
|
||||
let row = sqlx::query_as::<_, (i64, f64, f64, f64, i32, i32, i64, i64, i64)>(
|
||||
"SELECT
|
||||
@@ -90,14 +100,18 @@ async fn get_analytics_summary(
|
||||
COUNT(*) FILTER (WHERE strategy = 'cheapest'),
|
||||
COUNT(*) FILTER (WHERE strategy = 'cascading')
|
||||
FROM request_events
|
||||
WHERE org_id = $1 AND time >= now() - interval '7 days'"
|
||||
WHERE org_id = $1 AND time >= now() - interval '7 days'",
|
||||
)
|
||||
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
|
||||
.fetch_one(&state.ts_pool)
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
let savings_pct = if row.1 > 0.0 { (row.3 / row.1) * 100.0 } else { 0.0 };
|
||||
let savings_pct = if row.1 > 0.0 {
|
||||
(row.3 / row.1) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Ok(Json(AnalyticsSummary {
|
||||
total_requests: row.0,
|
||||
@@ -131,7 +145,11 @@ async fn get_analytics_timeseries(
|
||||
let auth = state.auth.authenticate(&headers).await?;
|
||||
let interval = range.interval.unwrap_or_else(|| "hour".to_string());
|
||||
|
||||
let view = if interval == "day" { "request_events_daily" } else { "request_events_hourly" };
|
||||
let view = if interval == "day" {
|
||||
"request_events_daily"
|
||||
} else {
|
||||
"request_events_hourly"
|
||||
};
|
||||
|
||||
let rows = sqlx::query_as::<_, (chrono::DateTime<chrono::Utc>, i64, f64, i32)>(
|
||||
&format!(
|
||||
@@ -147,12 +165,16 @@ async fn get_analytics_timeseries(
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(Json(rows.iter().map(|r| TimeseriesPoint {
|
||||
bucket: r.0.to_rfc3339(),
|
||||
request_count: r.1,
|
||||
cost_saved: r.2,
|
||||
avg_latency_ms: r.3,
|
||||
}).collect()))
|
||||
Ok(Json(
|
||||
rows.iter()
|
||||
.map(|r| TimeseriesPoint {
|
||||
bucket: r.0.to_rfc3339(),
|
||||
request_count: r.1,
|
||||
cost_saved: r.2,
|
||||
avg_latency_ms: r.3,
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -181,12 +203,16 @@ async fn get_model_breakdown(
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(Json(rows.iter().map(|r| ModelBreakdown {
|
||||
model: r.0.clone(),
|
||||
request_count: r.1,
|
||||
total_tokens: r.2,
|
||||
total_cost: r.3,
|
||||
}).collect()))
|
||||
Ok(Json(
|
||||
rows.iter()
|
||||
.map(|r| ModelBreakdown {
|
||||
model: r.0.clone(),
|
||||
request_count: r.1,
|
||||
total_tokens: r.2,
|
||||
total_cost: r.3,
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
// --- Routing Rules CRUD ---
|
||||
@@ -222,20 +248,24 @@ async fn list_rules(
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(Json(rows.iter().map(|r| RoutingRuleDto {
|
||||
id: Some(r.0),
|
||||
priority: r.1,
|
||||
name: r.2.clone(),
|
||||
match_model: r.3.clone(),
|
||||
match_feature: r.4.clone(),
|
||||
match_team: r.5.clone(),
|
||||
match_complexity: r.6.clone(),
|
||||
strategy: r.7.clone(),
|
||||
target_model: r.8.clone(),
|
||||
target_provider: r.9.clone(),
|
||||
fallback_models: r.10.clone(),
|
||||
enabled: r.11,
|
||||
}).collect()))
|
||||
Ok(Json(
|
||||
rows.iter()
|
||||
.map(|r| RoutingRuleDto {
|
||||
id: Some(r.0),
|
||||
priority: r.1,
|
||||
name: r.2.clone(),
|
||||
match_model: r.3.clone(),
|
||||
match_feature: r.4.clone(),
|
||||
match_team: r.5.clone(),
|
||||
match_complexity: r.6.clone(),
|
||||
strategy: r.7.clone(),
|
||||
target_model: r.8.clone(),
|
||||
target_provider: r.9.clone(),
|
||||
fallback_models: r.10.clone(),
|
||||
enabled: r.11,
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_rule(
|
||||
@@ -361,23 +391,37 @@ async fn list_keys(
|
||||
) -> Result<Json<Vec<ApiKeyDto>>, ApiError> {
|
||||
let auth = state.auth.authenticate(&headers).await?;
|
||||
|
||||
let rows = sqlx::query_as::<_, (Uuid, String, String, Vec<String>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(
|
||||
let rows = sqlx::query_as::<
|
||||
_,
|
||||
(
|
||||
Uuid,
|
||||
String,
|
||||
String,
|
||||
Vec<String>,
|
||||
Option<chrono::DateTime<chrono::Utc>>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
),
|
||||
>(
|
||||
"SELECT id, name, key_prefix, scopes, last_used_at, created_at
|
||||
FROM api_keys WHERE org_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
FROM api_keys WHERE org_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC",
|
||||
)
|
||||
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
|
||||
.fetch_all(&state.pg_pool)
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(Json(rows.iter().map(|r| ApiKeyDto {
|
||||
id: r.0,
|
||||
name: r.1.clone(),
|
||||
key_prefix: r.2.clone(),
|
||||
scopes: r.3.clone(),
|
||||
last_used_at: r.4,
|
||||
created_at: r.5,
|
||||
}).collect()))
|
||||
Ok(Json(
|
||||
rows.iter()
|
||||
.map(|r| ApiKeyDto {
|
||||
id: r.0,
|
||||
name: r.1.clone(),
|
||||
key_prefix: r.2.clone(),
|
||||
scopes: r.3.clone(),
|
||||
last_used_at: r.4,
|
||||
created_at: r.5,
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_key(
|
||||
@@ -409,11 +453,14 @@ async fn create_key(
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
Ok((StatusCode::CREATED, Json(ApiKeyCreated {
|
||||
id,
|
||||
key: raw_key,
|
||||
name: req.name,
|
||||
})))
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(ApiKeyCreated {
|
||||
id,
|
||||
key: raw_key,
|
||||
name: req.name,
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
async fn revoke_key(
|
||||
@@ -456,19 +503,23 @@ async fn list_providers(
|
||||
let auth = state.auth.authenticate(&headers).await?;
|
||||
|
||||
let rows = sqlx::query_as::<_, (String, Option<String>, bool)>(
|
||||
"SELECT provider, base_url, is_default FROM provider_configs WHERE org_id = $1"
|
||||
"SELECT provider, base_url, is_default FROM provider_configs WHERE org_id = $1",
|
||||
)
|
||||
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
|
||||
.fetch_all(&state.pg_pool)
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(Json(rows.iter().map(|r| ProviderDto {
|
||||
provider: r.0.clone(),
|
||||
base_url: r.1.clone(),
|
||||
is_default: r.2,
|
||||
has_key: true,
|
||||
}).collect()))
|
||||
Ok(Json(
|
||||
rows.iter()
|
||||
.map(|r| ProviderDto {
|
||||
provider: r.0.clone(),
|
||||
base_url: r.1.clone(),
|
||||
is_default: r.2,
|
||||
has_key: true,
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -488,13 +539,12 @@ async fn upsert_provider(
|
||||
require_role(&auth, Role::Owner)?;
|
||||
|
||||
// Encrypt API key with AES-256-GCM before storing
|
||||
let encryption_key = std::env::var("PROVIDER_KEY_ENCRYPTION_KEY")
|
||||
.unwrap_or_else(|_| "0".repeat(64)); // 32-byte hex key
|
||||
let key_bytes = hex::decode(&encryption_key)
|
||||
.unwrap_or_else(|_| vec![0u8; 32]);
|
||||
let encryption_key =
|
||||
std::env::var("PROVIDER_KEY_ENCRYPTION_KEY").unwrap_or_else(|_| "0".repeat(64)); // 32-byte hex key
|
||||
let key_bytes = hex::decode(&encryption_key).unwrap_or_else(|_| vec![0u8; 32]);
|
||||
|
||||
use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
|
||||
use aes_gcm::Nonce;
|
||||
use aes_gcm::{aead::Aead, Aes256Gcm, KeyInit};
|
||||
|
||||
let cipher = Aes256Gcm::new_from_slice(&key_bytes)
|
||||
.map_err(|e| ApiError::Internal(format!("Encryption key error: {}", e)))?;
|
||||
@@ -502,7 +552,8 @@ async fn upsert_provider(
|
||||
getrandom::getrandom(&mut nonce_bytes)
|
||||
.map_err(|e| ApiError::Internal(format!("RNG error: {}", e)))?;
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
let ciphertext = cipher.encrypt(nonce, req.api_key.as_bytes())
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, req.api_key.as_bytes())
|
||||
.map_err(|e| ApiError::Internal(format!("Encryption error: {}", e)))?;
|
||||
|
||||
// Store as nonce || ciphertext
|
||||
@@ -544,7 +595,7 @@ async fn get_org(
|
||||
let auth = state.auth.authenticate(&headers).await?;
|
||||
|
||||
let row = sqlx::query_as::<_, (Uuid, String, String, String)>(
|
||||
"SELECT id, name, slug, tier FROM organizations WHERE id = $1"
|
||||
"SELECT id, name, slug, tier FROM organizations WHERE id = $1",
|
||||
)
|
||||
.bind(auth.org_id.parse::<Uuid>().unwrap_or_default())
|
||||
.fetch_optional(&state.pg_pool)
|
||||
@@ -606,7 +657,10 @@ impl IntoResponse for ApiError {
|
||||
ApiError::AuthError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),
|
||||
ApiError::Forbidden => (StatusCode::FORBIDDEN, self.to_string()),
|
||||
ApiError::NotFound => (StatusCode::NOT_FOUND, self.to_string()),
|
||||
ApiError::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error".to_string()),
|
||||
ApiError::Internal(_) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Internal error".to_string(),
|
||||
),
|
||||
};
|
||||
(status, serde_json::json!({"error": msg}).to_string()).into_response()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use axum::http::HeaderMap;
|
||||
use async_trait::async_trait;
|
||||
use axum::http::HeaderMap;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -40,7 +40,11 @@ pub struct LocalAuthProvider {
|
||||
|
||||
impl LocalAuthProvider {
|
||||
pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self {
|
||||
Self { pool, _jwt_secret: jwt_secret, redis }
|
||||
Self {
|
||||
pool,
|
||||
_jwt_secret: jwt_secret,
|
||||
redis,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +76,7 @@ impl AuthProvider for LocalAuthProvider {
|
||||
|
||||
// 2. Fall back to PostgreSQL
|
||||
let row = sqlx::query_as::<_, (String, String)>(
|
||||
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL"
|
||||
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL",
|
||||
)
|
||||
.bind(&key[..8])
|
||||
.fetch_optional(&self.pool)
|
||||
|
||||
@@ -41,18 +41,24 @@ impl AppConfig {
|
||||
|
||||
let mut providers = HashMap::new();
|
||||
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
|
||||
providers.insert("openai".to_string(), ProviderConfig {
|
||||
api_key: key,
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
|
||||
});
|
||||
providers.insert(
|
||||
"openai".to_string(),
|
||||
ProviderConfig {
|
||||
api_key: key,
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
|
||||
providers.insert("anthropic".to_string(), ProviderConfig {
|
||||
api_key: key,
|
||||
base_url: std::env::var("ANTHROPIC_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
|
||||
});
|
||||
providers.insert(
|
||||
"anthropic".to_string(),
|
||||
ProviderConfig {
|
||||
api_key: key,
|
||||
base_url: std::env::var("ANTHROPIC_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
@@ -66,8 +72,9 @@ impl AppConfig {
|
||||
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5432/dd0c".to_string()),
|
||||
redis_url: std::env::var("REDIS_URL")
|
||||
.unwrap_or_else(|_| "redis://localhost:6379".to_string()),
|
||||
timescale_url: std::env::var("TIMESCALE_URL")
|
||||
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()),
|
||||
timescale_url: std::env::var("TIMESCALE_URL").unwrap_or_else(|_| {
|
||||
"postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()
|
||||
}),
|
||||
jwt_secret: std::env::var("JWT_SECRET")
|
||||
.unwrap_or_else(|_| "dev-secret-change-me".to_string()),
|
||||
auth_mode: if std::env::var("AUTH_MODE").unwrap_or_default() == "oauth" {
|
||||
|
||||
@@ -4,8 +4,8 @@ pub mod data;
|
||||
pub mod proxy;
|
||||
pub mod router;
|
||||
|
||||
pub use auth::{AuthProvider, AuthContext, AuthError, LocalAuthProvider, Role};
|
||||
pub use auth::{AuthContext, AuthError, AuthProvider, LocalAuthProvider, Role};
|
||||
pub use config::AppConfig;
|
||||
pub use data::{EventQueue, ObjectStore, TelemetryEvent};
|
||||
pub use proxy::{create_router, ProxyState, ProxyError};
|
||||
pub use router::{RouterBrain, RoutingDecision, Complexity};
|
||||
pub use proxy::{create_router, ProxyError, ProxyState};
|
||||
pub use router::{Complexity, RouterBrain, RoutingDecision};
|
||||
|
||||
@@ -10,7 +10,6 @@ use axum::{
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::config::AppConfig;
|
||||
use crate::data::TelemetryEvent;
|
||||
@@ -61,10 +60,7 @@ async fn proxy_chat_completions(
|
||||
.to_string();
|
||||
|
||||
// 3. Route (pick model + provider)
|
||||
let decision = state
|
||||
.router
|
||||
.route(&auth_ctx.org_id, &request)
|
||||
.await;
|
||||
let decision = state.router.route(&auth_ctx.org_id, &request).await;
|
||||
|
||||
// Apply routing decision
|
||||
if let Some(ref routed_model) = decision.model {
|
||||
@@ -80,7 +76,10 @@ async fn proxy_chat_completions(
|
||||
let upstream_resp = state
|
||||
.http_client
|
||||
.post(format!("{}/v1/chat/completions", upstream_url))
|
||||
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
|
||||
.header(
|
||||
"Authorization",
|
||||
format!("Bearer {}", state.config.provider_key(&provider)),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request.to_string())
|
||||
.send()
|
||||
@@ -100,7 +99,7 @@ async fn proxy_chat_completions(
|
||||
latency_ms: latency.as_millis() as u32,
|
||||
status_code: status.as_u16(),
|
||||
is_streaming,
|
||||
prompt_tokens: 0, // Filled by worker from response
|
||||
prompt_tokens: 0, // Filled by worker from response
|
||||
completion_tokens: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
});
|
||||
@@ -155,7 +154,10 @@ async fn proxy_embeddings(
|
||||
let resp = state
|
||||
.http_client
|
||||
.post(format!("{}/v1/embeddings", upstream_url))
|
||||
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
|
||||
.header(
|
||||
"Authorization",
|
||||
format!("Bearer {}", state.config.provider_key(&provider)),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.send()
|
||||
@@ -163,8 +165,14 @@ async fn proxy_embeddings(
|
||||
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp.bytes().await.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
|
||||
Ok(Response::builder().status(status).body(Body::from(body)).unwrap())
|
||||
let body = resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
|
||||
Ok(Response::builder()
|
||||
.status(status)
|
||||
.body(Body::from(body))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
// --- Error types ---
|
||||
|
||||
@@ -3,7 +3,7 @@ use tokio::sync::mpsc;
|
||||
use tracing::info;
|
||||
|
||||
use dd0c_route::{
|
||||
AppConfig, LocalAuthProvider, RouterBrain, ProxyState, TelemetryEvent, create_router,
|
||||
create_router, AppConfig, LocalAuthProvider, ProxyState, RouterBrain, TelemetryEvent,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
@@ -27,12 +27,15 @@ async fn main() -> anyhow::Result<()> {
|
||||
let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?;
|
||||
|
||||
// Telemetry channel (bounded, non-blocking)
|
||||
let (telemetry_tx, mut telemetry_rx) = mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size);
|
||||
let (telemetry_tx, mut telemetry_rx) =
|
||||
mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size);
|
||||
|
||||
// Spawn telemetry worker (writes to TimescaleDB)
|
||||
let ts_url = config.timescale_url.clone();
|
||||
tokio::spawn(async move {
|
||||
let ts_pool = sqlx::PgPool::connect(&ts_url).await.expect("TimescaleDB connection failed");
|
||||
let ts_pool = sqlx::PgPool::connect(&ts_url)
|
||||
.await
|
||||
.expect("TimescaleDB connection failed");
|
||||
while let Some(event) = telemetry_rx.recv().await {
|
||||
if let Err(e) = sqlx::query(
|
||||
"INSERT INTO request_events (org_id, original_model, routed_model, provider, strategy, latency_ms, status_code, is_streaming, prompt_tokens, completion_tokens, created_at)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pub mod handler;
|
||||
pub mod middleware;
|
||||
|
||||
pub use handler::{create_router, ProxyState, ProxyError};
|
||||
pub use handler::{create_router, ProxyError, ProxyState};
|
||||
|
||||
@@ -41,13 +41,21 @@ impl RouterBrain {
|
||||
.get("messages")
|
||||
.and_then(|m| m.as_array())
|
||||
.and_then(|msgs| {
|
||||
msgs.iter().find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
|
||||
msgs.iter()
|
||||
.find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
|
||||
})
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let high_complexity_keywords = ["analyze", "reason", "compare", "evaluate", "synthesize", "debate"];
|
||||
let high_complexity_keywords = [
|
||||
"analyze",
|
||||
"reason",
|
||||
"compare",
|
||||
"evaluate",
|
||||
"synthesize",
|
||||
"debate",
|
||||
];
|
||||
let has_complex_task = high_complexity_keywords
|
||||
.iter()
|
||||
.any(|kw| system_prompt.to_lowercase().contains(kw));
|
||||
@@ -142,7 +150,9 @@ mod tests {
|
||||
];
|
||||
for i in 0..12 {
|
||||
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
|
||||
messages.push(serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}));
|
||||
messages.push(
|
||||
serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}),
|
||||
);
|
||||
}
|
||||
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
|
||||
let complexity = brain.classify_complexity(&request);
|
||||
|
||||
@@ -4,7 +4,7 @@ use uuid::Uuid;
|
||||
|
||||
async fn get_org_owner_email(pool: &PgPool, org_id: Uuid) -> Result<String, anyhow::Error> {
|
||||
let row = sqlx::query_as::<_, (String,)>(
|
||||
"SELECT email FROM users WHERE org_id = $1 AND role = 'owner' LIMIT 1"
|
||||
"SELECT email FROM users WHERE org_id = $1 AND role = 'owner' LIMIT 1",
|
||||
)
|
||||
.bind(org_id)
|
||||
.fetch_one(pool)
|
||||
@@ -18,7 +18,7 @@ async fn get_org_owner_email(pool: &PgPool, org_id: Uuid) -> Result<String, anyh
|
||||
pub async fn check_anomalies(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow::Result<()> {
|
||||
// Get orgs with recent activity
|
||||
let orgs = sqlx::query_as::<_, (Uuid,)>(
|
||||
"SELECT DISTINCT org_id FROM request_events WHERE time >= now() - interval '1 hour'"
|
||||
"SELECT DISTINCT org_id FROM request_events WHERE time >= now() - interval '1 hour'",
|
||||
)
|
||||
.fetch_all(ts_pool)
|
||||
.await?;
|
||||
@@ -35,7 +35,7 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
|
||||
let current = sqlx::query_as::<_, (f64,)>(
|
||||
"SELECT COALESCE(SUM(cost_actual), 0)::float8
|
||||
FROM request_events
|
||||
WHERE org_id = $1 AND time >= now() - interval '1 hour'"
|
||||
WHERE org_id = $1 AND time >= now() - interval '1 hour'",
|
||||
)
|
||||
.bind(org_id)
|
||||
.fetch_one(ts_pool)
|
||||
@@ -88,11 +88,7 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
|
||||
]
|
||||
});
|
||||
let client = reqwest::Client::new();
|
||||
if let Err(e) = client.post(&slack_url)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
if let Err(e) = client.post(&slack_url).json(&payload).send().await {
|
||||
warn!(error = %e, "Failed to send Slack anomaly alert");
|
||||
}
|
||||
}
|
||||
@@ -114,7 +110,8 @@ async fn check_org_anomaly(ts_pool: &PgPool, pg_pool: &PgPool, org_id: Uuid) ->
|
||||
)
|
||||
});
|
||||
let client = reqwest::Client::new();
|
||||
if let Err(e) = client.post("https://api.resend.com/emails")
|
||||
if let Err(e) = client
|
||||
.post("https://api.resend.com/emails")
|
||||
.bearer_auth(&resend_key)
|
||||
.json(&email_body)
|
||||
.send()
|
||||
|
||||
@@ -6,19 +6,17 @@ use uuid::Uuid;
|
||||
/// Calculate next Monday 9 AM UTC from a given time
|
||||
pub fn next_monday_9am(from: DateTime<Utc>) -> DateTime<Utc> {
|
||||
let days_until_monday = (7 - from.weekday().num_days_from_monday()) % 7;
|
||||
let days_until_monday = if days_until_monday == 0 && from.time() >= NaiveTime::from_hms_opt(9, 0, 0).unwrap() {
|
||||
7 // Already past Monday 9 AM, go to next week
|
||||
} else if days_until_monday == 0 {
|
||||
0 // It's Monday but before 9 AM
|
||||
} else {
|
||||
days_until_monday
|
||||
};
|
||||
let days_until_monday =
|
||||
if days_until_monday == 0 && from.time() >= NaiveTime::from_hms_opt(9, 0, 0).unwrap() {
|
||||
7 // Already past Monday 9 AM, go to next week
|
||||
} else if days_until_monday == 0 {
|
||||
0 // It's Monday but before 9 AM
|
||||
} else {
|
||||
days_until_monday
|
||||
};
|
||||
|
||||
let target_date = from.date_naive() + chrono::Duration::days(days_until_monday as i64);
|
||||
target_date
|
||||
.and_hms_opt(9, 0, 0)
|
||||
.unwrap()
|
||||
.and_utc()
|
||||
target_date.and_hms_opt(9, 0, 0).unwrap().and_utc()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -58,7 +56,7 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
|
||||
WHERE o.id IN (
|
||||
SELECT DISTINCT org_id FROM request_events
|
||||
WHERE time >= now() - interval '7 days'
|
||||
)"
|
||||
)",
|
||||
)
|
||||
.fetch_all(pg_pool)
|
||||
.await?;
|
||||
@@ -70,13 +68,27 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
|
||||
Ok(digest) => {
|
||||
// Send weekly digest via Resend
|
||||
if let Ok(resend_key) = std::env::var("RESEND_API_KEY") {
|
||||
let models_html: String = digest.top_models.iter().map(|m| {
|
||||
format!("<tr><td>{}</td><td>{}</td><td>${:.4}</td></tr>", m.model, m.request_count, m.cost)
|
||||
}).collect();
|
||||
let models_html: String = digest
|
||||
.top_models
|
||||
.iter()
|
||||
.map(|m| {
|
||||
format!(
|
||||
"<tr><td>{}</td><td>{}</td><td>${:.4}</td></tr>",
|
||||
m.model, m.request_count, m.cost
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let savings_html: String = digest.top_savings.iter().map(|s| {
|
||||
format!("<tr><td>{} → {}</td><td>{}</td><td>${:.4}</td></tr>", s.original_model, s.routed_model, s.requests_routed, s.cost_saved)
|
||||
}).collect();
|
||||
let savings_html: String = digest
|
||||
.top_savings
|
||||
.iter()
|
||||
.map(|s| {
|
||||
format!(
|
||||
"<tr><td>{} → {}</td><td>{}</td><td>${:.4}</td></tr>",
|
||||
s.original_model, s.routed_model, s.requests_routed, s.cost_saved
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let html = format!(
|
||||
"<h2>Weekly Cost Digest: {}</h2>\
|
||||
@@ -93,10 +105,14 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
|
||||
<table style='border-collapse:collapse;width:100%'>\
|
||||
<tr><th>Route</th><th>Requests</th><th>Saved</th></tr>{}</table>\
|
||||
<p><a href='https://route.dd0c.dev/dashboard'>View Dashboard →</a></p>",
|
||||
digest.org_name, digest.total_requests,
|
||||
digest.total_cost_original, digest.total_cost_actual,
|
||||
digest.total_cost_saved, digest.savings_pct,
|
||||
models_html, savings_html
|
||||
digest.org_name,
|
||||
digest.total_requests,
|
||||
digest.total_cost_original,
|
||||
digest.total_cost_actual,
|
||||
digest.total_cost_saved,
|
||||
digest.savings_pct,
|
||||
models_html,
|
||||
savings_html
|
||||
);
|
||||
|
||||
let email_body = serde_json::json!({
|
||||
@@ -107,7 +123,8 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
match client.post("https://api.resend.com/emails")
|
||||
match client
|
||||
.post("https://api.resend.com/emails")
|
||||
.bearer_auth(&resend_key)
|
||||
.json(&email_body)
|
||||
.send()
|
||||
@@ -134,7 +151,11 @@ pub async fn generate_all_digests(ts_pool: &PgPool, pg_pool: &PgPool) -> anyhow:
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyhow::Result<DigestData> {
|
||||
async fn generate_digest(
|
||||
ts_pool: &PgPool,
|
||||
org_id: Uuid,
|
||||
org_name: &str,
|
||||
) -> anyhow::Result<DigestData> {
|
||||
// Summary stats
|
||||
let summary = sqlx::query_as::<_, (i64, f64, f64, f64)>(
|
||||
"SELECT COUNT(*),
|
||||
@@ -142,7 +163,7 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
|
||||
COALESCE(SUM(cost_actual), 0)::float8,
|
||||
COALESCE(SUM(cost_saved), 0)::float8
|
||||
FROM request_events
|
||||
WHERE org_id = $1 AND time >= now() - interval '7 days'"
|
||||
WHERE org_id = $1 AND time >= now() - interval '7 days'",
|
||||
)
|
||||
.bind(org_id)
|
||||
.fetch_one(ts_pool)
|
||||
@@ -155,7 +176,7 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
|
||||
WHERE org_id = $1 AND time >= now() - interval '7 days'
|
||||
GROUP BY original_model
|
||||
ORDER BY SUM(cost_actual) DESC
|
||||
LIMIT 5"
|
||||
LIMIT 5",
|
||||
)
|
||||
.bind(org_id)
|
||||
.fetch_all(ts_pool)
|
||||
@@ -168,13 +189,17 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
|
||||
WHERE org_id = $1 AND time >= now() - interval '7 days' AND strategy != 'passthrough'
|
||||
GROUP BY original_model, routed_model
|
||||
ORDER BY SUM(cost_saved) DESC
|
||||
LIMIT 5"
|
||||
LIMIT 5",
|
||||
)
|
||||
.bind(org_id)
|
||||
.fetch_all(ts_pool)
|
||||
.await?;
|
||||
|
||||
let savings_pct = if summary.1 > 0.0 { (summary.3 / summary.1) * 100.0 } else { 0.0 };
|
||||
let savings_pct = if summary.1 > 0.0 {
|
||||
(summary.3 / summary.1) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Ok(DigestData {
|
||||
_org_id: org_id,
|
||||
@@ -184,17 +209,23 @@ async fn generate_digest(ts_pool: &PgPool, org_id: Uuid, org_name: &str) -> anyh
|
||||
total_cost_actual: summary.2,
|
||||
total_cost_saved: summary.3,
|
||||
savings_pct,
|
||||
top_models: top_models.iter().map(|r| ModelUsage {
|
||||
model: r.0.clone(),
|
||||
request_count: r.1,
|
||||
cost: r.2,
|
||||
}).collect(),
|
||||
top_savings: top_savings.iter().map(|r| RoutingSaving {
|
||||
original_model: r.0.clone(),
|
||||
routed_model: r.1.clone(),
|
||||
requests_routed: r.2,
|
||||
cost_saved: r.3,
|
||||
}).collect(),
|
||||
top_models: top_models
|
||||
.iter()
|
||||
.map(|r| ModelUsage {
|
||||
model: r.0.clone(),
|
||||
request_count: r.1,
|
||||
cost: r.2,
|
||||
})
|
||||
.collect(),
|
||||
top_savings: top_savings
|
||||
.iter()
|
||||
.map(|r| RoutingSaving {
|
||||
original_model: r.0.clone(),
|
||||
routed_model: r.1.clone(),
|
||||
requests_routed: r.2,
|
||||
cost_saved: r.3,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -205,27 +236,45 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn next_monday_from_wednesday() {
|
||||
let wed = chrono::NaiveDate::from_ymd_opt(2026, 3, 4).unwrap() // Wednesday
|
||||
.and_hms_opt(14, 0, 0).unwrap().and_utc();
|
||||
let wed = chrono::NaiveDate::from_ymd_opt(2026, 3, 4)
|
||||
.unwrap() // Wednesday
|
||||
.and_hms_opt(14, 0, 0)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
let next = next_monday_9am(wed);
|
||||
assert_eq!(next.weekday(), Weekday::Mon);
|
||||
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap());
|
||||
assert_eq!(
|
||||
next.date_naive(),
|
||||
chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()
|
||||
);
|
||||
assert_eq!(next.time(), NaiveTime::from_hms_opt(9, 0, 0).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_monday_from_monday_before_9am() {
|
||||
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap() // Monday
|
||||
.and_hms_opt(8, 0, 0).unwrap().and_utc();
|
||||
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2)
|
||||
.unwrap() // Monday
|
||||
.and_hms_opt(8, 0, 0)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
let next = next_monday_9am(mon);
|
||||
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap());
|
||||
assert_eq!(
|
||||
next.date_naive(),
|
||||
chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_monday_from_monday_after_9am() {
|
||||
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2).unwrap()
|
||||
.and_hms_opt(10, 0, 0).unwrap().and_utc();
|
||||
let mon = chrono::NaiveDate::from_ymd_opt(2026, 3, 2)
|
||||
.unwrap()
|
||||
.and_hms_opt(10, 0, 0)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
let next = next_monday_9am(mon);
|
||||
assert_eq!(next.date_naive(), chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap());
|
||||
assert_eq!(
|
||||
next.date_naive(),
|
||||
chrono::NaiveDate::from_ymd_opt(2026, 3, 9).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, error};
|
||||
use chrono::Utc;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
|
||||
use dd0c_route::AppConfig;
|
||||
|
||||
mod digest;
|
||||
mod anomaly;
|
||||
mod digest;
|
||||
|
||||
/// Refresh model pricing from known provider pricing pages.
|
||||
/// Falls back to hardcoded defaults if fetch fails.
|
||||
@@ -83,7 +83,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
loop {
|
||||
let now = Utc::now();
|
||||
let next_monday = digest::next_monday_9am(now);
|
||||
let sleep_duration = (next_monday - now).to_std().unwrap_or(std::time::Duration::from_secs(3600));
|
||||
let sleep_duration = (next_monday - now)
|
||||
.to_std()
|
||||
.unwrap_or(std::time::Duration::from_secs(3600));
|
||||
tokio::time::sleep(sleep_duration).await;
|
||||
|
||||
info!("Generating weekly digests");
|
||||
|
||||
Reference in New Issue
Block a user