All checks were successful
CI — P1 Route (Rust) / test (push) Successful in 6m35s
194 lines
6.1 KiB
Rust
194 lines
6.1 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
|
|
/// Routing decision made by the Router Brain
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct RoutingDecision {
|
|
pub model: Option<String>,
|
|
pub provider: Option<String>,
|
|
pub strategy: String,
|
|
pub cost_delta: f64,
|
|
pub complexity: Complexity,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum Complexity {
|
|
Low,
|
|
Medium,
|
|
High,
|
|
}
|
|
|
|
/// The Router Brain — evaluates request complexity and applies routing rules
|
|
#[derive(Default)]
|
|
pub struct RouterBrain {
|
|
// In V1, rules are loaded from config. Later: from DB per org.
|
|
}
|
|
|
|
impl RouterBrain {
|
|
pub fn new() -> Self {
|
|
Self {}
|
|
}
|
|
|
|
/// Classify request complexity based on heuristics
|
|
fn classify_complexity(&self, request: &serde_json::Value) -> Complexity {
|
|
let messages = request
|
|
.get("messages")
|
|
.and_then(|m| m.as_array())
|
|
.map(|a| a.len())
|
|
.unwrap_or(0);
|
|
|
|
let system_prompt = request
|
|
.get("messages")
|
|
.and_then(|m| m.as_array())
|
|
.and_then(|msgs| {
|
|
msgs.iter()
|
|
.find(|m| m.get("role").and_then(|r| r.as_str()) == Some("system"))
|
|
})
|
|
.and_then(|m| m.get("content"))
|
|
.and_then(|c| c.as_str())
|
|
.unwrap_or("");
|
|
|
|
let high_complexity_keywords = [
|
|
"analyze",
|
|
"reason",
|
|
"compare",
|
|
"evaluate",
|
|
"synthesize",
|
|
"debate",
|
|
];
|
|
let has_complex_task = high_complexity_keywords
|
|
.iter()
|
|
.any(|kw| system_prompt.to_lowercase().contains(kw));
|
|
|
|
if messages > 10 || has_complex_task {
|
|
Complexity::High
|
|
} else if messages > 3 || system_prompt.len() > 500 {
|
|
Complexity::Medium
|
|
} else {
|
|
Complexity::Low
|
|
}
|
|
}
|
|
|
|
/// Route a request — returns the routing decision
|
|
pub async fn route(&self, _org_id: &str, request: &serde_json::Value) -> RoutingDecision {
|
|
let complexity = self.classify_complexity(request);
|
|
let original_model = request
|
|
.get("model")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("gpt-4o");
|
|
|
|
// V1 routing logic: downgrade low-complexity to cheaper model
|
|
let (routed_model, strategy) = match complexity {
|
|
Complexity::Low => {
|
|
if original_model.contains("gpt-4") {
|
|
(Some("gpt-4o-mini".to_string()), "cheapest".to_string())
|
|
} else {
|
|
(None, "passthrough".to_string())
|
|
}
|
|
}
|
|
Complexity::Medium => (None, "passthrough".to_string()),
|
|
Complexity::High => (None, "passthrough".to_string()),
|
|
};
|
|
|
|
// Calculate cost delta
|
|
let cost_delta = match (&routed_model, original_model) {
|
|
(Some(routed), orig) => estimate_cost_delta(orig, routed),
|
|
_ => 0.0,
|
|
};
|
|
|
|
RoutingDecision {
|
|
model: routed_model,
|
|
provider: None, // V1: same provider
|
|
strategy,
|
|
cost_delta,
|
|
complexity,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Estimate cost savings per 1K tokens when downgrading models
|
|
fn estimate_cost_delta(original: &str, routed: &str) -> f64 {
|
|
let price_per_1k = |model: &str| -> f64 {
|
|
match model {
|
|
"gpt-4o" => 0.005,
|
|
"gpt-4o-mini" => 0.00015,
|
|
"gpt-4-turbo" => 0.01,
|
|
"gpt-3.5-turbo" => 0.0005,
|
|
"claude-3-opus" => 0.015,
|
|
"claude-3-sonnet" => 0.003,
|
|
"claude-3-haiku" => 0.00025,
|
|
_ => 0.005, // default to gpt-4o pricing
|
|
}
|
|
};
|
|
|
|
price_per_1k(original) - price_per_1k(routed)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn low_complexity_simple_extraction() {
|
|
let brain = RouterBrain::new();
|
|
let request = serde_json::json!({
|
|
"model": "gpt-4o",
|
|
"messages": [
|
|
{"role": "system", "content": "Extract the name from this text"},
|
|
{"role": "user", "content": "My name is Alice"}
|
|
]
|
|
});
|
|
let complexity = brain.classify_complexity(&request);
|
|
assert_eq!(complexity, Complexity::Low);
|
|
}
|
|
|
|
#[test]
|
|
fn high_complexity_multi_turn_reasoning() {
|
|
let brain = RouterBrain::new();
|
|
let mut messages = vec![
|
|
serde_json::json!({"role": "system", "content": "Analyze and compare these approaches"}),
|
|
];
|
|
for i in 0..12 {
|
|
messages.push(serde_json::json!({"role": "user", "content": format!("Turn {}", i)}));
|
|
messages.push(
|
|
serde_json::json!({"role": "assistant", "content": format!("Response {}", i)}),
|
|
);
|
|
}
|
|
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
|
|
let complexity = brain.classify_complexity(&request);
|
|
assert_eq!(complexity, Complexity::High);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn low_complexity_routes_to_mini() {
|
|
let brain = RouterBrain::new();
|
|
let request = serde_json::json!({
|
|
"model": "gpt-4o",
|
|
"messages": [{"role": "user", "content": "What is 2+2?"}]
|
|
});
|
|
let decision = brain.route("org-1", &request).await;
|
|
assert_eq!(decision.model, Some("gpt-4o-mini".to_string()));
|
|
assert_eq!(decision.strategy, "cheapest");
|
|
assert!(decision.cost_delta > 0.0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn high_complexity_passes_through() {
|
|
let brain = RouterBrain::new();
|
|
let mut messages = vec![];
|
|
for i in 0..15 {
|
|
messages.push(serde_json::json!({"role": "user", "content": format!("msg {}", i)}));
|
|
}
|
|
let request = serde_json::json!({ "model": "gpt-4o", "messages": messages });
|
|
let decision = brain.route("org-1", &request).await;
|
|
assert_eq!(decision.model, None);
|
|
assert_eq!(decision.strategy, "passthrough");
|
|
}
|
|
|
|
#[test]
|
|
fn cost_delta_gpt4o_to_mini() {
|
|
let delta = estimate_cost_delta("gpt-4o", "gpt-4o-mini");
|
|
assert!((delta - 0.00485).abs() < 0.0001);
|
|
}
|
|
}
|