use serde::{Deserialize, Serialize}; /// Routing decision made by the Router Brain #[derive(Debug, Clone, Serialize)] pub struct RoutingDecision { pub model: Option, pub provider: Option, pub strategy: String, pub cost_delta: f64, pub complexity: Complexity, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Complexity { Low, Medium, High, } /// The Router Brain — evaluates request complexity and applies routing rules #[derive(Default)] pub struct RouterBrain { // In V1, rules are loaded from config. Later: from DB per org. } impl RouterBrain { pub fn new() -> Self { Self {} } /// Classify request complexity based on heuristics fn classify_complexity(&self, request: &serde_json::Value) -> Complexity { let messages = request .get("messages") .and_then(|m| m.as_array()) .map(|a| a.len()) .unwrap_or(0); let system_prompt = request .get("messages") .and_then(|m| m.as_array()) .and_then(|msgs| { msgs.iter() .find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system")) }) .and_then(|m| m.get("content")) .and_then(|c| c.as_str()) .unwrap_or(""); let high_complexity_keywords = [ "analyze", "reason", "compare", "evaluate", "synthesize", "debate", ]; let has_complex_task = high_complexity_keywords .iter() .any(|kw| system_prompt.to_lowercase().contains(kw)); if messages > 10 || has_complex_task { Complexity::High } else if messages > 3 || system_prompt.len() > 500 { Complexity::Medium } else { Complexity::Low } } /// Route a request — returns the routing decision pub async fn route(&self, _org_id: &str, request: &serde_json::Value) -> RoutingDecision { let complexity = self.classify_complexity(request); let original_model = request .get("model") .and_then(|v| v.as_str()) .unwrap_or("gpt-4o"); // V1 routing logic: downgrade low-complexity to cheaper model let (routed_model, strategy) = match complexity { Complexity::Low => { if original_model.contains("gpt-4") { (Some("gpt-4o-mini".to_string()), "cheapest".to_string()) } else { (None, "passthrough".to_string()) } } Complexity::Medium => (None, "passthrough".to_string()), Complexity::High => (None, "passthrough".to_string()), }; // Calculate cost delta let cost_delta = match (&routed_model, original_model) { (Some(routed), orig) => estimate_cost_delta(orig, routed), _ => 0.0, }; RoutingDecision { model: routed_model, provider: None, // V1: same provider strategy, cost_delta, complexity, } } } /// Estimate cost savings per 1K tokens when downgrading models fn estimate_cost_delta(original: &str, routed: &str) -> f64 { let price_per_1k = |model: &str| -> f64 { match model { "gpt-4o" => 0.005, "gpt-4o-mini" => 0.00015, "gpt-4-turbo" => 0.01, "gpt-3.5-turbo" => 0.0005, "claude-3-opus" => 0.015, "claude-3-sonnet" => 0.003, "claude-3-haiku" => 0.00025, _ => 0.005, // default to gpt-4o pricing } }; price_per_1k(original) - price_per_1k(routed) } #[cfg(test)] mod tests { use super::*; #[test] fn low_complexity_simple_extraction() { let brain = RouterBrain::new(); let request = serde_json::json!({ "model": "gpt-4o", "messages": [ {"role": "system", "content": "Extract the name from this text"}, {"role": "user", "content": "My name is Alice"} ] }); let complexity = brain.classify_complexity(&request); assert_eq!(complexity, Complexity::Low); } #[test] fn high_complexity_multi_turn_reasoning() { let brain = RouterBrain::new(); let mut messages = vec![ serde_json::json!({"role": "system", "content": "Analyze and compare these approaches"}), ]; for i in 0..12 { messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)})); messages.push( serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}), ); } let request = serde_json::json!({ "model": "gpt-4o", "messages": messages }); let complexity = brain.classify_complexity(&request); assert_eq!(complexity, Complexity::High); } #[tokio::test] async fn low_complexity_routes_to_mini() { let brain = RouterBrain::new(); let request = serde_json::json!({ "model": "gpt-4o", "messages": [{"role": "user", "content": "What is 2+2?"}] }); let decision = brain.route("org-1", &request).await; assert_eq!(decision.model, Some("gpt-4o-mini".to_string())); assert_eq!(decision.strategy, "cheapest"); assert!(decision.cost_delta > 0.0); } #[tokio::test] async fn high_complexity_passes_through() { let brain = RouterBrain::new(); let mut messages = vec![]; for i in 0..15 { messages.push(serde_json::json!({"role": "user", "content": format!("msg {}", i)})); } let request = serde_json::json!({ "model": "gpt-4o", "messages": messages }); let decision = brain.route("org-1", &request).await; assert_eq!(decision.model, None); assert_eq!(decision.strategy, "passthrough"); } #[test] fn cost_delta_gpt4o_to_mini() { let delta = estimate_cost_delta("gpt-4o", "gpt-4o-mini"); assert!((delta - 0.00485).abs() < 0.0001); } }