Files
dd0c/products/01-llm-cost-router/src/router/mod.rs

194 lines
6.1 KiB
Rust
Raw Normal View History

use serde::{Deserialize, Serialize};
/// Routing decision made by the Router Brain
#[derive(Debug, Clone, Serialize)]
pub struct RoutingDecision {
pub model: Option<String>,
pub provider: Option<String>,
pub strategy: String,
pub cost_delta: f64,
pub complexity: Complexity,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Complexity {
Low,
Medium,
High,
}
/// The Router Brain — evaluates request complexity and applies routing rules
#[derive(Default)]
pub struct RouterBrain {
// In V1, rules are loaded from config. Later: from DB per org.
}
impl RouterBrain {
pub fn new() -> Self {
Self {}
}
/// Classify request complexity based on heuristics
fn classify_complexity(&self, request: &serde_json::Value) -> Complexity {
let messages = request
.get("messages")
.and_then(|m| m.as_array())
.map(|a| a.len())
.unwrap_or(0);
let system_prompt = request
.get("messages")
.and_then(|m| m.as_array())
.and_then(|msgs| {
msgs.iter()
.find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
})
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("");
let high_complexity_keywords = [
"analyze",
"reason",
"compare",
"evaluate",
"synthesize",
"debate",
];
let has_complex_task = high_complexity_keywords
.iter()
.any(|kw| system_prompt.to_lowercase().contains(kw));
if messages > 10 || has_complex_task {
Complexity::High
} else if messages > 3 || system_prompt.len() > 500 {
Complexity::Medium
} else {
Complexity::Low
}
}
/// Route a request — returns the routing decision
pub async fn route(&self, _org_id: &str, request: &serde_json::Value) -> RoutingDecision {
let complexity = self.classify_complexity(request);
let original_model = request
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("gpt-4o");
// V1 routing logic: downgrade low-complexity to cheaper model
let (routed_model, strategy) = match complexity {
Complexity::Low => {
if original_model.contains("gpt-4") {
(Some("gpt-4o-mini".to_string()), "cheapest".to_string())
} else {
(None, "passthrough".to_string())
}
}
Complexity::Medium => (None, "passthrough".to_string()),
Complexity::High => (None, "passthrough".to_string()),
};
// Calculate cost delta
let cost_delta = match (&routed_model, original_model) {
(Some(routed), orig) => estimate_cost_delta(orig, routed),
_ => 0.0,
};
RoutingDecision {
model: routed_model,
provider: None, // V1: same provider
strategy,
cost_delta,
complexity,
}
}
}
/// Estimate cost savings per 1K tokens when downgrading models
fn estimate_cost_delta(original: &str, routed: &str) -> f64 {
let price_per_1k = |model: &str| -> f64 {
match model {
"gpt-4o" => 0.005,
"gpt-4o-mini" => 0.00015,
"gpt-4-turbo" => 0.01,
"gpt-3.5-turbo" => 0.0005,
"claude-3-opus" => 0.015,
"claude-3-sonnet" => 0.003,
"claude-3-haiku" => 0.00025,
_ => 0.005, // default to gpt-4o pricing
}
};
price_per_1k(original) - price_per_1k(routed)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn low_complexity_simple_extraction() {
let brain = RouterBrain::new();
let request = serde_json::json!({
"model": "gpt-4o",
"messages": [
{"role": "system", "content": "Extract the name from this text"},
{"role": "user", "content": "My name is Alice"}
]
});
let complexity = brain.classify_complexity(&request);
assert_eq!(complexity, Complexity::Low);
}
#[test]
fn high_complexity_multi_turn_reasoning() {
let brain = RouterBrain::new();
let mut messages = vec![
serde_json::json!({"role": "system", "content": "Analyze and compare these approaches"}),
];
for i in 0..12 {
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
messages.push(
serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}),
);
}
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
let complexity = brain.classify_complexity(&request);
assert_eq!(complexity, Complexity::High);
}
#[tokio::test]
async fn low_complexity_routes_to_mini() {
let brain = RouterBrain::new();
let request = serde_json::json!({
"model": "gpt-4o",
"messages": [{"role": "user", "content": "What is 2+2?"}]
});
let decision = brain.route("org-1", &request).await;
assert_eq!(decision.model, Some("gpt-4o-mini".to_string()));
assert_eq!(decision.strategy, "cheapest");
assert!(decision.cost_delta > 0.0);
}
#[tokio::test]
async fn high_complexity_passes_through() {
let brain = RouterBrain::new();
let mut messages = vec![];
for i in 0..15 {
messages.push(serde_json::json!({"role": "user", "content": format!("msg {}", i)}));
}
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
let decision = brain.route("org-1", &request).await;
assert_eq!(decision.model, None);
assert_eq!(decision.strategy, "passthrough");
}
#[test]
fn cost_delta_gpt4o_to_mini() {
let delta = estimate_cost_delta("gpt-4o", "gpt-4o-mini");
assert!((delta - 0.00485).abs() < 0.0001);
}
}