Scaffold dd0c/route core proxy engine (handler, router, auth, config)

This commit is contained in:
2026-03-01 02:23:27 +00:00
parent d038cd9c5c
commit cc003cbb1c
12 changed files with 872 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
[package]
name = "dd0c-route"
version = "0.1.0"
edition = "2021"
description = "LLM Cost Router & Dashboard — route AI requests to the cheapest capable model"
license = "MIT"
[[bin]]
name = "dd0c-proxy"
path = "src/proxy/main.rs"
[[bin]]
name = "dd0c-api"
path = "src/api/main.rs"
[[bin]]
name = "dd0c-worker"
path = "src/worker/main.rs"
[dependencies]
# Web framework
axum = { version = "0.7", features = ["ws", "macros"] }
tokio = { version = "1", features = ["full"] }
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "trace", "compression-gzip"] }
hyper = { version = "1", features = ["full"] }
hyper-util = "0.1"
# Serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# Database
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "uuid", "chrono", "json"] }
deadpool-redis = "0.15"
redis = { version = "0.25", features = ["tokio-comp", "connection-manager"] }
# Auth
jsonwebtoken = "9"
bcrypt = "0.15"
uuid = { version = "1", features = ["v4", "serde"] }
# Observability
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
opentelemetry = "0.22"
opentelemetry-otlp = "0.15"
# Config
dotenvy = "0.15"
config = "0.14"
# HTTP client (upstream providers)
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"] }
reqwest-eventsource = "0.6"
futures = "0.3"
tokio-stream = "0.1"
bytes = "1"
# Crypto (provider key encryption)
aes-gcm = "0.10"
base64 = "0.22"
# Feature flags
serde_yaml = "0.9"
# Misc
chrono = { version = "0.4", features = ["serde"] }
thiserror = "1"
anyhow = "1"
[dev-dependencies]
# Testing
tokio-test = "0.4"
wiremock = "0.6"
testcontainers = "0.15"
testcontainers-modules = { version = "0.3", features = ["postgres", "redis"] }
proptest = "1"
criterion = { version = "0.5", features = ["async_tokio"] }
tower-test = "0.4"
[[bench]]
name = "proxy_latency"
harness = false

View File

@@ -0,0 +1,4 @@
fn main() {
println!("dd0c/route API server — not yet implemented");
// TODO: Dashboard API (Epic 4)
}

View File

@@ -0,0 +1,105 @@
use axum::http::HeaderMap;
use async_trait::async_trait;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct AuthContext {
pub org_id: String,
pub user_id: Option<String>,
pub role: Role,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Role {
Owner,
Member,
Viewer,
}
#[derive(Debug, Error)]
pub enum AuthError {
#[error("Invalid API key")]
InvalidKey,
#[error("Expired token")]
ExpiredToken,
#[error("Missing authorization header")]
MissingAuth,
}
#[async_trait]
pub trait AuthProvider: Send + Sync {
async fn authenticate(&self, headers: &HeaderMap) -> Result<AuthContext, AuthError>;
}
/// Local auth — bcrypt passwords + HS256 JWT (self-hosted mode)
pub struct LocalAuthProvider {
pool: sqlx::PgPool,
jwt_secret: String,
redis: deadpool_redis::Pool,
}
impl LocalAuthProvider {
pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self {
Self { pool, jwt_secret, redis }
}
}
#[async_trait]
impl AuthProvider for LocalAuthProvider {
async fn authenticate(&self, headers: &HeaderMap) -> Result<AuthContext, AuthError> {
let key = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.ok_or(AuthError::MissingAuth)?;
// 1. Check Redis cache first
if let Ok(mut conn) = self.redis.get().await {
let cached: Option<String> = redis::cmd("GET")
.arg(format!("apikey:{}", &key[..8])) // prefix lookup
.query_async(&mut *conn)
.await
.unwrap_or(None);
if let Some(org_id) = cached {
return Ok(AuthContext {
org_id,
user_id: None,
role: Role::Member,
});
}
}
// 2. Fall back to PostgreSQL
let row = sqlx::query_as::<_, (String, String)>(
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL"
)
.bind(&key[..8])
.fetch_optional(&self.pool)
.await
.map_err(|_| AuthError::InvalidKey)?
.ok_or(AuthError::InvalidKey)?;
// 3. Verify bcrypt hash
let valid = bcrypt::verify(key, &row.1).unwrap_or(false);
if !valid {
return Err(AuthError::InvalidKey);
}
// 4. Cache in Redis for next time (5 min TTL)
if let Ok(mut conn) = self.redis.get().await {
let _: Result<(), _> = redis::cmd("SETEX")
.arg(format!("apikey:{}", &key[..8]))
.arg(300)
.arg(&row.0)
.query_async(&mut *conn)
.await;
}
Ok(AuthContext {
org_id: row.0,
user_id: None,
role: Role::Member,
})
}
}

View File

@@ -0,0 +1,104 @@
use serde::Deserialize;
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub proxy_port: u16,
pub api_port: u16,
pub database_url: String,
pub redis_url: String,
pub timescale_url: String,
pub jwt_secret: String,
pub auth_mode: AuthMode,
pub governance_mode: GovernanceMode,
pub providers: HashMap<String, ProviderConfig>,
pub telemetry_channel_size: usize,
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum AuthMode {
Local,
OAuth,
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum GovernanceMode {
Strict,
Audit,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProviderConfig {
pub api_key: String,
pub base_url: String,
}
impl AppConfig {
pub fn from_env() -> anyhow::Result<Self> {
dotenvy::dotenv().ok();
let mut providers = HashMap::new();
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
providers.insert("openai".to_string(), ProviderConfig {
api_key: key,
base_url: std::env::var("OPENAI_BASE_URL")
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
});
}
if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
providers.insert("anthropic".to_string(), ProviderConfig {
api_key: key,
base_url: std::env::var("ANTHROPIC_BASE_URL")
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
});
}
Ok(Self {
proxy_port: std::env::var("PROXY_PORT")
.unwrap_or_else(|_| "8080".to_string())
.parse()?,
api_port: std::env::var("API_PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse()?,
database_url: std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5432/dd0c".to_string()),
redis_url: std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost:6379".to_string()),
timescale_url: std::env::var("TIMESCALE_URL")
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()),
jwt_secret: std::env::var("JWT_SECRET")
.unwrap_or_else(|_| "dev-secret-change-me".to_string()),
auth_mode: if std::env::var("AUTH_MODE").unwrap_or_default() == "oauth" {
AuthMode::OAuth
} else {
AuthMode::Local
},
governance_mode: if std::env::var("GOVERNANCE_MODE").unwrap_or_default() == "strict" {
GovernanceMode::Strict
} else {
GovernanceMode::Audit
},
providers,
telemetry_channel_size: std::env::var("TELEMETRY_CHANNEL_SIZE")
.unwrap_or_else(|_| "1000".to_string())
.parse()
.unwrap_or(1000),
})
}
pub fn provider_url(&self, provider: &str) -> String {
self.providers
.get(provider)
.map(|p| p.base_url.clone())
.unwrap_or_else(|| "https://api.openai.com".to_string())
}
pub fn provider_key(&self, provider: &str) -> String {
self.providers
.get(provider)
.map(|p| p.api_key.clone())
.unwrap_or_default()
}
}

View File

@@ -0,0 +1,32 @@
use serde::{Deserialize, Serialize};
/// Telemetry event emitted by the proxy on every request.
/// Sent via mpsc channel to the worker for async persistence.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TelemetryEvent {
pub org_id: String,
pub original_model: String,
pub routed_model: String,
pub provider: String,
pub strategy: String,
pub latency_ms: u32,
pub status_code: u16,
pub is_streaming: bool,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
/// Abstraction over event queues (SQS for cloud, pgmq for self-hosted)
#[async_trait::async_trait]
pub trait EventQueue: Send + Sync {
async fn publish(&self, event: TelemetryEvent) -> anyhow::Result<()>;
async fn consume(&self, batch_size: usize) -> anyhow::Result<Vec<TelemetryEvent>>;
}
/// Abstraction over object storage (S3 for cloud, local FS for self-hosted)
#[async_trait::async_trait]
pub trait ObjectStore: Send + Sync {
async fn put(&self, key: &str, data: &[u8]) -> anyhow::Result<()>;
async fn get(&self, key: &str) -> anyhow::Result<Vec<u8>>;
}

View File

@@ -0,0 +1,11 @@
mod auth;
mod config;
mod data;
mod proxy;
mod router;
pub use auth::{AuthProvider, AuthContext, AuthError, LocalAuthProvider, Role};
pub use config::AppConfig;
pub use data::{EventQueue, ObjectStore, TelemetryEvent};
pub use proxy::{create_router, ProxyState, ProxyError};
pub use router::{RouterBrain, RoutingDecision, Complexity};

View File

@@ -0,0 +1,199 @@
use anyhow::Result;
use axum::{
body::Body,
extract::State,
http::{HeaderMap, Request, StatusCode},
response::{IntoResponse, Response, Sse},
routing::post,
Router,
};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{info, warn};
mod middleware;
use crate::auth::AuthProvider;
use crate::config::AppConfig;
use crate::data::{EventQueue, TelemetryEvent};
use crate::router::RouterBrain;
pub struct ProxyState {
pub auth: Arc<dyn AuthProvider>,
pub router: Arc<RouterBrain>,
pub telemetry_tx: mpsc::Sender<TelemetryEvent>,
pub http_client: reqwest::Client,
pub config: Arc<AppConfig>,
}
pub fn create_router(state: Arc<ProxyState>) -> Router {
Router::new()
.route("/v1/chat/completions", post(proxy_chat_completions))
.route("/v1/completions", post(proxy_completions))
.route("/v1/embeddings", post(proxy_embeddings))
.route("/health", axum::routing::get(health))
.with_state(state)
}
async fn health() -> &'static str {
"ok"
}
async fn proxy_chat_completions(
State(state): State<Arc<ProxyState>>,
headers: HeaderMap,
body: String,
) -> Result<Response, ProxyError> {
// 1. Authenticate
let auth_ctx = state.auth.authenticate(&headers).await?;
// 2. Parse request
let mut request: serde_json::Value =
serde_json::from_str(&body).map_err(|_| ProxyError::BadRequest("Invalid JSON"))?;
let is_streaming = request
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let original_model = request
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("gpt-4o")
.to_string();
// 3. Route (pick model + provider)
let decision = state
.router
.route(&auth_ctx.org_id, &request)
.await;
// Apply routing decision
if let Some(ref routed_model) = decision.model {
request["model"] = serde_json::Value::String(routed_model.clone());
}
let provider = decision.provider.unwrap_or_default();
let upstream_url = state.config.provider_url(&provider);
// 4. Forward to upstream
let start = std::time::Instant::now();
let upstream_resp = state
.http_client
.post(format!("{}/v1/chat/completions", upstream_url))
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
.header("Content-Type", "application/json")
.body(request.to_string())
.send()
.await
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
let latency = start.elapsed();
let status = upstream_resp.status();
// 5. Emit telemetry (non-blocking)
let _ = state.telemetry_tx.try_send(TelemetryEvent {
org_id: auth_ctx.org_id.clone(),
original_model: original_model.clone(),
routed_model: decision.model.clone().unwrap_or(original_model),
provider: provider.clone(),
strategy: decision.strategy.clone(),
latency_ms: latency.as_millis() as u32,
status_code: status.as_u16(),
is_streaming,
prompt_tokens: 0, // Filled by worker from response
completion_tokens: 0,
timestamp: chrono::Utc::now(),
});
// 6. Pass through response transparently
let resp_headers = upstream_resp.headers().clone();
let resp_status = upstream_resp.status();
if is_streaming {
// Stream SSE chunks directly
let byte_stream = upstream_resp.bytes_stream();
let body = Body::from_stream(byte_stream);
let mut response = Response::builder().status(resp_status);
for (key, value) in resp_headers.iter() {
response = response.header(key, value);
}
Ok(response.body(body).unwrap())
} else {
let resp_body = upstream_resp
.bytes()
.await
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
let mut response = Response::builder().status(resp_status);
for (key, value) in resp_headers.iter() {
response = response.header(key, value);
}
Ok(response.body(Body::from(resp_body)).unwrap())
}
}
// Placeholder — same pattern as chat completions
async fn proxy_completions(
State(state): State<Arc<ProxyState>>,
headers: HeaderMap,
body: String,
) -> Result<Response, ProxyError> {
proxy_chat_completions(State(state), headers, body).await
}
async fn proxy_embeddings(
State(state): State<Arc<ProxyState>>,
headers: HeaderMap,
body: String,
) -> Result<Response, ProxyError> {
// Embeddings don't need routing — pass through directly
let auth_ctx = state.auth.authenticate(&headers).await?;
let provider = "openai".to_string();
let upstream_url = state.config.provider_url(&provider);
let resp = state
.http_client
.post(format!("{}/v1/embeddings", upstream_url))
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
.header("Content-Type", "application/json")
.body(body)
.send()
.await
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
let status = resp.status();
let body = resp.bytes().await.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
Ok(Response::builder().status(status).body(Body::from(body)).unwrap())
}
// --- Error types ---
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("Authentication failed: {0}")]
AuthError(String),
#[error("Bad request: {0}")]
BadRequest(&'static str),
#[error("Upstream error: {0}")]
UpstreamError(String),
}
impl From<crate::auth::AuthError> for ProxyError {
fn from(e: crate::auth::AuthError) -> Self {
ProxyError::AuthError(e.to_string())
}
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let (status, msg) = match &self {
ProxyError::AuthError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),
ProxyError::BadRequest(_) => (StatusCode::BAD_REQUEST, self.to_string()),
ProxyError::UpstreamError(_) => (StatusCode::BAD_GATEWAY, self.to_string()),
};
(status, serde_json::json!({"error": msg}).to_string()).into_response()
}
}

View File

@@ -0,0 +1,81 @@
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::info;
use dd0c_route::{
AppConfig, LocalAuthProvider, RouterBrain, ProxyState, TelemetryEvent, create_router,
};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Init tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "dd0c_route=info,tower_http=info".into()),
)
.json()
.init();
// Load config
let config = Arc::new(AppConfig::from_env()?);
info!(port = config.proxy_port, "Starting dd0c/route proxy");
// Connect to databases
let pg_pool = sqlx::PgPool::connect(&config.database_url).await?;
let redis_cfg = deadpool_redis::Config::from_url(&config.redis_url);
let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?;
// Telemetry channel (bounded, non-blocking)
let (telemetry_tx, mut telemetry_rx) = mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size);
// Spawn telemetry worker (writes to TimescaleDB)
let ts_url = config.timescale_url.clone();
tokio::spawn(async move {
let ts_pool = sqlx::PgPool::connect(&ts_url).await.expect("TimescaleDB connection failed");
while let Some(event) = telemetry_rx.recv().await {
if let Err(e) = sqlx::query(
"INSERT INTO request_events (org_id, original_model, routed_model, provider, strategy, latency_ms, status_code, is_streaming, prompt_tokens, completion_tokens, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
)
.bind(&event.org_id)
.bind(&event.original_model)
.bind(&event.routed_model)
.bind(&event.provider)
.bind(&event.strategy)
.bind(event.latency_ms as i32)
.bind(event.status_code as i16)
.bind(event.is_streaming)
.bind(event.prompt_tokens as i32)
.bind(event.completion_tokens as i32)
.bind(event.timestamp)
.execute(&ts_pool)
.await {
tracing::warn!(error = %e, "Failed to persist telemetry event");
}
}
});
// Build proxy state
let state = Arc::new(ProxyState {
auth: Arc::new(LocalAuthProvider::new(
pg_pool.clone(),
config.jwt_secret.clone(),
redis_pool.clone(),
)),
router: Arc::new(RouterBrain::new()),
telemetry_tx,
http_client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300)) // 5min for long completions
.build()?,
config: config.clone(),
});
// Start server
let app = create_router(state);
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", config.proxy_port)).await?;
info!(port = config.proxy_port, "Proxy listening");
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -0,0 +1,61 @@
// Proxy middleware — API key redaction in error traces
use axum::http::HeaderMap;
use tracing::warn;
/// Redact any Bearer tokens or API keys from a string.
/// Used in panic handlers and error logging to prevent key leakage.
pub fn redact_sensitive(input: &str) -> String {
let patterns = [
// OpenAI keys
(r"sk-[a-zA-Z0-9_-]{20,}", "[REDACTED_API_KEY]"),
// Anthropic keys
(r"sk-ant-[a-zA-Z0-9_-]{20,}", "[REDACTED_API_KEY]"),
// Bearer tokens
(r"Bearer\s+[a-zA-Z0-9_.-]+", "Bearer [REDACTED]"),
];
let mut result = input.to_string();
for (pattern, replacement) in &patterns {
if let Ok(re) = regex::Regex::new(pattern) {
result = re.replace_all(&result, *replacement).to_string();
}
}
result
}
/// Extract API key from Authorization header, returning redacted version for logging
pub fn extract_api_key(headers: &HeaderMap) -> Option<String> {
headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|s| s.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn redacts_openai_key() {
let input = "Error with key sk-proj-abc123xyz456def789ghi";
let result = redact_sensitive(input);
assert!(!result.contains("sk-proj-abc123xyz456def789ghi"));
assert!(result.contains("[REDACTED_API_KEY]"));
}
#[test]
fn redacts_bearer_token() {
let input = "Authorization: Bearer sk-live-abc123xyz";
let result = redact_sensitive(input);
assert!(!result.contains("sk-live-abc123xyz"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn does_not_redact_normal_text() {
let input = "Hello world, this is a normal log message";
let result = redact_sensitive(input);
assert_eq!(result, input);
}
}

View File

@@ -0,0 +1,4 @@
mod handler;
mod middleware;
pub use handler::{create_router, ProxyState, ProxyError};

View File

@@ -0,0 +1,183 @@
use serde::{Deserialize, Serialize};
use crate::auth::AuthContext;
/// Routing decision made by the Router Brain
#[derive(Debug, Clone, Serialize)]
pub struct RoutingDecision {
pub model: Option<String>,
pub provider: Option<String>,
pub strategy: String,
pub cost_delta: f64,
pub complexity: Complexity,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Complexity {
Low,
Medium,
High,
}
/// The Router Brain — evaluates request complexity and applies routing rules
pub struct RouterBrain {
// In V1, rules are loaded from config. Later: from DB per org.
}
impl RouterBrain {
pub fn new() -> Self {
Self {}
}
/// Classify request complexity based on heuristics
fn classify_complexity(&self, request: &serde_json::Value) -> Complexity {
let messages = request
.get("messages")
.and_then(|m| m.as_array())
.map(|a| a.len())
.unwrap_or(0);
let system_prompt = request
.get("messages")
.and_then(|m| m.as_array())
.and_then(|msgs| {
msgs.iter().find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
})
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("");
let high_complexity_keywords = ["analyze", "reason", "compare", "evaluate", "synthesize", "debate"];
let has_complex_task = high_complexity_keywords
.iter()
.any(|kw| system_prompt.to_lowercase().contains(kw));
if messages > 10 || has_complex_task {
Complexity::High
} else if messages > 3 || system_prompt.len() > 500 {
Complexity::Medium
} else {
Complexity::Low
}
}
/// Route a request — returns the routing decision
pub async fn route(&self, _org_id: &str, request: &serde_json::Value) -> RoutingDecision {
let complexity = self.classify_complexity(request);
let original_model = request
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("gpt-4o");
// V1 routing logic: downgrade low-complexity to cheaper model
let (routed_model, strategy) = match complexity {
Complexity::Low => {
if original_model.contains("gpt-4") {
(Some("gpt-4o-mini".to_string()), "cheapest".to_string())
} else {
(None, "passthrough".to_string())
}
}
Complexity::Medium => (None, "passthrough".to_string()),
Complexity::High => (None, "passthrough".to_string()),
};
// Calculate cost delta
let cost_delta = match (&routed_model, original_model) {
(Some(routed), orig) => estimate_cost_delta(orig, routed),
_ => 0.0,
};
RoutingDecision {
model: routed_model,
provider: None, // V1: same provider
strategy,
cost_delta,
complexity,
}
}
}
/// Estimate cost savings per 1K tokens when downgrading models
fn estimate_cost_delta(original: &str, routed: &str) -> f64 {
let price_per_1k = |model: &str| -> f64 {
match model {
"gpt-4o" => 0.005,
"gpt-4o-mini" => 0.00015,
"gpt-4-turbo" => 0.01,
"gpt-3.5-turbo" => 0.0005,
"claude-3-opus" => 0.015,
"claude-3-sonnet" => 0.003,
"claude-3-haiku" => 0.00025,
_ => 0.005, // default to gpt-4o pricing
}
};
price_per_1k(original) - price_per_1k(routed)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn low_complexity_simple_extraction() {
let brain = RouterBrain::new();
let request = serde_json::json!({
"model": "gpt-4o",
"messages": [
{"role": "system", "content": "Extract the name from this text"},
{"role": "user", "content": "My name is Alice"}
]
});
let complexity = brain.classify_complexity(&request);
assert_eq!(complexity, Complexity::Low);
}
#[test]
fn high_complexity_multi_turn_reasoning() {
let brain = RouterBrain::new();
let mut messages = vec![
serde_json::json!({"role": "system", "content": "Analyze and compare these approaches"}),
];
for i in 0..12 {
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
messages.push(serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}));
}
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
let complexity = brain.classify_complexity(&request);
assert_eq!(complexity, Complexity::High);
}
#[tokio::test]
async fn low_complexity_routes_to_mini() {
let brain = RouterBrain::new();
let request = serde_json::json!({
"model": "gpt-4o",
"messages": [{"role": "user", "content": "What is 2+2?"}]
});
let decision = brain.route("org-1", &request).await;
assert_eq!(decision.model, Some("gpt-4o-mini".to_string()));
assert_eq!(decision.strategy, "cheapest");
assert!(decision.cost_delta > 0.0);
}
#[tokio::test]
async fn high_complexity_passes_through() {
let brain = RouterBrain::new();
let mut messages = vec![];
for i in 0..15 {
messages.push(serde_json::json!({"role": "user", "content": format!("msg {}", i)}));
}
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
let decision = brain.route("org-1", &request).await;
assert_eq!(decision.model, None);
assert_eq!(decision.strategy, "passthrough");
}
#[test]
fn cost_delta_gpt4o_to_mini() {
let delta = estimate_cost_delta("gpt-4o", "gpt-4o-mini");
assert!((delta - 0.00485).abs() < 0.0001);
}
}

View File

@@ -0,0 +1,4 @@
fn main() {
println!("dd0c/route worker — not yet implemented");
// TODO: Background worker for digests, aggregations (Epic 7)
}