106 lines
2.9 KiB
Rust
106 lines
2.9 KiB
Rust
use axum::http::HeaderMap;
|
|
use async_trait::async_trait;
|
|
use thiserror::Error;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct AuthContext {
|
|
pub org_id: String,
|
|
pub user_id: Option<String>,
|
|
pub role: Role,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub enum Role {
|
|
Owner,
|
|
Member,
|
|
Viewer,
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum AuthError {
|
|
#[error("Invalid API key")]
|
|
InvalidKey,
|
|
#[error("Expired token")]
|
|
ExpiredToken,
|
|
#[error("Missing authorization header")]
|
|
MissingAuth,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait AuthProvider: Send + Sync {
|
|
async fn authenticate(&self, headers: &HeaderMap) -> Result<AuthContext, AuthError>;
|
|
}
|
|
|
|
/// Local auth — bcrypt passwords + HS256 JWT (self-hosted mode)
|
|
pub struct LocalAuthProvider {
|
|
pool: sqlx::PgPool,
|
|
_jwt_secret: String,
|
|
redis: deadpool_redis::Pool,
|
|
}
|
|
|
|
impl LocalAuthProvider {
|
|
pub fn new(pool: sqlx::PgPool, jwt_secret: String, redis: deadpool_redis::Pool) -> Self {
|
|
Self { pool, _jwt_secret: jwt_secret, redis }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl AuthProvider for LocalAuthProvider {
|
|
async fn authenticate(&self, headers: &HeaderMap) -> Result<AuthContext, AuthError> {
|
|
let key = headers
|
|
.get("authorization")
|
|
.and_then(|v| v.to_str().ok())
|
|
.and_then(|v| v.strip_prefix("Bearer "))
|
|
.ok_or(AuthError::MissingAuth)?;
|
|
|
|
// 1. Check Redis cache first
|
|
if let Ok(mut conn) = self.redis.get().await {
|
|
let cached: Option<String> = redis::cmd("GET")
|
|
.arg(format!("apikey:{}", &key[..8])) // prefix lookup
|
|
.query_async(&mut *conn)
|
|
.await
|
|
.unwrap_or(None);
|
|
|
|
if let Some(org_id) = cached {
|
|
return Ok(AuthContext {
|
|
org_id,
|
|
user_id: None,
|
|
role: Role::Member,
|
|
});
|
|
}
|
|
}
|
|
|
|
// 2. Fall back to PostgreSQL
|
|
let row = sqlx::query_as::<_, (String, String)>(
|
|
"SELECT org_id, key_hash FROM api_keys WHERE key_prefix = $1 AND revoked_at IS NULL"
|
|
)
|
|
.bind(&key[..8])
|
|
.fetch_optional(&self.pool)
|
|
.await
|
|
.map_err(|_| AuthError::InvalidKey)?
|
|
.ok_or(AuthError::InvalidKey)?;
|
|
|
|
// 3. Verify bcrypt hash
|
|
let valid = bcrypt::verify(key, &row.1).unwrap_or(false);
|
|
if !valid {
|
|
return Err(AuthError::InvalidKey);
|
|
}
|
|
|
|
// 4. Cache in Redis for next time (5 min TTL)
|
|
if let Ok(mut conn) = self.redis.get().await {
|
|
let _: Result<(), _> = redis::cmd("SETEX")
|
|
.arg(format!("apikey:{}", &key[..8]))
|
|
.arg(300)
|
|
.arg(&row.0)
|
|
.query_async(&mut *conn)
|
|
.await;
|
|
}
|
|
|
|
Ok(AuthContext {
|
|
org_id: row.0,
|
|
user_id: None,
|
|
role: Role::Member,
|
|
})
|
|
}
|
|
}
|