Scaffold dd0c/route core proxy engine (handler, router, auth, config)
This commit is contained in:
183
products/01-llm-cost-router/src/router/mod.rs
Normal file
183
products/01-llm-cost-router/src/router/mod.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::auth::AuthContext;
|
||||
|
||||
/// Routing decision made by the Router Brain
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RoutingDecision {
|
||||
pub model: Option<String>,
|
||||
pub provider: Option<String>,
|
||||
pub strategy: String,
|
||||
pub cost_delta: f64,
|
||||
pub complexity: Complexity,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Complexity {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
/// The Router Brain — evaluates request complexity and applies routing rules
|
||||
pub struct RouterBrain {
|
||||
// In V1, rules are loaded from config. Later: from DB per org.
|
||||
}
|
||||
|
||||
impl RouterBrain {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
/// Classify request complexity based on heuristics
|
||||
fn classify_complexity(&self, request: &serde_json::Value) -> Complexity {
|
||||
let messages = request
|
||||
.get("messages")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or(0);
|
||||
|
||||
let system_prompt = request
|
||||
.get("messages")
|
||||
.and_then(|m| m.as_array())
|
||||
.and_then(|msgs| {
|
||||
msgs.iter().find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
|
||||
})
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let high_complexity_keywords = ["analyze", "reason", "compare", "evaluate", "synthesize", "debate"];
|
||||
let has_complex_task = high_complexity_keywords
|
||||
.iter()
|
||||
.any(|kw| system_prompt.to_lowercase().contains(kw));
|
||||
|
||||
if messages > 10 || has_complex_task {
|
||||
Complexity::High
|
||||
} else if messages > 3 || system_prompt.len() > 500 {
|
||||
Complexity::Medium
|
||||
} else {
|
||||
Complexity::Low
|
||||
}
|
||||
}
|
||||
|
||||
/// Route a request — returns the routing decision
|
||||
pub async fn route(&self, _org_id: &str, request: &serde_json::Value) -> RoutingDecision {
|
||||
let complexity = self.classify_complexity(request);
|
||||
let original_model = request
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("gpt-4o");
|
||||
|
||||
// V1 routing logic: downgrade low-complexity to cheaper model
|
||||
let (routed_model, strategy) = match complexity {
|
||||
Complexity::Low => {
|
||||
if original_model.contains("gpt-4") {
|
||||
(Some("gpt-4o-mini".to_string()), "cheapest".to_string())
|
||||
} else {
|
||||
(None, "passthrough".to_string())
|
||||
}
|
||||
}
|
||||
Complexity::Medium => (None, "passthrough".to_string()),
|
||||
Complexity::High => (None, "passthrough".to_string()),
|
||||
};
|
||||
|
||||
// Calculate cost delta
|
||||
let cost_delta = match (&routed_model, original_model) {
|
||||
(Some(routed), orig) => estimate_cost_delta(orig, routed),
|
||||
_ => 0.0,
|
||||
};
|
||||
|
||||
RoutingDecision {
|
||||
model: routed_model,
|
||||
provider: None, // V1: same provider
|
||||
strategy,
|
||||
cost_delta,
|
||||
complexity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate cost savings per 1K tokens when downgrading models
|
||||
fn estimate_cost_delta(original: &str, routed: &str) -> f64 {
|
||||
let price_per_1k = |model: &str| -> f64 {
|
||||
match model {
|
||||
"gpt-4o" => 0.005,
|
||||
"gpt-4o-mini" => 0.00015,
|
||||
"gpt-4-turbo" => 0.01,
|
||||
"gpt-3.5-turbo" => 0.0005,
|
||||
"claude-3-opus" => 0.015,
|
||||
"claude-3-sonnet" => 0.003,
|
||||
"claude-3-haiku" => 0.00025,
|
||||
_ => 0.005, // default to gpt-4o pricing
|
||||
}
|
||||
};
|
||||
|
||||
price_per_1k(original) - price_per_1k(routed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn low_complexity_simple_extraction() {
|
||||
let brain = RouterBrain::new();
|
||||
let request = serde_json::json!({
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Extract the name from this text"},
|
||||
{"role": "user", "content": "My name is Alice"}
|
||||
]
|
||||
});
|
||||
let complexity = brain.classify_complexity(&request);
|
||||
assert_eq!(complexity, Complexity::Low);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn high_complexity_multi_turn_reasoning() {
|
||||
let brain = RouterBrain::new();
|
||||
let mut messages = vec![
|
||||
serde_json::json!({"role": "system", "content": "Analyze and compare these approaches"}),
|
||||
];
|
||||
for i in 0..12 {
|
||||
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
|
||||
messages.push(serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}));
|
||||
}
|
||||
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
|
||||
let complexity = brain.classify_complexity(&request);
|
||||
assert_eq!(complexity, Complexity::High);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn low_complexity_routes_to_mini() {
|
||||
let brain = RouterBrain::new();
|
||||
let request = serde_json::json!({
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "What is 2+2?"}]
|
||||
});
|
||||
let decision = brain.route("org-1", &request).await;
|
||||
assert_eq!(decision.model, Some("gpt-4o-mini".to_string()));
|
||||
assert_eq!(decision.strategy, "cheapest");
|
||||
assert!(decision.cost_delta > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn high_complexity_passes_through() {
|
||||
let brain = RouterBrain::new();
|
||||
let mut messages = vec![];
|
||||
for i in 0..15 {
|
||||
messages.push(serde_json::json!({"role": "user", "content": format!("msg {}", i)}));
|
||||
}
|
||||
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
|
||||
let decision = brain.route("org-1", &request).await;
|
||||
assert_eq!(decision.model, None);
|
||||
assert_eq!(decision.strategy, "passthrough");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_delta_gpt4o_to_mini() {
|
||||
let delta = estimate_cost_delta("gpt-4o", "gpt-4o-mini");
|
||||
assert!((delta - 0.00485).abs() < 0.0001);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user