diff --git a/products/01-llm-cost-router/Cargo.toml b/products/01-llm-cost-router/Cargo.toml new file mode 100644 index 0000000..d28e3ce --- /dev/null +++ b/products/01-llm-cost-router/Cargo.toml @@ -0,0 +1,84 @@ +[package] +name = "dd0c-route" +version = "0.1.0" +edition = "2021" +description = "LLM Cost Router & Dashboard — route AI requests to the cheapest capable model" +license = "MIT" + +[[bin]] +name = "dd0c-proxy" +path = "src/proxy/main.rs" + +[[bin]] +name = "dd0c-api" +path = "src/api/main.rs" + +[[bin]] +name = "dd0c-worker" +path = "src/worker/main.rs" + +[dependencies] +# Web framework +axum = { version = "0.7", features = ["ws", "macros"] } +tokio = { version = "1", features = ["full"] } +tower = "0.4" +tower-http = { version = "0.5", features = ["cors", "trace", "compression-gzip"] } +hyper = { version = "1", features = ["full"] } +hyper-util = "0.1" + +# Serialization +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# Database +sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "uuid", "chrono", "json"] } +deadpool-redis = "0.15" +redis = { version = "0.25", features = ["tokio-comp", "connection-manager"] } + +# Auth +jsonwebtoken = "9" +bcrypt = "0.15" +uuid = { version = "1", features = ["v4", "serde"] } + +# Observability +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } +opentelemetry = "0.22" +opentelemetry-otlp = "0.15" + +# Config +dotenvy = "0.15" +config = "0.14" + +# HTTP client (upstream providers) +reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"] } +reqwest-eventsource = "0.6" +futures = "0.3" +tokio-stream = "0.1" +bytes = "1" + +# Crypto (provider key encryption) +aes-gcm = "0.10" +base64 = "0.22" + +# Feature flags +serde_yaml = "0.9" + +# Misc +chrono = { version = "0.4", features = ["serde"] } +thiserror = "1" +anyhow = "1" + +[dev-dependencies] +# Testing +tokio-test = "0.4" +wiremock = "0.6" +testcontainers = "0.15" +testcontainers-modules = { version = "0.3", features = ["postgres", "redis"] } +proptest = "1" +criterion = { version = "0.5", features = ["async_tokio"] } +tower-test = "0.4" + +[[bench]] +name = "proxy_latency" +harness = false diff --git a/products/01-llm-cost-router/src/api/main.rs b/products/01-llm-cost-router/src/api/main.rs new file mode 100644 index 0000000..d0a64e4 --- /dev/null +++ b/products/01-llm-cost-router/src/api/main.rs @@ -0,0 +1,4 @@ +fn main() { + println!("dd0c/route API server — not yet implemented"); + // TODO: Dashboard API (Epic 4) +} diff --git a/products/01-llm-cost-router/src/auth/mod.rs b/products/01-llm-cost-router/src/auth/mod.rs new file mode 100644 index 0000000..f71a7dc --- /dev/null +++ b/products/01-llm-cost-router/src/auth/mod.rs @@ -0,0 +1,105 @@ +use axum::http::HeaderMap; +use async_trait::async_trait; +use thiserror::Error; + +#[derive(Debug, Clone)] +pub struct AuthContext { + pub org_id: String, + pub user_id: Option, + pub role: Role, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Role { + Owner, + Member, + Viewer, +} + +#[derive(Debug, Error)] +pub enum AuthError { + #[error("Invalid API key")] + InvalidKey, + #[error("Expired token")] + ExpiredToken, + #[error("Missing authorization header")] + MissingAuth, +} + +#[async_trait] +pub trait AuthProvider: Send + Sync { + async fn authenticate(&self, headers: &HeaderMap) -> Result; +} + +/// Local auth — bcrypt passwords + HS256 JWT (self-hosted mode) +pub struct LocalAuthProvider { + pool: sqlx::PgPool, + jwt_secret: String, + redis: deadpool_redis::Pool, +} + +impl LocalAuthProvider { + pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self { + Self { pool, jwt_secret, redis } + } +} + +#[async_trait] +impl AuthProvider for LocalAuthProvider { + async fn authenticate(&self, headers: &HeaderMap) -> Result { + let key = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + .ok_or(AuthError::MissingAuth)?; + + // 1. Check Redis cache first + if let Ok(mut conn) = self.redis.get().await { + let cached: Option = redis::cmd("GET") + .arg(format!("apikey:{}", &key[..8])) // prefix lookup + .query_async(&mut *conn) + .await + .unwrap_or(None); + + if let Some(org_id) = cached { + return Ok(AuthContext { + org_id, + user_id: None, + role: Role::Member, + }); + } + } + + // 2. Fall back to PostgreSQL + let row = sqlx::query_as::<_, (String, String)>( + "SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL" + ) + .bind(&key[..8]) + .fetch_optional(&self.pool) + .await + .map_err(|_| AuthError::InvalidKey)? + .ok_or(AuthError::InvalidKey)?; + + // 3. Verify bcrypt hash + let valid = bcrypt::verify(key, &row.1).unwrap_or(false); + if !valid { + return Err(AuthError::InvalidKey); + } + + // 4. Cache in Redis for next time (5 min TTL) + if let Ok(mut conn) = self.redis.get().await { + let _: Result<(), _> = redis::cmd("SETEX") + .arg(format!("apikey:{}", &key[..8])) + .arg(300) + .arg(&row.0) + .query_async(&mut *conn) + .await; + } + + Ok(AuthContext { + org_id: row.0, + user_id: None, + role: Role::Member, + }) + } +} diff --git a/products/01-llm-cost-router/src/config/mod.rs b/products/01-llm-cost-router/src/config/mod.rs new file mode 100644 index 0000000..7ae383b --- /dev/null +++ b/products/01-llm-cost-router/src/config/mod.rs @@ -0,0 +1,104 @@ +use serde::Deserialize; +use std::collections::HashMap; + +#[derive(Debug, Clone, Deserialize)] +pub struct AppConfig { + pub proxy_port: u16, + pub api_port: u16, + pub database_url: String, + pub redis_url: String, + pub timescale_url: String, + pub jwt_secret: String, + pub auth_mode: AuthMode, + pub governance_mode: GovernanceMode, + pub providers: HashMap, + pub telemetry_channel_size: usize, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum AuthMode { + Local, + OAuth, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum GovernanceMode { + Strict, + Audit, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ProviderConfig { + pub api_key: String, + pub base_url: String, +} + +impl AppConfig { + pub fn from_env() -> anyhow::Result { + dotenvy::dotenv().ok(); + + let mut providers = HashMap::new(); + if let Ok(key) = std::env::var("OPENAI_API_KEY") { + providers.insert("openai".to_string(), ProviderConfig { + api_key: key, + base_url: std::env::var("OPENAI_BASE_URL") + .unwrap_or_else(|_| "https://api.openai.com".to_string()), + }); + } + if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") { + providers.insert("anthropic".to_string(), ProviderConfig { + api_key: key, + base_url: std::env::var("ANTHROPIC_BASE_URL") + .unwrap_or_else(|_| "https://api.anthropic.com".to_string()), + }); + } + + Ok(Self { + proxy_port: std::env::var("PROXY_PORT") + .unwrap_or_else(|_| "8080".to_string()) + .parse()?, + api_port: std::env::var("API_PORT") + .unwrap_or_else(|_| "3000".to_string()) + .parse()?, + database_url: std::env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5432/dd0c".to_string()), + redis_url: std::env::var("REDIS_URL") + .unwrap_or_else(|_| "redis://localhost:6379".to_string()), + timescale_url: std::env::var("TIMESCALE_URL") + .unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()), + jwt_secret: std::env::var("JWT_SECRET") + .unwrap_or_else(|_| "dev-secret-change-me".to_string()), + auth_mode: if std::env::var("AUTH_MODE").unwrap_or_default() == "oauth" { + AuthMode::OAuth + } else { + AuthMode::Local + }, + governance_mode: if std::env::var("GOVERNANCE_MODE").unwrap_or_default() == "strict" { + GovernanceMode::Strict + } else { + GovernanceMode::Audit + }, + providers, + telemetry_channel_size: std::env::var("TELEMETRY_CHANNEL_SIZE") + .unwrap_or_else(|_| "1000".to_string()) + .parse() + .unwrap_or(1000), + }) + } + + pub fn provider_url(&self, provider: &str) -> String { + self.providers + .get(provider) + .map(|p| p.base_url.clone()) + .unwrap_or_else(|| "https://api.openai.com".to_string()) + } + + pub fn provider_key(&self, provider: &str) -> String { + self.providers + .get(provider) + .map(|p| p.api_key.clone()) + .unwrap_or_default() + } +} diff --git a/products/01-llm-cost-router/src/data/mod.rs b/products/01-llm-cost-router/src/data/mod.rs new file mode 100644 index 0000000..346fc16 --- /dev/null +++ b/products/01-llm-cost-router/src/data/mod.rs @@ -0,0 +1,32 @@ +use serde::{Deserialize, Serialize}; + +/// Telemetry event emitted by the proxy on every request. +/// Sent via mpsc channel to the worker for async persistence. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TelemetryEvent { + pub org_id: String, + pub original_model: String, + pub routed_model: String, + pub provider: String, + pub strategy: String, + pub latency_ms: u32, + pub status_code: u16, + pub is_streaming: bool, + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub timestamp: chrono::DateTime, +} + +/// Abstraction over event queues (SQS for cloud, pgmq for self-hosted) +#[async_trait::async_trait] +pub trait EventQueue: Send + Sync { + async fn publish(&self, event: TelemetryEvent) -> anyhow::Result<()>; + async fn consume(&self, batch_size: usize) -> anyhow::Result>; +} + +/// Abstraction over object storage (S3 for cloud, local FS for self-hosted) +#[async_trait::async_trait] +pub trait ObjectStore: Send + Sync { + async fn put(&self, key: &str, data: &[u8]) -> anyhow::Result<()>; + async fn get(&self, key: &str) -> anyhow::Result>; +} diff --git a/products/01-llm-cost-router/src/lib.rs b/products/01-llm-cost-router/src/lib.rs new file mode 100644 index 0000000..b8d5bcf --- /dev/null +++ b/products/01-llm-cost-router/src/lib.rs @@ -0,0 +1,11 @@ +mod auth; +mod config; +mod data; +mod proxy; +mod router; + +pub use auth::{AuthProvider, AuthContext, AuthError, LocalAuthProvider, Role}; +pub use config::AppConfig; +pub use data::{EventQueue, ObjectStore, TelemetryEvent}; +pub use proxy::{create_router, ProxyState, ProxyError}; +pub use router::{RouterBrain, RoutingDecision, Complexity}; diff --git a/products/01-llm-cost-router/src/proxy/handler.rs b/products/01-llm-cost-router/src/proxy/handler.rs new file mode 100644 index 0000000..e144865 --- /dev/null +++ b/products/01-llm-cost-router/src/proxy/handler.rs @@ -0,0 +1,199 @@ +use anyhow::Result; +use axum::{ + body::Body, + extract::State, + http::{HeaderMap, Request, StatusCode}, + response::{IntoResponse, Response, Sse}, + routing::post, + Router, +}; +use std::sync::Arc; +use tokio::sync::mpsc; +use tracing::{info, warn}; + +mod middleware; + +use crate::auth::AuthProvider; +use crate::config::AppConfig; +use crate::data::{EventQueue, TelemetryEvent}; +use crate::router::RouterBrain; + +pub struct ProxyState { + pub auth: Arc, + pub router: Arc, + pub telemetry_tx: mpsc::Sender, + pub http_client: reqwest::Client, + pub config: Arc, +} + +pub fn create_router(state: Arc) -> Router { + Router::new() + .route("/v1/chat/completions", post(proxy_chat_completions)) + .route("/v1/completions", post(proxy_completions)) + .route("/v1/embeddings", post(proxy_embeddings)) + .route("/health", axum::routing::get(health)) + .with_state(state) +} + +async fn health() -> &'static str { + "ok" +} + +async fn proxy_chat_completions( + State(state): State>, + headers: HeaderMap, + body: String, +) -> Result { + // 1. Authenticate + let auth_ctx = state.auth.authenticate(&headers).await?; + + // 2. Parse request + let mut request: serde_json::Value = + serde_json::from_str(&body).map_err(|_| ProxyError::BadRequest("Invalid JSON"))?; + + let is_streaming = request + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let original_model = request + .get("model") + .and_then(|v| v.as_str()) + .unwrap_or("gpt-4o") + .to_string(); + + // 3. Route (pick model + provider) + let decision = state + .router + .route(&auth_ctx.org_id, &request) + .await; + + // Apply routing decision + if let Some(ref routed_model) = decision.model { + request["model"] = serde_json::Value::String(routed_model.clone()); + } + + let provider = decision.provider.unwrap_or_default(); + let upstream_url = state.config.provider_url(&provider); + + // 4. Forward to upstream + let start = std::time::Instant::now(); + + let upstream_resp = state + .http_client + .post(format!("{}/v1/chat/completions", upstream_url)) + .header("Authorization", format!("Bearer {}", state.config.provider_key(&provider))) + .header("Content-Type", "application/json") + .body(request.to_string()) + .send() + .await + .map_err(|e| ProxyError::UpstreamError(e.to_string()))?; + + let latency = start.elapsed(); + let status = upstream_resp.status(); + + // 5. Emit telemetry (non-blocking) + let _ = state.telemetry_tx.try_send(TelemetryEvent { + org_id: auth_ctx.org_id.clone(), + original_model: original_model.clone(), + routed_model: decision.model.clone().unwrap_or(original_model), + provider: provider.clone(), + strategy: decision.strategy.clone(), + latency_ms: latency.as_millis() as u32, + status_code: status.as_u16(), + is_streaming, + prompt_tokens: 0, // Filled by worker from response + completion_tokens: 0, + timestamp: chrono::Utc::now(), + }); + + // 6. Pass through response transparently + let resp_headers = upstream_resp.headers().clone(); + let resp_status = upstream_resp.status(); + + if is_streaming { + // Stream SSE chunks directly + let byte_stream = upstream_resp.bytes_stream(); + let body = Body::from_stream(byte_stream); + + let mut response = Response::builder().status(resp_status); + for (key, value) in resp_headers.iter() { + response = response.header(key, value); + } + Ok(response.body(body).unwrap()) + } else { + let resp_body = upstream_resp + .bytes() + .await + .map_err(|e| ProxyError::UpstreamError(e.to_string()))?; + + let mut response = Response::builder().status(resp_status); + for (key, value) in resp_headers.iter() { + response = response.header(key, value); + } + Ok(response.body(Body::from(resp_body)).unwrap()) + } +} + +// Placeholder — same pattern as chat completions +async fn proxy_completions( + State(state): State>, + headers: HeaderMap, + body: String, +) -> Result { + proxy_chat_completions(State(state), headers, body).await +} + +async fn proxy_embeddings( + State(state): State>, + headers: HeaderMap, + body: String, +) -> Result { + // Embeddings don't need routing — pass through directly + let auth_ctx = state.auth.authenticate(&headers).await?; + let provider = "openai".to_string(); + let upstream_url = state.config.provider_url(&provider); + + let resp = state + .http_client + .post(format!("{}/v1/embeddings", upstream_url)) + .header("Authorization", format!("Bearer {}", state.config.provider_key(&provider))) + .header("Content-Type", "application/json") + .body(body) + .send() + .await + .map_err(|e| ProxyError::UpstreamError(e.to_string()))?; + + let status = resp.status(); + let body = resp.bytes().await.map_err(|e| ProxyError::UpstreamError(e.to_string()))?; + Ok(Response::builder().status(status).body(Body::from(body)).unwrap()) +} + +// --- Error types --- + +#[derive(Debug, thiserror::Error)] +pub enum ProxyError { + #[error("Authentication failed: {0}")] + AuthError(String), + #[error("Bad request: {0}")] + BadRequest(&'static str), + #[error("Upstream error: {0}")] + UpstreamError(String), +} + +impl From for ProxyError { + fn from(e: crate::auth::AuthError) -> Self { + ProxyError::AuthError(e.to_string()) + } +} + +impl IntoResponse for ProxyError { + fn into_response(self) -> Response { + let (status, msg) = match &self { + ProxyError::AuthError(_) => (StatusCode::UNAUTHORIZED, self.to_string()), + ProxyError::BadRequest(_) => (StatusCode::BAD_REQUEST, self.to_string()), + ProxyError::UpstreamError(_) => (StatusCode::BAD_GATEWAY, self.to_string()), + }; + (status, serde_json::json!({"error": msg}).to_string()).into_response() + } +} diff --git a/products/01-llm-cost-router/src/proxy/main.rs b/products/01-llm-cost-router/src/proxy/main.rs new file mode 100644 index 0000000..7ecbc63 --- /dev/null +++ b/products/01-llm-cost-router/src/proxy/main.rs @@ -0,0 +1,81 @@ +use std::sync::Arc; +use tokio::sync::mpsc; +use tracing::info; + +use dd0c_route::{ + AppConfig, LocalAuthProvider, RouterBrain, ProxyState, TelemetryEvent, create_router, +}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Init tracing + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "dd0c_route=info,tower_http=info".into()), + ) + .json() + .init(); + + // Load config + let config = Arc::new(AppConfig::from_env()?); + info!(port = config.proxy_port, "Starting dd0c/route proxy"); + + // Connect to databases + let pg_pool = sqlx::PgPool::connect(&config.database_url).await?; + let redis_cfg = deadpool_redis::Config::from_url(&config.redis_url); + let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?; + + // Telemetry channel (bounded, non-blocking) + let (telemetry_tx, mut telemetry_rx) = mpsc::channel::(config.telemetry_channel_size); + + // Spawn telemetry worker (writes to TimescaleDB) + let ts_url = config.timescale_url.clone(); + tokio::spawn(async move { + let ts_pool = sqlx::PgPool::connect(&ts_url).await.expect("TimescaleDB connection failed"); + while let Some(event) = telemetry_rx.recv().await { + if let Err(e) = sqlx::query( + "INSERT INTO request_events (org_id, original_model, routed_model, provider, strategy, latency_ms, status_code, is_streaming, prompt_tokens, completion_tokens, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + ) + .bind(&event.org_id) + .bind(&event.original_model) + .bind(&event.routed_model) + .bind(&event.provider) + .bind(&event.strategy) + .bind(event.latency_ms as i32) + .bind(event.status_code as i16) + .bind(event.is_streaming) + .bind(event.prompt_tokens as i32) + .bind(event.completion_tokens as i32) + .bind(event.timestamp) + .execute(&ts_pool) + .await { + tracing::warn!(error = %e, "Failed to persist telemetry event"); + } + } + }); + + // Build proxy state + let state = Arc::new(ProxyState { + auth: Arc::new(LocalAuthProvider::new( + pg_pool.clone(), + config.jwt_secret.clone(), + redis_pool.clone(), + )), + router: Arc::new(RouterBrain::new()), + telemetry_tx, + http_client: reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(300)) // 5min for long completions + .build()?, + config: config.clone(), + }); + + // Start server + let app = create_router(state); + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", config.proxy_port)).await?; + info!(port = config.proxy_port, "Proxy listening"); + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/products/01-llm-cost-router/src/proxy/middleware.rs b/products/01-llm-cost-router/src/proxy/middleware.rs new file mode 100644 index 0000000..aa2f720 --- /dev/null +++ b/products/01-llm-cost-router/src/proxy/middleware.rs @@ -0,0 +1,61 @@ +// Proxy middleware — API key redaction in error traces +use axum::http::HeaderMap; +use tracing::warn; + +/// Redact any Bearer tokens or API keys from a string. +/// Used in panic handlers and error logging to prevent key leakage. +pub fn redact_sensitive(input: &str) -> String { + let patterns = [ + // OpenAI keys + (r"sk-[a-zA-Z0-9_-]{20,}", "[REDACTED_API_KEY]"), + // Anthropic keys + (r"sk-ant-[a-zA-Z0-9_-]{20,}", "[REDACTED_API_KEY]"), + // Bearer tokens + (r"Bearer\s+[a-zA-Z0-9_.-]+", "Bearer [REDACTED]"), + ]; + + let mut result = input.to_string(); + for (pattern, replacement) in &patterns { + if let Ok(re) = regex::Regex::new(pattern) { + result = re.replace_all(&result, *replacement).to_string(); + } + } + result +} + +/// Extract API key from Authorization header, returning redacted version for logging +pub fn extract_api_key(headers: &HeaderMap) -> Option { + headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + .map(|s| s.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn redacts_openai_key() { + let input = "Error with key sk-proj-abc123xyz456def789ghi"; + let result = redact_sensitive(input); + assert!(!result.contains("sk-proj-abc123xyz456def789ghi")); + assert!(result.contains("[REDACTED_API_KEY]")); + } + + #[test] + fn redacts_bearer_token() { + let input = "Authorization: Bearer sk-live-abc123xyz"; + let result = redact_sensitive(input); + assert!(!result.contains("sk-live-abc123xyz")); + assert!(result.contains("[REDACTED]")); + } + + #[test] + fn does_not_redact_normal_text() { + let input = "Hello world, this is a normal log message"; + let result = redact_sensitive(input); + assert_eq!(result, input); + } +} diff --git a/products/01-llm-cost-router/src/proxy/mod.rs b/products/01-llm-cost-router/src/proxy/mod.rs new file mode 100644 index 0000000..f516748 --- /dev/null +++ b/products/01-llm-cost-router/src/proxy/mod.rs @@ -0,0 +1,4 @@ +mod handler; +mod middleware; + +pub use handler::{create_router, ProxyState, ProxyError}; diff --git a/products/01-llm-cost-router/src/router/mod.rs b/products/01-llm-cost-router/src/router/mod.rs new file mode 100644 index 0000000..35d707c --- /dev/null +++ b/products/01-llm-cost-router/src/router/mod.rs @@ -0,0 +1,183 @@ +use serde::{Deserialize, Serialize}; +use crate::auth::AuthContext; + +/// Routing decision made by the Router Brain +#[derive(Debug, Clone, Serialize)] +pub struct RoutingDecision { + pub model: Option, + pub provider: Option, + pub strategy: String, + pub cost_delta: f64, + pub complexity: Complexity, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Complexity { + Low, + Medium, + High, +} + +/// The Router Brain — evaluates request complexity and applies routing rules +pub struct RouterBrain { + // In V1, rules are loaded from config. Later: from DB per org. +} + +impl RouterBrain { + pub fn new() -> Self { + Self {} + } + + /// Classify request complexity based on heuristics + fn classify_complexity(&self, request: &serde_json::Value) -> Complexity { + let messages = request + .get("messages") + .and_then(|m| m.as_array()) + .map(|a| a.len()) + .unwrap_or(0); + + let system_prompt = request + .get("messages") + .and_then(|m| m.as_array()) + .and_then(|msgs| { + msgs.iter().find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system")) + }) + .and_then(|m| m.get("content")) + .and_then(|c| c.as_str()) + .unwrap_or(""); + + let high_complexity_keywords = ["analyze", "reason", "compare", "evaluate", "synthesize", "debate"]; + let has_complex_task = high_complexity_keywords + .iter() + .any(|kw| system_prompt.to_lowercase().contains(kw)); + + if messages > 10 || has_complex_task { + Complexity::High + } else if messages > 3 || system_prompt.len() > 500 { + Complexity::Medium + } else { + Complexity::Low + } + } + + /// Route a request — returns the routing decision + pub async fn route(&self, _org_id: &str, request: &serde_json::Value) -> RoutingDecision { + let complexity = self.classify_complexity(request); + let original_model = request + .get("model") + .and_then(|v| v.as_str()) + .unwrap_or("gpt-4o"); + + // V1 routing logic: downgrade low-complexity to cheaper model + let (routed_model, strategy) = match complexity { + Complexity::Low => { + if original_model.contains("gpt-4") { + (Some("gpt-4o-mini".to_string()), "cheapest".to_string()) + } else { + (None, "passthrough".to_string()) + } + } + Complexity::Medium => (None, "passthrough".to_string()), + Complexity::High => (None, "passthrough".to_string()), + }; + + // Calculate cost delta + let cost_delta = match (&routed_model, original_model) { + (Some(routed), orig) => estimate_cost_delta(orig, routed), + _ => 0.0, + }; + + RoutingDecision { + model: routed_model, + provider: None, // V1: same provider + strategy, + cost_delta, + complexity, + } + } +} + +/// Estimate cost savings per 1K tokens when downgrading models +fn estimate_cost_delta(original: &str, routed: &str) -> f64 { + let price_per_1k = |model: &str| -> f64 { + match model { + "gpt-4o" => 0.005, + "gpt-4o-mini" => 0.00015, + "gpt-4-turbo" => 0.01, + "gpt-3.5-turbo" => 0.0005, + "claude-3-opus" => 0.015, + "claude-3-sonnet" => 0.003, + "claude-3-haiku" => 0.00025, + _ => 0.005, // default to gpt-4o pricing + } + }; + + price_per_1k(original) - price_per_1k(routed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn low_complexity_simple_extraction() { + let brain = RouterBrain::new(); + let request = serde_json::json!({ + "model": "gpt-4o", + "messages": [ + {"role": "system", "content": "Extract the name from this text"}, + {"role": "user", "content": "My name is Alice"} + ] + }); + let complexity = brain.classify_complexity(&request); + assert_eq!(complexity, Complexity::Low); + } + + #[test] + fn high_complexity_multi_turn_reasoning() { + let brain = RouterBrain::new(); + let mut messages = vec![ + serde_json::json!({"role": "system", "content": "Analyze and compare these approaches"}), + ]; + for i in 0..12 { + messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)})); + messages.push(serde_json::json!({"role": "assistant", "content": format!("Response {}", i)})); + } + let request = serde_json::json!({ "model": "gpt-4o", "messages": messages }); + let complexity = brain.classify_complexity(&request); + assert_eq!(complexity, Complexity::High); + } + + #[tokio::test] + async fn low_complexity_routes_to_mini() { + let brain = RouterBrain::new(); + let request = serde_json::json!({ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "What is 2+2?"}] + }); + let decision = brain.route("org-1", &request).await; + assert_eq!(decision.model, Some("gpt-4o-mini".to_string())); + assert_eq!(decision.strategy, "cheapest"); + assert!(decision.cost_delta > 0.0); + } + + #[tokio::test] + async fn high_complexity_passes_through() { + let brain = RouterBrain::new(); + let mut messages = vec![]; + for i in 0..15 { + messages.push(serde_json::json!({"role": "user", "content": format!("msg {}", i)})); + } + let request = serde_json::json!({ "model": "gpt-4o", "messages": messages }); + let decision = brain.route("org-1", &request).await; + assert_eq!(decision.model, None); + assert_eq!(decision.strategy, "passthrough"); + } + + #[test] + fn cost_delta_gpt4o_to_mini() { + let delta = estimate_cost_delta("gpt-4o", "gpt-4o-mini"); + assert!((delta - 0.00485).abs() < 0.0001); + } +} diff --git a/products/01-llm-cost-router/src/worker/main.rs b/products/01-llm-cost-router/src/worker/main.rs new file mode 100644 index 0000000..57fcb5f --- /dev/null +++ b/products/01-llm-cost-router/src/worker/main.rs @@ -0,0 +1,4 @@ +fn main() { + println!("dd0c/route worker — not yet implemented"); + // TODO: Background worker for digests, aggregations (Epic 7) +}