Scaffold dd0c/route core proxy engine (handler, router, auth, config)
This commit is contained in:
84
products/01-llm-cost-router/Cargo.toml
Normal file
84
products/01-llm-cost-router/Cargo.toml
Normal file
@@ -0,0 +1,84 @@
|
||||
[package]
|
||||
name = "dd0c-route"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "LLM Cost Router & Dashboard — route AI requests to the cheapest capable model"
|
||||
license = "MIT"
|
||||
|
||||
[[bin]]
|
||||
name = "dd0c-proxy"
|
||||
path = "src/proxy/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "dd0c-api"
|
||||
path = "src/api/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "dd0c-worker"
|
||||
path = "src/worker/main.rs"
|
||||
|
||||
[dependencies]
|
||||
# Web framework
|
||||
axum = { version = "0.7", features = ["ws", "macros"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tower = "0.4"
|
||||
tower-http = { version = "0.5", features = ["cors", "trace", "compression-gzip"] }
|
||||
hyper = { version = "1", features = ["full"] }
|
||||
hyper-util = "0.1"
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# Database
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "uuid", "chrono", "json"] }
|
||||
deadpool-redis = "0.15"
|
||||
redis = { version = "0.25", features = ["tokio-comp", "connection-manager"] }
|
||||
|
||||
# Auth
|
||||
jsonwebtoken = "9"
|
||||
bcrypt = "0.15"
|
||||
uuid = { version = "1", features = ["v4", "serde"] }
|
||||
|
||||
# Observability
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
opentelemetry = "0.22"
|
||||
opentelemetry-otlp = "0.15"
|
||||
|
||||
# Config
|
||||
dotenvy = "0.15"
|
||||
config = "0.14"
|
||||
|
||||
# HTTP client (upstream providers)
|
||||
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
futures = "0.3"
|
||||
tokio-stream = "0.1"
|
||||
bytes = "1"
|
||||
|
||||
# Crypto (provider key encryption)
|
||||
aes-gcm = "0.10"
|
||||
base64 = "0.22"
|
||||
|
||||
# Feature flags
|
||||
serde_yaml = "0.9"
|
||||
|
||||
# Misc
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
thiserror = "1"
|
||||
anyhow = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
# Testing
|
||||
tokio-test = "0.4"
|
||||
wiremock = "0.6"
|
||||
testcontainers = "0.15"
|
||||
testcontainers-modules = { version = "0.3", features = ["postgres", "redis"] }
|
||||
proptest = "1"
|
||||
criterion = { version = "0.5", features = ["async_tokio"] }
|
||||
tower-test = "0.4"
|
||||
|
||||
[[bench]]
|
||||
name = "proxy_latency"
|
||||
harness = false
|
||||
4
products/01-llm-cost-router/src/api/main.rs
Normal file
4
products/01-llm-cost-router/src/api/main.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
fn main() {
|
||||
println!("dd0c/route API server — not yet implemented");
|
||||
// TODO: Dashboard API (Epic 4)
|
||||
}
|
||||
105
products/01-llm-cost-router/src/auth/mod.rs
Normal file
105
products/01-llm-cost-router/src/auth/mod.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use axum::http::HeaderMap;
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthContext {
|
||||
pub org_id: String,
|
||||
pub user_id: Option<String>,
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Role {
|
||||
Owner,
|
||||
Member,
|
||||
Viewer,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthError {
|
||||
#[error("Invalid API key")]
|
||||
InvalidKey,
|
||||
#[error("Expired token")]
|
||||
ExpiredToken,
|
||||
#[error("Missing authorization header")]
|
||||
MissingAuth,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AuthProvider: Send + Sync {
|
||||
async fn authenticate(&self, headers: &HeaderMap) -> Result<AuthContext, AuthError>;
|
||||
}
|
||||
|
||||
/// Local auth — bcrypt passwords + HS256 JWT (self-hosted mode)
|
||||
pub struct LocalAuthProvider {
|
||||
pool: sqlx::PgPool,
|
||||
jwt_secret: String,
|
||||
redis: deadpool_redis::Pool,
|
||||
}
|
||||
|
||||
impl LocalAuthProvider {
|
||||
pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self {
|
||||
Self { pool, jwt_secret, redis }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthProvider for LocalAuthProvider {
|
||||
async fn authenticate(&self, headers: &HeaderMap) -> Result<AuthContext, AuthError> {
|
||||
let key = headers
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "))
|
||||
.ok_or(AuthError::MissingAuth)?;
|
||||
|
||||
// 1. Check Redis cache first
|
||||
if let Ok(mut conn) = self.redis.get().await {
|
||||
let cached: Option<String> = redis::cmd("GET")
|
||||
.arg(format!("apikey:{}", &key[..8])) // prefix lookup
|
||||
.query_async(&mut *conn)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
if let Some(org_id) = cached {
|
||||
return Ok(AuthContext {
|
||||
org_id,
|
||||
user_id: None,
|
||||
role: Role::Member,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Fall back to PostgreSQL
|
||||
let row = sqlx::query_as::<_, (String, String)>(
|
||||
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(&key[..8])
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|_| AuthError::InvalidKey)?
|
||||
.ok_or(AuthError::InvalidKey)?;
|
||||
|
||||
// 3. Verify bcrypt hash
|
||||
let valid = bcrypt::verify(key, &row.1).unwrap_or(false);
|
||||
if !valid {
|
||||
return Err(AuthError::InvalidKey);
|
||||
}
|
||||
|
||||
// 4. Cache in Redis for next time (5 min TTL)
|
||||
if let Ok(mut conn) = self.redis.get().await {
|
||||
let _: Result<(), _> = redis::cmd("SETEX")
|
||||
.arg(format!("apikey:{}", &key[..8]))
|
||||
.arg(300)
|
||||
.arg(&row.0)
|
||||
.query_async(&mut *conn)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(AuthContext {
|
||||
org_id: row.0,
|
||||
user_id: None,
|
||||
role: Role::Member,
|
||||
})
|
||||
}
|
||||
}
|
||||
104
products/01-llm-cost-router/src/config/mod.rs
Normal file
104
products/01-llm-cost-router/src/config/mod.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
pub proxy_port: u16,
|
||||
pub api_port: u16,
|
||||
pub database_url: String,
|
||||
pub redis_url: String,
|
||||
pub timescale_url: String,
|
||||
pub jwt_secret: String,
|
||||
pub auth_mode: AuthMode,
|
||||
pub governance_mode: GovernanceMode,
|
||||
pub providers: HashMap<String, ProviderConfig>,
|
||||
pub telemetry_channel_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AuthMode {
|
||||
Local,
|
||||
OAuth,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum GovernanceMode {
|
||||
Strict,
|
||||
Audit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
pub api_key: String,
|
||||
pub base_url: String,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub fn from_env() -> anyhow::Result<Self> {
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
let mut providers = HashMap::new();
|
||||
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
|
||||
providers.insert("openai".to_string(), ProviderConfig {
|
||||
api_key: key,
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
|
||||
});
|
||||
}
|
||||
if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
|
||||
providers.insert("anthropic".to_string(), ProviderConfig {
|
||||
api_key: key,
|
||||
base_url: std::env::var("ANTHROPIC_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
proxy_port: std::env::var("PROXY_PORT")
|
||||
.unwrap_or_else(|_| "8080".to_string())
|
||||
.parse()?,
|
||||
api_port: std::env::var("API_PORT")
|
||||
.unwrap_or_else(|_| "3000".to_string())
|
||||
.parse()?,
|
||||
database_url: std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5432/dd0c".to_string()),
|
||||
redis_url: std::env::var("REDIS_URL")
|
||||
.unwrap_or_else(|_| "redis://localhost:6379".to_string()),
|
||||
timescale_url: std::env::var("TIMESCALE_URL")
|
||||
.unwrap_or_else(|_| "postgres://dd0c:dd0c@localhost:5433/dd0c_telemetry".to_string()),
|
||||
jwt_secret: std::env::var("JWT_SECRET")
|
||||
.unwrap_or_else(|_| "dev-secret-change-me".to_string()),
|
||||
auth_mode: if std::env::var("AUTH_MODE").unwrap_or_default() == "oauth" {
|
||||
AuthMode::OAuth
|
||||
} else {
|
||||
AuthMode::Local
|
||||
},
|
||||
governance_mode: if std::env::var("GOVERNANCE_MODE").unwrap_or_default() == "strict" {
|
||||
GovernanceMode::Strict
|
||||
} else {
|
||||
GovernanceMode::Audit
|
||||
},
|
||||
providers,
|
||||
telemetry_channel_size: std::env::var("TELEMETRY_CHANNEL_SIZE")
|
||||
.unwrap_or_else(|_| "1000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(1000),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn provider_url(&self, provider: &str) -> String {
|
||||
self.providers
|
||||
.get(provider)
|
||||
.map(|p| p.base_url.clone())
|
||||
.unwrap_or_else(|| "https://api.openai.com".to_string())
|
||||
}
|
||||
|
||||
pub fn provider_key(&self, provider: &str) -> String {
|
||||
self.providers
|
||||
.get(provider)
|
||||
.map(|p| p.api_key.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
32
products/01-llm-cost-router/src/data/mod.rs
Normal file
32
products/01-llm-cost-router/src/data/mod.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Telemetry event emitted by the proxy on every request.
|
||||
/// Sent via mpsc channel to the worker for async persistence.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TelemetryEvent {
|
||||
pub org_id: String,
|
||||
pub original_model: String,
|
||||
pub routed_model: String,
|
||||
pub provider: String,
|
||||
pub strategy: String,
|
||||
pub latency_ms: u32,
|
||||
pub status_code: u16,
|
||||
pub is_streaming: bool,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub timestamp: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
/// Abstraction over event queues (SQS for cloud, pgmq for self-hosted)
|
||||
#[async_trait::async_trait]
|
||||
pub trait EventQueue: Send + Sync {
|
||||
async fn publish(&self, event: TelemetryEvent) -> anyhow::Result<()>;
|
||||
async fn consume(&self, batch_size: usize) -> anyhow::Result<Vec<TelemetryEvent>>;
|
||||
}
|
||||
|
||||
/// Abstraction over object storage (S3 for cloud, local FS for self-hosted)
|
||||
#[async_trait::async_trait]
|
||||
pub trait ObjectStore: Send + Sync {
|
||||
async fn put(&self, key: &str, data: &[u8]) -> anyhow::Result<()>;
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Vec<u8>>;
|
||||
}
|
||||
11
products/01-llm-cost-router/src/lib.rs
Normal file
11
products/01-llm-cost-router/src/lib.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
mod auth;
|
||||
mod config;
|
||||
mod data;
|
||||
mod proxy;
|
||||
mod router;
|
||||
|
||||
pub use auth::{AuthProvider, AuthContext, AuthError, LocalAuthProvider, Role};
|
||||
pub use config::AppConfig;
|
||||
pub use data::{EventQueue, ObjectStore, TelemetryEvent};
|
||||
pub use proxy::{create_router, ProxyState, ProxyError};
|
||||
pub use router::{RouterBrain, RoutingDecision, Complexity};
|
||||
199
products/01-llm-cost-router/src/proxy/handler.rs
Normal file
199
products/01-llm-cost-router/src/proxy/handler.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::State,
|
||||
http::{HeaderMap, Request, StatusCode},
|
||||
response::{IntoResponse, Response, Sse},
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
mod middleware;
|
||||
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::config::AppConfig;
|
||||
use crate::data::{EventQueue, TelemetryEvent};
|
||||
use crate::router::RouterBrain;
|
||||
|
||||
pub struct ProxyState {
|
||||
pub auth: Arc<dyn AuthProvider>,
|
||||
pub router: Arc<RouterBrain>,
|
||||
pub telemetry_tx: mpsc::Sender<TelemetryEvent>,
|
||||
pub http_client: reqwest::Client,
|
||||
pub config: Arc<AppConfig>,
|
||||
}
|
||||
|
||||
pub fn create_router(state: Arc<ProxyState>) -> Router {
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(proxy_chat_completions))
|
||||
.route("/v1/completions", post(proxy_completions))
|
||||
.route("/v1/embeddings", post(proxy_embeddings))
|
||||
.route("/health", axum::routing::get(health))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn health() -> &'static str {
|
||||
"ok"
|
||||
}
|
||||
|
||||
async fn proxy_chat_completions(
|
||||
State(state): State<Arc<ProxyState>>,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
) -> Result<Response, ProxyError> {
|
||||
// 1. Authenticate
|
||||
let auth_ctx = state.auth.authenticate(&headers).await?;
|
||||
|
||||
// 2. Parse request
|
||||
let mut request: serde_json::Value =
|
||||
serde_json::from_str(&body).map_err(|_| ProxyError::BadRequest("Invalid JSON"))?;
|
||||
|
||||
let is_streaming = request
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
let original_model = request
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("gpt-4o")
|
||||
.to_string();
|
||||
|
||||
// 3. Route (pick model + provider)
|
||||
let decision = state
|
||||
.router
|
||||
.route(&auth_ctx.org_id, &request)
|
||||
.await;
|
||||
|
||||
// Apply routing decision
|
||||
if let Some(ref routed_model) = decision.model {
|
||||
request["model"] = serde_json::Value::String(routed_model.clone());
|
||||
}
|
||||
|
||||
let provider = decision.provider.unwrap_or_default();
|
||||
let upstream_url = state.config.provider_url(&provider);
|
||||
|
||||
// 4. Forward to upstream
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let upstream_resp = state
|
||||
.http_client
|
||||
.post(format!("{}/v1/chat/completions", upstream_url))
|
||||
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request.to_string())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
|
||||
|
||||
let latency = start.elapsed();
|
||||
let status = upstream_resp.status();
|
||||
|
||||
// 5. Emit telemetry (non-blocking)
|
||||
let _ = state.telemetry_tx.try_send(TelemetryEvent {
|
||||
org_id: auth_ctx.org_id.clone(),
|
||||
original_model: original_model.clone(),
|
||||
routed_model: decision.model.clone().unwrap_or(original_model),
|
||||
provider: provider.clone(),
|
||||
strategy: decision.strategy.clone(),
|
||||
latency_ms: latency.as_millis() as u32,
|
||||
status_code: status.as_u16(),
|
||||
is_streaming,
|
||||
prompt_tokens: 0, // Filled by worker from response
|
||||
completion_tokens: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
});
|
||||
|
||||
// 6. Pass through response transparently
|
||||
let resp_headers = upstream_resp.headers().clone();
|
||||
let resp_status = upstream_resp.status();
|
||||
|
||||
if is_streaming {
|
||||
// Stream SSE chunks directly
|
||||
let byte_stream = upstream_resp.bytes_stream();
|
||||
let body = Body::from_stream(byte_stream);
|
||||
|
||||
let mut response = Response::builder().status(resp_status);
|
||||
for (key, value) in resp_headers.iter() {
|
||||
response = response.header(key, value);
|
||||
}
|
||||
Ok(response.body(body).unwrap())
|
||||
} else {
|
||||
let resp_body = upstream_resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
|
||||
|
||||
let mut response = Response::builder().status(resp_status);
|
||||
for (key, value) in resp_headers.iter() {
|
||||
response = response.header(key, value);
|
||||
}
|
||||
Ok(response.body(Body::from(resp_body)).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// Placeholder — same pattern as chat completions
|
||||
async fn proxy_completions(
|
||||
State(state): State<Arc<ProxyState>>,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
) -> Result<Response, ProxyError> {
|
||||
proxy_chat_completions(State(state), headers, body).await
|
||||
}
|
||||
|
||||
async fn proxy_embeddings(
|
||||
State(state): State<Arc<ProxyState>>,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
) -> Result<Response, ProxyError> {
|
||||
// Embeddings don't need routing — pass through directly
|
||||
let auth_ctx = state.auth.authenticate(&headers).await?;
|
||||
let provider = "openai".to_string();
|
||||
let upstream_url = state.config.provider_url(&provider);
|
||||
|
||||
let resp = state
|
||||
.http_client
|
||||
.post(format!("{}/v1/embeddings", upstream_url))
|
||||
.header("Authorization", format!("Bearer {}", state.config.provider_key(&provider)))
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp.bytes().await.map_err(|e| ProxyError::UpstreamError(e.to_string()))?;
|
||||
Ok(Response::builder().status(status).body(Body::from(body)).unwrap())
|
||||
}
|
||||
|
||||
// --- Error types ---
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProxyError {
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthError(String),
|
||||
#[error("Bad request: {0}")]
|
||||
BadRequest(&'static str),
|
||||
#[error("Upstream error: {0}")]
|
||||
UpstreamError(String),
|
||||
}
|
||||
|
||||
impl From<crate::auth::AuthError> for ProxyError {
|
||||
fn from(e: crate::auth::AuthError) -> Self {
|
||||
ProxyError::AuthError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ProxyError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, msg) = match &self {
|
||||
ProxyError::AuthError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),
|
||||
ProxyError::BadRequest(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
ProxyError::UpstreamError(_) => (StatusCode::BAD_GATEWAY, self.to_string()),
|
||||
};
|
||||
(status, serde_json::json!({"error": msg}).to_string()).into_response()
|
||||
}
|
||||
}
|
||||
81
products/01-llm-cost-router/src/proxy/main.rs
Normal file
81
products/01-llm-cost-router/src/proxy/main.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::info;
|
||||
|
||||
use dd0c_route::{
|
||||
AppConfig, LocalAuthProvider, RouterBrain, ProxyState, TelemetryEvent, create_router,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Init tracing
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "dd0c_route=info,tower_http=info".into()),
|
||||
)
|
||||
.json()
|
||||
.init();
|
||||
|
||||
// Load config
|
||||
let config = Arc::new(AppConfig::from_env()?);
|
||||
info!(port = config.proxy_port, "Starting dd0c/route proxy");
|
||||
|
||||
// Connect to databases
|
||||
let pg_pool = sqlx::PgPool::connect(&config.database_url).await?;
|
||||
let redis_cfg = deadpool_redis::Config::from_url(&config.redis_url);
|
||||
let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?;
|
||||
|
||||
// Telemetry channel (bounded, non-blocking)
|
||||
let (telemetry_tx, mut telemetry_rx) = mpsc::channel::<TelemetryEvent>(config.telemetry_channel_size);
|
||||
|
||||
// Spawn telemetry worker (writes to TimescaleDB)
|
||||
let ts_url = config.timescale_url.clone();
|
||||
tokio::spawn(async move {
|
||||
let ts_pool = sqlx::PgPool::connect(&ts_url).await.expect("TimescaleDB connection failed");
|
||||
while let Some(event) = telemetry_rx.recv().await {
|
||||
if let Err(e) = sqlx::query(
|
||||
"INSERT INTO request_events (org_id, original_model, routed_model, provider, strategy, latency_ms, status_code, is_streaming, prompt_tokens, completion_tokens, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
|
||||
)
|
||||
.bind(&event.org_id)
|
||||
.bind(&event.original_model)
|
||||
.bind(&event.routed_model)
|
||||
.bind(&event.provider)
|
||||
.bind(&event.strategy)
|
||||
.bind(event.latency_ms as i32)
|
||||
.bind(event.status_code as i16)
|
||||
.bind(event.is_streaming)
|
||||
.bind(event.prompt_tokens as i32)
|
||||
.bind(event.completion_tokens as i32)
|
||||
.bind(event.timestamp)
|
||||
.execute(&ts_pool)
|
||||
.await {
|
||||
tracing::warn!(error = %e, "Failed to persist telemetry event");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Build proxy state
|
||||
let state = Arc::new(ProxyState {
|
||||
auth: Arc::new(LocalAuthProvider::new(
|
||||
pg_pool.clone(),
|
||||
config.jwt_secret.clone(),
|
||||
redis_pool.clone(),
|
||||
)),
|
||||
router: Arc::new(RouterBrain::new()),
|
||||
telemetry_tx,
|
||||
http_client: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300)) // 5min for long completions
|
||||
.build()?,
|
||||
config: config.clone(),
|
||||
});
|
||||
|
||||
// Start server
|
||||
let app = create_router(state);
|
||||
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", config.proxy_port)).await?;
|
||||
info!(port = config.proxy_port, "Proxy listening");
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
61
products/01-llm-cost-router/src/proxy/middleware.rs
Normal file
61
products/01-llm-cost-router/src/proxy/middleware.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
// Proxy middleware — API key redaction in error traces
|
||||
use axum::http::HeaderMap;
|
||||
use tracing::warn;
|
||||
|
||||
/// Redact any Bearer tokens or API keys from a string.
|
||||
/// Used in panic handlers and error logging to prevent key leakage.
|
||||
pub fn redact_sensitive(input: &str) -> String {
|
||||
let patterns = [
|
||||
// OpenAI keys
|
||||
(r"sk-[a-zA-Z0-9_-]{20,}", "[REDACTED_API_KEY]"),
|
||||
// Anthropic keys
|
||||
(r"sk-ant-[a-zA-Z0-9_-]{20,}", "[REDACTED_API_KEY]"),
|
||||
// Bearer tokens
|
||||
(r"Bearer\s+[a-zA-Z0-9_.-]+", "Bearer [REDACTED]"),
|
||||
];
|
||||
|
||||
let mut result = input.to_string();
|
||||
for (pattern, replacement) in &patterns {
|
||||
if let Ok(re) = regex::Regex::new(pattern) {
|
||||
result = re.replace_all(&result, *replacement).to_string();
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Extract API key from Authorization header, returning redacted version for logging
|
||||
pub fn extract_api_key(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "))
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn redacts_openai_key() {
|
||||
let input = "Error with key sk-proj-abc123xyz456def789ghi";
|
||||
let result = redact_sensitive(input);
|
||||
assert!(!result.contains("sk-proj-abc123xyz456def789ghi"));
|
||||
assert!(result.contains("[REDACTED_API_KEY]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redacts_bearer_token() {
|
||||
let input = "Authorization: Bearer sk-live-abc123xyz";
|
||||
let result = redact_sensitive(input);
|
||||
assert!(!result.contains("sk-live-abc123xyz"));
|
||||
assert!(result.contains("[REDACTED]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn does_not_redact_normal_text() {
|
||||
let input = "Hello world, this is a normal log message";
|
||||
let result = redact_sensitive(input);
|
||||
assert_eq!(result, input);
|
||||
}
|
||||
}
|
||||
4
products/01-llm-cost-router/src/proxy/mod.rs
Normal file
4
products/01-llm-cost-router/src/proxy/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
mod handler;
|
||||
mod middleware;
|
||||
|
||||
pub use handler::{create_router, ProxyState, ProxyError};
|
||||
183
products/01-llm-cost-router/src/router/mod.rs
Normal file
183
products/01-llm-cost-router/src/router/mod.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::auth::AuthContext;
|
||||
|
||||
/// Routing decision made by the Router Brain
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RoutingDecision {
|
||||
pub model: Option<String>,
|
||||
pub provider: Option<String>,
|
||||
pub strategy: String,
|
||||
pub cost_delta: f64,
|
||||
pub complexity: Complexity,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Complexity {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
/// The Router Brain — evaluates request complexity and applies routing rules
|
||||
pub struct RouterBrain {
|
||||
// In V1, rules are loaded from config. Later: from DB per org.
|
||||
}
|
||||
|
||||
impl RouterBrain {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
/// Classify request complexity based on heuristics
|
||||
fn classify_complexity(&self, request: &serde_json::Value) -> Complexity {
|
||||
let messages = request
|
||||
.get("messages")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or(0);
|
||||
|
||||
let system_prompt = request
|
||||
.get("messages")
|
||||
.and_then(|m| m.as_array())
|
||||
.and_then(|msgs| {
|
||||
msgs.iter().find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
|
||||
})
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let high_complexity_keywords = ["analyze", "reason", "compare", "evaluate", "synthesize", "debate"];
|
||||
let has_complex_task = high_complexity_keywords
|
||||
.iter()
|
||||
.any(|kw| system_prompt.to_lowercase().contains(kw));
|
||||
|
||||
if messages > 10 || has_complex_task {
|
||||
Complexity::High
|
||||
} else if messages > 3 || system_prompt.len() > 500 {
|
||||
Complexity::Medium
|
||||
} else {
|
||||
Complexity::Low
|
||||
}
|
||||
}
|
||||
|
||||
/// Route a request — returns the routing decision
|
||||
pub async fn route(&self, _org_id: &str, request: &serde_json::Value) -> RoutingDecision {
|
||||
let complexity = self.classify_complexity(request);
|
||||
let original_model = request
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("gpt-4o");
|
||||
|
||||
// V1 routing logic: downgrade low-complexity to cheaper model
|
||||
let (routed_model, strategy) = match complexity {
|
||||
Complexity::Low => {
|
||||
if original_model.contains("gpt-4") {
|
||||
(Some("gpt-4o-mini".to_string()), "cheapest".to_string())
|
||||
} else {
|
||||
(None, "passthrough".to_string())
|
||||
}
|
||||
}
|
||||
Complexity::Medium => (None, "passthrough".to_string()),
|
||||
Complexity::High => (None, "passthrough".to_string()),
|
||||
};
|
||||
|
||||
// Calculate cost delta
|
||||
let cost_delta = match (&routed_model, original_model) {
|
||||
(Some(routed), orig) => estimate_cost_delta(orig, routed),
|
||||
_ => 0.0,
|
||||
};
|
||||
|
||||
RoutingDecision {
|
||||
model: routed_model,
|
||||
provider: None, // V1: same provider
|
||||
strategy,
|
||||
cost_delta,
|
||||
complexity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate cost savings per 1K tokens when downgrading models
|
||||
fn estimate_cost_delta(original: &str, routed: &str) -> f64 {
|
||||
let price_per_1k = |model: &str| -> f64 {
|
||||
match model {
|
||||
"gpt-4o" => 0.005,
|
||||
"gpt-4o-mini" => 0.00015,
|
||||
"gpt-4-turbo" => 0.01,
|
||||
"gpt-3.5-turbo" => 0.0005,
|
||||
"claude-3-opus" => 0.015,
|
||||
"claude-3-sonnet" => 0.003,
|
||||
"claude-3-haiku" => 0.00025,
|
||||
_ => 0.005, // default to gpt-4o pricing
|
||||
}
|
||||
};
|
||||
|
||||
price_per_1k(original) - price_per_1k(routed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn low_complexity_simple_extraction() {
|
||||
let brain = RouterBrain::new();
|
||||
let request = serde_json::json!({
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Extract the name from this text"},
|
||||
{"role": "user", "content": "My name is Alice"}
|
||||
]
|
||||
});
|
||||
let complexity = brain.classify_complexity(&request);
|
||||
assert_eq!(complexity, Complexity::Low);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn high_complexity_multi_turn_reasoning() {
|
||||
let brain = RouterBrain::new();
|
||||
let mut messages = vec![
|
||||
serde_json::json!({"role": "system", "content": "Analyze and compare these approaches"}),
|
||||
];
|
||||
for i in 0..12 {
|
||||
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
|
||||
messages.push(serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}));
|
||||
}
|
||||
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
|
||||
let complexity = brain.classify_complexity(&request);
|
||||
assert_eq!(complexity, Complexity::High);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn low_complexity_routes_to_mini() {
|
||||
let brain = RouterBrain::new();
|
||||
let request = serde_json::json!({
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "What is 2+2?"}]
|
||||
});
|
||||
let decision = brain.route("org-1", &request).await;
|
||||
assert_eq!(decision.model, Some("gpt-4o-mini".to_string()));
|
||||
assert_eq!(decision.strategy, "cheapest");
|
||||
assert!(decision.cost_delta > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn high_complexity_passes_through() {
|
||||
let brain = RouterBrain::new();
|
||||
let mut messages = vec![];
|
||||
for i in 0..15 {
|
||||
messages.push(serde_json::json!({"role": "user", "content": format!("msg {}", i)}));
|
||||
}
|
||||
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
|
||||
let decision = brain.route("org-1", &request).await;
|
||||
assert_eq!(decision.model, None);
|
||||
assert_eq!(decision.strategy, "passthrough");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_delta_gpt4o_to_mini() {
|
||||
let delta = estimate_cost_delta("gpt-4o", "gpt-4o-mini");
|
||||
assert!((delta - 0.00485).abs() < 0.0001);
|
||||
}
|
||||
}
|
||||
4
products/01-llm-cost-router/src/worker/main.rs
Normal file
4
products/01-llm-cost-router/src/worker/main.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
fn main() {
|
||||
println!("dd0c/route worker — not yet implemented");
|
||||
// TODO: Background worker for digests, aggregations (Epic 7)
|
||||
}
|
||||
Reference in New Issue
Block a user