re-uploading work
This commit is contained in:
91
crates/api/Cargo.toml
Normal file
91
crates/api/Cargo.toml
Normal file
@@ -0,0 +1,91 @@
|
||||
[package]
|
||||
name = "attune-api"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "attune_api"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "attune-api"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
# Internal dependencies
|
||||
attune-common = { path = "../common" }
|
||||
attune-worker = { path = "../worker" }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# Web framework
|
||||
axum = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
|
||||
# Database
|
||||
sqlx = { workspace = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml_ng = { workspace = true }
|
||||
|
||||
# Logging and tracing
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
# Error handling
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
# Configuration
|
||||
config = { workspace = true }
|
||||
|
||||
# Date/Time
|
||||
chrono = { workspace = true }
|
||||
|
||||
# UUID
|
||||
uuid = { workspace = true }
|
||||
|
||||
# Validation
|
||||
validator = { workspace = true }
|
||||
|
||||
# CLI
|
||||
clap = { workspace = true }
|
||||
|
||||
# JSON Schema
|
||||
schemars = { workspace = true }
|
||||
jsonschema = { workspace = true }
|
||||
|
||||
# HTTP client
|
||||
reqwest = { workspace = true }
|
||||
|
||||
# Authentication
|
||||
jsonwebtoken = { version = "10.2", features = ["rust_crypto"] }
|
||||
argon2 = { workspace = true }
|
||||
rand = "0.9"
|
||||
|
||||
# HMAC and cryptography
|
||||
hmac = "0.12"
|
||||
sha1 = "0.10"
|
||||
sha2 = { workspace = true }
|
||||
hex = "0.4"
|
||||
|
||||
# OpenAPI/Swagger
|
||||
utoipa = { workspace = true, features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "9.0", features = ["axum"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mockall = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
reqwest-eventsource = { workspace = true }
|
||||
389
crates/api/src/auth/jwt.rs
Normal file
389
crates/api/src/auth/jwt.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
//! JWT token generation and validation
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum JwtError {
|
||||
#[error("Failed to encode JWT: {0}")]
|
||||
EncodeError(String),
|
||||
#[error("Failed to decode JWT: {0}")]
|
||||
DecodeError(String),
|
||||
#[error("Token has expired")]
|
||||
Expired,
|
||||
#[error("Invalid token")]
|
||||
Invalid,
|
||||
}
|
||||
|
||||
/// JWT Claims structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
/// Subject (identity ID)
|
||||
pub sub: String,
|
||||
/// Identity login
|
||||
pub login: String,
|
||||
/// Issued at (Unix timestamp)
|
||||
pub iat: i64,
|
||||
/// Expiration time (Unix timestamp)
|
||||
pub exp: i64,
|
||||
/// Token type (access or refresh)
|
||||
#[serde(default)]
|
||||
pub token_type: TokenType,
|
||||
/// Optional scope (e.g., "sensor", "service")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scope: Option<String>,
|
||||
/// Optional metadata (e.g., trigger_types for sensors)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TokenType {
|
||||
Access,
|
||||
Refresh,
|
||||
Sensor,
|
||||
}
|
||||
|
||||
impl Default for TokenType {
|
||||
fn default() -> Self {
|
||||
Self::Access
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for JWT tokens
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JwtConfig {
|
||||
/// Secret key for signing tokens
|
||||
pub secret: String,
|
||||
/// Access token expiration duration (in seconds)
|
||||
pub access_token_expiration: i64,
|
||||
/// Refresh token expiration duration (in seconds)
|
||||
pub refresh_token_expiration: i64,
|
||||
}
|
||||
|
||||
impl Default for JwtConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
secret: "insecure_default_secret_change_in_production".to_string(),
|
||||
access_token_expiration: 3600, // 1 hour
|
||||
refresh_token_expiration: 604800, // 7 days
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a JWT access token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID
|
||||
/// * `login` - The identity login
|
||||
/// * `config` - JWT configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_access_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
) -> Result<String, JwtError> {
|
||||
generate_token(identity_id, login, config, TokenType::Access)
|
||||
}
|
||||
|
||||
/// Generate a JWT refresh token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID
|
||||
/// * `login` - The identity login
|
||||
/// * `config` - JWT configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_refresh_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
) -> Result<String, JwtError> {
|
||||
generate_token(identity_id, login, config, TokenType::Refresh)
|
||||
}
|
||||
|
||||
/// Generate a JWT token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID
|
||||
/// * `login` - The identity login
|
||||
/// * `config` - JWT configuration
|
||||
/// * `token_type` - Type of token to generate
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
token_type: TokenType,
|
||||
) -> Result<String, JwtError> {
|
||||
let now = Utc::now();
|
||||
let expiration = match token_type {
|
||||
TokenType::Access => config.access_token_expiration,
|
||||
TokenType::Refresh => config.refresh_token_expiration,
|
||||
TokenType::Sensor => 86400, // Sensor tokens handled separately via generate_sensor_token()
|
||||
};
|
||||
|
||||
let exp = (now + Duration::seconds(expiration)).timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
sub: identity_id.to_string(),
|
||||
login: login.to_string(),
|
||||
iat: now.timestamp(),
|
||||
exp,
|
||||
token_type,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| JwtError::EncodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Generate a sensor token with specific trigger types
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID for the sensor
|
||||
/// * `sensor_ref` - The sensor reference (e.g., "sensor:core.timer")
|
||||
/// * `trigger_types` - List of trigger types this sensor can create events for
|
||||
/// * `config` - JWT configuration
|
||||
/// * `ttl_seconds` - Time to live in seconds (default: 24 hours)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_sensor_token(
|
||||
identity_id: i64,
|
||||
sensor_ref: &str,
|
||||
trigger_types: Vec<String>,
|
||||
config: &JwtConfig,
|
||||
ttl_seconds: Option<i64>,
|
||||
) -> Result<String, JwtError> {
|
||||
let now = Utc::now();
|
||||
let expiration = ttl_seconds.unwrap_or(86400); // Default: 24 hours
|
||||
let exp = (now + Duration::seconds(expiration)).timestamp();
|
||||
|
||||
let metadata = serde_json::json!({
|
||||
"trigger_types": trigger_types,
|
||||
});
|
||||
|
||||
let claims = Claims {
|
||||
sub: identity_id.to_string(),
|
||||
login: sensor_ref.to_string(),
|
||||
iat: now.timestamp(),
|
||||
exp,
|
||||
token_type: TokenType::Sensor,
|
||||
scope: Some("sensor".to_string()),
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| JwtError::EncodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Validate and decode a JWT token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - The JWT token string
|
||||
/// * `config` - JWT configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Claims, JwtError>` - The decoded claims if valid
|
||||
pub fn validate_token(token: &str, config: &JwtConfig) -> Result<Claims, JwtError> {
|
||||
let validation = Validation::default();
|
||||
|
||||
decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(config.secret.as_bytes()),
|
||||
&validation,
|
||||
)
|
||||
.map(|data| data.claims)
|
||||
.map_err(|e| {
|
||||
if e.to_string().contains("ExpiredSignature") {
|
||||
JwtError::Expired
|
||||
} else {
|
||||
JwtError::DecodeError(e.to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract token from Authorization header
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `auth_header` - The Authorization header value
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Option<&str>` - The token if present and valid format
|
||||
pub fn extract_token_from_header(auth_header: &str) -> Option<&str> {
|
||||
if auth_header.starts_with("Bearer ") {
|
||||
Some(&auth_header[7..])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> JwtConfig {
|
||||
JwtConfig {
|
||||
secret: "test_secret_key_for_testing".to_string(),
|
||||
access_token_expiration: 3600,
|
||||
refresh_token_expiration: 604800,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_validate_access_token() {
|
||||
let config = test_config();
|
||||
let token =
|
||||
generate_access_token(123, "testuser", &config).expect("Failed to generate token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "123");
|
||||
assert_eq!(claims.login, "testuser");
|
||||
assert_eq!(claims.token_type, TokenType::Access);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_validate_refresh_token() {
|
||||
let config = test_config();
|
||||
let token =
|
||||
generate_refresh_token(456, "anotheruser", &config).expect("Failed to generate token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "456");
|
||||
assert_eq!(claims.login, "anotheruser");
|
||||
assert_eq!(claims.token_type, TokenType::Refresh);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_token() {
|
||||
let config = test_config();
|
||||
let result = validate_token("invalid.token.here", &config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_with_wrong_secret() {
|
||||
let config = test_config();
|
||||
let token = generate_access_token(789, "user", &config).expect("Failed to generate token");
|
||||
|
||||
let wrong_config = JwtConfig {
|
||||
secret: "different_secret".to_string(),
|
||||
..config
|
||||
};
|
||||
|
||||
let result = validate_token(&token, &wrong_config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expired_token() {
|
||||
// Create a token that's already expired by setting exp in the past
|
||||
let now = Utc::now().timestamp();
|
||||
let expired_claims = Claims {
|
||||
sub: "999".to_string(),
|
||||
login: "expireduser".to_string(),
|
||||
iat: now - 3600,
|
||||
exp: now - 1800, // Expired 30 minutes ago
|
||||
token_type: TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let config = test_config();
|
||||
|
||||
let expired_token = encode(
|
||||
&Header::default(),
|
||||
&expired_claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.expect("Failed to encode token");
|
||||
|
||||
// Validate the expired token
|
||||
let result = validate_token(&expired_token, &config);
|
||||
assert!(matches!(result, Err(JwtError::Expired)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_from_header() {
|
||||
let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";
|
||||
let token = extract_token_from_header(header);
|
||||
assert_eq!(token, Some("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"));
|
||||
|
||||
let invalid_header = "Token abc123";
|
||||
let token = extract_token_from_header(invalid_header);
|
||||
assert_eq!(token, None);
|
||||
|
||||
let no_token = "Bearer ";
|
||||
let token = extract_token_from_header(no_token);
|
||||
assert_eq!(token, Some(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claims_serialization() {
|
||||
let claims = Claims {
|
||||
sub: "123".to_string(),
|
||||
login: "testuser".to_string(),
|
||||
iat: 1234567890,
|
||||
exp: 1234571490,
|
||||
token_type: TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&claims).expect("Failed to serialize");
|
||||
let deserialized: Claims = serde_json::from_str(&json).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(claims.sub, deserialized.sub);
|
||||
assert_eq!(claims.login, deserialized.login);
|
||||
assert_eq!(claims.token_type, deserialized.token_type);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_sensor_token() {
|
||||
let config = test_config();
|
||||
let trigger_types = vec!["core.timer".to_string(), "core.webhook".to_string()];
|
||||
|
||||
let token = generate_sensor_token(
|
||||
999,
|
||||
"sensor:core.timer",
|
||||
trigger_types.clone(),
|
||||
&config,
|
||||
Some(86400),
|
||||
)
|
||||
.expect("Failed to generate sensor token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "999");
|
||||
assert_eq!(claims.login, "sensor:core.timer");
|
||||
assert_eq!(claims.token_type, TokenType::Sensor);
|
||||
assert_eq!(claims.scope, Some("sensor".to_string()));
|
||||
|
||||
let metadata = claims.metadata.expect("Metadata should be present");
|
||||
let trigger_types_from_token = metadata["trigger_types"]
|
||||
.as_array()
|
||||
.expect("trigger_types should be an array");
|
||||
|
||||
assert_eq!(trigger_types_from_token.len(), 2);
|
||||
}
|
||||
}
|
||||
176
crates/api/src/auth/middleware.rs
Normal file
176
crates/api/src/auth/middleware.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
//! Authentication middleware for protecting routes
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{header::AUTHORIZATION, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::jwt::{extract_token_from_header, validate_token, Claims, JwtConfig, TokenType};
|
||||
|
||||
/// Authentication middleware state
|
||||
#[derive(Clone)]
|
||||
pub struct AuthMiddleware {
|
||||
pub jwt_config: Arc<JwtConfig>,
|
||||
}
|
||||
|
||||
impl AuthMiddleware {
|
||||
pub fn new(jwt_config: JwtConfig) -> Self {
|
||||
Self {
|
||||
jwt_config: Arc::new(jwt_config),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension type for storing authenticated claims in request
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AuthenticatedUser {
|
||||
pub claims: Claims,
|
||||
}
|
||||
|
||||
impl AuthenticatedUser {
|
||||
pub fn identity_id(&self) -> Result<i64, std::num::ParseIntError> {
|
||||
self.claims.sub.parse()
|
||||
}
|
||||
|
||||
pub fn login(&self) -> &str {
|
||||
&self.claims.login
|
||||
}
|
||||
}
|
||||
|
||||
/// Middleware function that validates JWT tokens
|
||||
pub async fn require_auth(
|
||||
State(auth): State<AuthMiddleware>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AuthError> {
|
||||
// Extract Authorization header
|
||||
let auth_header = request
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or(AuthError::MissingToken)?;
|
||||
|
||||
// Extract token from Bearer scheme
|
||||
let token = extract_token_from_header(auth_header).ok_or(AuthError::InvalidToken)?;
|
||||
|
||||
// Validate token
|
||||
let claims = validate_token(token, &auth.jwt_config).map_err(|e| match e {
|
||||
super::jwt::JwtError::Expired => AuthError::ExpiredToken,
|
||||
_ => AuthError::InvalidToken,
|
||||
})?;
|
||||
|
||||
// Add claims to request extensions
|
||||
request
|
||||
.extensions_mut()
|
||||
.insert(AuthenticatedUser { claims });
|
||||
|
||||
// Continue to next middleware/handler
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
/// Extractor for authenticated user
|
||||
pub struct RequireAuth(pub AuthenticatedUser);
|
||||
|
||||
impl axum::extract::FromRequestParts<crate::state::SharedState> for RequireAuth {
|
||||
type Rejection = AuthError;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
state: &crate::state::SharedState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
// First check if middleware already added the user
|
||||
if let Some(user) = parts.extensions.get::<AuthenticatedUser>() {
|
||||
return Ok(RequireAuth(user.clone()));
|
||||
}
|
||||
|
||||
// Otherwise, extract and validate token directly from header
|
||||
// Extract Authorization header
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or(AuthError::MissingToken)?;
|
||||
|
||||
// Extract token from Bearer scheme
|
||||
let token = extract_token_from_header(auth_header).ok_or(AuthError::InvalidToken)?;
|
||||
|
||||
// Validate token using jwt_config from app state
|
||||
let claims = validate_token(token, &state.jwt_config).map_err(|e| match e {
|
||||
super::jwt::JwtError::Expired => AuthError::ExpiredToken,
|
||||
_ => AuthError::InvalidToken,
|
||||
})?;
|
||||
|
||||
// Allow both access tokens and sensor tokens
|
||||
if claims.token_type != TokenType::Access && claims.token_type != TokenType::Sensor {
|
||||
return Err(AuthError::InvalidToken);
|
||||
}
|
||||
|
||||
Ok(RequireAuth(AuthenticatedUser { claims }))
|
||||
}
|
||||
}
|
||||
|
||||
/// Authentication errors
|
||||
#[derive(Debug)]
|
||||
pub enum AuthError {
|
||||
MissingToken,
|
||||
InvalidToken,
|
||||
ExpiredToken,
|
||||
Unauthorized,
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, message) = match self {
|
||||
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing authentication token"),
|
||||
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid authentication token"),
|
||||
AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "Authentication token expired"),
|
||||
AuthError::Unauthorized => (StatusCode::FORBIDDEN, "Insufficient permissions"),
|
||||
};
|
||||
|
||||
let body = Json(json!({
|
||||
"error": {
|
||||
"code": status.as_u16(),
|
||||
"message": message,
|
||||
}
|
||||
}));
|
||||
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_authenticated_user() {
|
||||
let claims = Claims {
|
||||
sub: "123".to_string(),
|
||||
login: "testuser".to_string(),
|
||||
iat: 1234567890,
|
||||
exp: 1234571490,
|
||||
token_type: super::super::jwt::TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let auth_user = AuthenticatedUser { claims };
|
||||
|
||||
assert_eq!(auth_user.identity_id().unwrap(), 123);
|
||||
assert_eq!(auth_user.login(), "testuser");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_from_header() {
|
||||
let token = extract_token_from_header("Bearer test.token.here");
|
||||
assert_eq!(token, Some("test.token.here"));
|
||||
|
||||
let no_bearer = extract_token_from_header("test.token.here");
|
||||
assert_eq!(no_bearer, None);
|
||||
}
|
||||
}
|
||||
9
crates/api/src/auth/mod.rs
Normal file
9
crates/api/src/auth/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! Authentication and authorization module
|
||||
|
||||
pub mod jwt;
|
||||
pub mod middleware;
|
||||
pub mod password;
|
||||
|
||||
pub use jwt::{generate_token, validate_token, Claims};
|
||||
pub use middleware::{AuthMiddleware, RequireAuth};
|
||||
pub use password::{hash_password, verify_password};
|
||||
108
crates/api/src/auth/password.rs
Normal file
108
crates/api/src/auth/password.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
//! Password hashing and verification using Argon2
|
||||
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PasswordError {
|
||||
#[error("Failed to hash password: {0}")]
|
||||
HashError(String),
|
||||
#[error("Failed to verify password: {0}")]
|
||||
VerifyError(String),
|
||||
#[error("Invalid password hash format")]
|
||||
InvalidHash,
|
||||
}
|
||||
|
||||
/// Hash a password using Argon2id
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `password` - The plaintext password to hash
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, PasswordError>` - The hashed password string (PHC format)
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use attune_api::auth::password::hash_password;
|
||||
///
|
||||
/// let hash = hash_password("my_secure_password").expect("Failed to hash password");
|
||||
/// assert!(!hash.is_empty());
|
||||
/// ```
|
||||
pub fn hash_password(password: &str) -> Result<String, PasswordError> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map(|hash| hash.to_string())
|
||||
.map_err(|e| PasswordError::HashError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Verify a password against a hash using Argon2id
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `password` - The plaintext password to verify
|
||||
/// * `hash` - The password hash string (PHC format)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<bool, PasswordError>` - True if password matches, false otherwise
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use attune_api::auth::password::{hash_password, verify_password};
|
||||
///
|
||||
/// let hash = hash_password("my_secure_password").expect("Failed to hash");
|
||||
/// let is_valid = verify_password("my_secure_password", &hash).expect("Failed to verify");
|
||||
/// assert!(is_valid);
|
||||
/// ```
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool, PasswordError> {
|
||||
let parsed_hash = PasswordHash::new(hash).map_err(|_| PasswordError::InvalidHash)?;
|
||||
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
match argon2.verify_password(password.as_bytes(), &parsed_hash) {
|
||||
Ok(_) => Ok(true),
|
||||
Err(argon2::password_hash::Error::Password) => Ok(false),
|
||||
Err(e) => Err(PasswordError::VerifyError(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hash_and_verify_password() {
|
||||
let password = "my_secure_password_123";
|
||||
let hash = hash_password(password).expect("Failed to hash password");
|
||||
|
||||
// Verify correct password
|
||||
assert!(verify_password(password, &hash).expect("Failed to verify"));
|
||||
|
||||
// Verify incorrect password
|
||||
assert!(!verify_password("wrong_password", &hash).expect("Failed to verify"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_produces_different_salts() {
|
||||
let password = "same_password";
|
||||
let hash1 = hash_password(password).expect("Failed to hash");
|
||||
let hash2 = hash_password(password).expect("Failed to hash");
|
||||
|
||||
// Hashes should be different due to different salts
|
||||
assert_ne!(hash1, hash2);
|
||||
|
||||
// But both should verify correctly
|
||||
assert!(verify_password(password, &hash1).expect("Failed to verify"));
|
||||
assert!(verify_password(password, &hash2).expect("Failed to verify"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_hash_format() {
|
||||
let result = verify_password("password", "not_a_valid_hash");
|
||||
assert!(matches!(result, Err(PasswordError::InvalidHash)));
|
||||
}
|
||||
}
|
||||
324
crates/api/src/dto/action.rs
Normal file
324
crates/api/src/dto/action.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
//! Action DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
/// Request DTO for creating a new action
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateActionRequest {
|
||||
/// Unique reference identifier (e.g., "core.http", "aws.ec2.start_instance")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference this action belongs to
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Post Message to Slack")]
|
||||
pub label: String,
|
||||
|
||||
/// Action description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Posts a message to a Slack channel")]
|
||||
pub description: String,
|
||||
|
||||
/// Entry point for action execution (e.g., path to script, function name)
|
||||
#[validate(length(min = 1, max = 1024))]
|
||||
#[schema(example = "/actions/slack/post_message.py")]
|
||||
pub entrypoint: String,
|
||||
|
||||
/// Optional runtime ID for this action
|
||||
#[schema(example = 1)]
|
||||
pub runtime: Option<i64>,
|
||||
|
||||
/// Parameter schema (JSON Schema) defining expected inputs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Object, nullable = true, example = json!({"type": "object", "properties": {"channel": {"type": "string"}, "message": {"type": "string"}}}))]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema (JSON Schema) defining expected outputs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Object, nullable = true, example = json!({"type": "object", "properties": {"message_id": {"type": "string"}}}))]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Request DTO for updating an action
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateActionRequest {
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Post Message to Slack (Updated)")]
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Action description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Posts a message to a Slack channel with enhanced features")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Entry point for action execution
|
||||
#[validate(length(min = 1, max = 1024))]
|
||||
#[schema(example = "/actions/slack/post_message_v2.py")]
|
||||
pub entrypoint: Option<String>,
|
||||
|
||||
/// Runtime ID
|
||||
#[schema(example = 1)]
|
||||
pub runtime: Option<i64>,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Response DTO for action information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ActionResponse {
|
||||
/// Action ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack ID
|
||||
#[schema(example = 1)]
|
||||
pub pack: i64,
|
||||
|
||||
/// Pack reference
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Post Message to Slack")]
|
||||
pub label: String,
|
||||
|
||||
/// Action description
|
||||
#[schema(example = "Posts a message to a Slack channel")]
|
||||
pub description: String,
|
||||
|
||||
/// Entry point
|
||||
#[schema(example = "/actions/slack/post_message.py")]
|
||||
pub entrypoint: String,
|
||||
|
||||
/// Runtime ID
|
||||
#[schema(example = 1)]
|
||||
pub runtime: Option<i64>,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
|
||||
/// Whether this is an ad-hoc action (not from pack installation)
|
||||
#[schema(example = false)]
|
||||
pub is_adhoc: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified action response (for list endpoints)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ActionSummary {
|
||||
/// Action ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Post Message to Slack")]
|
||||
pub label: String,
|
||||
|
||||
/// Action description
|
||||
#[schema(example = "Posts a message to a Slack channel")]
|
||||
pub description: String,
|
||||
|
||||
/// Entry point
|
||||
#[schema(example = "/actions/slack/post_message.py")]
|
||||
pub entrypoint: String,
|
||||
|
||||
/// Runtime ID
|
||||
#[schema(example = 1)]
|
||||
pub runtime: Option<i64>,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Convert from Action model to ActionResponse
|
||||
impl From<attune_common::models::action::Action> for ActionResponse {
|
||||
fn from(action: attune_common::models::action::Action) -> Self {
|
||||
Self {
|
||||
id: action.id,
|
||||
r#ref: action.r#ref,
|
||||
pack: action.pack,
|
||||
pack_ref: action.pack_ref,
|
||||
label: action.label,
|
||||
description: action.description,
|
||||
entrypoint: action.entrypoint,
|
||||
runtime: action.runtime,
|
||||
param_schema: action.param_schema,
|
||||
out_schema: action.out_schema,
|
||||
is_adhoc: action.is_adhoc,
|
||||
created: action.created,
|
||||
updated: action.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Action model to ActionSummary
|
||||
impl From<attune_common::models::action::Action> for ActionSummary {
|
||||
fn from(action: attune_common::models::action::Action) -> Self {
|
||||
Self {
|
||||
id: action.id,
|
||||
r#ref: action.r#ref,
|
||||
pack_ref: action.pack_ref,
|
||||
label: action.label,
|
||||
description: action.description,
|
||||
entrypoint: action.entrypoint,
|
||||
runtime: action.runtime,
|
||||
created: action.created,
|
||||
updated: action.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response DTO for queue statistics
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct QueueStatsResponse {
|
||||
/// Action ID
|
||||
#[schema(example = 1)]
|
||||
pub action_id: i64,
|
||||
|
||||
/// Action reference
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub action_ref: String,
|
||||
|
||||
/// Number of executions waiting in queue
|
||||
#[schema(example = 5)]
|
||||
pub queue_length: i32,
|
||||
|
||||
/// Number of currently running executions
|
||||
#[schema(example = 2)]
|
||||
pub active_count: i32,
|
||||
|
||||
/// Maximum concurrent executions allowed
|
||||
#[schema(example = 3)]
|
||||
pub max_concurrent: i32,
|
||||
|
||||
/// Timestamp of oldest queued execution (if any)
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub oldest_enqueued_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Total executions enqueued since queue creation
|
||||
#[schema(example = 100)]
|
||||
pub total_enqueued: i64,
|
||||
|
||||
/// Total executions completed since queue creation
|
||||
#[schema(example = 95)]
|
||||
pub total_completed: i64,
|
||||
|
||||
/// Timestamp of last statistics update
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub last_updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Convert from QueueStats repository model to QueueStatsResponse
|
||||
impl From<attune_common::repositories::queue_stats::QueueStats> for QueueStatsResponse {
|
||||
fn from(stats: attune_common::repositories::queue_stats::QueueStats) -> Self {
|
||||
Self {
|
||||
action_id: stats.action_id,
|
||||
action_ref: String::new(), // Will be populated by the handler
|
||||
queue_length: stats.queue_length,
|
||||
active_count: stats.active_count,
|
||||
max_concurrent: stats.max_concurrent,
|
||||
oldest_enqueued_at: stats.oldest_enqueued_at,
|
||||
total_enqueued: stats.total_enqueued,
|
||||
total_completed: stats.total_completed,
|
||||
last_updated: stats.last_updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_action_request_validation() {
|
||||
let req = CreateActionRequest {
|
||||
r#ref: "".to_string(), // Invalid: empty
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
entrypoint: "/actions/test.py".to_string(),
|
||||
runtime: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_action_request_valid() {
|
||||
let req = CreateActionRequest {
|
||||
r#ref: "test.action".to_string(),
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
entrypoint: "/actions/test.py".to_string(),
|
||||
runtime: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_action_request_all_none() {
|
||||
let req = UpdateActionRequest {
|
||||
label: None,
|
||||
description: None,
|
||||
entrypoint: None,
|
||||
runtime: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
};
|
||||
|
||||
// Should be valid even with all None values
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
}
|
||||
138
crates/api/src/dto/auth.rs
Normal file
138
crates/api/src/dto/auth.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
//! Authentication DTOs
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
/// Login request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct LoginRequest {
|
||||
/// Identity login (username)
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "admin")]
|
||||
pub login: String,
|
||||
|
||||
/// Password
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "changeme123")]
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
/// Register request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct RegisterRequest {
|
||||
/// Identity login (username)
|
||||
#[validate(length(min = 3, max = 255))]
|
||||
#[schema(example = "newuser")]
|
||||
pub login: String,
|
||||
|
||||
/// Password
|
||||
#[validate(length(min = 8, max = 128))]
|
||||
#[schema(example = "SecurePass123!")]
|
||||
pub password: String,
|
||||
|
||||
/// Display name (optional)
|
||||
#[validate(length(max = 255))]
|
||||
#[schema(example = "New User")]
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Token response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct TokenResponse {
|
||||
/// Access token (JWT)
|
||||
#[schema(example = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...")]
|
||||
pub access_token: String,
|
||||
|
||||
/// Refresh token
|
||||
#[schema(example = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...")]
|
||||
pub refresh_token: String,
|
||||
|
||||
/// Token type (always "Bearer")
|
||||
#[schema(example = "Bearer")]
|
||||
pub token_type: String,
|
||||
|
||||
/// Access token expiration in seconds
|
||||
#[schema(example = 3600)]
|
||||
pub expires_in: i64,
|
||||
|
||||
/// User information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<UserInfo>,
|
||||
}
|
||||
|
||||
/// User information included in token response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct UserInfo {
|
||||
/// Identity ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Identity login
|
||||
#[schema(example = "admin")]
|
||||
pub login: String,
|
||||
|
||||
/// Display name
|
||||
#[schema(example = "Administrator")]
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
impl TokenResponse {
|
||||
pub fn new(access_token: String, refresh_token: String, expires_in: i64) -> Self {
|
||||
Self {
|
||||
access_token,
|
||||
refresh_token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in,
|
||||
user: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_user(mut self, id: i64, login: String, display_name: Option<String>) -> Self {
|
||||
self.user = Some(UserInfo {
|
||||
id,
|
||||
login,
|
||||
display_name,
|
||||
});
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Refresh token request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct RefreshTokenRequest {
|
||||
/// Refresh token
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...")]
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
/// Change password request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct ChangePasswordRequest {
|
||||
/// Current password
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "OldPassword123!")]
|
||||
pub current_password: String,
|
||||
|
||||
/// New password
|
||||
#[validate(length(min = 8, max = 128))]
|
||||
#[schema(example = "NewPassword456!")]
|
||||
pub new_password: String,
|
||||
}
|
||||
|
||||
/// Current user response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct CurrentUserResponse {
|
||||
/// Identity ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Identity login
|
||||
#[schema(example = "admin")]
|
||||
pub login: String,
|
||||
|
||||
/// Display name
|
||||
#[schema(example = "Administrator")]
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
221
crates/api/src/dto/common.rs
Normal file
221
crates/api/src/dto/common.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
//! Common DTO types used across all API endpoints
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
/// Pagination parameters for list endpoints
|
||||
#[derive(Debug, Clone, Deserialize, IntoParams)]
|
||||
pub struct PaginationParams {
|
||||
/// Page number (1-based)
|
||||
#[serde(default = "default_page")]
|
||||
#[param(example = 1, minimum = 1)]
|
||||
pub page: u32,
|
||||
|
||||
/// Number of items per page
|
||||
#[serde(default = "default_page_size")]
|
||||
#[param(example = 50, minimum = 1, maximum = 100)]
|
||||
pub page_size: u32,
|
||||
}
|
||||
|
||||
fn default_page() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_page_size() -> u32 {
|
||||
50
|
||||
}
|
||||
|
||||
impl PaginationParams {
|
||||
/// Get the SQL offset value
|
||||
pub fn offset(&self) -> u32 {
|
||||
(self.page.saturating_sub(1)) * self.page_size
|
||||
}
|
||||
|
||||
/// Get the SQL limit value
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.page_size.min(100) // Max 100 items per page
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PaginationParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
page: default_page(),
|
||||
page_size: default_page_size(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Paginated response wrapper
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PaginatedResponse<T> {
|
||||
/// The data items
|
||||
pub data: Vec<T>,
|
||||
|
||||
/// Pagination metadata
|
||||
pub pagination: PaginationMeta,
|
||||
}
|
||||
|
||||
/// Pagination metadata
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PaginationMeta {
|
||||
/// Current page number (1-based)
|
||||
#[schema(example = 1)]
|
||||
pub page: u32,
|
||||
|
||||
/// Number of items per page
|
||||
#[schema(example = 50)]
|
||||
pub page_size: u32,
|
||||
|
||||
/// Total number of items
|
||||
#[schema(example = 150)]
|
||||
pub total_items: u64,
|
||||
|
||||
/// Total number of pages
|
||||
#[schema(example = 3)]
|
||||
pub total_pages: u32,
|
||||
}
|
||||
|
||||
impl PaginationMeta {
|
||||
/// Create pagination metadata
|
||||
pub fn new(page: u32, page_size: u32, total_items: u64) -> Self {
|
||||
let total_pages = if page_size > 0 {
|
||||
((total_items as f64) / (page_size as f64)).ceil() as u32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Self {
|
||||
page,
|
||||
page_size,
|
||||
total_items,
|
||||
total_pages,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PaginatedResponse<T> {
|
||||
/// Create a new paginated response
|
||||
pub fn new(data: Vec<T>, params: &PaginationParams, total_items: u64) -> Self {
|
||||
Self {
|
||||
data,
|
||||
pagination: PaginationMeta::new(params.page, params.page_size, total_items),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Standard API response wrapper
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ApiResponse<T> {
|
||||
/// Response data
|
||||
pub data: T,
|
||||
|
||||
/// Optional message
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
impl<T> ApiResponse<T> {
|
||||
/// Create a new API response
|
||||
pub fn new(data: T) -> Self {
|
||||
Self {
|
||||
data,
|
||||
message: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an API response with a message
|
||||
pub fn with_message(data: T, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
data,
|
||||
message: Some(message.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Success message response (for operations that don't return data)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct SuccessResponse {
|
||||
/// Success indicator
|
||||
#[schema(example = true)]
|
||||
pub success: bool,
|
||||
|
||||
/// Message describing the operation
|
||||
#[schema(example = "Operation completed successfully")]
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl SuccessResponse {
|
||||
/// Create a success response
|
||||
pub fn new(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pagination_params_offset() {
|
||||
let params = PaginationParams {
|
||||
page: 1,
|
||||
page_size: 10,
|
||||
};
|
||||
assert_eq!(params.offset(), 0);
|
||||
|
||||
let params = PaginationParams {
|
||||
page: 2,
|
||||
page_size: 10,
|
||||
};
|
||||
assert_eq!(params.offset(), 10);
|
||||
|
||||
let params = PaginationParams {
|
||||
page: 3,
|
||||
page_size: 25,
|
||||
};
|
||||
assert_eq!(params.offset(), 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pagination_params_limit() {
|
||||
let params = PaginationParams {
|
||||
page: 1,
|
||||
page_size: 50,
|
||||
};
|
||||
assert_eq!(params.limit(), 50);
|
||||
|
||||
// Should cap at 100
|
||||
let params = PaginationParams {
|
||||
page: 1,
|
||||
page_size: 200,
|
||||
};
|
||||
assert_eq!(params.limit(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pagination_meta() {
|
||||
let meta = PaginationMeta::new(1, 10, 45);
|
||||
assert_eq!(meta.page, 1);
|
||||
assert_eq!(meta.page_size, 10);
|
||||
assert_eq!(meta.total_items, 45);
|
||||
assert_eq!(meta.total_pages, 5);
|
||||
|
||||
let meta = PaginationMeta::new(2, 20, 100);
|
||||
assert_eq!(meta.total_pages, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paginated_response() {
|
||||
let data = vec![1, 2, 3, 4, 5];
|
||||
let params = PaginationParams::default();
|
||||
let response = PaginatedResponse::new(data.clone(), ¶ms, 100);
|
||||
|
||||
assert_eq!(response.data, data);
|
||||
assert_eq!(response.pagination.total_items, 100);
|
||||
assert_eq!(response.pagination.page, 1);
|
||||
}
|
||||
}
|
||||
344
crates/api/src/dto/event.rs
Normal file
344
crates/api/src/dto/event.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
//! Event and Enforcement data transfer objects
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
use attune_common::models::{
|
||||
enums::{EnforcementCondition, EnforcementStatus},
|
||||
event::{Enforcement, Event},
|
||||
Id, JsonDict,
|
||||
};
|
||||
|
||||
/// Full event response with all details
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct EventResponse {
|
||||
/// Event ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Trigger ID
|
||||
#[schema(example = 1)]
|
||||
pub trigger: Option<Id>,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "core.webhook")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Event configuration
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub config: Option<JsonDict>,
|
||||
|
||||
/// Event payload data
|
||||
#[schema(value_type = Object, example = json!({"url": "/webhook", "method": "POST"}))]
|
||||
pub payload: Option<JsonDict>,
|
||||
|
||||
/// Source ID (sensor that generated this event)
|
||||
#[schema(example = 1)]
|
||||
pub source: Option<Id>,
|
||||
|
||||
/// Source reference
|
||||
#[schema(example = "monitoring.webhook_sensor")]
|
||||
pub source_ref: Option<String>,
|
||||
|
||||
/// Rule ID (if event was generated by a specific rule)
|
||||
#[schema(example = 1)]
|
||||
pub rule: Option<Id>,
|
||||
|
||||
/// Rule reference (if event was generated by a specific rule)
|
||||
#[schema(example = "core.timer_rule")]
|
||||
pub rule_ref: Option<String>,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Event> for EventResponse {
|
||||
fn from(event: Event) -> Self {
|
||||
Self {
|
||||
id: event.id,
|
||||
trigger: event.trigger,
|
||||
trigger_ref: event.trigger_ref,
|
||||
config: event.config,
|
||||
payload: event.payload,
|
||||
source: event.source,
|
||||
source_ref: event.source_ref,
|
||||
rule: event.rule,
|
||||
rule_ref: event.rule_ref,
|
||||
created: event.created,
|
||||
updated: event.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary event response for list views
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct EventSummary {
|
||||
/// Event ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Trigger ID
|
||||
#[schema(example = 1)]
|
||||
pub trigger: Option<Id>,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "core.webhook")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Source ID
|
||||
#[schema(example = 1)]
|
||||
pub source: Option<Id>,
|
||||
|
||||
/// Source reference
|
||||
#[schema(example = "monitoring.webhook_sensor")]
|
||||
pub source_ref: Option<String>,
|
||||
|
||||
/// Rule ID (if event was generated by a specific rule)
|
||||
#[schema(example = 1)]
|
||||
pub rule: Option<Id>,
|
||||
|
||||
/// Rule reference (if event was generated by a specific rule)
|
||||
#[schema(example = "core.timer_rule")]
|
||||
pub rule_ref: Option<String>,
|
||||
|
||||
/// Whether event has payload data
|
||||
#[schema(example = true)]
|
||||
pub has_payload: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Event> for EventSummary {
|
||||
fn from(event: Event) -> Self {
|
||||
Self {
|
||||
id: event.id,
|
||||
trigger: event.trigger,
|
||||
trigger_ref: event.trigger_ref,
|
||||
source: event.source,
|
||||
source_ref: event.source_ref,
|
||||
rule: event.rule,
|
||||
rule_ref: event.rule_ref,
|
||||
has_payload: event.payload.is_some(),
|
||||
created: event.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query parameters for filtering events
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, IntoParams)]
|
||||
pub struct EventQueryParams {
|
||||
/// Filter by trigger ID
|
||||
#[param(example = 1)]
|
||||
pub trigger: Option<Id>,
|
||||
|
||||
/// Filter by trigger reference
|
||||
#[param(example = "core.webhook")]
|
||||
pub trigger_ref: Option<String>,
|
||||
|
||||
/// Filter by source ID
|
||||
#[param(example = 1)]
|
||||
pub source: Option<Id>,
|
||||
|
||||
/// Page number (1-indexed)
|
||||
#[serde(default = "default_page")]
|
||||
#[param(example = 1, minimum = 1)]
|
||||
pub page: u32,
|
||||
|
||||
/// Items per page
|
||||
#[serde(default = "default_per_page")]
|
||||
#[param(example = 50, minimum = 1, maximum = 100)]
|
||||
pub per_page: u32,
|
||||
}
|
||||
|
||||
fn default_page() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_per_page() -> u32 {
|
||||
50
|
||||
}
|
||||
|
||||
impl EventQueryParams {
|
||||
/// Get the offset for pagination
|
||||
pub fn offset(&self) -> u32 {
|
||||
(self.page - 1) * self.per_page
|
||||
}
|
||||
|
||||
/// Get the limit for pagination
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.per_page
|
||||
}
|
||||
}
|
||||
|
||||
/// Full enforcement response with all details
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct EnforcementResponse {
|
||||
/// Enforcement ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Rule ID
|
||||
#[schema(example = 1)]
|
||||
pub rule: Option<Id>,
|
||||
|
||||
/// Rule reference
|
||||
#[schema(example = "slack.notify_on_error")]
|
||||
pub rule_ref: String,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "system.error_event")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Enforcement configuration
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub config: Option<JsonDict>,
|
||||
|
||||
/// Event ID that triggered this enforcement
|
||||
#[schema(example = 1)]
|
||||
pub event: Option<Id>,
|
||||
|
||||
/// Enforcement status
|
||||
#[schema(example = "succeeded")]
|
||||
pub status: EnforcementStatus,
|
||||
|
||||
/// Enforcement payload
|
||||
#[schema(value_type = Object)]
|
||||
pub payload: JsonDict,
|
||||
|
||||
/// Enforcement condition
|
||||
#[schema(example = "matched")]
|
||||
pub condition: EnforcementCondition,
|
||||
|
||||
/// Enforcement conditions (rule evaluation criteria)
|
||||
#[schema(value_type = Object)]
|
||||
pub conditions: JsonValue,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Enforcement> for EnforcementResponse {
|
||||
fn from(enforcement: Enforcement) -> Self {
|
||||
Self {
|
||||
id: enforcement.id,
|
||||
rule: enforcement.rule,
|
||||
rule_ref: enforcement.rule_ref,
|
||||
trigger_ref: enforcement.trigger_ref,
|
||||
config: enforcement.config,
|
||||
event: enforcement.event,
|
||||
status: enforcement.status,
|
||||
payload: enforcement.payload,
|
||||
condition: enforcement.condition,
|
||||
conditions: enforcement.conditions,
|
||||
created: enforcement.created,
|
||||
updated: enforcement.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary enforcement response for list views
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct EnforcementSummary {
|
||||
/// Enforcement ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Rule ID
|
||||
#[schema(example = 1)]
|
||||
pub rule: Option<Id>,
|
||||
|
||||
/// Rule reference
|
||||
#[schema(example = "slack.notify_on_error")]
|
||||
pub rule_ref: String,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "system.error_event")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Event ID
|
||||
#[schema(example = 1)]
|
||||
pub event: Option<Id>,
|
||||
|
||||
/// Enforcement status
|
||||
#[schema(example = "succeeded")]
|
||||
pub status: EnforcementStatus,
|
||||
|
||||
/// Enforcement condition
|
||||
#[schema(example = "matched")]
|
||||
pub condition: EnforcementCondition,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Enforcement> for EnforcementSummary {
|
||||
fn from(enforcement: Enforcement) -> Self {
|
||||
Self {
|
||||
id: enforcement.id,
|
||||
rule: enforcement.rule,
|
||||
rule_ref: enforcement.rule_ref,
|
||||
trigger_ref: enforcement.trigger_ref,
|
||||
event: enforcement.event,
|
||||
status: enforcement.status,
|
||||
condition: enforcement.condition,
|
||||
created: enforcement.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query parameters for filtering enforcements
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, IntoParams)]
|
||||
pub struct EnforcementQueryParams {
|
||||
/// Filter by rule ID
|
||||
#[param(example = 1)]
|
||||
pub rule: Option<Id>,
|
||||
|
||||
/// Filter by event ID
|
||||
#[param(example = 1)]
|
||||
pub event: Option<Id>,
|
||||
|
||||
/// Filter by status
|
||||
#[param(example = "success")]
|
||||
pub status: Option<EnforcementStatus>,
|
||||
|
||||
/// Filter by trigger reference
|
||||
#[param(example = "core.webhook")]
|
||||
pub trigger_ref: Option<String>,
|
||||
|
||||
/// Page number (1-indexed)
|
||||
#[serde(default = "default_page")]
|
||||
#[param(example = 1, minimum = 1)]
|
||||
pub page: u32,
|
||||
|
||||
/// Items per page
|
||||
#[serde(default = "default_per_page")]
|
||||
#[param(example = 50, minimum = 1, maximum = 100)]
|
||||
pub per_page: u32,
|
||||
}
|
||||
|
||||
impl EnforcementQueryParams {
|
||||
/// Get the offset for pagination
|
||||
pub fn offset(&self) -> u32 {
|
||||
(self.page - 1) * self.per_page
|
||||
}
|
||||
|
||||
/// Get the limit for pagination
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.per_page
|
||||
}
|
||||
}
|
||||
283
crates/api/src/dto/execution.rs
Normal file
283
crates/api/src/dto/execution.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
//! Execution DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
use attune_common::models::enums::ExecutionStatus;
|
||||
|
||||
/// Request DTO for creating a manual execution
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
pub struct CreateExecutionRequest {
|
||||
/// Action reference to execute
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub action_ref: String,
|
||||
|
||||
/// Execution parameters/configuration
|
||||
#[schema(value_type = Object, example = json!({"channel": "#alerts", "message": "Manual test"}))]
|
||||
pub parameters: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Response DTO for execution information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ExecutionResponse {
|
||||
/// Execution ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Action ID (optional, may be null for ad-hoc executions)
|
||||
#[schema(example = 1)]
|
||||
pub action: Option<i64>,
|
||||
|
||||
/// Action reference
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub action_ref: String,
|
||||
|
||||
/// Execution configuration/parameters
|
||||
#[schema(value_type = Object, example = json!({"channel": "#alerts", "message": "System error detected"}))]
|
||||
pub config: Option<JsonValue>,
|
||||
|
||||
/// Parent execution ID (for nested/child executions)
|
||||
#[schema(example = 1)]
|
||||
pub parent: Option<i64>,
|
||||
|
||||
/// Enforcement ID (rule enforcement that triggered this)
|
||||
#[schema(example = 1)]
|
||||
pub enforcement: Option<i64>,
|
||||
|
||||
/// Executor ID (worker/executor that ran this)
|
||||
#[schema(example = 1)]
|
||||
pub executor: Option<i64>,
|
||||
|
||||
/// Execution status
|
||||
#[schema(example = "succeeded")]
|
||||
pub status: ExecutionStatus,
|
||||
|
||||
/// Execution result/output
|
||||
#[schema(value_type = Object, example = json!({"message_id": "1234567890.123456"}))]
|
||||
pub result: Option<JsonValue>,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:35:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified execution response (for list endpoints)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ExecutionSummary {
|
||||
/// Execution ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Action reference
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub action_ref: String,
|
||||
|
||||
/// Execution status
|
||||
#[schema(example = "succeeded")]
|
||||
pub status: ExecutionStatus,
|
||||
|
||||
/// Parent execution ID
|
||||
#[schema(example = 1)]
|
||||
pub parent: Option<i64>,
|
||||
|
||||
/// Enforcement ID
|
||||
#[schema(example = 1)]
|
||||
pub enforcement: Option<i64>,
|
||||
|
||||
/// Rule reference (if triggered by a rule)
|
||||
#[schema(example = "core.on_timer")]
|
||||
pub rule_ref: Option<String>,
|
||||
|
||||
/// Trigger reference (if triggered by a trigger)
|
||||
#[schema(example = "core.timer")]
|
||||
pub trigger_ref: Option<String>,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:35:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Query parameters for filtering executions
|
||||
#[derive(Debug, Clone, Deserialize, IntoParams)]
|
||||
pub struct ExecutionQueryParams {
|
||||
/// Filter by execution status
|
||||
#[param(example = "succeeded")]
|
||||
pub status: Option<ExecutionStatus>,
|
||||
|
||||
/// Filter by action reference
|
||||
#[param(example = "slack.post_message")]
|
||||
pub action_ref: Option<String>,
|
||||
|
||||
/// Filter by pack name
|
||||
#[param(example = "core")]
|
||||
pub pack_name: Option<String>,
|
||||
|
||||
/// Filter by rule reference
|
||||
#[param(example = "core.on_timer")]
|
||||
pub rule_ref: Option<String>,
|
||||
|
||||
/// Filter by trigger reference
|
||||
#[param(example = "core.timer")]
|
||||
pub trigger_ref: Option<String>,
|
||||
|
||||
/// Filter by executor ID
|
||||
#[param(example = 1)]
|
||||
pub executor: Option<i64>,
|
||||
|
||||
/// Search in result JSON (case-insensitive substring match)
|
||||
#[param(example = "error")]
|
||||
pub result_contains: Option<String>,
|
||||
|
||||
/// Filter by enforcement ID
|
||||
#[param(example = 1)]
|
||||
pub enforcement: Option<i64>,
|
||||
|
||||
/// Filter by parent execution ID
|
||||
#[param(example = 1)]
|
||||
pub parent: Option<i64>,
|
||||
|
||||
/// Page number (for pagination)
|
||||
#[serde(default = "default_page")]
|
||||
#[param(example = 1, minimum = 1)]
|
||||
pub page: u32,
|
||||
|
||||
/// Items per page (for pagination)
|
||||
#[serde(default = "default_per_page")]
|
||||
#[param(example = 50, minimum = 1, maximum = 100)]
|
||||
pub per_page: u32,
|
||||
}
|
||||
|
||||
impl ExecutionQueryParams {
|
||||
/// Get the SQL offset value
|
||||
pub fn offset(&self) -> u32 {
|
||||
(self.page.saturating_sub(1)) * self.per_page
|
||||
}
|
||||
|
||||
/// Get the limit value (with max cap)
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.per_page.min(100)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Execution model to ExecutionResponse
|
||||
impl From<attune_common::models::execution::Execution> for ExecutionResponse {
|
||||
fn from(execution: attune_common::models::execution::Execution) -> Self {
|
||||
Self {
|
||||
id: execution.id,
|
||||
action: execution.action,
|
||||
action_ref: execution.action_ref,
|
||||
config: execution
|
||||
.config
|
||||
.map(|c| serde_json::to_value(c).unwrap_or(JsonValue::Null)),
|
||||
parent: execution.parent,
|
||||
enforcement: execution.enforcement,
|
||||
executor: execution.executor,
|
||||
status: execution.status,
|
||||
result: execution
|
||||
.result
|
||||
.map(|r| serde_json::to_value(r).unwrap_or(JsonValue::Null)),
|
||||
created: execution.created,
|
||||
updated: execution.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Execution model to ExecutionSummary
|
||||
impl From<attune_common::models::execution::Execution> for ExecutionSummary {
|
||||
fn from(execution: attune_common::models::execution::Execution) -> Self {
|
||||
Self {
|
||||
id: execution.id,
|
||||
action_ref: execution.action_ref,
|
||||
status: execution.status,
|
||||
parent: execution.parent,
|
||||
enforcement: execution.enforcement,
|
||||
rule_ref: None, // Populated separately via enforcement lookup
|
||||
trigger_ref: None, // Populated separately via enforcement lookup
|
||||
created: execution.created,
|
||||
updated: execution.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_page() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_per_page() -> u32 {
|
||||
20
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_query_params_defaults() {
|
||||
let json = r#"{}"#;
|
||||
let params: ExecutionQueryParams = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(params.page, 1);
|
||||
assert_eq!(params.per_page, 20);
|
||||
assert!(params.status.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_params_with_filters() {
|
||||
let json = r#"{
|
||||
"status": "completed",
|
||||
"action_ref": "test.action",
|
||||
"page": 2,
|
||||
"per_page": 50
|
||||
}"#;
|
||||
let params: ExecutionQueryParams = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(params.page, 2);
|
||||
assert_eq!(params.per_page, 50);
|
||||
assert_eq!(params.status, Some(ExecutionStatus::Completed));
|
||||
assert_eq!(params.action_ref, Some("test.action".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_params_offset() {
|
||||
let params = ExecutionQueryParams {
|
||||
status: None,
|
||||
action_ref: None,
|
||||
enforcement: None,
|
||||
parent: None,
|
||||
pack_name: None,
|
||||
rule_ref: None,
|
||||
trigger_ref: None,
|
||||
executor: None,
|
||||
result_contains: None,
|
||||
page: 3,
|
||||
per_page: 20,
|
||||
};
|
||||
assert_eq!(params.offset(), 40); // (3-1) * 20
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_params_limit_cap() {
|
||||
let params = ExecutionQueryParams {
|
||||
status: None,
|
||||
action_ref: None,
|
||||
enforcement: None,
|
||||
parent: None,
|
||||
pack_name: None,
|
||||
rule_ref: None,
|
||||
trigger_ref: None,
|
||||
executor: None,
|
||||
result_contains: None,
|
||||
page: 1,
|
||||
per_page: 200, // Exceeds max
|
||||
};
|
||||
assert_eq!(params.limit(), 100); // Capped at 100
|
||||
}
|
||||
}
|
||||
215
crates/api/src/dto/inquiry.rs
Normal file
215
crates/api/src/dto/inquiry.rs
Normal file
@@ -0,0 +1,215 @@
|
||||
//! Inquiry data transfer objects
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::models::{enums::InquiryStatus, inquiry::Inquiry, Id, JsonDict, JsonSchema};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
/// Full inquiry response with all details
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct InquiryResponse {
|
||||
/// Inquiry ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Execution ID this inquiry belongs to
|
||||
#[schema(example = 1)]
|
||||
pub execution: Id,
|
||||
|
||||
/// Prompt text displayed to the user
|
||||
#[schema(example = "Approve deployment to production?")]
|
||||
pub prompt: String,
|
||||
|
||||
/// JSON schema for expected response
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub response_schema: Option<JsonSchema>,
|
||||
|
||||
/// Identity ID this inquiry is assigned to
|
||||
#[schema(example = 1)]
|
||||
pub assigned_to: Option<Id>,
|
||||
|
||||
/// Current status of the inquiry
|
||||
#[schema(example = "pending")]
|
||||
pub status: InquiryStatus,
|
||||
|
||||
/// Response data provided by the user
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub response: Option<JsonDict>,
|
||||
|
||||
/// When the inquiry expires
|
||||
#[schema(example = "2024-01-13T11:30:00Z")]
|
||||
pub timeout_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// When the inquiry was responded to
|
||||
#[schema(example = "2024-01-13T10:45:00Z")]
|
||||
pub responded_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:45:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Inquiry> for InquiryResponse {
|
||||
fn from(inquiry: Inquiry) -> Self {
|
||||
Self {
|
||||
id: inquiry.id,
|
||||
execution: inquiry.execution,
|
||||
prompt: inquiry.prompt,
|
||||
response_schema: inquiry.response_schema,
|
||||
assigned_to: inquiry.assigned_to,
|
||||
status: inquiry.status,
|
||||
response: inquiry.response,
|
||||
timeout_at: inquiry.timeout_at,
|
||||
responded_at: inquiry.responded_at,
|
||||
created: inquiry.created,
|
||||
updated: inquiry.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary inquiry response for list views
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct InquirySummary {
|
||||
/// Inquiry ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Execution ID
|
||||
#[schema(example = 1)]
|
||||
pub execution: Id,
|
||||
|
||||
/// Prompt text
|
||||
#[schema(example = "Approve deployment to production?")]
|
||||
pub prompt: String,
|
||||
|
||||
/// Assigned identity ID
|
||||
#[schema(example = 1)]
|
||||
pub assigned_to: Option<Id>,
|
||||
|
||||
/// Inquiry status
|
||||
#[schema(example = "pending")]
|
||||
pub status: InquiryStatus,
|
||||
|
||||
/// Whether a response has been provided
|
||||
#[schema(example = false)]
|
||||
pub has_response: bool,
|
||||
|
||||
/// Timeout timestamp
|
||||
#[schema(example = "2024-01-13T11:30:00Z")]
|
||||
pub timeout_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Inquiry> for InquirySummary {
|
||||
fn from(inquiry: Inquiry) -> Self {
|
||||
Self {
|
||||
id: inquiry.id,
|
||||
execution: inquiry.execution,
|
||||
prompt: inquiry.prompt,
|
||||
assigned_to: inquiry.assigned_to,
|
||||
status: inquiry.status,
|
||||
has_response: inquiry.response.is_some(),
|
||||
timeout_at: inquiry.timeout_at,
|
||||
created: inquiry.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to create a new inquiry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateInquiryRequest {
|
||||
/// Execution ID this inquiry belongs to
|
||||
#[schema(example = 1)]
|
||||
pub execution: Id,
|
||||
|
||||
/// Prompt text to display to the user
|
||||
#[validate(length(min = 1, max = 10000))]
|
||||
#[schema(example = "Approve deployment to production?")]
|
||||
pub prompt: String,
|
||||
|
||||
/// Optional JSON schema for the expected response format
|
||||
#[schema(value_type = Object, example = json!({"type": "object", "properties": {"approved": {"type": "boolean"}}}))]
|
||||
pub response_schema: Option<JsonSchema>,
|
||||
|
||||
/// Optional identity ID to assign this inquiry to
|
||||
#[schema(example = 1)]
|
||||
pub assigned_to: Option<Id>,
|
||||
|
||||
/// Optional timeout timestamp (when inquiry expires)
|
||||
#[schema(example = "2024-01-13T11:30:00Z")]
|
||||
pub timeout_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Request to update an inquiry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateInquiryRequest {
|
||||
/// Update the inquiry status
|
||||
#[schema(example = "responded")]
|
||||
pub status: Option<InquiryStatus>,
|
||||
|
||||
/// Update the response data
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub response: Option<JsonDict>,
|
||||
|
||||
/// Update the assigned_to identity
|
||||
#[schema(example = 2)]
|
||||
pub assigned_to: Option<Id>,
|
||||
}
|
||||
|
||||
/// Request to respond to an inquiry (user-facing endpoint)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct InquiryRespondRequest {
|
||||
/// Response data conforming to the inquiry's response_schema
|
||||
#[schema(value_type = Object)]
|
||||
pub response: JsonValue,
|
||||
}
|
||||
|
||||
/// Query parameters for filtering inquiries
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, IntoParams)]
|
||||
pub struct InquiryQueryParams {
|
||||
/// Filter by status
|
||||
#[param(example = "pending")]
|
||||
pub status: Option<InquiryStatus>,
|
||||
|
||||
/// Filter by execution ID
|
||||
#[param(example = 1)]
|
||||
pub execution: Option<Id>,
|
||||
|
||||
/// Filter by assigned identity
|
||||
#[param(example = 1)]
|
||||
pub assigned_to: Option<Id>,
|
||||
|
||||
/// Pagination offset
|
||||
#[param(example = 0)]
|
||||
pub offset: Option<usize>,
|
||||
|
||||
/// Pagination limit
|
||||
#[param(example = 50)]
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
/// Paginated list response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ListResponse<T> {
|
||||
/// List of items
|
||||
pub data: Vec<T>,
|
||||
|
||||
/// Total count of items (before pagination)
|
||||
pub total: usize,
|
||||
|
||||
/// Offset used for this page
|
||||
pub offset: usize,
|
||||
|
||||
/// Limit used for this page
|
||||
pub limit: usize,
|
||||
}
|
||||
270
crates/api/src/dto/key.rs
Normal file
270
crates/api/src/dto/key.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
//! Key/Secret data transfer objects
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::models::{key::Key, Id, OwnerType};
|
||||
|
||||
/// Full key response with all details (value redacted in list views)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct KeyResponse {
|
||||
/// Unique key ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "github_token")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Type of owner
|
||||
pub owner_type: OwnerType,
|
||||
|
||||
/// Owner identifier
|
||||
#[schema(example = "github-integration")]
|
||||
pub owner: Option<String>,
|
||||
|
||||
/// Owner identity ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_identity: Option<Id>,
|
||||
|
||||
/// Owner pack ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_pack: Option<Id>,
|
||||
|
||||
/// Owner pack reference
|
||||
#[schema(example = "github")]
|
||||
pub owner_pack_ref: Option<String>,
|
||||
|
||||
/// Owner action ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_action: Option<Id>,
|
||||
|
||||
/// Owner action reference
|
||||
#[schema(example = "github.create_issue")]
|
||||
pub owner_action_ref: Option<String>,
|
||||
|
||||
/// Owner sensor ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_sensor: Option<Id>,
|
||||
|
||||
/// Owner sensor reference
|
||||
#[schema(example = "github.webhook")]
|
||||
pub owner_sensor_ref: Option<String>,
|
||||
|
||||
/// Human-readable name
|
||||
#[schema(example = "GitHub API Token")]
|
||||
pub name: String,
|
||||
|
||||
/// Whether the value is encrypted
|
||||
#[schema(example = true)]
|
||||
pub encrypted: bool,
|
||||
|
||||
/// The secret value (decrypted if encrypted)
|
||||
#[schema(example = "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")]
|
||||
pub value: String,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Key> for KeyResponse {
|
||||
fn from(key: Key) -> Self {
|
||||
Self {
|
||||
id: key.id,
|
||||
r#ref: key.r#ref,
|
||||
owner_type: key.owner_type,
|
||||
owner: key.owner,
|
||||
owner_identity: key.owner_identity,
|
||||
owner_pack: key.owner_pack,
|
||||
owner_pack_ref: key.owner_pack_ref,
|
||||
owner_action: key.owner_action,
|
||||
owner_action_ref: key.owner_action_ref,
|
||||
owner_sensor: key.owner_sensor,
|
||||
owner_sensor_ref: key.owner_sensor_ref,
|
||||
name: key.name,
|
||||
encrypted: key.encrypted,
|
||||
value: key.value,
|
||||
created: key.created,
|
||||
updated: key.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary key response for list views (value redacted)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct KeySummary {
|
||||
/// Unique key ID
|
||||
#[schema(example = 1)]
|
||||
pub id: Id,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "github_token")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Type of owner
|
||||
pub owner_type: OwnerType,
|
||||
|
||||
/// Owner identifier
|
||||
#[schema(example = "github-integration")]
|
||||
pub owner: Option<String>,
|
||||
|
||||
/// Human-readable name
|
||||
#[schema(example = "GitHub API Token")]
|
||||
pub name: String,
|
||||
|
||||
/// Whether the value is encrypted
|
||||
#[schema(example = true)]
|
||||
pub encrypted: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<Key> for KeySummary {
|
||||
fn from(key: Key) -> Self {
|
||||
Self {
|
||||
id: key.id,
|
||||
r#ref: key.r#ref,
|
||||
owner_type: key.owner_type,
|
||||
owner: key.owner,
|
||||
name: key.name,
|
||||
encrypted: key.encrypted,
|
||||
created: key.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to create a new key/secret
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateKeyRequest {
|
||||
/// Unique reference for the key (e.g., "github_token", "aws_secret_key")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "github_token")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Type of owner (system, identity, pack, action, sensor)
|
||||
pub owner_type: OwnerType,
|
||||
|
||||
/// Optional owner string identifier
|
||||
#[validate(length(max = 255))]
|
||||
#[schema(example = "github-integration")]
|
||||
pub owner: Option<String>,
|
||||
|
||||
/// Optional owner identity ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_identity: Option<Id>,
|
||||
|
||||
/// Optional owner pack ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_pack: Option<Id>,
|
||||
|
||||
/// Optional owner pack reference
|
||||
#[validate(length(max = 255))]
|
||||
#[schema(example = "github")]
|
||||
pub owner_pack_ref: Option<String>,
|
||||
|
||||
/// Optional owner action ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_action: Option<Id>,
|
||||
|
||||
/// Optional owner action reference
|
||||
#[validate(length(max = 255))]
|
||||
#[schema(example = "github.create_issue")]
|
||||
pub owner_action_ref: Option<String>,
|
||||
|
||||
/// Optional owner sensor ID
|
||||
#[schema(example = 1)]
|
||||
pub owner_sensor: Option<Id>,
|
||||
|
||||
/// Optional owner sensor reference
|
||||
#[validate(length(max = 255))]
|
||||
#[schema(example = "github.webhook")]
|
||||
pub owner_sensor_ref: Option<String>,
|
||||
|
||||
/// Human-readable name for the key
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "GitHub API Token")]
|
||||
pub name: String,
|
||||
|
||||
/// The secret value to store
|
||||
#[validate(length(min = 1, max = 10000))]
|
||||
#[schema(example = "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")]
|
||||
pub value: String,
|
||||
|
||||
/// Whether to encrypt the value (recommended: true)
|
||||
#[serde(default = "default_encrypted")]
|
||||
#[schema(example = true)]
|
||||
pub encrypted: bool,
|
||||
}
|
||||
|
||||
fn default_encrypted() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Request to update an existing key/secret
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateKeyRequest {
|
||||
/// Update the human-readable name
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "GitHub API Token (Updated)")]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Update the secret value
|
||||
#[validate(length(min = 1, max = 10000))]
|
||||
#[schema(example = "ghp_new_token_xxxxxxxxxxxxxxxxxxxxxxxx")]
|
||||
pub value: Option<String>,
|
||||
|
||||
/// Update encryption status (re-encrypts if changing from false to true)
|
||||
#[schema(example = true)]
|
||||
pub encrypted: Option<bool>,
|
||||
}
|
||||
|
||||
/// Query parameters for filtering keys
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, IntoParams)]
|
||||
pub struct KeyQueryParams {
|
||||
/// Filter by owner type
|
||||
#[param(example = "pack")]
|
||||
pub owner_type: Option<OwnerType>,
|
||||
|
||||
/// Filter by owner string
|
||||
#[param(example = "github-integration")]
|
||||
pub owner: Option<String>,
|
||||
|
||||
/// Page number (1-indexed)
|
||||
#[serde(default = "default_page")]
|
||||
#[param(example = 1, minimum = 1)]
|
||||
pub page: u32,
|
||||
|
||||
/// Items per page
|
||||
#[serde(default = "default_per_page")]
|
||||
#[param(example = 50, minimum = 1, maximum = 100)]
|
||||
pub per_page: u32,
|
||||
}
|
||||
|
||||
fn default_page() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_per_page() -> u32 {
|
||||
50
|
||||
}
|
||||
|
||||
impl KeyQueryParams {
|
||||
/// Get the offset for pagination
|
||||
pub fn offset(&self) -> u32 {
|
||||
(self.page - 1) * self.per_page
|
||||
}
|
||||
|
||||
/// Get the limit for pagination
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.per_page
|
||||
}
|
||||
}
|
||||
44
crates/api/src/dto/mod.rs
Normal file
44
crates/api/src/dto/mod.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
//! Data Transfer Objects (DTOs) for API requests and responses
|
||||
|
||||
pub mod action;
|
||||
pub mod auth;
|
||||
pub mod common;
|
||||
pub mod event;
|
||||
pub mod execution;
|
||||
pub mod inquiry;
|
||||
pub mod key;
|
||||
pub mod pack;
|
||||
pub mod rule;
|
||||
pub mod trigger;
|
||||
pub mod webhook;
|
||||
pub mod workflow;
|
||||
|
||||
pub use action::{ActionResponse, ActionSummary, CreateActionRequest, UpdateActionRequest};
|
||||
pub use auth::{
|
||||
ChangePasswordRequest, CurrentUserResponse, LoginRequest, RefreshTokenRequest, RegisterRequest,
|
||||
TokenResponse,
|
||||
};
|
||||
pub use common::{
|
||||
ApiResponse, PaginatedResponse, PaginationMeta, PaginationParams, SuccessResponse,
|
||||
};
|
||||
pub use event::{
|
||||
EnforcementQueryParams, EnforcementResponse, EnforcementSummary, EventQueryParams,
|
||||
EventResponse, EventSummary,
|
||||
};
|
||||
pub use execution::{CreateExecutionRequest, ExecutionQueryParams, ExecutionResponse, ExecutionSummary};
|
||||
pub use inquiry::{
|
||||
CreateInquiryRequest, InquiryQueryParams, InquiryRespondRequest, InquiryResponse,
|
||||
InquirySummary, UpdateInquiryRequest,
|
||||
};
|
||||
pub use key::{CreateKeyRequest, KeyQueryParams, KeyResponse, KeySummary, UpdateKeyRequest};
|
||||
pub use pack::{CreatePackRequest, PackResponse, PackSummary, UpdatePackRequest};
|
||||
pub use rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest};
|
||||
pub use trigger::{
|
||||
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary, TriggerResponse,
|
||||
TriggerSummary, UpdateSensorRequest, UpdateTriggerRequest,
|
||||
};
|
||||
pub use webhook::{WebhookReceiverRequest, WebhookReceiverResponse};
|
||||
pub use workflow::{
|
||||
CreateWorkflowRequest, UpdateWorkflowRequest, WorkflowResponse, WorkflowSearchParams,
|
||||
WorkflowSummary,
|
||||
};
|
||||
381
crates/api/src/dto/pack.rs
Normal file
381
crates/api/src/dto/pack.rs
Normal file
@@ -0,0 +1,381 @@
|
||||
//! Pack DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
/// Request DTO for creating a new pack
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreatePackRequest {
|
||||
/// Unique reference identifier (e.g., "core", "aws", "slack")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Slack Integration")]
|
||||
pub label: String,
|
||||
|
||||
/// Pack description
|
||||
#[schema(example = "Integration with Slack for messaging and notifications")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Pack version (semver format recommended)
|
||||
#[validate(length(min = 1, max = 50))]
|
||||
#[schema(example = "1.0.0")]
|
||||
pub version: String,
|
||||
|
||||
/// Configuration schema (JSON Schema)
|
||||
#[serde(default = "default_empty_object")]
|
||||
#[schema(value_type = Object, example = json!({"type": "object", "properties": {"api_token": {"type": "string"}}}))]
|
||||
pub conf_schema: JsonValue,
|
||||
|
||||
/// Pack configuration values
|
||||
#[serde(default = "default_empty_object")]
|
||||
#[schema(value_type = Object, example = json!({"api_token": "xoxb-..."}))]
|
||||
pub config: JsonValue,
|
||||
|
||||
/// Pack metadata
|
||||
#[serde(default = "default_empty_object")]
|
||||
#[schema(value_type = Object, example = json!({"author": "Attune Team"}))]
|
||||
pub meta: JsonValue,
|
||||
|
||||
/// Tags for categorization
|
||||
#[serde(default)]
|
||||
#[schema(example = json!(["messaging", "collaboration"]))]
|
||||
pub tags: Vec<String>,
|
||||
|
||||
/// Runtime dependencies (refs of required packs)
|
||||
#[serde(default)]
|
||||
#[schema(example = json!(["core"]))]
|
||||
pub runtime_deps: Vec<String>,
|
||||
|
||||
/// Whether this is a standard/built-in pack
|
||||
#[serde(default)]
|
||||
#[schema(example = false)]
|
||||
pub is_standard: bool,
|
||||
}
|
||||
|
||||
/// Request DTO for registering a pack from local filesystem
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct RegisterPackRequest {
|
||||
/// Local filesystem path to the pack directory
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "/path/to/packs/mypack")]
|
||||
pub path: String,
|
||||
|
||||
/// Skip running pack tests during registration
|
||||
#[serde(default)]
|
||||
#[schema(example = false)]
|
||||
pub skip_tests: bool,
|
||||
|
||||
/// Force registration even if tests fail
|
||||
#[serde(default)]
|
||||
#[schema(example = false)]
|
||||
pub force: bool,
|
||||
}
|
||||
|
||||
/// Request DTO for installing a pack from remote source
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct InstallPackRequest {
|
||||
/// Repository URL or source location
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "https://github.com/attune/pack-slack.git")]
|
||||
pub source: String,
|
||||
|
||||
/// Git branch, tag, or commit reference
|
||||
#[schema(example = "main")]
|
||||
pub ref_spec: Option<String>,
|
||||
|
||||
/// Force reinstall if pack already exists
|
||||
#[serde(default)]
|
||||
#[schema(example = false)]
|
||||
pub force: bool,
|
||||
|
||||
/// Skip running pack tests during installation
|
||||
#[serde(default)]
|
||||
#[schema(example = false)]
|
||||
pub skip_tests: bool,
|
||||
|
||||
/// Skip dependency validation (not recommended)
|
||||
#[serde(default)]
|
||||
#[schema(example = false)]
|
||||
pub skip_deps: bool,
|
||||
}
|
||||
|
||||
/// Response for pack install/register operations with test results
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PackInstallResponse {
|
||||
/// The installed/registered pack
|
||||
pub pack: PackResponse,
|
||||
|
||||
/// Test execution result (if tests were run)
|
||||
pub test_result: Option<attune_common::models::pack_test::PackTestResult>,
|
||||
|
||||
/// Whether tests were skipped
|
||||
pub tests_skipped: bool,
|
||||
}
|
||||
|
||||
/// Request DTO for updating a pack
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdatePackRequest {
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Slack Integration v2")]
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Pack description
|
||||
#[schema(example = "Enhanced Slack integration with new features")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Pack version
|
||||
#[validate(length(min = 1, max = 50))]
|
||||
#[schema(example = "2.0.0")]
|
||||
pub version: Option<String>,
|
||||
|
||||
/// Configuration schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub conf_schema: Option<JsonValue>,
|
||||
|
||||
/// Pack configuration values
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub config: Option<JsonValue>,
|
||||
|
||||
/// Pack metadata
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub meta: Option<JsonValue>,
|
||||
|
||||
/// Tags for categorization
|
||||
#[schema(example = json!(["messaging", "collaboration", "webhooks"]))]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Runtime dependencies
|
||||
#[schema(example = json!(["core", "http"]))]
|
||||
pub runtime_deps: Option<Vec<String>>,
|
||||
|
||||
/// Whether this is a standard pack
|
||||
#[schema(example = false)]
|
||||
pub is_standard: Option<bool>,
|
||||
}
|
||||
|
||||
/// Response DTO for pack information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PackResponse {
|
||||
/// Pack ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Slack Integration")]
|
||||
pub label: String,
|
||||
|
||||
/// Pack description
|
||||
#[schema(example = "Integration with Slack for messaging and notifications")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Pack version
|
||||
#[schema(example = "1.0.0")]
|
||||
pub version: String,
|
||||
|
||||
/// Configuration schema
|
||||
#[schema(value_type = Object)]
|
||||
pub conf_schema: JsonValue,
|
||||
|
||||
/// Pack configuration
|
||||
#[schema(value_type = Object)]
|
||||
pub config: JsonValue,
|
||||
|
||||
/// Pack metadata
|
||||
#[schema(value_type = Object)]
|
||||
pub meta: JsonValue,
|
||||
|
||||
/// Tags
|
||||
#[schema(example = json!(["messaging", "collaboration"]))]
|
||||
pub tags: Vec<String>,
|
||||
|
||||
/// Runtime dependencies
|
||||
#[schema(example = json!(["core"]))]
|
||||
pub runtime_deps: Vec<String>,
|
||||
|
||||
/// Is standard pack
|
||||
#[schema(example = false)]
|
||||
pub is_standard: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified pack response (for list endpoints)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PackSummary {
|
||||
/// Pack ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Slack Integration")]
|
||||
pub label: String,
|
||||
|
||||
/// Pack description
|
||||
#[schema(example = "Integration with Slack for messaging and notifications")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Pack version
|
||||
#[schema(example = "1.0.0")]
|
||||
pub version: String,
|
||||
|
||||
/// Tags
|
||||
#[schema(example = json!(["messaging", "collaboration"]))]
|
||||
pub tags: Vec<String>,
|
||||
|
||||
/// Is standard pack
|
||||
#[schema(example = false)]
|
||||
pub is_standard: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Convert from Pack model to PackResponse
|
||||
impl From<attune_common::models::Pack> for PackResponse {
|
||||
fn from(pack: attune_common::models::Pack) -> Self {
|
||||
Self {
|
||||
id: pack.id,
|
||||
r#ref: pack.r#ref,
|
||||
label: pack.label,
|
||||
description: pack.description,
|
||||
version: pack.version,
|
||||
conf_schema: pack.conf_schema,
|
||||
config: pack.config,
|
||||
meta: pack.meta,
|
||||
tags: pack.tags,
|
||||
runtime_deps: pack.runtime_deps,
|
||||
is_standard: pack.is_standard,
|
||||
created: pack.created,
|
||||
updated: pack.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Pack model to PackSummary
|
||||
impl From<attune_common::models::Pack> for PackSummary {
|
||||
fn from(pack: attune_common::models::Pack) -> Self {
|
||||
Self {
|
||||
id: pack.id,
|
||||
r#ref: pack.r#ref,
|
||||
label: pack.label,
|
||||
description: pack.description,
|
||||
version: pack.version,
|
||||
tags: pack.tags,
|
||||
is_standard: pack.is_standard,
|
||||
created: pack.created,
|
||||
updated: pack.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response for pack workflow sync operation
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PackWorkflowSyncResponse {
|
||||
/// Pack reference
|
||||
pub pack_ref: String,
|
||||
/// Number of workflows loaded from filesystem
|
||||
pub loaded_count: usize,
|
||||
/// Number of workflows registered/updated in database
|
||||
pub registered_count: usize,
|
||||
/// Individual workflow registration results
|
||||
pub workflows: Vec<WorkflowSyncResult>,
|
||||
/// Any errors encountered during sync
|
||||
pub errors: Vec<String>,
|
||||
}
|
||||
|
||||
/// Individual workflow sync result
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct WorkflowSyncResult {
|
||||
/// Workflow reference name
|
||||
pub ref_name: String,
|
||||
/// Whether the workflow was created (false = updated)
|
||||
pub created: bool,
|
||||
/// Workflow definition ID
|
||||
pub workflow_def_id: i64,
|
||||
/// Any warnings during registration
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
/// Response for pack workflow validation operation
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PackWorkflowValidationResponse {
|
||||
/// Pack reference
|
||||
pub pack_ref: String,
|
||||
/// Number of workflows validated
|
||||
pub validated_count: usize,
|
||||
/// Number of workflows with errors
|
||||
pub error_count: usize,
|
||||
/// Validation errors by workflow reference
|
||||
pub errors: std::collections::HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
fn default_empty_object() -> JsonValue {
|
||||
serde_json::json!({})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_pack_request_defaults() {
|
||||
let json = r#"{
|
||||
"ref": "test-pack",
|
||||
"label": "Test Pack",
|
||||
"version": "1.0.0"
|
||||
}"#;
|
||||
|
||||
let req: CreatePackRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(req.r#ref, "test-pack");
|
||||
assert_eq!(req.label, "Test Pack");
|
||||
assert_eq!(req.version, "1.0.0");
|
||||
assert!(req.tags.is_empty());
|
||||
assert!(req.runtime_deps.is_empty());
|
||||
assert!(!req.is_standard);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_pack_request_validation() {
|
||||
let req = CreatePackRequest {
|
||||
r#ref: "".to_string(), // Invalid: empty
|
||||
label: "Test".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
description: None,
|
||||
conf_schema: default_empty_object(),
|
||||
config: default_empty_object(),
|
||||
meta: default_empty_object(),
|
||||
tags: vec![],
|
||||
runtime_deps: vec![],
|
||||
is_standard: false,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
}
|
||||
363
crates/api/src/dto/rule.rs
Normal file
363
crates/api/src/dto/rule.rs
Normal file
@@ -0,0 +1,363 @@
|
||||
//! Rule DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
/// Request DTO for creating a new rule
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateRuleRequest {
|
||||
/// Unique reference identifier (e.g., "mypack.notify_on_error")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack.notify_on_error")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference this rule belongs to
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Notify on Error")]
|
||||
pub label: String,
|
||||
|
||||
/// Rule description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Send Slack notification when an error occurs")]
|
||||
pub description: String,
|
||||
|
||||
/// Action reference to execute when rule matches
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub action_ref: String,
|
||||
|
||||
/// Trigger reference that activates this rule
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "system.error_event")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Conditions for rule evaluation (JSON Logic or custom format)
|
||||
#[serde(default = "default_empty_object")]
|
||||
#[schema(value_type = Object, example = json!({"var": "event.severity", ">=": 3}))]
|
||||
pub conditions: JsonValue,
|
||||
|
||||
/// Parameters to pass to the action when rule is triggered
|
||||
#[serde(default = "default_empty_object")]
|
||||
#[schema(value_type = Object, example = json!({"message": "hello, world"}))]
|
||||
pub action_params: JsonValue,
|
||||
|
||||
/// Parameters for trigger configuration and event filtering
|
||||
#[serde(default = "default_empty_object")]
|
||||
#[schema(value_type = Object, example = json!({"severity": "high"}))]
|
||||
pub trigger_params: JsonValue,
|
||||
|
||||
/// Whether the rule is enabled
|
||||
#[serde(default = "default_true")]
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Request DTO for updating a rule
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateRuleRequest {
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Notify on Error (Updated)")]
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Rule description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Enhanced error notification with filtering")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Conditions for rule evaluation
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub conditions: Option<JsonValue>,
|
||||
|
||||
/// Parameters to pass to the action when rule is triggered
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub action_params: Option<JsonValue>,
|
||||
|
||||
/// Parameters for trigger configuration and event filtering
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub trigger_params: Option<JsonValue>,
|
||||
|
||||
/// Whether the rule is enabled
|
||||
#[schema(example = false)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Response DTO for rule information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct RuleResponse {
|
||||
/// Rule ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack.notify_on_error")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack ID
|
||||
#[schema(example = 1)]
|
||||
pub pack: i64,
|
||||
|
||||
/// Pack reference
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Notify on Error")]
|
||||
pub label: String,
|
||||
|
||||
/// Rule description
|
||||
#[schema(example = "Send Slack notification when an error occurs")]
|
||||
pub description: String,
|
||||
|
||||
/// Action ID
|
||||
#[schema(example = 1)]
|
||||
pub action: i64,
|
||||
|
||||
/// Action reference
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub action_ref: String,
|
||||
|
||||
/// Trigger ID
|
||||
#[schema(example = 1)]
|
||||
pub trigger: i64,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "system.error_event")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Conditions for rule evaluation
|
||||
#[schema(value_type = Object)]
|
||||
pub conditions: JsonValue,
|
||||
|
||||
/// Parameters to pass to the action when rule is triggered
|
||||
#[schema(value_type = Object)]
|
||||
pub action_params: JsonValue,
|
||||
|
||||
/// Parameters for trigger configuration and event filtering
|
||||
#[schema(value_type = Object)]
|
||||
pub trigger_params: JsonValue,
|
||||
|
||||
/// Whether the rule is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Whether this is an ad-hoc rule (not from pack installation)
|
||||
#[schema(example = false)]
|
||||
pub is_adhoc: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified rule response (for list endpoints)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct RuleSummary {
|
||||
/// Rule ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack.notify_on_error")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Notify on Error")]
|
||||
pub label: String,
|
||||
|
||||
/// Rule description
|
||||
#[schema(example = "Send Slack notification when an error occurs")]
|
||||
pub description: String,
|
||||
|
||||
/// Action reference
|
||||
#[schema(example = "slack.post_message")]
|
||||
pub action_ref: String,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "system.error_event")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Parameters to pass to the action when rule is triggered
|
||||
#[schema(value_type = Object)]
|
||||
pub action_params: JsonValue,
|
||||
|
||||
/// Parameters for trigger configuration and event filtering
|
||||
#[schema(value_type = Object)]
|
||||
pub trigger_params: JsonValue,
|
||||
|
||||
/// Whether the rule is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Convert from Rule model to RuleResponse
|
||||
impl From<attune_common::models::rule::Rule> for RuleResponse {
|
||||
fn from(rule: attune_common::models::rule::Rule) -> Self {
|
||||
Self {
|
||||
id: rule.id,
|
||||
r#ref: rule.r#ref,
|
||||
pack: rule.pack,
|
||||
pack_ref: rule.pack_ref,
|
||||
label: rule.label,
|
||||
description: rule.description,
|
||||
action: rule.action,
|
||||
action_ref: rule.action_ref,
|
||||
trigger: rule.trigger,
|
||||
trigger_ref: rule.trigger_ref,
|
||||
conditions: rule.conditions,
|
||||
action_params: rule.action_params,
|
||||
trigger_params: rule.trigger_params,
|
||||
enabled: rule.enabled,
|
||||
is_adhoc: rule.is_adhoc,
|
||||
created: rule.created,
|
||||
updated: rule.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Rule model to RuleSummary
|
||||
impl From<attune_common::models::rule::Rule> for RuleSummary {
|
||||
fn from(rule: attune_common::models::rule::Rule) -> Self {
|
||||
Self {
|
||||
id: rule.id,
|
||||
r#ref: rule.r#ref,
|
||||
pack_ref: rule.pack_ref,
|
||||
label: rule.label,
|
||||
description: rule.description,
|
||||
action_ref: rule.action_ref,
|
||||
trigger_ref: rule.trigger_ref,
|
||||
action_params: rule.action_params,
|
||||
trigger_params: rule.trigger_params,
|
||||
enabled: rule.enabled,
|
||||
created: rule.created,
|
||||
updated: rule.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_empty_object() -> JsonValue {
|
||||
serde_json::json!({})
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_rule_request_defaults() {
|
||||
let json = r#"{
|
||||
"ref": "test-rule",
|
||||
"pack_ref": "test-pack",
|
||||
"label": "Test Rule",
|
||||
"description": "Test description",
|
||||
"action_ref": "test.action",
|
||||
"trigger_ref": "test.trigger"
|
||||
}"#;
|
||||
|
||||
let req: CreateRuleRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(req.r#ref, "test-rule");
|
||||
assert_eq!(req.label, "Test Rule");
|
||||
assert_eq!(req.action_ref, "test.action");
|
||||
assert_eq!(req.trigger_ref, "test.trigger");
|
||||
assert!(req.enabled);
|
||||
assert_eq!(req.conditions, serde_json::json!({}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_rule_request_validation() {
|
||||
let req = CreateRuleRequest {
|
||||
r#ref: "".to_string(), // Invalid: empty
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Rule".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
action_ref: "test.action".to_string(),
|
||||
trigger_ref: "test.trigger".to_string(),
|
||||
conditions: default_empty_object(),
|
||||
action_params: default_empty_object(),
|
||||
trigger_params: default_empty_object(),
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_rule_request_valid() {
|
||||
let req = CreateRuleRequest {
|
||||
r#ref: "test.rule".to_string(),
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Rule".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
action_ref: "test.action".to_string(),
|
||||
trigger_ref: "test.trigger".to_string(),
|
||||
conditions: serde_json::json!({
|
||||
"and": [
|
||||
{"var": "event.status", "==": "error"},
|
||||
{"var": "event.severity", ">": 3}
|
||||
]
|
||||
}),
|
||||
action_params: default_empty_object(),
|
||||
trigger_params: default_empty_object(),
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_rule_request_all_none() {
|
||||
let req = UpdateRuleRequest {
|
||||
label: None,
|
||||
description: None,
|
||||
conditions: None,
|
||||
action_params: None,
|
||||
trigger_params: None,
|
||||
enabled: None,
|
||||
};
|
||||
|
||||
// Should be valid even with all None values
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_rule_request_partial() {
|
||||
let req = UpdateRuleRequest {
|
||||
label: Some("Updated Rule".to_string()),
|
||||
description: None,
|
||||
conditions: Some(serde_json::json!({"var": "status", "==": "ok"})),
|
||||
action_params: None,
|
||||
trigger_params: None,
|
||||
enabled: Some(false),
|
||||
};
|
||||
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
}
|
||||
519
crates/api/src/dto/trigger.rs
Normal file
519
crates/api/src/dto/trigger.rs
Normal file
@@ -0,0 +1,519 @@
|
||||
//! Trigger and Sensor DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
/// Request DTO for creating a new trigger
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateTriggerRequest {
|
||||
/// Unique reference identifier (e.g., "core.webhook", "system.timer")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "core.webhook")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Optional pack reference this trigger belongs to
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "core")]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Webhook Trigger")]
|
||||
pub label: String,
|
||||
|
||||
/// Trigger description
|
||||
#[schema(example = "Triggers when a webhook is received")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Parameter schema (JSON Schema) defining event payload structure
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Object, nullable = true, example = json!({"type": "object", "properties": {"url": {"type": "string"}}}))]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema (JSON Schema) defining event data structure
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Object, nullable = true, example = json!({"type": "object", "properties": {"payload": {"type": "object"}}}))]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
|
||||
/// Whether the trigger is enabled
|
||||
#[serde(default = "default_true")]
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Request DTO for updating a trigger
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateTriggerRequest {
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Webhook Trigger (Updated)")]
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Trigger description
|
||||
#[schema(example = "Updated webhook trigger description")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
|
||||
/// Whether the trigger is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Response DTO for trigger information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct TriggerResponse {
|
||||
/// Trigger ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "core.webhook")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack ID (optional)
|
||||
#[schema(example = 1)]
|
||||
pub pack: Option<i64>,
|
||||
|
||||
/// Pack reference (optional)
|
||||
#[schema(example = "core")]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Webhook Trigger")]
|
||||
pub label: String,
|
||||
|
||||
/// Trigger description
|
||||
#[schema(example = "Triggers when a webhook is received")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Whether the trigger is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
|
||||
/// Whether webhooks are enabled for this trigger
|
||||
#[schema(example = false)]
|
||||
pub webhook_enabled: bool,
|
||||
|
||||
/// Webhook key (only present if webhooks are enabled)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "wh_k7j2n9p4m8q1r5w3x6z0a2b5c8d1e4f7g9h2")]
|
||||
pub webhook_key: Option<String>,
|
||||
|
||||
/// Whether this is an ad-hoc trigger (not from pack installation)
|
||||
#[schema(example = false)]
|
||||
pub is_adhoc: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified trigger response (for list endpoints)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct TriggerSummary {
|
||||
/// Trigger ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "core.webhook")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference (optional)
|
||||
#[schema(example = "core")]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Webhook Trigger")]
|
||||
pub label: String,
|
||||
|
||||
/// Trigger description
|
||||
#[schema(example = "Triggers when a webhook is received")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Whether the trigger is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Whether webhooks are enabled for this trigger
|
||||
#[schema(example = false)]
|
||||
pub webhook_enabled: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Request DTO for creating a new sensor
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateSensorRequest {
|
||||
/// Unique reference identifier (e.g., "mypack.cpu_monitor")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "monitoring.cpu_sensor")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference this sensor belongs to
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "monitoring")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "CPU Monitoring Sensor")]
|
||||
pub label: String,
|
||||
|
||||
/// Sensor description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Monitors CPU usage and generates events")]
|
||||
pub description: String,
|
||||
|
||||
/// Entry point for sensor execution (e.g., path to script, function name)
|
||||
#[validate(length(min = 1, max = 1024))]
|
||||
#[schema(example = "/sensors/monitoring/cpu_monitor.py")]
|
||||
pub entrypoint: String,
|
||||
|
||||
/// Runtime reference for this sensor
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "python3")]
|
||||
pub runtime_ref: String,
|
||||
|
||||
/// Trigger reference this sensor monitors for
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "monitoring.cpu_threshold")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Parameter schema (JSON Schema) for sensor configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Object, nullable = true, example = json!({"type": "object", "properties": {"threshold": {"type": "number"}}}))]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Configuration values for this sensor instance (conforms to param_schema)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Object, nullable = true, example = json!({"interval": 60, "threshold": 80}))]
|
||||
pub config: Option<JsonValue>,
|
||||
|
||||
/// Whether the sensor is enabled
|
||||
#[serde(default = "default_true")]
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Request DTO for updating a sensor
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateSensorRequest {
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "CPU Monitoring Sensor (Updated)")]
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Sensor description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Enhanced CPU monitoring with alerts")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Entry point for sensor execution
|
||||
#[validate(length(min = 1, max = 1024))]
|
||||
#[schema(example = "/sensors/monitoring/cpu_monitor_v2.py")]
|
||||
pub entrypoint: Option<String>,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Whether the sensor is enabled
|
||||
#[schema(example = false)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Response DTO for sensor information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct SensorResponse {
|
||||
/// Sensor ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "monitoring.cpu_sensor")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack ID (optional)
|
||||
#[schema(example = 1)]
|
||||
pub pack: Option<i64>,
|
||||
|
||||
/// Pack reference (optional)
|
||||
#[schema(example = "monitoring")]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "CPU Monitoring Sensor")]
|
||||
pub label: String,
|
||||
|
||||
/// Sensor description
|
||||
#[schema(example = "Monitors CPU usage and generates events")]
|
||||
pub description: String,
|
||||
|
||||
/// Entry point
|
||||
#[schema(example = "/sensors/monitoring/cpu_monitor.py")]
|
||||
pub entrypoint: String,
|
||||
|
||||
/// Runtime ID
|
||||
#[schema(example = 1)]
|
||||
pub runtime: i64,
|
||||
|
||||
/// Runtime reference
|
||||
#[schema(example = "python3")]
|
||||
pub runtime_ref: String,
|
||||
|
||||
/// Trigger ID
|
||||
#[schema(example = 1)]
|
||||
pub trigger: i64,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "monitoring.cpu_threshold")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Whether the sensor is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified sensor response (for list endpoints)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct SensorSummary {
|
||||
/// Sensor ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "monitoring.cpu_sensor")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference (optional)
|
||||
#[schema(example = "monitoring")]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "CPU Monitoring Sensor")]
|
||||
pub label: String,
|
||||
|
||||
/// Sensor description
|
||||
#[schema(example = "Monitors CPU usage and generates events")]
|
||||
pub description: String,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "monitoring.cpu_threshold")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Whether the sensor is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Convert from Trigger model to TriggerResponse
|
||||
impl From<attune_common::models::trigger::Trigger> for TriggerResponse {
|
||||
fn from(trigger: attune_common::models::trigger::Trigger) -> Self {
|
||||
Self {
|
||||
id: trigger.id,
|
||||
r#ref: trigger.r#ref,
|
||||
pack: trigger.pack,
|
||||
pack_ref: trigger.pack_ref,
|
||||
label: trigger.label,
|
||||
description: trigger.description,
|
||||
enabled: trigger.enabled,
|
||||
param_schema: trigger.param_schema,
|
||||
out_schema: trigger.out_schema,
|
||||
webhook_enabled: trigger.webhook_enabled,
|
||||
webhook_key: trigger.webhook_key,
|
||||
is_adhoc: trigger.is_adhoc,
|
||||
created: trigger.created,
|
||||
updated: trigger.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Trigger model to TriggerSummary
|
||||
impl From<attune_common::models::trigger::Trigger> for TriggerSummary {
|
||||
fn from(trigger: attune_common::models::trigger::Trigger) -> Self {
|
||||
Self {
|
||||
id: trigger.id,
|
||||
r#ref: trigger.r#ref,
|
||||
pack_ref: trigger.pack_ref,
|
||||
label: trigger.label,
|
||||
description: trigger.description,
|
||||
enabled: trigger.enabled,
|
||||
webhook_enabled: trigger.webhook_enabled,
|
||||
created: trigger.created,
|
||||
updated: trigger.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Sensor model to SensorResponse
|
||||
impl From<attune_common::models::trigger::Sensor> for SensorResponse {
|
||||
fn from(sensor: attune_common::models::trigger::Sensor) -> Self {
|
||||
Self {
|
||||
id: sensor.id,
|
||||
r#ref: sensor.r#ref,
|
||||
pack: sensor.pack,
|
||||
pack_ref: sensor.pack_ref,
|
||||
label: sensor.label,
|
||||
description: sensor.description,
|
||||
entrypoint: sensor.entrypoint,
|
||||
runtime: sensor.runtime,
|
||||
runtime_ref: sensor.runtime_ref,
|
||||
trigger: sensor.trigger,
|
||||
trigger_ref: sensor.trigger_ref,
|
||||
enabled: sensor.enabled,
|
||||
param_schema: sensor.param_schema,
|
||||
created: sensor.created,
|
||||
updated: sensor.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Sensor model to SensorSummary
|
||||
impl From<attune_common::models::trigger::Sensor> for SensorSummary {
|
||||
fn from(sensor: attune_common::models::trigger::Sensor) -> Self {
|
||||
Self {
|
||||
id: sensor.id,
|
||||
r#ref: sensor.r#ref,
|
||||
pack_ref: sensor.pack_ref,
|
||||
label: sensor.label,
|
||||
description: sensor.description,
|
||||
trigger_ref: sensor.trigger_ref,
|
||||
enabled: sensor.enabled,
|
||||
created: sensor.created,
|
||||
updated: sensor.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_trigger_request_defaults() {
|
||||
let json = r#"{
|
||||
"ref": "test-trigger",
|
||||
"label": "Test Trigger"
|
||||
}"#;
|
||||
|
||||
let req: CreateTriggerRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(req.r#ref, "test-trigger");
|
||||
assert_eq!(req.label, "Test Trigger");
|
||||
assert!(req.enabled);
|
||||
assert!(req.pack_ref.is_none());
|
||||
assert!(req.description.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_trigger_request_validation() {
|
||||
let req = CreateTriggerRequest {
|
||||
r#ref: "".to_string(), // Invalid: empty
|
||||
pack_ref: None,
|
||||
label: "Test Trigger".to_string(),
|
||||
description: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_sensor_request_valid() {
|
||||
let req = CreateSensorRequest {
|
||||
r#ref: "test.sensor".to_string(),
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Sensor".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
entrypoint: "/sensors/test.py".to_string(),
|
||||
runtime_ref: "python3".to_string(),
|
||||
trigger_ref: "test.trigger".to_string(),
|
||||
param_schema: None,
|
||||
config: None,
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_trigger_request_all_none() {
|
||||
let req = UpdateTriggerRequest {
|
||||
label: None,
|
||||
description: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
enabled: None,
|
||||
};
|
||||
|
||||
// Should be valid even with all None values
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_sensor_request_partial() {
|
||||
let req = UpdateSensorRequest {
|
||||
label: Some("Updated Sensor".to_string()),
|
||||
description: None,
|
||||
entrypoint: Some("/sensors/test_v2.py".to_string()),
|
||||
param_schema: None,
|
||||
enabled: Some(false),
|
||||
};
|
||||
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
}
|
||||
41
crates/api/src/dto/webhook.rs
Normal file
41
crates/api/src/dto/webhook.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
//! Webhook-related DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
/// Request body for webhook receiver endpoint
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct WebhookReceiverRequest {
|
||||
/// Webhook payload (arbitrary JSON)
|
||||
pub payload: JsonValue,
|
||||
|
||||
/// Optional headers from the webhook request (for logging/debugging)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<JsonValue>,
|
||||
|
||||
/// Optional source IP address
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub source_ip: Option<String>,
|
||||
|
||||
/// Optional user agent
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user_agent: Option<String>,
|
||||
}
|
||||
|
||||
/// Response from webhook receiver endpoint
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct WebhookReceiverResponse {
|
||||
/// ID of the event created from this webhook
|
||||
pub event_id: i64,
|
||||
|
||||
/// Reference of the trigger that received this webhook
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Timestamp when the webhook was received
|
||||
pub received_at: DateTime<Utc>,
|
||||
|
||||
/// Success message
|
||||
pub message: String,
|
||||
}
|
||||
327
crates/api/src/dto/workflow.rs
Normal file
327
crates/api/src/dto/workflow.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
//! Workflow DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
use validator::Validate;
|
||||
|
||||
/// Request DTO for creating a new workflow
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateWorkflowRequest {
|
||||
/// Unique reference identifier (e.g., "core.notify_on_failure", "slack.incident_workflow")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack.incident_workflow")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference this workflow belongs to
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Incident Response Workflow")]
|
||||
pub label: String,
|
||||
|
||||
/// Workflow description
|
||||
#[schema(example = "Automated incident response workflow with notifications and approvals")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Workflow version (semantic versioning recommended)
|
||||
#[validate(length(min = 1, max = 50))]
|
||||
#[schema(example = "1.0.0")]
|
||||
pub version: String,
|
||||
|
||||
/// Parameter schema (JSON Schema) defining expected inputs
|
||||
#[schema(value_type = Object, example = json!({"type": "object", "properties": {"severity": {"type": "string"}, "channel": {"type": "string"}}}))]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema (JSON Schema) defining expected outputs
|
||||
#[schema(value_type = Object, example = json!({"type": "object", "properties": {"incident_id": {"type": "string"}}}))]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
|
||||
/// Workflow definition (complete workflow YAML structure as JSON)
|
||||
#[schema(value_type = Object)]
|
||||
pub definition: JsonValue,
|
||||
|
||||
/// Tags for categorization and search
|
||||
#[schema(example = json!(["incident", "slack", "approval"]))]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Request DTO for updating a workflow
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateWorkflowRequest {
|
||||
/// Human-readable label
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Incident Response Workflow (Updated)")]
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Workflow description
|
||||
#[schema(example = "Enhanced incident response workflow with additional automation")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Workflow version
|
||||
#[validate(length(min = 1, max = 50))]
|
||||
#[schema(example = "1.1.0")]
|
||||
pub version: Option<String>,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
|
||||
/// Workflow definition
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub definition: Option<JsonValue>,
|
||||
|
||||
/// Tags
|
||||
#[schema(example = json!(["incident", "slack", "approval", "automation"]))]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Response DTO for workflow information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct WorkflowResponse {
|
||||
/// Workflow ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack.incident_workflow")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack ID
|
||||
#[schema(example = 1)]
|
||||
pub pack: i64,
|
||||
|
||||
/// Pack reference
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Incident Response Workflow")]
|
||||
pub label: String,
|
||||
|
||||
/// Workflow description
|
||||
#[schema(example = "Automated incident response workflow with notifications and approvals")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Workflow version
|
||||
#[schema(example = "1.0.0")]
|
||||
pub version: String,
|
||||
|
||||
/// Parameter schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
|
||||
/// Output schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
|
||||
/// Workflow definition
|
||||
#[schema(value_type = Object)]
|
||||
pub definition: JsonValue,
|
||||
|
||||
/// Tags
|
||||
#[schema(example = json!(["incident", "slack", "approval"]))]
|
||||
pub tags: Vec<String>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified workflow response (for list endpoints)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct WorkflowSummary {
|
||||
/// Workflow ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Unique reference identifier
|
||||
#[schema(example = "slack.incident_workflow")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Pack reference
|
||||
#[schema(example = "slack")]
|
||||
pub pack_ref: String,
|
||||
|
||||
/// Human-readable label
|
||||
#[schema(example = "Incident Response Workflow")]
|
||||
pub label: String,
|
||||
|
||||
/// Workflow description
|
||||
#[schema(example = "Automated incident response workflow with notifications and approvals")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Workflow version
|
||||
#[schema(example = "1.0.0")]
|
||||
pub version: String,
|
||||
|
||||
/// Tags
|
||||
#[schema(example = json!(["incident", "slack", "approval"]))]
|
||||
pub tags: Vec<String>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Convert from WorkflowDefinition model to WorkflowResponse
|
||||
impl From<attune_common::models::workflow::WorkflowDefinition> for WorkflowResponse {
|
||||
fn from(workflow: attune_common::models::workflow::WorkflowDefinition) -> Self {
|
||||
Self {
|
||||
id: workflow.id,
|
||||
r#ref: workflow.r#ref,
|
||||
pack: workflow.pack,
|
||||
pack_ref: workflow.pack_ref,
|
||||
label: workflow.label,
|
||||
description: workflow.description,
|
||||
version: workflow.version,
|
||||
param_schema: workflow.param_schema,
|
||||
out_schema: workflow.out_schema,
|
||||
definition: workflow.definition,
|
||||
tags: workflow.tags,
|
||||
enabled: workflow.enabled,
|
||||
created: workflow.created,
|
||||
updated: workflow.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from WorkflowDefinition model to WorkflowSummary
|
||||
impl From<attune_common::models::workflow::WorkflowDefinition> for WorkflowSummary {
|
||||
fn from(workflow: attune_common::models::workflow::WorkflowDefinition) -> Self {
|
||||
Self {
|
||||
id: workflow.id,
|
||||
r#ref: workflow.r#ref,
|
||||
pack_ref: workflow.pack_ref,
|
||||
label: workflow.label,
|
||||
description: workflow.description,
|
||||
version: workflow.version,
|
||||
tags: workflow.tags,
|
||||
enabled: workflow.enabled,
|
||||
created: workflow.created,
|
||||
updated: workflow.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query parameters for workflow search and filtering
|
||||
#[derive(Debug, Clone, Deserialize, Validate, IntoParams)]
|
||||
pub struct WorkflowSearchParams {
|
||||
/// Filter by tag(s) - comma-separated list
|
||||
#[param(example = "incident,approval")]
|
||||
pub tags: Option<String>,
|
||||
|
||||
/// Filter by enabled status
|
||||
#[param(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
|
||||
/// Search term for label/description (case-insensitive)
|
||||
#[param(example = "incident")]
|
||||
pub search: Option<String>,
|
||||
|
||||
/// Filter by pack reference
|
||||
#[param(example = "core")]
|
||||
pub pack_ref: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_workflow_request_validation() {
|
||||
let req = CreateWorkflowRequest {
|
||||
r#ref: "".to_string(), // Invalid: empty
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Workflow".to_string(),
|
||||
description: Some("Test description".to_string()),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: serde_json::json!({"tasks": []}),
|
||||
tags: None,
|
||||
enabled: None,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_workflow_request_valid() {
|
||||
let req = CreateWorkflowRequest {
|
||||
r#ref: "test.workflow".to_string(),
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Workflow".to_string(),
|
||||
description: Some("Test description".to_string()),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: serde_json::json!({"tasks": []}),
|
||||
tags: Some(vec!["test".to_string()]),
|
||||
enabled: Some(true),
|
||||
};
|
||||
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_workflow_request_all_none() {
|
||||
let req = UpdateWorkflowRequest {
|
||||
label: None,
|
||||
description: None,
|
||||
version: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
};
|
||||
|
||||
// Should be valid even with all None values
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workflow_search_params() {
|
||||
let params = WorkflowSearchParams {
|
||||
tags: Some("incident,approval".to_string()),
|
||||
enabled: Some(true),
|
||||
search: Some("response".to_string()),
|
||||
pack_ref: Some("core".to_string()),
|
||||
};
|
||||
|
||||
assert!(params.validate().is_ok());
|
||||
}
|
||||
}
|
||||
20
crates/api/src/lib.rs
Normal file
20
crates/api/src/lib.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
//! Attune API Service Library
|
||||
//!
|
||||
//! This library provides the core components of the Attune API service,
|
||||
//! including the server, routing, authentication, and state management.
|
||||
//! It is primarily used by the binary target and integration tests.
|
||||
|
||||
pub mod auth;
|
||||
pub mod dto;
|
||||
pub mod middleware;
|
||||
pub mod openapi;
|
||||
pub mod postgres_listener;
|
||||
pub mod routes;
|
||||
pub mod server;
|
||||
pub mod state;
|
||||
pub mod validation;
|
||||
pub mod webhook_security;
|
||||
|
||||
// Re-export commonly used items for convenience
|
||||
pub use server::Server;
|
||||
pub use state::AppState;
|
||||
151
crates/api/src/main.rs
Normal file
151
crates/api/src/main.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
//! Attune API Service
|
||||
//!
|
||||
//! REST API gateway for all client interactions with the Attune platform.
|
||||
//! Provides endpoints for managing packs, actions, triggers, rules, executions,
|
||||
//! inquiries, and other automation components.
|
||||
|
||||
use anyhow::Result;
|
||||
use attune_common::{
|
||||
config::Config,
|
||||
db::Database,
|
||||
mq::{Connection, Publisher, PublisherConfig},
|
||||
};
|
||||
use clap::Parser;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use attune_api::{postgres_listener, AppState, Server};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "attune-api")]
|
||||
#[command(about = "Attune API Service", long_about = None)]
|
||||
struct Args {
|
||||
/// Path to configuration file
|
||||
#[arg(short, long)]
|
||||
config: Option<String>,
|
||||
|
||||
/// Server host address
|
||||
#[arg(long)]
|
||||
host: Option<String>,
|
||||
|
||||
/// Server port
|
||||
#[arg(long)]
|
||||
port: Option<u16>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing subscriber
|
||||
tracing_subscriber::fmt()
|
||||
.with_target(false)
|
||||
.with_thread_ids(true)
|
||||
.with_level(true)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
info!("Starting Attune API Service");
|
||||
|
||||
// Load configuration
|
||||
if let Some(config_path) = args.config {
|
||||
std::env::set_var("ATTUNE_CONFIG", config_path);
|
||||
}
|
||||
|
||||
let config = Config::load()?;
|
||||
config.validate()?;
|
||||
|
||||
info!("Configuration loaded successfully");
|
||||
info!("Environment: {}", config.environment);
|
||||
info!(
|
||||
"Server will bind to {}:{}",
|
||||
config.server.host, config.server.port
|
||||
);
|
||||
|
||||
// Initialize database connection pool
|
||||
info!("Connecting to database...");
|
||||
let database = Database::new(&config.database).await?;
|
||||
info!("Database connection established");
|
||||
|
||||
// Initialize message queue connection and publisher (optional)
|
||||
let mut state = AppState::new(database.pool().clone(), config.clone());
|
||||
|
||||
if let Some(ref mq_config) = config.message_queue {
|
||||
info!("Connecting to message queue...");
|
||||
match Connection::connect(&mq_config.url).await {
|
||||
Ok(mq_connection) => {
|
||||
info!("Message queue connection established");
|
||||
|
||||
// Create publisher
|
||||
match Publisher::new(
|
||||
&mq_connection,
|
||||
PublisherConfig {
|
||||
confirm_publish: true,
|
||||
timeout_secs: 30,
|
||||
exchange: "attune.executions".to_string(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(publisher) => {
|
||||
info!("Message queue publisher initialized");
|
||||
state = state.with_publisher(Arc::new(publisher));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create publisher: {}", e);
|
||||
warn!("Executions will not be queued for processing");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to message queue: {}", e);
|
||||
warn!("Executions will not be queued for processing");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("Message queue not configured");
|
||||
warn!("Executions will not be queued for processing");
|
||||
}
|
||||
|
||||
info!(
|
||||
"CORS configured with {} allowed origin(s)",
|
||||
if config.server.cors_origins.is_empty() {
|
||||
"default development"
|
||||
} else {
|
||||
"custom"
|
||||
}
|
||||
);
|
||||
|
||||
// Start PostgreSQL listener for SSE broadcasting
|
||||
let broadcast_tx = state.broadcast_tx.clone();
|
||||
let listener_db = database.pool().clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = postgres_listener::start_postgres_listener(listener_db, broadcast_tx).await
|
||||
{
|
||||
tracing::error!("PostgreSQL listener error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
info!("PostgreSQL notification listener started");
|
||||
|
||||
// Create and start server
|
||||
let server = Server::new(std::sync::Arc::new(state));
|
||||
|
||||
info!("Attune API Service is ready");
|
||||
|
||||
// Run server with graceful shutdown
|
||||
tokio::select! {
|
||||
result = server.run() => {
|
||||
if let Err(e) = result {
|
||||
tracing::error!("Server error: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("Received shutdown signal");
|
||||
}
|
||||
}
|
||||
|
||||
info!("Shutting down Attune API Service");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
61
crates/api/src/middleware/cors.rs
Normal file
61
crates/api/src/middleware/cors.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
//! CORS middleware configuration
|
||||
|
||||
use axum::http::{header, HeaderValue, Method};
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
|
||||
/// Create CORS layer configured from allowed origins
|
||||
///
|
||||
/// If no origins are provided, defaults to common development origins.
|
||||
/// Cannot use `allow_origin(Any)` with credentials enabled.
|
||||
pub fn create_cors_layer(allowed_origins: Vec<String>) -> CorsLayer {
|
||||
// Get the list of allowed origins
|
||||
let origins = if allowed_origins.is_empty() {
|
||||
// Default development origins
|
||||
vec![
|
||||
"http://localhost:3000".to_string(),
|
||||
"http://localhost:5173".to_string(),
|
||||
"http://localhost:8080".to_string(),
|
||||
"http://127.0.0.1:3000".to_string(),
|
||||
"http://127.0.0.1:5173".to_string(),
|
||||
"http://127.0.0.1:8080".to_string(),
|
||||
]
|
||||
} else {
|
||||
allowed_origins
|
||||
};
|
||||
|
||||
// Convert origins to HeaderValues for matching
|
||||
let allowed_origin_values: Arc<Vec<HeaderValue>> = Arc::new(
|
||||
origins
|
||||
.iter()
|
||||
.filter_map(|o| o.parse::<HeaderValue>().ok())
|
||||
.collect(),
|
||||
);
|
||||
|
||||
CorsLayer::new()
|
||||
// Allow common HTTP methods
|
||||
.allow_methods([
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::PATCH,
|
||||
Method::OPTIONS,
|
||||
])
|
||||
// Allow specific headers (required when using credentials)
|
||||
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
|
||||
// Expose headers to the frontend
|
||||
.expose_headers([
|
||||
header::AUTHORIZATION,
|
||||
header::CONTENT_TYPE,
|
||||
header::CONTENT_LENGTH,
|
||||
header::ACCEPT,
|
||||
])
|
||||
// Allow credentials (cookies, authorization headers)
|
||||
.allow_credentials(true)
|
||||
// Use predicate to match against allowed origins
|
||||
// Arc allows the closure to be called multiple times (preflight + actual request)
|
||||
.allow_origin(AllowOrigin::predicate(move |origin: &HeaderValue, _| {
|
||||
allowed_origin_values.contains(origin)
|
||||
}))
|
||||
}
|
||||
251
crates/api/src/middleware/error.rs
Normal file
251
crates/api/src/middleware/error.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
//! Error handling middleware and response types
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Standard API error response
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
/// Error message
|
||||
pub error: String,
|
||||
/// Optional error code
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub code: Option<String>,
|
||||
/// Optional additional details
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
/// Create a new error response
|
||||
pub fn new(error: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error: error.into(),
|
||||
code: None,
|
||||
details: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set error code
|
||||
pub fn with_code(mut self, code: impl Into<String>) -> Self {
|
||||
self.code = Some(code.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set error details
|
||||
pub fn with_details(mut self, details: serde_json::Value) -> Self {
|
||||
self.details = Some(details);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// API error type that can be converted to HTTP responses
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
/// Bad request (400)
|
||||
BadRequest(String),
|
||||
/// Unauthorized (401)
|
||||
Unauthorized(String),
|
||||
/// Forbidden (403)
|
||||
Forbidden(String),
|
||||
/// Not found (404)
|
||||
NotFound(String),
|
||||
/// Conflict (409)
|
||||
Conflict(String),
|
||||
/// Unprocessable entity (422)
|
||||
UnprocessableEntity(String),
|
||||
/// Too many requests (429)
|
||||
TooManyRequests(String),
|
||||
/// Internal server error (500)
|
||||
InternalServerError(String),
|
||||
/// Not implemented (501)
|
||||
NotImplemented(String),
|
||||
/// Database error
|
||||
DatabaseError(String),
|
||||
/// Validation error
|
||||
ValidationError(String),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Get the HTTP status code for this error
|
||||
pub fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
ApiError::BadRequest(_) => StatusCode::BAD_REQUEST,
|
||||
ApiError::Unauthorized(_) => StatusCode::UNAUTHORIZED,
|
||||
ApiError::Forbidden(_) => StatusCode::FORBIDDEN,
|
||||
ApiError::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
ApiError::Conflict(_) => StatusCode::CONFLICT,
|
||||
ApiError::UnprocessableEntity(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
ApiError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
ApiError::TooManyRequests(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||
ApiError::NotImplemented(_) => StatusCode::NOT_IMPLEMENTED,
|
||||
ApiError::InternalServerError(_) | ApiError::DatabaseError(_) => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the error message
|
||||
pub fn message(&self) -> &str {
|
||||
match self {
|
||||
ApiError::BadRequest(msg)
|
||||
| ApiError::Unauthorized(msg)
|
||||
| ApiError::Forbidden(msg)
|
||||
| ApiError::NotFound(msg)
|
||||
| ApiError::Conflict(msg)
|
||||
| ApiError::UnprocessableEntity(msg)
|
||||
| ApiError::TooManyRequests(msg)
|
||||
| ApiError::NotImplemented(msg)
|
||||
| ApiError::InternalServerError(msg)
|
||||
| ApiError::DatabaseError(msg)
|
||||
| ApiError::ValidationError(msg) => msg,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the error code
|
||||
pub fn code(&self) -> &str {
|
||||
match self {
|
||||
ApiError::BadRequest(_) => "BAD_REQUEST",
|
||||
ApiError::Unauthorized(_) => "UNAUTHORIZED",
|
||||
ApiError::Forbidden(_) => "FORBIDDEN",
|
||||
ApiError::NotFound(_) => "NOT_FOUND",
|
||||
ApiError::Conflict(_) => "CONFLICT",
|
||||
ApiError::UnprocessableEntity(_) => "UNPROCESSABLE_ENTITY",
|
||||
ApiError::TooManyRequests(_) => "TOO_MANY_REQUESTS",
|
||||
ApiError::NotImplemented(_) => "NOT_IMPLEMENTED",
|
||||
ApiError::ValidationError(_) => "VALIDATION_ERROR",
|
||||
ApiError::DatabaseError(_) => "DATABASE_ERROR",
|
||||
ApiError::InternalServerError(_) => "INTERNAL_SERVER_ERROR",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ApiError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.message())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ApiError {}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status_code();
|
||||
let error_response = ErrorResponse::new(self.message()).with_code(self.code());
|
||||
|
||||
(status, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from common error types
|
||||
impl From<sqlx::Error> for ApiError {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
match err {
|
||||
sqlx::Error::RowNotFound => ApiError::NotFound("Resource not found".to_string()),
|
||||
sqlx::Error::Database(db_err) => {
|
||||
// Check for unique constraint violations
|
||||
if let Some(constraint) = db_err.constraint() {
|
||||
ApiError::Conflict(format!("Constraint violation: {}", constraint))
|
||||
} else {
|
||||
ApiError::DatabaseError(format!("Database error: {}", db_err))
|
||||
}
|
||||
}
|
||||
_ => ApiError::DatabaseError(format!("Database error: {}", err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<attune_common::error::Error> for ApiError {
|
||||
fn from(err: attune_common::error::Error) -> Self {
|
||||
match err {
|
||||
attune_common::error::Error::NotFound {
|
||||
entity,
|
||||
field,
|
||||
value,
|
||||
} => ApiError::NotFound(format!("{} with {}={} not found", entity, field, value)),
|
||||
attune_common::error::Error::AlreadyExists {
|
||||
entity,
|
||||
field,
|
||||
value,
|
||||
} => ApiError::Conflict(format!(
|
||||
"{} with {}={} already exists",
|
||||
entity, field, value
|
||||
)),
|
||||
attune_common::error::Error::Validation(msg) => ApiError::BadRequest(msg),
|
||||
attune_common::error::Error::SchemaValidation(msg) => ApiError::BadRequest(msg),
|
||||
attune_common::error::Error::Database(err) => ApiError::from(err),
|
||||
attune_common::error::Error::InvalidState(msg) => ApiError::BadRequest(msg),
|
||||
attune_common::error::Error::PermissionDenied(msg) => ApiError::Forbidden(msg),
|
||||
attune_common::error::Error::AuthenticationFailed(msg) => ApiError::Unauthorized(msg),
|
||||
attune_common::error::Error::Configuration(msg) => ApiError::InternalServerError(msg),
|
||||
attune_common::error::Error::Serialization(err) => {
|
||||
ApiError::InternalServerError(format!("{}", err))
|
||||
}
|
||||
attune_common::error::Error::Io(msg)
|
||||
| attune_common::error::Error::Encryption(msg)
|
||||
| attune_common::error::Error::Timeout(msg)
|
||||
| attune_common::error::Error::ExternalService(msg)
|
||||
| attune_common::error::Error::Worker(msg)
|
||||
| attune_common::error::Error::Execution(msg)
|
||||
| attune_common::error::Error::Internal(msg) => ApiError::InternalServerError(msg),
|
||||
attune_common::error::Error::Other(err) => {
|
||||
ApiError::InternalServerError(format!("{}", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<validator::ValidationErrors> for ApiError {
|
||||
fn from(err: validator::ValidationErrors) -> Self {
|
||||
ApiError::ValidationError(format!("Validation failed: {}", err))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::auth::jwt::JwtError> for ApiError {
|
||||
fn from(err: crate::auth::jwt::JwtError) -> Self {
|
||||
match err {
|
||||
crate::auth::jwt::JwtError::Expired => {
|
||||
ApiError::Unauthorized("Token has expired".to_string())
|
||||
}
|
||||
crate::auth::jwt::JwtError::Invalid => {
|
||||
ApiError::Unauthorized("Invalid token".to_string())
|
||||
}
|
||||
crate::auth::jwt::JwtError::EncodeError(msg) => {
|
||||
ApiError::InternalServerError(format!("Failed to encode token: {}", msg))
|
||||
}
|
||||
crate::auth::jwt::JwtError::DecodeError(msg) => {
|
||||
ApiError::Unauthorized(format!("Failed to decode token: {}", msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::auth::password::PasswordError> for ApiError {
|
||||
fn from(err: crate::auth::password::PasswordError) -> Self {
|
||||
match err {
|
||||
crate::auth::password::PasswordError::HashError(msg) => {
|
||||
ApiError::InternalServerError(format!("Failed to hash password: {}", msg))
|
||||
}
|
||||
crate::auth::password::PasswordError::VerifyError(msg) => {
|
||||
ApiError::InternalServerError(format!("Failed to verify password: {}", msg))
|
||||
}
|
||||
crate::auth::password::PasswordError::InvalidHash => {
|
||||
ApiError::InternalServerError("Invalid password hash format".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::num::ParseIntError> for ApiError {
|
||||
fn from(err: std::num::ParseIntError) -> Self {
|
||||
ApiError::BadRequest(format!("Invalid number format: {}", err))
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type alias for API handlers
|
||||
pub type ApiResult<T> = Result<T, ApiError>;
|
||||
54
crates/api/src/middleware/logging.rs
Normal file
54
crates/api/src/middleware/logging.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
//! Request/Response logging middleware
|
||||
|
||||
use axum::{extract::Request, middleware::Next, response::Response};
|
||||
use std::time::Instant;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Middleware for logging HTTP requests and responses
|
||||
pub async fn log_request(req: Request, next: Next) -> Response {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let version = req.version();
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
info!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
version = ?version,
|
||||
"request started"
|
||||
);
|
||||
|
||||
let response = next.run(req).await;
|
||||
|
||||
let duration = start.elapsed();
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
info!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"request completed"
|
||||
);
|
||||
} else if status.is_client_error() {
|
||||
warn!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"request failed (client error)"
|
||||
);
|
||||
} else if status.is_server_error() {
|
||||
warn!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"request failed (server error)"
|
||||
);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
9
crates/api/src/middleware/mod.rs
Normal file
9
crates/api/src/middleware/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! Middleware modules for the API service
|
||||
|
||||
pub mod cors;
|
||||
pub mod error;
|
||||
pub mod logging;
|
||||
|
||||
pub use cors::create_cors_layer;
|
||||
pub use error::{ApiError, ApiResult};
|
||||
pub use logging::log_request;
|
||||
410
crates/api/src/openapi.rs
Normal file
410
crates/api/src/openapi.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
//! OpenAPI specification and documentation
|
||||
|
||||
use utoipa::{
|
||||
openapi::security::{HttpAuthScheme, HttpBuilder, SecurityScheme},
|
||||
Modify, OpenApi,
|
||||
};
|
||||
|
||||
use crate::dto::{
|
||||
action::{
|
||||
ActionResponse, ActionSummary, CreateActionRequest, QueueStatsResponse, UpdateActionRequest,
|
||||
},
|
||||
auth::{
|
||||
ChangePasswordRequest, CurrentUserResponse, LoginRequest, RefreshTokenRequest,
|
||||
RegisterRequest, TokenResponse,
|
||||
},
|
||||
common::{ApiResponse, PaginatedResponse, PaginationMeta, SuccessResponse},
|
||||
event::{EnforcementResponse, EnforcementSummary, EventResponse, EventSummary},
|
||||
execution::{ExecutionResponse, ExecutionSummary},
|
||||
inquiry::{
|
||||
CreateInquiryRequest, InquiryRespondRequest, InquiryResponse, InquirySummary,
|
||||
UpdateInquiryRequest,
|
||||
},
|
||||
key::{CreateKeyRequest, KeyResponse, KeySummary, UpdateKeyRequest},
|
||||
pack::{
|
||||
CreatePackRequest, InstallPackRequest, PackInstallResponse, PackResponse, PackSummary,
|
||||
PackWorkflowSyncResponse, PackWorkflowValidationResponse, RegisterPackRequest,
|
||||
UpdatePackRequest, WorkflowSyncResult,
|
||||
},
|
||||
rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest},
|
||||
trigger::{
|
||||
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary, TriggerResponse,
|
||||
TriggerSummary, UpdateSensorRequest, UpdateTriggerRequest,
|
||||
},
|
||||
webhook::{WebhookReceiverRequest, WebhookReceiverResponse},
|
||||
workflow::{CreateWorkflowRequest, UpdateWorkflowRequest, WorkflowResponse, WorkflowSummary},
|
||||
};
|
||||
|
||||
/// OpenAPI documentation structure
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
info(
|
||||
title = "Attune API",
|
||||
version = "0.1.0",
|
||||
description = "Event-driven automation and orchestration platform API",
|
||||
contact(
|
||||
name = "Attune Team",
|
||||
url = "https://github.com/yourusername/attune"
|
||||
),
|
||||
license(
|
||||
name = "MIT",
|
||||
url = "https://opensource.org/licenses/MIT"
|
||||
)
|
||||
),
|
||||
servers(
|
||||
(url = "http://localhost:8080", description = "Local development server"),
|
||||
(url = "https://api.attune.example.com", description = "Production server")
|
||||
),
|
||||
paths(
|
||||
// Health check
|
||||
crate::routes::health::health,
|
||||
crate::routes::health::health_detailed,
|
||||
crate::routes::health::readiness,
|
||||
crate::routes::health::liveness,
|
||||
|
||||
// Authentication
|
||||
crate::routes::auth::login,
|
||||
crate::routes::auth::register,
|
||||
crate::routes::auth::refresh_token,
|
||||
crate::routes::auth::get_current_user,
|
||||
crate::routes::auth::change_password,
|
||||
|
||||
// Packs
|
||||
crate::routes::packs::list_packs,
|
||||
crate::routes::packs::get_pack,
|
||||
crate::routes::packs::create_pack,
|
||||
crate::routes::packs::update_pack,
|
||||
crate::routes::packs::delete_pack,
|
||||
crate::routes::packs::register_pack,
|
||||
crate::routes::packs::install_pack,
|
||||
crate::routes::packs::sync_pack_workflows,
|
||||
crate::routes::packs::validate_pack_workflows,
|
||||
crate::routes::packs::test_pack,
|
||||
crate::routes::packs::get_pack_test_history,
|
||||
crate::routes::packs::get_pack_latest_test,
|
||||
|
||||
// Actions
|
||||
crate::routes::actions::list_actions,
|
||||
crate::routes::actions::list_actions_by_pack,
|
||||
crate::routes::actions::get_action,
|
||||
crate::routes::actions::create_action,
|
||||
crate::routes::actions::update_action,
|
||||
crate::routes::actions::delete_action,
|
||||
crate::routes::actions::get_queue_stats,
|
||||
|
||||
// Triggers
|
||||
crate::routes::triggers::list_triggers,
|
||||
crate::routes::triggers::list_enabled_triggers,
|
||||
crate::routes::triggers::list_triggers_by_pack,
|
||||
crate::routes::triggers::get_trigger,
|
||||
crate::routes::triggers::create_trigger,
|
||||
crate::routes::triggers::update_trigger,
|
||||
crate::routes::triggers::delete_trigger,
|
||||
crate::routes::triggers::enable_trigger,
|
||||
crate::routes::triggers::disable_trigger,
|
||||
|
||||
// Sensors
|
||||
crate::routes::triggers::list_sensors,
|
||||
crate::routes::triggers::list_enabled_sensors,
|
||||
crate::routes::triggers::list_sensors_by_pack,
|
||||
crate::routes::triggers::list_sensors_by_trigger,
|
||||
crate::routes::triggers::get_sensor,
|
||||
crate::routes::triggers::create_sensor,
|
||||
crate::routes::triggers::update_sensor,
|
||||
crate::routes::triggers::delete_sensor,
|
||||
crate::routes::triggers::enable_sensor,
|
||||
crate::routes::triggers::disable_sensor,
|
||||
|
||||
// Rules
|
||||
crate::routes::rules::list_rules,
|
||||
crate::routes::rules::list_enabled_rules,
|
||||
crate::routes::rules::list_rules_by_pack,
|
||||
crate::routes::rules::list_rules_by_action,
|
||||
crate::routes::rules::list_rules_by_trigger,
|
||||
crate::routes::rules::get_rule,
|
||||
crate::routes::rules::create_rule,
|
||||
crate::routes::rules::update_rule,
|
||||
crate::routes::rules::delete_rule,
|
||||
crate::routes::rules::enable_rule,
|
||||
crate::routes::rules::disable_rule,
|
||||
|
||||
// Executions
|
||||
crate::routes::executions::list_executions,
|
||||
crate::routes::executions::get_execution,
|
||||
crate::routes::executions::list_executions_by_status,
|
||||
crate::routes::executions::list_executions_by_enforcement,
|
||||
crate::routes::executions::get_execution_stats,
|
||||
|
||||
// Events
|
||||
crate::routes::events::list_events,
|
||||
crate::routes::events::get_event,
|
||||
|
||||
// Enforcements
|
||||
crate::routes::events::list_enforcements,
|
||||
crate::routes::events::get_enforcement,
|
||||
|
||||
// Inquiries
|
||||
crate::routes::inquiries::list_inquiries,
|
||||
crate::routes::inquiries::get_inquiry,
|
||||
crate::routes::inquiries::list_inquiries_by_status,
|
||||
crate::routes::inquiries::list_inquiries_by_execution,
|
||||
crate::routes::inquiries::create_inquiry,
|
||||
crate::routes::inquiries::update_inquiry,
|
||||
crate::routes::inquiries::respond_to_inquiry,
|
||||
crate::routes::inquiries::delete_inquiry,
|
||||
|
||||
// Keys/Secrets
|
||||
crate::routes::keys::list_keys,
|
||||
crate::routes::keys::get_key,
|
||||
crate::routes::keys::create_key,
|
||||
crate::routes::keys::update_key,
|
||||
crate::routes::keys::delete_key,
|
||||
|
||||
// Workflows
|
||||
crate::routes::workflows::list_workflows,
|
||||
crate::routes::workflows::list_workflows_by_pack,
|
||||
crate::routes::workflows::get_workflow,
|
||||
crate::routes::workflows::create_workflow,
|
||||
crate::routes::workflows::update_workflow,
|
||||
crate::routes::workflows::delete_workflow,
|
||||
|
||||
// Webhooks
|
||||
crate::routes::webhooks::enable_webhook,
|
||||
crate::routes::webhooks::disable_webhook,
|
||||
crate::routes::webhooks::regenerate_webhook_key,
|
||||
crate::routes::webhooks::receive_webhook,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
// Common types
|
||||
ApiResponse<TokenResponse>,
|
||||
ApiResponse<CurrentUserResponse>,
|
||||
ApiResponse<PackResponse>,
|
||||
ApiResponse<PackInstallResponse>,
|
||||
ApiResponse<ActionResponse>,
|
||||
ApiResponse<TriggerResponse>,
|
||||
ApiResponse<SensorResponse>,
|
||||
ApiResponse<RuleResponse>,
|
||||
ApiResponse<ExecutionResponse>,
|
||||
ApiResponse<EventResponse>,
|
||||
ApiResponse<EnforcementResponse>,
|
||||
ApiResponse<InquiryResponse>,
|
||||
ApiResponse<KeyResponse>,
|
||||
ApiResponse<WorkflowResponse>,
|
||||
ApiResponse<QueueStatsResponse>,
|
||||
PaginatedResponse<PackSummary>,
|
||||
PaginatedResponse<ActionSummary>,
|
||||
PaginatedResponse<TriggerSummary>,
|
||||
PaginatedResponse<SensorSummary>,
|
||||
PaginatedResponse<RuleSummary>,
|
||||
PaginatedResponse<ExecutionSummary>,
|
||||
PaginatedResponse<EventSummary>,
|
||||
PaginatedResponse<EnforcementSummary>,
|
||||
PaginatedResponse<InquirySummary>,
|
||||
PaginatedResponse<KeySummary>,
|
||||
PaginatedResponse<WorkflowSummary>,
|
||||
PaginationMeta,
|
||||
SuccessResponse,
|
||||
|
||||
// Auth DTOs
|
||||
LoginRequest,
|
||||
RegisterRequest,
|
||||
RefreshTokenRequest,
|
||||
ChangePasswordRequest,
|
||||
TokenResponse,
|
||||
CurrentUserResponse,
|
||||
|
||||
// Pack DTOs
|
||||
CreatePackRequest,
|
||||
UpdatePackRequest,
|
||||
RegisterPackRequest,
|
||||
InstallPackRequest,
|
||||
PackResponse,
|
||||
PackSummary,
|
||||
PackInstallResponse,
|
||||
PackWorkflowSyncResponse,
|
||||
PackWorkflowValidationResponse,
|
||||
WorkflowSyncResult,
|
||||
attune_common::models::pack_test::PackTestResult,
|
||||
attune_common::models::pack_test::PackTestExecution,
|
||||
attune_common::models::pack_test::TestSuiteResult,
|
||||
attune_common::models::pack_test::TestCaseResult,
|
||||
attune_common::models::pack_test::TestStatus,
|
||||
attune_common::models::pack_test::PackTestSummary,
|
||||
PaginatedResponse<attune_common::models::pack_test::PackTestSummary>,
|
||||
|
||||
// Action DTOs
|
||||
CreateActionRequest,
|
||||
UpdateActionRequest,
|
||||
ActionResponse,
|
||||
ActionSummary,
|
||||
QueueStatsResponse,
|
||||
|
||||
// Trigger DTOs
|
||||
CreateTriggerRequest,
|
||||
UpdateTriggerRequest,
|
||||
TriggerResponse,
|
||||
TriggerSummary,
|
||||
|
||||
// Sensor DTOs
|
||||
CreateSensorRequest,
|
||||
UpdateSensorRequest,
|
||||
SensorResponse,
|
||||
SensorSummary,
|
||||
|
||||
// Rule DTOs
|
||||
CreateRuleRequest,
|
||||
UpdateRuleRequest,
|
||||
RuleResponse,
|
||||
RuleSummary,
|
||||
|
||||
// Execution DTOs
|
||||
ExecutionResponse,
|
||||
ExecutionSummary,
|
||||
|
||||
// Event DTOs
|
||||
EventResponse,
|
||||
EventSummary,
|
||||
|
||||
// Enforcement DTOs
|
||||
EnforcementResponse,
|
||||
EnforcementSummary,
|
||||
|
||||
// Inquiry DTOs
|
||||
CreateInquiryRequest,
|
||||
UpdateInquiryRequest,
|
||||
InquiryRespondRequest,
|
||||
InquiryResponse,
|
||||
InquirySummary,
|
||||
|
||||
// Key/Secret DTOs
|
||||
CreateKeyRequest,
|
||||
UpdateKeyRequest,
|
||||
KeyResponse,
|
||||
KeySummary,
|
||||
|
||||
// Workflow DTOs
|
||||
CreateWorkflowRequest,
|
||||
UpdateWorkflowRequest,
|
||||
WorkflowResponse,
|
||||
WorkflowSummary,
|
||||
|
||||
// Webhook DTOs
|
||||
WebhookReceiverRequest,
|
||||
WebhookReceiverResponse,
|
||||
ApiResponse<WebhookReceiverResponse>,
|
||||
)
|
||||
),
|
||||
modifiers(&SecurityAddon),
|
||||
tags(
|
||||
(name = "health", description = "Health check endpoints"),
|
||||
(name = "auth", description = "Authentication and authorization endpoints"),
|
||||
(name = "packs", description = "Pack management endpoints"),
|
||||
(name = "actions", description = "Action management endpoints"),
|
||||
(name = "triggers", description = "Trigger management endpoints"),
|
||||
(name = "sensors", description = "Sensor management endpoints"),
|
||||
(name = "rules", description = "Rule management endpoints"),
|
||||
(name = "executions", description = "Execution query endpoints"),
|
||||
(name = "inquiries", description = "Inquiry (human-in-the-loop) endpoints"),
|
||||
(name = "events", description = "Event query endpoints"),
|
||||
(name = "enforcements", description = "Enforcement query endpoints"),
|
||||
(name = "secrets", description = "Secret management endpoints"),
|
||||
(name = "workflows", description = "Workflow management endpoints"),
|
||||
(name = "webhooks", description = "Webhook management and receiver endpoints"),
|
||||
)
|
||||
)]
|
||||
pub struct ApiDoc;
|
||||
|
||||
/// Security scheme modifier to add JWT Bearer authentication
|
||||
struct SecurityAddon;
|
||||
|
||||
impl Modify for SecurityAddon {
|
||||
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
|
||||
if let Some(components) = openapi.components.as_mut() {
|
||||
components.add_security_scheme(
|
||||
"bearer_auth",
|
||||
SecurityScheme::Http(
|
||||
HttpBuilder::new()
|
||||
.scheme(HttpAuthScheme::Bearer)
|
||||
.bearer_format("JWT")
|
||||
.description(Some(
|
||||
"JWT access token obtained from /auth/login or /auth/register",
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_openapi_spec_generation() {
|
||||
let doc = ApiDoc::openapi();
|
||||
|
||||
// Verify basic info
|
||||
assert_eq!(doc.info.title, "Attune API");
|
||||
assert_eq!(doc.info.version, "0.1.0");
|
||||
|
||||
// Verify we have components
|
||||
assert!(doc.components.is_some());
|
||||
|
||||
// Verify we have security schemes
|
||||
let components = doc.components.unwrap();
|
||||
assert!(components.security_schemes.contains_key("bearer_auth"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openapi_endpoint_count() {
|
||||
let doc = ApiDoc::openapi();
|
||||
|
||||
// Count all paths in the OpenAPI spec
|
||||
let path_count = doc.paths.paths.len();
|
||||
|
||||
// Count all operations (methods on paths)
|
||||
let operation_count: usize = doc
|
||||
.paths
|
||||
.paths
|
||||
.values()
|
||||
.map(|path_item| {
|
||||
let mut count = 0;
|
||||
if path_item.get.is_some() {
|
||||
count += 1;
|
||||
}
|
||||
if path_item.post.is_some() {
|
||||
count += 1;
|
||||
}
|
||||
if path_item.put.is_some() {
|
||||
count += 1;
|
||||
}
|
||||
if path_item.delete.is_some() {
|
||||
count += 1;
|
||||
}
|
||||
if path_item.patch.is_some() {
|
||||
count += 1;
|
||||
}
|
||||
count
|
||||
})
|
||||
.sum();
|
||||
|
||||
// We have 57 unique paths with 81 total operations (HTTP methods)
|
||||
// This test ensures we don't accidentally remove endpoints
|
||||
assert!(
|
||||
path_count >= 57,
|
||||
"Expected at least 57 unique API paths, found {}",
|
||||
path_count
|
||||
);
|
||||
|
||||
assert!(
|
||||
operation_count >= 81,
|
||||
"Expected at least 81 API operations, found {}",
|
||||
operation_count
|
||||
);
|
||||
|
||||
println!("Total API paths: {}", path_count);
|
||||
println!("Total API operations: {}", operation_count);
|
||||
}
|
||||
}
|
||||
67
crates/api/src/postgres_listener.rs
Normal file
67
crates/api/src/postgres_listener.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! PostgreSQL LISTEN/NOTIFY listener for SSE broadcasting
|
||||
|
||||
use sqlx::postgres::{PgListener, PgPool};
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Start listening to PostgreSQL notifications and broadcast them to SSE clients
|
||||
pub async fn start_postgres_listener(
|
||||
db: PgPool,
|
||||
broadcast_tx: broadcast::Sender<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
info!("Starting PostgreSQL notification listener for SSE broadcasting");
|
||||
|
||||
// Create a listener
|
||||
let mut listener = PgListener::connect_with(&db).await?;
|
||||
|
||||
// Subscribe to the notifications channel
|
||||
listener.listen("attune_notifications").await?;
|
||||
|
||||
info!("Listening on channel: attune_notifications");
|
||||
|
||||
// Process notifications in a loop
|
||||
loop {
|
||||
match listener.recv().await {
|
||||
Ok(notification) => {
|
||||
let payload = notification.payload();
|
||||
debug!("Received notification: {}", payload);
|
||||
|
||||
// Broadcast to all SSE clients
|
||||
match broadcast_tx.send(payload.to_string()) {
|
||||
Ok(receiver_count) => {
|
||||
debug!("Broadcasted notification to {} SSE clients", receiver_count);
|
||||
}
|
||||
Err(e) => {
|
||||
// This happens when there are no active receivers, which is normal
|
||||
debug!("No active SSE clients to receive notification: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error receiving notification: {}", e);
|
||||
|
||||
// If the connection is lost, try to reconnect
|
||||
warn!("Attempting to reconnect to PostgreSQL listener...");
|
||||
|
||||
match PgListener::connect_with(&db).await {
|
||||
Ok(mut new_listener) => {
|
||||
match new_listener.listen("attune_notifications").await {
|
||||
Ok(_) => {
|
||||
info!("Successfully reconnected to PostgreSQL listener");
|
||||
listener = new_listener;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to resubscribe after reconnect: {}", e);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to reconnect to PostgreSQL: {}", e);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
353
crates/api/src/routes/actions.rs
Normal file
353
crates/api/src/routes/actions.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
//! Action management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::repositories::{
|
||||
action::{ActionRepository, CreateActionInput, UpdateActionInput},
|
||||
pack::PackRepository,
|
||||
queue_stats::QueueStatsRepository,
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
action::{
|
||||
ActionResponse, ActionSummary, CreateActionRequest, QueueStatsResponse,
|
||||
UpdateActionRequest,
|
||||
},
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// List all actions with pagination
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/actions",
|
||||
tag = "actions",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of actions", body = PaginatedResponse<ActionSummary>),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_actions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all actions (we'll implement pagination in repository later)
|
||||
let actions = ActionRepository::list(&state.db).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = actions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(actions.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_actions: Vec<ActionSummary> = actions[start..end]
|
||||
.iter()
|
||||
.map(|a| ActionSummary::from(a.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_actions, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List actions by pack reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/packs/{pack_ref}/actions",
|
||||
tag = "actions",
|
||||
params(
|
||||
("pack_ref" = String, Path, description = "Pack reference identifier"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of actions for pack", body = PaginatedResponse<ActionSummary>),
|
||||
(status = 404, description = "Pack not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_actions_by_pack(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(pack_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify pack exists
|
||||
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get actions for this pack
|
||||
let actions = ActionRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = actions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(actions.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_actions: Vec<ActionSummary> = actions[start..end]
|
||||
.iter()
|
||||
.map(|a| ActionSummary::from(a.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_actions, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single action by reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/actions/{ref}",
|
||||
tag = "actions",
|
||||
params(
|
||||
("ref" = String, Path, description = "Action reference identifier")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Action details", body = inline(ApiResponse<ActionResponse>)),
|
||||
(status = 404, description = "Action not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_action(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(action_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let action = ActionRepository::find_by_ref(&state.db, &action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
|
||||
|
||||
let response = ApiResponse::new(ActionResponse::from(action));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create a new action
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/actions",
|
||||
tag = "actions",
|
||||
request_body = CreateActionRequest,
|
||||
responses(
|
||||
(status = 201, description = "Action created successfully", body = inline(ApiResponse<ActionResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Pack not found"),
|
||||
(status = 409, description = "Action with same ref already exists")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_action(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(request): Json<CreateActionRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if action with same ref already exists
|
||||
if let Some(_) = ActionRepository::find_by_ref(&state.db, &request.r#ref).await? {
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Action with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
// Verify pack exists and get its ID
|
||||
let pack = PackRepository::find_by_ref(&state.db, &request.pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", request.pack_ref)))?;
|
||||
|
||||
// If runtime is specified, we could verify it exists (future enhancement)
|
||||
// For now, the database foreign key constraint will handle invalid runtime IDs
|
||||
|
||||
// Create action input
|
||||
let action_input = CreateActionInput {
|
||||
r#ref: request.r#ref,
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
entrypoint: request.entrypoint,
|
||||
runtime: request.runtime,
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
is_adhoc: true, // Actions created via API are ad-hoc (not from pack installation)
|
||||
};
|
||||
|
||||
let action = ActionRepository::create(&state.db, action_input).await?;
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(ActionResponse::from(action), "Action created successfully");
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Update an existing action
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/actions/{ref}",
|
||||
tag = "actions",
|
||||
params(
|
||||
("ref" = String, Path, description = "Action reference identifier")
|
||||
),
|
||||
request_body = UpdateActionRequest,
|
||||
responses(
|
||||
(status = 200, description = "Action updated successfully", body = inline(ApiResponse<ActionResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Action not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn update_action(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(action_ref): Path<String>,
|
||||
Json(request): Json<UpdateActionRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if action exists
|
||||
let existing_action = ActionRepository::find_by_ref(&state.db, &action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
|
||||
|
||||
// Create update input
|
||||
let update_input = UpdateActionInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
entrypoint: request.entrypoint,
|
||||
runtime: request.runtime,
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
};
|
||||
|
||||
let action = ActionRepository::update(&state.db, existing_action.id, update_input).await?;
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(ActionResponse::from(action), "Action updated successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Delete an action
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/actions/{ref}",
|
||||
tag = "actions",
|
||||
params(
|
||||
("ref" = String, Path, description = "Action reference identifier")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Action deleted successfully", body = SuccessResponse),
|
||||
(status = 404, description = "Action not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_action(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(action_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if action exists
|
||||
let action = ActionRepository::find_by_ref(&state.db, &action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
|
||||
|
||||
// Delete the action
|
||||
let deleted = ActionRepository::delete(&state.db, action.id).await?;
|
||||
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!(
|
||||
"Action '{}' not found",
|
||||
action_ref
|
||||
)));
|
||||
}
|
||||
|
||||
let response = SuccessResponse::new(format!("Action '{}' deleted successfully", action_ref));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get queue statistics for an action
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/actions/{ref}/queue-stats",
|
||||
tag = "actions",
|
||||
params(
|
||||
("ref" = String, Path, description = "Action reference identifier")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Queue statistics", body = inline(ApiResponse<QueueStatsResponse>)),
|
||||
(status = 404, description = "Action not found or no queue statistics available")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_queue_stats(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(action_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Find the action by reference
|
||||
let action = ActionRepository::find_by_ref(&state.db, &action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
|
||||
|
||||
// Get queue statistics from database
|
||||
let queue_stats = QueueStatsRepository::find_by_action(&state.db, action.id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!(
|
||||
"No queue statistics available for action '{}'",
|
||||
action_ref
|
||||
))
|
||||
})?;
|
||||
|
||||
// Convert to response DTO and populate action_ref
|
||||
let mut response_stats = QueueStatsResponse::from(queue_stats);
|
||||
response_stats.action_ref = action.r#ref.clone();
|
||||
|
||||
let response = ApiResponse::new(response_stats);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create action routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/actions", get(list_actions).post(create_action))
|
||||
.route(
|
||||
"/actions/{ref}",
|
||||
get(get_action).put(update_action).delete(delete_action),
|
||||
)
|
||||
.route("/actions/{ref}/queue-stats", get(get_queue_stats))
|
||||
.route("/packs/{pack_ref}/actions", get(list_actions_by_pack))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_action_routes_structure() {
|
||||
// Just verify the router can be constructed
|
||||
let _router = routes();
|
||||
}
|
||||
}
|
||||
464
crates/api/src/routes/auth.rs
Normal file
464
crates/api/src/routes/auth.rs
Normal file
@@ -0,0 +1,464 @@
|
||||
//! Authentication routes
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::repositories::{
|
||||
identity::{CreateIdentityInput, IdentityRepository},
|
||||
Create, FindById,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::{
|
||||
hash_password,
|
||||
jwt::{
|
||||
generate_access_token, generate_refresh_token, generate_sensor_token, validate_token,
|
||||
TokenType,
|
||||
},
|
||||
middleware::RequireAuth,
|
||||
verify_password,
|
||||
},
|
||||
dto::{
|
||||
ApiResponse, ChangePasswordRequest, CurrentUserResponse, LoginRequest, RefreshTokenRequest,
|
||||
RegisterRequest, SuccessResponse, TokenResponse,
|
||||
},
|
||||
middleware::error::ApiError,
|
||||
state::SharedState,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
/// Request body for creating sensor tokens
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateSensorTokenRequest {
|
||||
/// Sensor reference (e.g., "core.timer")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
pub sensor_ref: String,
|
||||
|
||||
/// List of trigger types this sensor can create events for
|
||||
#[validate(length(min = 1))]
|
||||
pub trigger_types: Vec<String>,
|
||||
|
||||
/// Optional TTL in seconds (default: 86400 = 24 hours, max: 259200 = 72 hours)
|
||||
#[validate(range(min = 3600, max = 259200))]
|
||||
pub ttl_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
/// Response for sensor token creation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct SensorTokenResponse {
|
||||
pub identity_id: i64,
|
||||
pub sensor_ref: String,
|
||||
pub token: String,
|
||||
pub expires_at: String,
|
||||
pub trigger_types: Vec<String>,
|
||||
}
|
||||
|
||||
/// Create authentication routes
|
||||
pub fn routes() -> Router<SharedState> {
|
||||
Router::new()
|
||||
.route("/login", post(login))
|
||||
.route("/register", post(register))
|
||||
.route("/refresh", post(refresh_token))
|
||||
.route("/me", get(get_current_user))
|
||||
.route("/change-password", post(change_password))
|
||||
.route("/sensor-token", post(create_sensor_token))
|
||||
.route("/internal/sensor-token", post(create_sensor_token_internal))
|
||||
}
|
||||
|
||||
/// Login endpoint
|
||||
///
|
||||
/// POST /auth/login
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/login",
|
||||
tag = "auth",
|
||||
request_body = LoginRequest,
|
||||
responses(
|
||||
(status = 200, description = "Successfully logged in", body = inline(ApiResponse<TokenResponse>)),
|
||||
(status = 401, description = "Invalid credentials"),
|
||||
(status = 400, description = "Validation error")
|
||||
)
|
||||
)]
|
||||
pub async fn login(
|
||||
State(state): State<SharedState>,
|
||||
Json(payload): Json<LoginRequest>,
|
||||
) -> Result<Json<ApiResponse<TokenResponse>>, ApiError> {
|
||||
// Validate request
|
||||
payload
|
||||
.validate()
|
||||
.map_err(|e| ApiError::ValidationError(format!("Invalid login request: {}", e)))?;
|
||||
|
||||
// Find identity by login
|
||||
let identity = IdentityRepository::find_by_login(&state.db, &payload.login)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::Unauthorized("Invalid login or password".to_string()))?;
|
||||
|
||||
// Check if identity has a password set
|
||||
let password_hash = identity
|
||||
.password_hash
|
||||
.as_ref()
|
||||
.ok_or_else(|| ApiError::Unauthorized("Invalid login or password".to_string()))?;
|
||||
|
||||
// Verify password
|
||||
let is_valid = verify_password(&payload.password, password_hash)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid login or password".to_string()))?;
|
||||
|
||||
if !is_valid {
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid login or password".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Generate tokens
|
||||
let access_token = generate_access_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
let refresh_token = generate_refresh_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
|
||||
let response = TokenResponse::new(
|
||||
access_token,
|
||||
refresh_token,
|
||||
state.jwt_config.access_token_expiration,
|
||||
)
|
||||
.with_user(
|
||||
identity.id,
|
||||
identity.login.clone(),
|
||||
identity.display_name.clone(),
|
||||
);
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Register endpoint
|
||||
///
|
||||
/// POST /auth/register
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/register",
|
||||
tag = "auth",
|
||||
request_body = RegisterRequest,
|
||||
responses(
|
||||
(status = 200, description = "Successfully registered", body = inline(ApiResponse<TokenResponse>)),
|
||||
(status = 409, description = "User already exists"),
|
||||
(status = 400, description = "Validation error")
|
||||
)
|
||||
)]
|
||||
pub async fn register(
|
||||
State(state): State<SharedState>,
|
||||
Json(payload): Json<RegisterRequest>,
|
||||
) -> Result<Json<ApiResponse<TokenResponse>>, ApiError> {
|
||||
// Validate request
|
||||
payload
|
||||
.validate()
|
||||
.map_err(|e| ApiError::ValidationError(format!("Invalid registration request: {}", e)))?;
|
||||
|
||||
// Check if login already exists
|
||||
if let Some(_) = IdentityRepository::find_by_login(&state.db, &payload.login).await? {
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Identity with login '{}' already exists",
|
||||
payload.login
|
||||
)));
|
||||
}
|
||||
|
||||
// Hash password
|
||||
let password_hash = hash_password(&payload.password)?;
|
||||
|
||||
// Create identity with password hash
|
||||
let input = CreateIdentityInput {
|
||||
login: payload.login.clone(),
|
||||
display_name: payload.display_name,
|
||||
password_hash: Some(password_hash),
|
||||
attributes: serde_json::json!({}),
|
||||
};
|
||||
|
||||
let identity = IdentityRepository::create(&state.db, input).await?;
|
||||
|
||||
// Generate tokens
|
||||
let access_token = generate_access_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
let refresh_token = generate_refresh_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
|
||||
let response = TokenResponse::new(
|
||||
access_token,
|
||||
refresh_token,
|
||||
state.jwt_config.access_token_expiration,
|
||||
)
|
||||
.with_user(
|
||||
identity.id,
|
||||
identity.login.clone(),
|
||||
identity.display_name.clone(),
|
||||
);
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Refresh token endpoint
|
||||
///
|
||||
/// POST /auth/refresh
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/refresh",
|
||||
tag = "auth",
|
||||
request_body = RefreshTokenRequest,
|
||||
responses(
|
||||
(status = 200, description = "Successfully refreshed token", body = inline(ApiResponse<TokenResponse>)),
|
||||
(status = 401, description = "Invalid or expired refresh token"),
|
||||
(status = 400, description = "Validation error")
|
||||
)
|
||||
)]
|
||||
pub async fn refresh_token(
|
||||
State(state): State<SharedState>,
|
||||
Json(payload): Json<RefreshTokenRequest>,
|
||||
) -> Result<Json<ApiResponse<TokenResponse>>, ApiError> {
|
||||
// Validate request
|
||||
payload
|
||||
.validate()
|
||||
.map_err(|e| ApiError::ValidationError(format!("Invalid refresh token request: {}", e)))?;
|
||||
|
||||
// Validate refresh token
|
||||
let claims = validate_token(&payload.refresh_token, &state.jwt_config)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid or expired refresh token".to_string()))?;
|
||||
|
||||
// Ensure it's a refresh token
|
||||
if claims.token_type != TokenType::Refresh {
|
||||
return Err(ApiError::Unauthorized("Invalid token type".to_string()));
|
||||
}
|
||||
|
||||
// Parse identity ID
|
||||
let identity_id: i64 = claims
|
||||
.sub
|
||||
.parse()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid token".to_string()))?;
|
||||
|
||||
// Verify identity still exists
|
||||
let identity = IdentityRepository::find_by_id(&state.db, identity_id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::Unauthorized("Identity not found".to_string()))?;
|
||||
|
||||
// Generate new tokens
|
||||
let access_token = generate_access_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
let refresh_token = generate_refresh_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
|
||||
let response = TokenResponse::new(
|
||||
access_token,
|
||||
refresh_token,
|
||||
state.jwt_config.access_token_expiration,
|
||||
);
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Get current user endpoint
|
||||
///
|
||||
/// GET /auth/me
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/auth/me",
|
||||
tag = "auth",
|
||||
responses(
|
||||
(status = 200, description = "Current user information", body = inline(ApiResponse<CurrentUserResponse>)),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Identity not found")
|
||||
),
|
||||
security(
|
||||
("bearer_auth" = [])
|
||||
)
|
||||
)]
|
||||
pub async fn get_current_user(
|
||||
State(state): State<SharedState>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
) -> Result<Json<ApiResponse<CurrentUserResponse>>, ApiError> {
|
||||
let identity_id = user.identity_id()?;
|
||||
|
||||
// Fetch identity from database
|
||||
let identity = IdentityRepository::find_by_id(&state.db, identity_id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound("Identity not found".to_string()))?;
|
||||
|
||||
let response = CurrentUserResponse {
|
||||
id: identity.id,
|
||||
login: identity.login,
|
||||
display_name: identity.display_name,
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Change password endpoint
|
||||
///
|
||||
/// POST /auth/change-password
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/change-password",
|
||||
tag = "auth",
|
||||
request_body = ChangePasswordRequest,
|
||||
responses(
|
||||
(status = 200, description = "Password changed successfully", body = inline(ApiResponse<SuccessResponse>)),
|
||||
(status = 401, description = "Invalid current password or unauthorized"),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Identity not found")
|
||||
),
|
||||
security(
|
||||
("bearer_auth" = [])
|
||||
)
|
||||
)]
|
||||
pub async fn change_password(
|
||||
State(state): State<SharedState>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Json(payload): Json<ChangePasswordRequest>,
|
||||
) -> Result<Json<ApiResponse<SuccessResponse>>, ApiError> {
|
||||
// Validate request
|
||||
payload.validate().map_err(|e| {
|
||||
ApiError::ValidationError(format!("Invalid change password request: {}", e))
|
||||
})?;
|
||||
|
||||
let identity_id = user.identity_id()?;
|
||||
|
||||
// Fetch identity from database
|
||||
let identity = IdentityRepository::find_by_id(&state.db, identity_id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound("Identity not found".to_string()))?;
|
||||
|
||||
// Get current password hash
|
||||
let current_password_hash = identity
|
||||
.password_hash
|
||||
.as_ref()
|
||||
.ok_or_else(|| ApiError::Unauthorized("No password set".to_string()))?;
|
||||
|
||||
// Verify current password
|
||||
let is_valid = verify_password(&payload.current_password, current_password_hash)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid current password".to_string()))?;
|
||||
|
||||
if !is_valid {
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid current password".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
let new_password_hash = hash_password(&payload.new_password)?;
|
||||
|
||||
// Update identity in database with new password hash
|
||||
use attune_common::repositories::identity::UpdateIdentityInput;
|
||||
use attune_common::repositories::Update;
|
||||
|
||||
let update_input = UpdateIdentityInput {
|
||||
display_name: None,
|
||||
password_hash: Some(new_password_hash),
|
||||
attributes: None,
|
||||
};
|
||||
|
||||
IdentityRepository::update(&state.db, identity_id, update_input).await?;
|
||||
|
||||
Ok(Json(ApiResponse::new(SuccessResponse::new(
|
||||
"Password changed successfully",
|
||||
))))
|
||||
}
|
||||
|
||||
/// Create sensor token endpoint (internal use by sensor service)
|
||||
///
|
||||
/// POST /auth/sensor-token
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/sensor-token",
|
||||
tag = "auth",
|
||||
request_body = CreateSensorTokenRequest,
|
||||
responses(
|
||||
(status = 200, description = "Sensor token created successfully", body = inline(ApiResponse<SensorTokenResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 401, description = "Unauthorized")
|
||||
),
|
||||
security(
|
||||
("bearer_auth" = [])
|
||||
)
|
||||
)]
|
||||
pub async fn create_sensor_token(
|
||||
State(state): State<SharedState>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(payload): Json<CreateSensorTokenRequest>,
|
||||
) -> Result<Json<ApiResponse<SensorTokenResponse>>, ApiError> {
|
||||
create_sensor_token_impl(state, payload).await
|
||||
}
|
||||
|
||||
/// Create sensor token endpoint for internal service use (no auth required)
|
||||
///
|
||||
/// POST /auth/internal/sensor-token
|
||||
///
|
||||
/// This endpoint is intended for internal use by the sensor service to provision
|
||||
/// tokens for standalone sensors. In production, this should be restricted by
|
||||
/// network policies or replaced with proper service-to-service authentication.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/internal/sensor-token",
|
||||
tag = "auth",
|
||||
request_body = CreateSensorTokenRequest,
|
||||
responses(
|
||||
(status = 200, description = "Sensor token created successfully", body = inline(ApiResponse<SensorTokenResponse>)),
|
||||
(status = 400, description = "Validation error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_sensor_token_internal(
|
||||
State(state): State<SharedState>,
|
||||
Json(payload): Json<CreateSensorTokenRequest>,
|
||||
) -> Result<Json<ApiResponse<SensorTokenResponse>>, ApiError> {
|
||||
create_sensor_token_impl(state, payload).await
|
||||
}
|
||||
|
||||
/// Shared implementation for sensor token creation
|
||||
async fn create_sensor_token_impl(
|
||||
state: SharedState,
|
||||
payload: CreateSensorTokenRequest,
|
||||
) -> Result<Json<ApiResponse<SensorTokenResponse>>, ApiError> {
|
||||
// Validate request
|
||||
payload
|
||||
.validate()
|
||||
.map_err(|e| ApiError::ValidationError(format!("Invalid sensor token request: {}", e)))?;
|
||||
|
||||
// Create or find sensor identity
|
||||
let sensor_login = format!("sensor:{}", payload.sensor_ref);
|
||||
|
||||
let identity = match IdentityRepository::find_by_login(&state.db, &sensor_login).await? {
|
||||
Some(identity) => identity,
|
||||
None => {
|
||||
// Create new sensor identity
|
||||
let input = CreateIdentityInput {
|
||||
login: sensor_login.clone(),
|
||||
display_name: Some(format!("Sensor: {}", payload.sensor_ref)),
|
||||
password_hash: None, // Sensors don't use passwords
|
||||
attributes: serde_json::json!({
|
||||
"type": "sensor",
|
||||
"sensor_ref": payload.sensor_ref,
|
||||
"trigger_types": payload.trigger_types,
|
||||
}),
|
||||
};
|
||||
IdentityRepository::create(&state.db, input).await?
|
||||
}
|
||||
};
|
||||
|
||||
// Generate sensor token
|
||||
let ttl_seconds = payload.ttl_seconds.unwrap_or(86400); // Default: 24 hours
|
||||
let token = generate_sensor_token(
|
||||
identity.id,
|
||||
&payload.sensor_ref,
|
||||
payload.trigger_types.clone(),
|
||||
&state.jwt_config,
|
||||
Some(ttl_seconds),
|
||||
)?;
|
||||
|
||||
// Calculate expiration time
|
||||
let expires_at = chrono::Utc::now() + chrono::Duration::seconds(ttl_seconds);
|
||||
|
||||
let response = SensorTokenResponse {
|
||||
identity_id: identity.id,
|
||||
sensor_ref: payload.sensor_ref,
|
||||
token,
|
||||
expires_at: expires_at.to_rfc3339(),
|
||||
trigger_types: payload.trigger_types,
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
391
crates/api/src/routes/events.rs
Normal file
391
crates/api/src/routes/events.rs
Normal file
@@ -0,0 +1,391 @@
|
||||
//! Event and Enforcement query API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::sync::Arc;
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::{
|
||||
mq::{EventCreatedPayload, MessageEnvelope, MessageType},
|
||||
repositories::{
|
||||
event::{CreateEventInput, EnforcementRepository, EventRepository},
|
||||
trigger::TriggerRepository,
|
||||
Create, FindById, FindByRef, List,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::auth::RequireAuth;
|
||||
use crate::{
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
event::{
|
||||
EnforcementQueryParams, EnforcementResponse, EnforcementSummary, EventQueryParams,
|
||||
EventResponse, EventSummary,
|
||||
},
|
||||
ApiResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// Request body for creating an event
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateEventRequest {
|
||||
/// Trigger reference (e.g., "core.timer", "core.webhook")
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "core.timer")]
|
||||
pub trigger_ref: String,
|
||||
|
||||
/// Event payload data
|
||||
#[schema(value_type = Object, example = json!({"timestamp": "2024-01-13T10:30:00Z"}))]
|
||||
pub payload: Option<JsonValue>,
|
||||
|
||||
/// Event configuration
|
||||
#[schema(value_type = Object)]
|
||||
pub config: Option<JsonValue>,
|
||||
|
||||
/// Trigger instance ID (for correlation, often rule_id)
|
||||
#[schema(example = "rule_123")]
|
||||
pub trigger_instance_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Create a new event
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/events",
|
||||
tag = "events",
|
||||
request_body = CreateEventRequest,
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 201, description = "Event created successfully", body = ApiResponse<EventResponse>),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_event(
|
||||
user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(payload): Json<CreateEventRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
payload
|
||||
.validate()
|
||||
.map_err(|e| ApiError::ValidationError(format!("Invalid event request: {}", e)))?;
|
||||
|
||||
// Lookup trigger by reference to get trigger ID
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &payload.trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Trigger '{}' not found", payload.trigger_ref))
|
||||
})?;
|
||||
|
||||
// Parse trigger_instance_id to extract rule ID (format: "rule_{id}")
|
||||
let (rule_id, rule_ref) = if let Some(instance_id) = &payload.trigger_instance_id {
|
||||
if let Some(id_str) = instance_id.strip_prefix("rule_") {
|
||||
if let Ok(rid) = id_str.parse::<i64>() {
|
||||
// Fetch rule reference from database
|
||||
let fetched_rule_ref: Option<String> =
|
||||
sqlx::query_scalar("SELECT ref FROM rule WHERE id = $1")
|
||||
.bind(rid)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
if let Some(rref) = fetched_rule_ref {
|
||||
tracing::debug!("Event associated with rule {} (id: {})", rref, rid);
|
||||
(Some(rid), Some(rref))
|
||||
} else {
|
||||
tracing::warn!("trigger_instance_id {} provided but rule not found", rid);
|
||||
(None, None)
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Invalid rule ID in trigger_instance_id: {}", instance_id);
|
||||
(None, None)
|
||||
}
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"trigger_instance_id doesn't match rule format: {}",
|
||||
instance_id
|
||||
);
|
||||
(None, None)
|
||||
}
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Determine source (sensor) from authenticated user if it's a sensor token
|
||||
use crate::auth::jwt::TokenType;
|
||||
let (source_id, source_ref) = match user.0.claims.token_type {
|
||||
TokenType::Sensor => {
|
||||
// Extract sensor reference from login
|
||||
let sensor_ref = user.0.claims.login.clone();
|
||||
|
||||
// Look up sensor by reference
|
||||
let sensor_id: Option<i64> = sqlx::query_scalar("SELECT id FROM sensor WHERE ref = $1")
|
||||
.bind(&sensor_ref)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
match sensor_id {
|
||||
Some(id) => {
|
||||
tracing::debug!("Event created by sensor {} (id: {})", sensor_ref, id);
|
||||
(Some(id), Some(sensor_ref))
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("Sensor token for ref '{}' but sensor not found", sensor_ref);
|
||||
(None, Some(sensor_ref))
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
// Create event input
|
||||
let input = CreateEventInput {
|
||||
trigger: Some(trigger.id),
|
||||
trigger_ref: payload.trigger_ref.clone(),
|
||||
config: payload.config,
|
||||
payload: payload.payload,
|
||||
source: source_id,
|
||||
source_ref,
|
||||
rule: rule_id,
|
||||
rule_ref,
|
||||
};
|
||||
|
||||
// Create the event
|
||||
let event = EventRepository::create(&state.db, input).await?;
|
||||
|
||||
// Publish EventCreated message to message queue if publisher is available
|
||||
if let Some(ref publisher) = state.publisher {
|
||||
let message_payload = EventCreatedPayload {
|
||||
event_id: event.id,
|
||||
trigger_id: event.trigger,
|
||||
trigger_ref: event.trigger_ref.clone(),
|
||||
sensor_id: event.source,
|
||||
sensor_ref: event.source_ref.clone(),
|
||||
payload: event.payload.clone().unwrap_or(serde_json::json!({})),
|
||||
config: event.config.clone(),
|
||||
};
|
||||
|
||||
let envelope = MessageEnvelope::new(MessageType::EventCreated, message_payload)
|
||||
.with_source("api-service");
|
||||
|
||||
if let Err(e) = publisher.publish_envelope(&envelope).await {
|
||||
tracing::warn!(
|
||||
"Failed to publish EventCreated message for event {}: {}",
|
||||
event.id,
|
||||
e
|
||||
);
|
||||
// Continue even if message publishing fails - event is already recorded
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"Published EventCreated message for event {} (trigger: {})",
|
||||
event.id,
|
||||
event.trigger_ref
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let response = ApiResponse::new(EventResponse::from(event));
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// List all events with pagination and optional filters
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/events",
|
||||
tag = "events",
|
||||
params(EventQueryParams),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "List of events", body = PaginatedResponse<EventSummary>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_events(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<EventQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get events based on filters
|
||||
let events = if let Some(trigger_id) = query.trigger {
|
||||
// Filter by trigger ID
|
||||
EventRepository::find_by_trigger(&state.db, trigger_id).await?
|
||||
} else if let Some(trigger_ref) = &query.trigger_ref {
|
||||
// Filter by trigger reference
|
||||
EventRepository::find_by_trigger_ref(&state.db, trigger_ref).await?
|
||||
} else {
|
||||
// Get all events
|
||||
EventRepository::list(&state.db).await?
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_events = events;
|
||||
|
||||
if let Some(source_id) = query.source {
|
||||
filtered_events.retain(|e| e.source == Some(source_id));
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_events.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_events.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_events: Vec<EventSummary> = filtered_events[start..end]
|
||||
.iter()
|
||||
.map(|event| EventSummary::from(event.clone()))
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_events, &pagination_params, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single event by ID
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/events/{id}",
|
||||
tag = "events",
|
||||
params(
|
||||
("id" = i64, Path, description = "Event ID")
|
||||
),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "Event details", body = ApiResponse<EventResponse>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Event not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn get_event(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let event = EventRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Event with ID {} not found", id)))?;
|
||||
|
||||
let response = ApiResponse::new(EventResponse::from(event));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List all enforcements with pagination and optional filters
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/enforcements",
|
||||
tag = "enforcements",
|
||||
params(EnforcementQueryParams),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "List of enforcements", body = PaginatedResponse<EnforcementSummary>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_enforcements(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<EnforcementQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enforcements based on filters
|
||||
let enforcements = if let Some(status) = query.status {
|
||||
// Filter by status
|
||||
EnforcementRepository::find_by_status(&state.db, status).await?
|
||||
} else if let Some(rule_id) = query.rule {
|
||||
// Filter by rule ID
|
||||
EnforcementRepository::find_by_rule(&state.db, rule_id).await?
|
||||
} else if let Some(event_id) = query.event {
|
||||
// Filter by event ID
|
||||
EnforcementRepository::find_by_event(&state.db, event_id).await?
|
||||
} else {
|
||||
// Get all enforcements
|
||||
EnforcementRepository::list(&state.db).await?
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_enforcements = enforcements;
|
||||
|
||||
if let Some(trigger_ref) = &query.trigger_ref {
|
||||
filtered_enforcements.retain(|e| e.trigger_ref == *trigger_ref);
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_enforcements.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_enforcements.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_enforcements: Vec<EnforcementSummary> = filtered_enforcements[start..end]
|
||||
.iter()
|
||||
.map(|enforcement| EnforcementSummary::from(enforcement.clone()))
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_enforcements, &pagination_params, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single enforcement by ID
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/enforcements/{id}",
|
||||
tag = "enforcements",
|
||||
params(
|
||||
("id" = i64, Path, description = "Enforcement ID")
|
||||
),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "Enforcement details", body = ApiResponse<EnforcementResponse>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Enforcement not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn get_enforcement(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let enforcement = EnforcementRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Enforcement with ID {} not found", id)))?;
|
||||
|
||||
let response = ApiResponse::new(EnforcementResponse::from(enforcement));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Register event and enforcement routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/events", get(list_events).post(create_event))
|
||||
.route("/events/{id}", get(get_event))
|
||||
.route("/enforcements", get(list_enforcements))
|
||||
.route("/enforcements/{id}", get(get_enforcement))
|
||||
}
|
||||
529
crates/api/src/routes/executions.rs
Normal file
529
crates/api/src/routes/executions.rs
Normal file
@@ -0,0 +1,529 @@
|
||||
//! Execution management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
IntoResponse,
|
||||
},
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use attune_common::models::enums::ExecutionStatus;
|
||||
use attune_common::mq::{ExecutionRequestedPayload, MessageEnvelope, MessageType};
|
||||
use attune_common::repositories::{
|
||||
action::ActionRepository,
|
||||
execution::{CreateExecutionInput, ExecutionRepository},
|
||||
Create, EnforcementRepository, FindById, FindByRef, List,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
execution::{
|
||||
CreateExecutionRequest, ExecutionQueryParams, ExecutionResponse, ExecutionSummary,
|
||||
},
|
||||
ApiResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// Create a new execution (manual execution)
|
||||
///
|
||||
/// This endpoint allows directly executing an action without a trigger or rule.
|
||||
/// The execution is queued and will be picked up by the executor service.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/executions/execute",
|
||||
tag = "executions",
|
||||
request_body = CreateExecutionRequest,
|
||||
responses(
|
||||
(status = 201, description = "Execution created and queued", body = ExecutionResponse),
|
||||
(status = 404, description = "Action not found"),
|
||||
(status = 400, description = "Invalid request"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_execution(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(request): Json<CreateExecutionRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate that the action exists
|
||||
let action = ActionRepository::find_by_ref(&state.db, &request.action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", request.action_ref)))?;
|
||||
|
||||
// Create execution input
|
||||
let execution_input = CreateExecutionInput {
|
||||
action: Some(action.id),
|
||||
action_ref: action.r#ref.clone(),
|
||||
config: request
|
||||
.parameters
|
||||
.as_ref()
|
||||
.and_then(|p| serde_json::from_value(p.clone()).ok()),
|
||||
parent: None,
|
||||
enforcement: None,
|
||||
executor: None,
|
||||
status: ExecutionStatus::Requested,
|
||||
result: None,
|
||||
workflow_task: None, // Non-workflow execution
|
||||
};
|
||||
|
||||
// Insert into database
|
||||
let created_execution = ExecutionRepository::create(&state.db, execution_input).await?;
|
||||
|
||||
// Publish ExecutionRequested message to queue
|
||||
let payload = ExecutionRequestedPayload {
|
||||
execution_id: created_execution.id,
|
||||
action_id: Some(action.id),
|
||||
action_ref: action.r#ref.clone(),
|
||||
parent_id: None,
|
||||
enforcement_id: None,
|
||||
config: request.parameters,
|
||||
};
|
||||
|
||||
let message = MessageEnvelope::new(MessageType::ExecutionRequested, payload)
|
||||
.with_source("api-service")
|
||||
.with_correlation_id(uuid::Uuid::new_v4());
|
||||
|
||||
if let Some(publisher) = &state.publisher {
|
||||
publisher.publish_envelope(&message).await.map_err(|e| {
|
||||
ApiError::InternalServerError(format!("Failed to publish message: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
let response = ExecutionResponse::from(created_execution);
|
||||
|
||||
Ok((StatusCode::CREATED, Json(ApiResponse::new(response))))
|
||||
}
|
||||
|
||||
/// List all executions with pagination and optional filters
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions",
|
||||
tag = "executions",
|
||||
params(ExecutionQueryParams),
|
||||
responses(
|
||||
(status = 200, description = "List of executions", body = PaginatedResponse<ExecutionSummary>),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_executions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(query): Query<ExecutionQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get executions based on filters
|
||||
let executions = if let Some(status) = query.status {
|
||||
// Filter by status
|
||||
ExecutionRepository::find_by_status(&state.db, status).await?
|
||||
} else if let Some(enforcement_id) = query.enforcement {
|
||||
// Filter by enforcement
|
||||
ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?
|
||||
} else {
|
||||
// Get all executions
|
||||
ExecutionRepository::list(&state.db).await?
|
||||
};
|
||||
|
||||
// Apply additional filters in memory (could be optimized with database queries)
|
||||
let mut filtered_executions = executions;
|
||||
|
||||
if let Some(action_ref) = &query.action_ref {
|
||||
filtered_executions.retain(|e| e.action_ref == *action_ref);
|
||||
}
|
||||
|
||||
if let Some(pack_name) = &query.pack_name {
|
||||
filtered_executions.retain(|e| {
|
||||
// action_ref format is "pack.action"
|
||||
e.action_ref.starts_with(&format!("{}.", pack_name))
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(result_search) = &query.result_contains {
|
||||
let search_lower = result_search.to_lowercase();
|
||||
filtered_executions.retain(|e| {
|
||||
if let Some(result) = &e.result {
|
||||
// Convert result to JSON string and search case-insensitively
|
||||
let result_str = serde_json::to_string(result).unwrap_or_default();
|
||||
result_str.to_lowercase().contains(&search_lower)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(parent_id) = query.parent {
|
||||
filtered_executions.retain(|e| e.parent == Some(parent_id));
|
||||
}
|
||||
|
||||
if let Some(executor_id) = query.executor {
|
||||
filtered_executions.retain(|e| e.executor == Some(executor_id));
|
||||
}
|
||||
|
||||
// Fetch enforcements for all executions to populate rule_ref and trigger_ref
|
||||
let enforcement_ids: Vec<i64> = filtered_executions
|
||||
.iter()
|
||||
.filter_map(|e| e.enforcement)
|
||||
.collect();
|
||||
|
||||
let enforcement_map: std::collections::HashMap<i64, _> = if !enforcement_ids.is_empty() {
|
||||
let enforcements = EnforcementRepository::list(&state.db).await?;
|
||||
enforcements.into_iter().map(|enf| (enf.id, enf)).collect()
|
||||
} else {
|
||||
std::collections::HashMap::new()
|
||||
};
|
||||
|
||||
// Filter by rule_ref if specified
|
||||
if let Some(rule_ref) = &query.rule_ref {
|
||||
filtered_executions.retain(|e| {
|
||||
e.enforcement
|
||||
.and_then(|enf_id| enforcement_map.get(&enf_id))
|
||||
.map(|enf| enf.rule_ref == *rule_ref)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Filter by trigger_ref if specified
|
||||
if let Some(trigger_ref) = &query.trigger_ref {
|
||||
filtered_executions.retain(|e| {
|
||||
e.enforcement
|
||||
.and_then(|enf_id| enforcement_map.get(&enf_id))
|
||||
.map(|enf| enf.trigger_ref == *trigger_ref)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_executions.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_executions.len());
|
||||
|
||||
// Get paginated slice and populate rule_ref/trigger_ref from enforcements
|
||||
let paginated_executions: Vec<ExecutionSummary> = filtered_executions[start..end]
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let mut summary = ExecutionSummary::from(e.clone());
|
||||
if let Some(enf_id) = e.enforcement {
|
||||
if let Some(enforcement) = enforcement_map.get(&enf_id) {
|
||||
summary.rule_ref = Some(enforcement.rule_ref.clone());
|
||||
summary.trigger_ref = Some(enforcement.trigger_ref.clone());
|
||||
}
|
||||
}
|
||||
summary
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination_params, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single execution by ID
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/{id}",
|
||||
tag = "executions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Execution ID")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Execution details", body = inline(ApiResponse<ExecutionResponse>)),
|
||||
(status = 404, description = "Execution not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_execution(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let execution = ExecutionRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Execution with ID {} not found", id)))?;
|
||||
|
||||
let response = ApiResponse::new(ExecutionResponse::from(execution));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List executions by status
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/status/{status}",
|
||||
tag = "executions",
|
||||
params(
|
||||
("status" = String, Path, description = "Execution status (requested, scheduling, scheduled, running, completed, failed, canceling, cancelled, timeout, abandoned)"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of executions with specified status", body = PaginatedResponse<ExecutionSummary>),
|
||||
(status = 400, description = "Invalid status"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_executions_by_status(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(status_str): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Parse status from string
|
||||
let status = match status_str.to_lowercase().as_str() {
|
||||
"requested" => attune_common::models::enums::ExecutionStatus::Requested,
|
||||
"scheduling" => attune_common::models::enums::ExecutionStatus::Scheduling,
|
||||
"scheduled" => attune_common::models::enums::ExecutionStatus::Scheduled,
|
||||
"running" => attune_common::models::enums::ExecutionStatus::Running,
|
||||
"completed" => attune_common::models::enums::ExecutionStatus::Completed,
|
||||
"failed" => attune_common::models::enums::ExecutionStatus::Failed,
|
||||
"canceling" => attune_common::models::enums::ExecutionStatus::Canceling,
|
||||
"cancelled" => attune_common::models::enums::ExecutionStatus::Cancelled,
|
||||
"timeout" => attune_common::models::enums::ExecutionStatus::Timeout,
|
||||
"abandoned" => attune_common::models::enums::ExecutionStatus::Abandoned,
|
||||
_ => {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"Invalid execution status: {}",
|
||||
status_str
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Get executions by status
|
||||
let executions = ExecutionRepository::find_by_status(&state.db, status).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = executions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(executions.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
|
||||
.iter()
|
||||
.map(|e| ExecutionSummary::from(e.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List executions by enforcement ID
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/enforcement/{enforcement_id}",
|
||||
tag = "executions",
|
||||
params(
|
||||
("enforcement_id" = i64, Path, description = "Enforcement ID"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of executions for enforcement", body = PaginatedResponse<ExecutionSummary>),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_executions_by_enforcement(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(enforcement_id): Path<i64>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get executions by enforcement
|
||||
let executions = ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = executions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(executions.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
|
||||
.iter()
|
||||
.map(|e| ExecutionSummary::from(e.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get execution statistics
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/stats",
|
||||
tag = "executions",
|
||||
responses(
|
||||
(status = 200, description = "Execution statistics", body = inline(Object)),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_execution_stats(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all executions (limited by repository to 1000)
|
||||
let executions = ExecutionRepository::list(&state.db).await?;
|
||||
|
||||
// Calculate statistics
|
||||
let total = executions.len();
|
||||
let completed = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Completed)
|
||||
.count();
|
||||
let failed = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Failed)
|
||||
.count();
|
||||
let running = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Running)
|
||||
.count();
|
||||
let pending = executions
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
matches!(
|
||||
e.status,
|
||||
attune_common::models::enums::ExecutionStatus::Requested
|
||||
| attune_common::models::enums::ExecutionStatus::Scheduling
|
||||
| attune_common::models::enums::ExecutionStatus::Scheduled
|
||||
)
|
||||
})
|
||||
.count();
|
||||
|
||||
let stats = serde_json::json!({
|
||||
"total": total,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"running": running,
|
||||
"pending": pending,
|
||||
"cancelled": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Cancelled).count(),
|
||||
"timeout": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Timeout).count(),
|
||||
"abandoned": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Abandoned).count(),
|
||||
});
|
||||
|
||||
let response = ApiResponse::new(stats);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create execution routes
|
||||
/// Stream execution updates via Server-Sent Events
|
||||
///
|
||||
/// This endpoint streams real-time updates for execution status changes.
|
||||
/// Optionally filter by execution_id to watch a specific execution.
|
||||
///
|
||||
/// Note: Authentication is done via `token` query parameter since EventSource
|
||||
/// doesn't support custom headers.
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/stream",
|
||||
tag = "executions",
|
||||
params(
|
||||
("execution_id" = Option<i64>, Query, description = "Optional execution ID to filter updates"),
|
||||
("token" = String, Query, description = "JWT access token for authentication")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "SSE stream of execution updates", content_type = "text/event-stream"),
|
||||
(status = 401, description = "Unauthorized - invalid or missing token"),
|
||||
)
|
||||
)]
|
||||
pub async fn stream_execution_updates(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<StreamExecutionParams>,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
|
||||
// Validate token from query parameter
|
||||
use crate::auth::jwt::validate_token;
|
||||
|
||||
let token = params.token.as_ref().ok_or(ApiError::Unauthorized(
|
||||
"Missing authentication token".to_string(),
|
||||
))?;
|
||||
|
||||
validate_token(token, &state.jwt_config)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
|
||||
let rx = state.broadcast_tx.subscribe();
|
||||
let stream = BroadcastStream::new(rx);
|
||||
|
||||
let filtered_stream = stream.filter_map(move |msg| {
|
||||
async move {
|
||||
match msg {
|
||||
Ok(notification) => {
|
||||
// Parse the notification as JSON
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(¬ification) {
|
||||
// Check if it's an execution update
|
||||
if let Some(entity_type) = value.get("entity_type").and_then(|v| v.as_str())
|
||||
{
|
||||
if entity_type == "execution" {
|
||||
// If filtering by execution_id, check if it matches
|
||||
if let Some(filter_id) = params.execution_id {
|
||||
if let Some(entity_id) =
|
||||
value.get("entity_id").and_then(|v| v.as_i64())
|
||||
{
|
||||
if entity_id != filter_id {
|
||||
return None; // Skip this event
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send the notification as an SSE event
|
||||
return Some(Ok(Event::default().data(notification)));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Err(_) => None, // Skip broadcast errors
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Sse::new(filtered_stream).keep_alive(KeepAlive::default()))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct StreamExecutionParams {
|
||||
pub execution_id: Option<i64>,
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/executions", get(list_executions))
|
||||
.route("/executions/execute", axum::routing::post(create_execution))
|
||||
.route("/executions/stats", get(get_execution_stats))
|
||||
.route("/executions/stream", get(stream_execution_updates))
|
||||
.route("/executions/{id}", get(get_execution))
|
||||
.route(
|
||||
"/executions/status/{status}",
|
||||
get(list_executions_by_status),
|
||||
)
|
||||
.route(
|
||||
"/enforcements/{enforcement_id}/executions",
|
||||
get(list_executions_by_enforcement),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_execution_routes_structure() {
|
||||
// Just verify the router can be constructed
|
||||
let _router = routes();
|
||||
}
|
||||
}
|
||||
131
crates/api/src/routes/health.rs
Normal file
131
crates/api/src/routes/health.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
//! Health check endpoints
|
||||
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Json, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Health check response
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct HealthResponse {
|
||||
/// Service status
|
||||
#[schema(example = "ok")]
|
||||
pub status: String,
|
||||
/// Service version
|
||||
#[schema(example = "0.1.0")]
|
||||
pub version: String,
|
||||
/// Database connectivity status
|
||||
#[schema(example = "connected")]
|
||||
pub database: String,
|
||||
}
|
||||
|
||||
/// Basic health check endpoint
|
||||
///
|
||||
/// Returns 200 OK if the service is running
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/health",
|
||||
tag = "health",
|
||||
responses(
|
||||
(status = 200, description = "Service is healthy", body = inline(Object), example = json!({"status": "ok"}))
|
||||
)
|
||||
)]
|
||||
pub async fn health() -> impl IntoResponse {
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"status": "ok"
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
/// Detailed health check endpoint
|
||||
///
|
||||
/// Checks database connectivity and returns detailed status
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/health/detailed",
|
||||
tag = "health",
|
||||
responses(
|
||||
(status = 200, description = "Service is healthy with details", body = HealthResponse),
|
||||
(status = 503, description = "Service unavailable", body = inline(Object))
|
||||
)
|
||||
)]
|
||||
pub async fn health_detailed(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Check database connectivity
|
||||
let db_status = match sqlx::query("SELECT 1").fetch_one(&state.db).await {
|
||||
Ok(_) => "connected",
|
||||
Err(e) => {
|
||||
tracing::error!("Database health check failed: {}", e);
|
||||
return Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({
|
||||
"status": "error",
|
||||
"database": "disconnected",
|
||||
"error": "Database connectivity check failed"
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let response = HealthResponse {
|
||||
status: "ok".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
database: db_status.to_string(),
|
||||
};
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Readiness check endpoint
|
||||
///
|
||||
/// Returns 200 OK if the service is ready to accept requests
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/health/ready",
|
||||
tag = "health",
|
||||
responses(
|
||||
(status = 200, description = "Service is ready"),
|
||||
(status = 503, description = "Service not ready")
|
||||
)
|
||||
)]
|
||||
pub async fn readiness(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
// Check if database is ready
|
||||
match sqlx::query("SELECT 1").fetch_one(&state.db).await {
|
||||
Ok(_) => Ok(StatusCode::OK),
|
||||
Err(e) => {
|
||||
tracing::error!("Readiness check failed: {}", e);
|
||||
Err(StatusCode::SERVICE_UNAVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Liveness check endpoint
|
||||
///
|
||||
/// Returns 200 OK if the service process is alive
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/health/live",
|
||||
tag = "health",
|
||||
responses(
|
||||
(status = 200, description = "Service is alive")
|
||||
)
|
||||
)]
|
||||
pub async fn liveness() -> impl IntoResponse {
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
/// Create health check router
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/health/detailed", get(health_detailed))
|
||||
.route("/health/ready", get(readiness))
|
||||
.route("/health/live", get(liveness))
|
||||
}
|
||||
507
crates/api/src/routes/inquiries.rs
Normal file
507
crates/api/src/routes/inquiries.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
//! Inquiry management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::{
|
||||
mq::{InquiryRespondedPayload, MessageEnvelope, MessageType},
|
||||
repositories::{
|
||||
execution::ExecutionRepository,
|
||||
inquiry::{CreateInquiryInput, InquiryRepository, UpdateInquiryInput},
|
||||
Create, Delete, FindById, List, Update,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::auth::RequireAuth;
|
||||
use crate::{
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
inquiry::{
|
||||
CreateInquiryRequest, InquiryQueryParams, InquiryRespondRequest, InquiryResponse,
|
||||
InquirySummary, UpdateInquiryRequest,
|
||||
},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// List all inquiries with pagination and optional filters
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/inquiries",
|
||||
tag = "inquiries",
|
||||
params(InquiryQueryParams),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "List of inquiries", body = PaginatedResponse<InquirySummary>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_inquiries(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<InquiryQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get inquiries based on filters
|
||||
let inquiries = if let Some(status) = query.status {
|
||||
// Filter by status
|
||||
InquiryRepository::find_by_status(&state.db, status).await?
|
||||
} else if let Some(execution_id) = query.execution {
|
||||
// Filter by execution
|
||||
InquiryRepository::find_by_execution(&state.db, execution_id).await?
|
||||
} else {
|
||||
// Get all inquiries
|
||||
InquiryRepository::list(&state.db).await?
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_inquiries = inquiries;
|
||||
|
||||
if let Some(assigned_to) = query.assigned_to {
|
||||
filtered_inquiries.retain(|i| i.assigned_to == Some(assigned_to));
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_inquiries.len() as u64;
|
||||
let offset = query.offset.unwrap_or(0);
|
||||
let limit = query.limit.unwrap_or(50).min(500);
|
||||
let start = offset;
|
||||
let end = (start + limit).min(filtered_inquiries.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_inquiries: Vec<InquirySummary> = filtered_inquiries[start..end]
|
||||
.iter()
|
||||
.map(|inquiry| InquirySummary::from(inquiry.clone()))
|
||||
.collect();
|
||||
|
||||
// Convert to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: (offset / limit.max(1)) as u32 + 1,
|
||||
page_size: limit as u32,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination_params, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single inquiry by ID
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/inquiries/{id}",
|
||||
tag = "inquiries",
|
||||
params(
|
||||
("id" = i64, Path, description = "Inquiry ID")
|
||||
),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "Inquiry details", body = ApiResponse<InquiryResponse>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Inquiry not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn get_inquiry(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let inquiry = InquiryRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Inquiry with ID {} not found", id)))?;
|
||||
|
||||
let response = ApiResponse::new(InquiryResponse::from(inquiry));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List inquiries by status
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/inquiries/status/{status}",
|
||||
tag = "inquiries",
|
||||
params(
|
||||
("status" = String, Path, description = "Inquiry status (pending, responded, timeout, canceled)"),
|
||||
PaginationParams
|
||||
),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "List of inquiries with specified status", body = PaginatedResponse<InquirySummary>),
|
||||
(status = 400, description = "Invalid status"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_inquiries_by_status(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(status_str): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Parse status from string
|
||||
let status = match status_str.to_lowercase().as_str() {
|
||||
"pending" => attune_common::models::enums::InquiryStatus::Pending,
|
||||
"responded" => attune_common::models::enums::InquiryStatus::Responded,
|
||||
"timeout" => attune_common::models::enums::InquiryStatus::Timeout,
|
||||
"canceled" => attune_common::models::enums::InquiryStatus::Cancelled,
|
||||
_ => {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"Invalid inquiry status: '{}'. Valid values are: pending, responded, timeout, canceled",
|
||||
status_str
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let inquiries = InquiryRepository::find_by_status(&state.db, status).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = inquiries.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(inquiries.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_inquiries: Vec<InquirySummary> = inquiries[start..end]
|
||||
.iter()
|
||||
.map(|inquiry| InquirySummary::from(inquiry.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List inquiries for a specific execution
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/{execution_id}/inquiries",
|
||||
tag = "inquiries",
|
||||
params(
|
||||
("execution_id" = i64, Path, description = "Execution ID"),
|
||||
PaginationParams
|
||||
),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "List of inquiries for execution", body = PaginatedResponse<InquirySummary>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Execution not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_inquiries_by_execution(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(execution_id): Path<i64>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify execution exists
|
||||
let _execution = ExecutionRepository::find_by_id(&state.db, execution_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Execution with ID {} not found", execution_id))
|
||||
})?;
|
||||
|
||||
let inquiries = InquiryRepository::find_by_execution(&state.db, execution_id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = inquiries.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(inquiries.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_inquiries: Vec<InquirySummary> = inquiries[start..end]
|
||||
.iter()
|
||||
.map(|inquiry| InquirySummary::from(inquiry.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create a new inquiry
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/inquiries",
|
||||
tag = "inquiries",
|
||||
request_body = CreateInquiryRequest,
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 201, description = "Inquiry created successfully", body = ApiResponse<InquiryResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Execution not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_inquiry(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(request): Json<CreateInquiryRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Verify execution exists
|
||||
let _execution = ExecutionRepository::find_by_id(&state.db, request.execution)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Execution with ID {} not found", request.execution))
|
||||
})?;
|
||||
|
||||
// Create inquiry input
|
||||
let inquiry_input = CreateInquiryInput {
|
||||
execution: request.execution,
|
||||
prompt: request.prompt,
|
||||
response_schema: request.response_schema,
|
||||
assigned_to: request.assigned_to,
|
||||
status: attune_common::models::enums::InquiryStatus::Pending,
|
||||
response: None,
|
||||
timeout_at: request.timeout_at,
|
||||
};
|
||||
|
||||
let inquiry = InquiryRepository::create(&state.db, inquiry_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
InquiryResponse::from(inquiry),
|
||||
"Inquiry created successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Update an existing inquiry
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/inquiries/{id}",
|
||||
tag = "inquiries",
|
||||
params(
|
||||
("id" = i64, Path, description = "Inquiry ID")
|
||||
),
|
||||
request_body = UpdateInquiryRequest,
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "Inquiry updated successfully", body = ApiResponse<InquiryResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Inquiry not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn update_inquiry(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<UpdateInquiryRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Verify inquiry exists
|
||||
let _existing = InquiryRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Inquiry with ID {} not found", id)))?;
|
||||
|
||||
// Create update input
|
||||
let update_input = UpdateInquiryInput {
|
||||
status: request.status,
|
||||
response: request.response,
|
||||
responded_at: None, // Let the database handle this if needed
|
||||
assigned_to: request.assigned_to,
|
||||
};
|
||||
|
||||
let updated_inquiry = InquiryRepository::update(&state.db, id, update_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
InquiryResponse::from(updated_inquiry),
|
||||
"Inquiry updated successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Respond to an inquiry (user-facing endpoint)
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/inquiries/{id}/respond",
|
||||
tag = "inquiries",
|
||||
params(
|
||||
("id" = i64, Path, description = "Inquiry ID")
|
||||
),
|
||||
request_body = InquiryRespondRequest,
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "Response submitted successfully", body = ApiResponse<InquiryResponse>),
|
||||
(status = 400, description = "Invalid request or inquiry cannot be responded to"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 403, description = "Not authorized to respond to this inquiry"),
|
||||
(status = 404, description = "Inquiry not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn respond_to_inquiry(
|
||||
user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<InquiryRespondRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Verify inquiry exists and is in pending status
|
||||
let inquiry = InquiryRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Inquiry with ID {} not found", id)))?;
|
||||
|
||||
// Check if inquiry is still pending
|
||||
if inquiry.status != attune_common::models::enums::InquiryStatus::Pending {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"Cannot respond to inquiry with status '{:?}'. Only pending inquiries can be responded to.",
|
||||
inquiry.status
|
||||
)));
|
||||
}
|
||||
|
||||
// Check if inquiry is assigned to this user (optional enforcement)
|
||||
if let Some(assigned_to) = inquiry.assigned_to {
|
||||
let user_id = user
|
||||
.0
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::InternalServerError("Invalid user identity".to_string()))?;
|
||||
if assigned_to != user_id {
|
||||
return Err(ApiError::Forbidden(
|
||||
"You are not authorized to respond to this inquiry".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Check if inquiry has timed out
|
||||
if let Some(timeout_at) = inquiry.timeout_at {
|
||||
if timeout_at < chrono::Utc::now() {
|
||||
// Update inquiry to timeout status
|
||||
let timeout_input = UpdateInquiryInput {
|
||||
status: Some(attune_common::models::enums::InquiryStatus::Timeout),
|
||||
response: None,
|
||||
responded_at: None,
|
||||
assigned_to: None,
|
||||
};
|
||||
let _ = InquiryRepository::update(&state.db, id, timeout_input).await?;
|
||||
|
||||
return Err(ApiError::BadRequest(
|
||||
"Inquiry has timed out and can no longer be responded to".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Validate response against response_schema if present
|
||||
// For now, just accept the response as-is
|
||||
|
||||
// Create update input with response
|
||||
let update_input = UpdateInquiryInput {
|
||||
status: Some(attune_common::models::enums::InquiryStatus::Responded),
|
||||
response: Some(request.response.clone()),
|
||||
responded_at: Some(chrono::Utc::now()),
|
||||
assigned_to: None,
|
||||
};
|
||||
|
||||
let updated_inquiry = InquiryRepository::update(&state.db, id, update_input).await?;
|
||||
|
||||
// Publish InquiryResponded message if publisher is available
|
||||
if let Some(publisher) = &state.publisher {
|
||||
let user_id = user
|
||||
.0
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::InternalServerError("Invalid user identity".to_string()))?;
|
||||
|
||||
let payload = InquiryRespondedPayload {
|
||||
inquiry_id: id,
|
||||
execution_id: inquiry.execution,
|
||||
response: request.response.clone(),
|
||||
responded_by: Some(user_id),
|
||||
responded_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let envelope =
|
||||
MessageEnvelope::new(MessageType::InquiryResponded, payload).with_source("api");
|
||||
|
||||
if let Err(e) = publisher.publish_envelope(&envelope).await {
|
||||
tracing::error!("Failed to publish InquiryResponded message: {}", e);
|
||||
// Don't fail the request - inquiry is already saved
|
||||
} else {
|
||||
tracing::info!("Published InquiryResponded message for inquiry {}", id);
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("No publisher available to publish InquiryResponded message");
|
||||
}
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
InquiryResponse::from(updated_inquiry),
|
||||
"Response submitted successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Delete an inquiry
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/inquiries/{id}",
|
||||
tag = "inquiries",
|
||||
params(
|
||||
("id" = i64, Path, description = "Inquiry ID")
|
||||
),
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "Inquiry deleted successfully", body = SuccessResponse),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Inquiry not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn delete_inquiry(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify inquiry exists
|
||||
let _inquiry = InquiryRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Inquiry with ID {} not found", id)))?;
|
||||
|
||||
// Delete the inquiry
|
||||
let deleted = InquiryRepository::delete(&state.db, id).await?;
|
||||
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!(
|
||||
"Inquiry with ID {} not found",
|
||||
id
|
||||
)));
|
||||
}
|
||||
|
||||
let response = SuccessResponse::new("Inquiry deleted successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Register inquiry routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/inquiries", get(list_inquiries).post(create_inquiry))
|
||||
.route(
|
||||
"/inquiries/{id}",
|
||||
get(get_inquiry).put(update_inquiry).delete(delete_inquiry),
|
||||
)
|
||||
.route("/inquiries/status/{status}", get(list_inquiries_by_status))
|
||||
.route(
|
||||
"/executions/{execution_id}/inquiries",
|
||||
get(list_inquiries_by_execution),
|
||||
)
|
||||
.route("/inquiries/{id}/respond", post(respond_to_inquiry))
|
||||
}
|
||||
363
crates/api/src/routes/keys.rs
Normal file
363
crates/api/src/routes/keys.rs
Normal file
@@ -0,0 +1,363 @@
|
||||
//! Key/Secret management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::repositories::{
|
||||
key::{CreateKeyInput, KeyRepository, UpdateKeyInput},
|
||||
Create, Delete, List, Update,
|
||||
};
|
||||
|
||||
use crate::auth::RequireAuth;
|
||||
use crate::{
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
key::{CreateKeyRequest, KeyQueryParams, KeyResponse, KeySummary, UpdateKeyRequest},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// List all keys with pagination and optional filters (values redacted)
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/keys",
|
||||
tag = "secrets",
|
||||
params(KeyQueryParams),
|
||||
responses(
|
||||
(status = 200, description = "List of keys (values redacted)", body = PaginatedResponse<KeySummary>),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_keys(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<KeyQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get keys based on filters
|
||||
let keys = if let Some(owner_type) = query.owner_type {
|
||||
// Filter by owner type
|
||||
KeyRepository::find_by_owner_type(&state.db, owner_type).await?
|
||||
} else {
|
||||
// Get all keys
|
||||
KeyRepository::list(&state.db).await?
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_keys = keys;
|
||||
|
||||
if let Some(owner) = &query.owner {
|
||||
filtered_keys.retain(|k| k.owner.as_ref() == Some(owner));
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_keys.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_keys.len());
|
||||
|
||||
// Get paginated slice (values redacted in summary)
|
||||
let paginated_keys: Vec<KeySummary> = filtered_keys[start..end]
|
||||
.iter()
|
||||
.map(|key| KeySummary::from(key.clone()))
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_keys, &pagination_params, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single key by reference (includes decrypted value)
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/keys/{ref}",
|
||||
tag = "secrets",
|
||||
params(
|
||||
("ref" = String, Path, description = "Key reference identifier")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Key details with decrypted value", body = inline(ApiResponse<KeyResponse>)),
|
||||
(status = 404, description = "Key not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_key(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(key_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let mut key = KeyRepository::find_by_ref(&state.db, &key_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
|
||||
|
||||
// Decrypt value if encrypted
|
||||
if key.encrypted {
|
||||
let encryption_key = state
|
||||
.config
|
||||
.security
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
ApiError::InternalServerError("Encryption key not configured on server".to_string())
|
||||
})?;
|
||||
|
||||
let decrypted_value =
|
||||
attune_common::crypto::decrypt(&key.value, encryption_key).map_err(|e| {
|
||||
tracing::error!("Failed to decrypt key '{}': {}", key_ref, e);
|
||||
ApiError::InternalServerError(format!("Failed to decrypt key: {}", e))
|
||||
})?;
|
||||
|
||||
key.value = decrypted_value;
|
||||
}
|
||||
|
||||
let response = ApiResponse::new(KeyResponse::from(key));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create a new key/secret
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/keys",
|
||||
tag = "secrets",
|
||||
request_body = CreateKeyRequest,
|
||||
responses(
|
||||
(status = 201, description = "Key created successfully", body = inline(ApiResponse<KeyResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 409, description = "Key with same ref already exists")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_key(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(request): Json<CreateKeyRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if key with same ref already exists
|
||||
if let Some(_) = KeyRepository::find_by_ref(&state.db, &request.r#ref).await? {
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Key with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
// Encrypt value if requested
|
||||
let (value, encryption_key_hash) = if request.encrypted {
|
||||
let encryption_key = state
|
||||
.config
|
||||
.security
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
ApiError::BadRequest(
|
||||
"Cannot encrypt: encryption key not configured on server".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let encrypted_value = attune_common::crypto::encrypt(&request.value, encryption_key)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to encrypt key value: {}", e);
|
||||
ApiError::InternalServerError(format!("Failed to encrypt value: {}", e))
|
||||
})?;
|
||||
|
||||
let key_hash = attune_common::crypto::hash_encryption_key(encryption_key);
|
||||
|
||||
(encrypted_value, Some(key_hash))
|
||||
} else {
|
||||
// Store in plaintext (not recommended for sensitive data)
|
||||
(request.value.clone(), None)
|
||||
};
|
||||
|
||||
// Create key input
|
||||
let key_input = CreateKeyInput {
|
||||
r#ref: request.r#ref,
|
||||
owner_type: request.owner_type,
|
||||
owner: request.owner,
|
||||
owner_identity: request.owner_identity,
|
||||
owner_pack: request.owner_pack,
|
||||
owner_pack_ref: request.owner_pack_ref,
|
||||
owner_action: request.owner_action,
|
||||
owner_action_ref: request.owner_action_ref,
|
||||
owner_sensor: request.owner_sensor,
|
||||
owner_sensor_ref: request.owner_sensor_ref,
|
||||
name: request.name,
|
||||
encrypted: request.encrypted,
|
||||
encryption_key_hash,
|
||||
value,
|
||||
};
|
||||
|
||||
let mut key = KeyRepository::create(&state.db, key_input).await?;
|
||||
|
||||
// Return decrypted value in response
|
||||
if key.encrypted {
|
||||
let encryption_key = state.config.security.encryption_key.as_ref().unwrap();
|
||||
key.value = attune_common::crypto::decrypt(&key.value, encryption_key).map_err(|e| {
|
||||
tracing::error!("Failed to decrypt newly created key: {}", e);
|
||||
ApiError::InternalServerError(format!("Failed to decrypt value: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
let response = ApiResponse::with_message(KeyResponse::from(key), "Key created successfully");
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Update an existing key/secret
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/keys/{ref}",
|
||||
tag = "secrets",
|
||||
params(
|
||||
("ref" = String, Path, description = "Key reference identifier")
|
||||
),
|
||||
request_body = UpdateKeyRequest,
|
||||
responses(
|
||||
(status = 200, description = "Key updated successfully", body = inline(ApiResponse<KeyResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Key not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn update_key(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(key_ref): Path<String>,
|
||||
Json(request): Json<UpdateKeyRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Verify key exists
|
||||
let existing = KeyRepository::find_by_ref(&state.db, &key_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
|
||||
|
||||
// Handle value update with encryption
|
||||
let (value, encrypted, encryption_key_hash) = if let Some(new_value) = request.value {
|
||||
let should_encrypt = request.encrypted.unwrap_or(existing.encrypted);
|
||||
|
||||
if should_encrypt {
|
||||
let encryption_key =
|
||||
state
|
||||
.config
|
||||
.security
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
ApiError::BadRequest(
|
||||
"Cannot encrypt: encryption key not configured on server".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let encrypted_value = attune_common::crypto::encrypt(&new_value, encryption_key)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to encrypt key value: {}", e);
|
||||
ApiError::InternalServerError(format!("Failed to encrypt value: {}", e))
|
||||
})?;
|
||||
|
||||
let key_hash = attune_common::crypto::hash_encryption_key(encryption_key);
|
||||
|
||||
(Some(encrypted_value), Some(should_encrypt), Some(key_hash))
|
||||
} else {
|
||||
(Some(new_value), Some(false), None)
|
||||
}
|
||||
} else {
|
||||
// No value update, but might be changing encryption status
|
||||
(None, request.encrypted, None)
|
||||
};
|
||||
|
||||
// Create update input
|
||||
let update_input = UpdateKeyInput {
|
||||
name: request.name,
|
||||
value,
|
||||
encrypted,
|
||||
encryption_key_hash,
|
||||
};
|
||||
|
||||
let mut updated_key = KeyRepository::update(&state.db, existing.id, update_input).await?;
|
||||
|
||||
// Return decrypted value in response
|
||||
if updated_key.encrypted {
|
||||
let encryption_key = state
|
||||
.config
|
||||
.security
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
ApiError::InternalServerError("Encryption key not configured on server".to_string())
|
||||
})?;
|
||||
|
||||
updated_key.value = attune_common::crypto::decrypt(&updated_key.value, encryption_key)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to decrypt updated key '{}': {}", key_ref, e);
|
||||
ApiError::InternalServerError(format!("Failed to decrypt value: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(KeyResponse::from(updated_key), "Key updated successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Delete a key/secret
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/keys/{ref}",
|
||||
tag = "secrets",
|
||||
params(
|
||||
("ref" = String, Path, description = "Key reference identifier")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Key deleted successfully", body = SuccessResponse),
|
||||
(status = 404, description = "Key not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_key(
|
||||
_user: RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(key_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify key exists
|
||||
let key = KeyRepository::find_by_ref(&state.db, &key_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
|
||||
|
||||
// Delete the key
|
||||
let deleted = KeyRepository::delete(&state.db, key.id).await?;
|
||||
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!("Key '{}' not found", key_ref)));
|
||||
}
|
||||
|
||||
let response = SuccessResponse::new("Key deleted successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Register key/secret routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/keys", get(list_keys).post(create_key))
|
||||
.route(
|
||||
"/keys/{ref}",
|
||||
get(get_key).put(update_key).delete(delete_key),
|
||||
)
|
||||
}
|
||||
27
crates/api/src/routes/mod.rs
Normal file
27
crates/api/src/routes/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
//! API route modules
|
||||
|
||||
pub mod actions;
|
||||
pub mod auth;
|
||||
pub mod events;
|
||||
pub mod executions;
|
||||
pub mod health;
|
||||
pub mod inquiries;
|
||||
pub mod keys;
|
||||
pub mod packs;
|
||||
pub mod rules;
|
||||
pub mod triggers;
|
||||
pub mod webhooks;
|
||||
pub mod workflows;
|
||||
|
||||
pub use actions::routes as action_routes;
|
||||
pub use auth::routes as auth_routes;
|
||||
pub use events::routes as event_routes;
|
||||
pub use executions::routes as execution_routes;
|
||||
pub use health::routes as health_routes;
|
||||
pub use inquiries::routes as inquiry_routes;
|
||||
pub use keys::routes as key_routes;
|
||||
pub use packs::routes as pack_routes;
|
||||
pub use rules::routes as rule_routes;
|
||||
pub use triggers::routes as trigger_routes;
|
||||
pub use webhooks::routes as webhook_routes;
|
||||
pub use workflows::routes as workflow_routes;
|
||||
1243
crates/api/src/routes/packs.rs
Normal file
1243
crates/api/src/routes/packs.rs
Normal file
File diff suppressed because it is too large
Load Diff
660
crates/api/src/routes/rules.rs
Normal file
660
crates/api/src/routes/rules.rs
Normal file
@@ -0,0 +1,660 @@
|
||||
//! Rule management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::mq::{
|
||||
MessageEnvelope, MessageType, RuleCreatedPayload, RuleDisabledPayload, RuleEnabledPayload,
|
||||
};
|
||||
use attune_common::repositories::{
|
||||
action::ActionRepository,
|
||||
pack::PackRepository,
|
||||
rule::{CreateRuleInput, RuleRepository, UpdateRuleInput},
|
||||
trigger::TriggerRepository,
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
validation::{validate_action_params, validate_trigger_params},
|
||||
};
|
||||
|
||||
/// List all rules with pagination
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/rules",
|
||||
tag = "rules",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of rules", body = PaginatedResponse<RuleSummary>),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_rules(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all rules
|
||||
let rules = RuleRepository::list(&state.db).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List enabled rules
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/rules/enabled",
|
||||
tag = "rules",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of enabled rules", body = PaginatedResponse<RuleSummary>),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_enabled_rules(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enabled rules
|
||||
let rules = RuleRepository::find_enabled(&state.db).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List rules by pack reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/packs/{pack_ref}/rules",
|
||||
tag = "rules",
|
||||
params(
|
||||
("pack_ref" = String, Path, description = "Pack reference"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of rules in pack", body = PaginatedResponse<RuleSummary>),
|
||||
(status = 404, description = "Pack not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_rules_by_pack(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(pack_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify pack exists
|
||||
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get rules for this pack
|
||||
let rules = RuleRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List rules by action reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/actions/{action_ref}/rules",
|
||||
tag = "rules",
|
||||
params(
|
||||
("action_ref" = String, Path, description = "Action reference"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of rules using this action", body = PaginatedResponse<RuleSummary>),
|
||||
(status = 404, description = "Action not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_rules_by_action(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(action_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify action exists
|
||||
let action = ActionRepository::find_by_ref(&state.db, &action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
|
||||
|
||||
// Get rules for this action
|
||||
let rules = RuleRepository::find_by_action(&state.db, action.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List rules by trigger reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/triggers/{trigger_ref}/rules",
|
||||
tag = "rules",
|
||||
params(
|
||||
("trigger_ref" = String, Path, description = "Trigger reference"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of rules using this trigger", body = PaginatedResponse<RuleSummary>),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_rules_by_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify trigger exists
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Get rules for this trigger
|
||||
let rules = RuleRepository::find_by_trigger(&state.db, trigger.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single rule by reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/rules/{ref}",
|
||||
tag = "rules",
|
||||
params(
|
||||
("ref" = String, Path, description = "Rule reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Rule details", body = ApiResponse<RuleResponse>),
|
||||
(status = 404, description = "Rule not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn get_rule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(rule_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let rule = RuleRepository::find_by_ref(&state.db, &rule_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Rule '{}' not found", rule_ref)))?;
|
||||
|
||||
let response = ApiResponse::new(RuleResponse::from(rule));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create a new rule
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/rules",
|
||||
tag = "rules",
|
||||
request_body = CreateRuleRequest,
|
||||
responses(
|
||||
(status = 201, description = "Rule created successfully", body = ApiResponse<RuleResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 404, description = "Pack, action, or trigger not found"),
|
||||
(status = 409, description = "Rule with same ref already exists"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_rule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(request): Json<CreateRuleRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if rule with same ref already exists
|
||||
if let Some(_) = RuleRepository::find_by_ref(&state.db, &request.r#ref).await? {
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Rule with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
// Verify pack exists and get its ID
|
||||
let pack = PackRepository::find_by_ref(&state.db, &request.pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", request.pack_ref)))?;
|
||||
|
||||
// Verify action exists and get its ID
|
||||
let action = ActionRepository::find_by_ref(&state.db, &request.action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", request.action_ref)))?;
|
||||
|
||||
// Verify trigger exists and get its ID
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &request.trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Trigger '{}' not found", request.trigger_ref))
|
||||
})?;
|
||||
|
||||
// Validate trigger parameters against schema
|
||||
validate_trigger_params(&trigger, &request.trigger_params)?;
|
||||
|
||||
// Validate action parameters against schema
|
||||
validate_action_params(&action, &request.action_params)?;
|
||||
|
||||
// Create rule input
|
||||
let rule_input = CreateRuleInput {
|
||||
r#ref: request.r#ref,
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
action: action.id,
|
||||
action_ref: action.r#ref.clone(),
|
||||
trigger: trigger.id,
|
||||
trigger_ref: trigger.r#ref.clone(),
|
||||
conditions: request.conditions,
|
||||
action_params: request.action_params,
|
||||
trigger_params: request.trigger_params,
|
||||
enabled: request.enabled,
|
||||
is_adhoc: true, // Rules created via API are ad-hoc (not from pack installation)
|
||||
};
|
||||
|
||||
let rule = RuleRepository::create(&state.db, rule_input).await?;
|
||||
|
||||
// Publish RuleCreated message to notify sensor service
|
||||
if let Some(ref publisher) = state.publisher {
|
||||
let payload = RuleCreatedPayload {
|
||||
rule_id: rule.id,
|
||||
rule_ref: rule.r#ref.clone(),
|
||||
trigger_id: Some(rule.trigger),
|
||||
trigger_ref: rule.trigger_ref.clone(),
|
||||
action_id: Some(rule.action),
|
||||
action_ref: rule.action_ref.clone(),
|
||||
trigger_params: Some(rule.trigger_params.clone()),
|
||||
enabled: rule.enabled,
|
||||
};
|
||||
|
||||
let envelope =
|
||||
MessageEnvelope::new(MessageType::RuleCreated, payload).with_source("api-service");
|
||||
|
||||
if let Err(e) = publisher.publish_envelope(&envelope).await {
|
||||
warn!(
|
||||
"Failed to publish RuleCreated message for rule {}: {}",
|
||||
rule.r#ref, e
|
||||
);
|
||||
} else {
|
||||
info!("Published RuleCreated message for rule {}", rule.r#ref);
|
||||
}
|
||||
}
|
||||
|
||||
let response = ApiResponse::with_message(RuleResponse::from(rule), "Rule created successfully");
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Update an existing rule
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/rules/{ref}",
|
||||
tag = "rules",
|
||||
params(
|
||||
("ref" = String, Path, description = "Rule reference")
|
||||
),
|
||||
request_body = UpdateRuleRequest,
|
||||
responses(
|
||||
(status = 200, description = "Rule updated successfully", body = ApiResponse<RuleResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 404, description = "Rule not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn update_rule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(rule_ref): Path<String>,
|
||||
Json(request): Json<UpdateRuleRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if rule exists
|
||||
let existing_rule = RuleRepository::find_by_ref(&state.db, &rule_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Rule '{}' not found", rule_ref)))?;
|
||||
|
||||
// If action parameters are being updated, validate against the action's schema
|
||||
if let Some(ref action_params) = request.action_params {
|
||||
let action = ActionRepository::find_by_ref(&state.db, &existing_rule.action_ref)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Action '{}' not found", existing_rule.action_ref))
|
||||
})?;
|
||||
validate_action_params(&action, action_params)?;
|
||||
}
|
||||
|
||||
// If trigger parameters are being updated, validate against the trigger's schema
|
||||
if let Some(ref trigger_params) = request.trigger_params {
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &existing_rule.trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Trigger '{}' not found", existing_rule.trigger_ref))
|
||||
})?;
|
||||
validate_trigger_params(&trigger, trigger_params)?;
|
||||
}
|
||||
|
||||
// Track if trigger params changed
|
||||
let trigger_params_changed = request.trigger_params.is_some()
|
||||
&& request.trigger_params != Some(existing_rule.trigger_params.clone());
|
||||
|
||||
// Create update input
|
||||
let update_input = UpdateRuleInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
conditions: request.conditions,
|
||||
action_params: request.action_params,
|
||||
trigger_params: request.trigger_params,
|
||||
enabled: request.enabled,
|
||||
};
|
||||
|
||||
let rule = RuleRepository::update(&state.db, existing_rule.id, update_input).await?;
|
||||
|
||||
// If the rule is enabled and trigger params changed, publish RuleEnabled message
|
||||
// to notify sensors to restart with new parameters
|
||||
if rule.enabled && trigger_params_changed {
|
||||
if let Some(ref publisher) = state.publisher {
|
||||
let payload = RuleEnabledPayload {
|
||||
rule_id: rule.id,
|
||||
rule_ref: rule.r#ref.clone(),
|
||||
trigger_ref: rule.trigger_ref.clone(),
|
||||
trigger_params: Some(rule.trigger_params.clone()),
|
||||
};
|
||||
|
||||
let envelope =
|
||||
MessageEnvelope::new(MessageType::RuleEnabled, payload).with_source("api-service");
|
||||
|
||||
if let Err(e) = publisher.publish_envelope(&envelope).await {
|
||||
warn!(
|
||||
"Failed to publish RuleEnabled message for updated rule {}: {}",
|
||||
rule.r#ref, e
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"Published RuleEnabled message for updated rule {} (trigger params changed)",
|
||||
rule.r#ref
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = ApiResponse::with_message(RuleResponse::from(rule), "Rule updated successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Delete a rule
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/rules/{ref}",
|
||||
tag = "rules",
|
||||
params(
|
||||
("ref" = String, Path, description = "Rule reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Rule deleted successfully", body = SuccessResponse),
|
||||
(status = 404, description = "Rule not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn delete_rule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(rule_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if rule exists
|
||||
let rule = RuleRepository::find_by_ref(&state.db, &rule_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Rule '{}' not found", rule_ref)))?;
|
||||
|
||||
// Delete the rule
|
||||
let deleted = RuleRepository::delete(&state.db, rule.id).await?;
|
||||
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!("Rule '{}' not found", rule_ref)));
|
||||
}
|
||||
|
||||
let response = SuccessResponse::new(format!("Rule '{}' deleted successfully", rule_ref));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Enable a rule
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/rules/{ref}/enable",
|
||||
tag = "rules",
|
||||
params(
|
||||
("ref" = String, Path, description = "Rule reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Rule enabled successfully", body = ApiResponse<RuleResponse>),
|
||||
(status = 404, description = "Rule not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn enable_rule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(rule_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if rule exists
|
||||
let existing_rule = RuleRepository::find_by_ref(&state.db, &rule_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Rule '{}' not found", rule_ref)))?;
|
||||
|
||||
// Update rule to enabled
|
||||
let update_input = UpdateRuleInput {
|
||||
label: None,
|
||||
description: None,
|
||||
conditions: None,
|
||||
action_params: None,
|
||||
trigger_params: None,
|
||||
enabled: Some(true),
|
||||
};
|
||||
|
||||
let rule = RuleRepository::update(&state.db, existing_rule.id, update_input).await?;
|
||||
|
||||
// Publish RuleEnabled message to notify sensor service
|
||||
if let Some(ref publisher) = state.publisher {
|
||||
let payload = RuleEnabledPayload {
|
||||
rule_id: rule.id,
|
||||
rule_ref: rule.r#ref.clone(),
|
||||
trigger_ref: rule.trigger_ref.clone(),
|
||||
trigger_params: Some(rule.trigger_params.clone()),
|
||||
};
|
||||
|
||||
let envelope =
|
||||
MessageEnvelope::new(MessageType::RuleEnabled, payload).with_source("api-service");
|
||||
|
||||
if let Err(e) = publisher.publish_envelope(&envelope).await {
|
||||
warn!(
|
||||
"Failed to publish RuleEnabled message for rule {}: {}",
|
||||
rule.r#ref, e
|
||||
);
|
||||
} else {
|
||||
info!("Published RuleEnabled message for rule {}", rule.r#ref);
|
||||
}
|
||||
}
|
||||
|
||||
let response = ApiResponse::with_message(RuleResponse::from(rule), "Rule enabled successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Disable a rule
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/rules/{ref}/disable",
|
||||
tag = "rules",
|
||||
params(
|
||||
("ref" = String, Path, description = "Rule reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Rule disabled successfully", body = ApiResponse<RuleResponse>),
|
||||
(status = 404, description = "Rule not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn disable_rule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(rule_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if rule exists
|
||||
let existing_rule = RuleRepository::find_by_ref(&state.db, &rule_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Rule '{}' not found", rule_ref)))?;
|
||||
|
||||
// Update rule to disabled
|
||||
let update_input = UpdateRuleInput {
|
||||
label: None,
|
||||
description: None,
|
||||
conditions: None,
|
||||
action_params: None,
|
||||
trigger_params: None,
|
||||
enabled: Some(false),
|
||||
};
|
||||
|
||||
let rule = RuleRepository::update(&state.db, existing_rule.id, update_input).await?;
|
||||
|
||||
// Publish RuleDisabled message to notify sensor service
|
||||
if let Some(ref publisher) = state.publisher {
|
||||
let payload = RuleDisabledPayload {
|
||||
rule_id: rule.id,
|
||||
rule_ref: rule.r#ref.clone(),
|
||||
trigger_ref: rule.trigger_ref.clone(),
|
||||
};
|
||||
|
||||
let envelope =
|
||||
MessageEnvelope::new(MessageType::RuleDisabled, payload).with_source("api-service");
|
||||
|
||||
if let Err(e) = publisher.publish_envelope(&envelope).await {
|
||||
warn!(
|
||||
"Failed to publish RuleDisabled message for rule {}: {}",
|
||||
rule.r#ref, e
|
||||
);
|
||||
} else {
|
||||
info!("Published RuleDisabled message for rule {}", rule.r#ref);
|
||||
}
|
||||
}
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(RuleResponse::from(rule), "Rule disabled successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create rule routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/rules", get(list_rules).post(create_rule))
|
||||
.route("/rules/enabled", get(list_enabled_rules))
|
||||
.route(
|
||||
"/rules/{ref}",
|
||||
get(get_rule).put(update_rule).delete(delete_rule),
|
||||
)
|
||||
.route("/rules/{ref}/enable", post(enable_rule))
|
||||
.route("/rules/{ref}/disable", post(disable_rule))
|
||||
.route("/packs/{pack_ref}/rules", get(list_rules_by_pack))
|
||||
.route("/actions/{action_ref}/rules", get(list_rules_by_action))
|
||||
.route("/triggers/{trigger_ref}/rules", get(list_rules_by_trigger))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rule_routes_structure() {
|
||||
// Just verify the router can be constructed
|
||||
let _router = routes();
|
||||
}
|
||||
}
|
||||
893
crates/api/src/routes/triggers.rs
Normal file
893
crates/api/src/routes/triggers.rs
Normal file
@@ -0,0 +1,893 @@
|
||||
//! Trigger and Sensor management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::repositories::{
|
||||
pack::PackRepository,
|
||||
runtime::RuntimeRepository,
|
||||
trigger::{
|
||||
CreateSensorInput, CreateTriggerInput, SensorRepository, TriggerRepository,
|
||||
UpdateSensorInput, UpdateTriggerInput,
|
||||
},
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
trigger::{
|
||||
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary,
|
||||
TriggerResponse, TriggerSummary, UpdateSensorRequest, UpdateTriggerRequest,
|
||||
},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TRIGGER ENDPOINTS
|
||||
// ============================================================================
|
||||
|
||||
/// List all triggers with pagination
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/triggers",
|
||||
tag = "triggers",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of triggers", body = PaginatedResponse<TriggerSummary>),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_triggers(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all triggers
|
||||
let triggers = TriggerRepository::list(&state.db).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = triggers.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(triggers.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
|
||||
.iter()
|
||||
.map(|t| TriggerSummary::from(t.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List enabled triggers
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/triggers/enabled",
|
||||
tag = "triggers",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of enabled triggers", body = PaginatedResponse<TriggerSummary>),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_enabled_triggers(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enabled triggers
|
||||
let triggers = TriggerRepository::find_enabled(&state.db).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = triggers.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(triggers.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
|
||||
.iter()
|
||||
.map(|t| TriggerSummary::from(t.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List triggers by pack reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/packs/{pack_ref}/triggers",
|
||||
tag = "triggers",
|
||||
params(
|
||||
("pack_ref" = String, Path, description = "Pack reference"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of triggers in pack", body = PaginatedResponse<TriggerSummary>),
|
||||
(status = 404, description = "Pack not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_triggers_by_pack(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(pack_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify pack exists
|
||||
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get triggers for this pack
|
||||
let triggers = TriggerRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = triggers.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(triggers.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
|
||||
.iter()
|
||||
.map(|t| TriggerSummary::from(t.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single trigger by reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/triggers/{ref}",
|
||||
tag = "triggers",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Trigger details", body = ApiResponse<TriggerResponse>),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn get_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
let response = ApiResponse::new(TriggerResponse::from(trigger));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create a new trigger
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/triggers",
|
||||
tag = "triggers",
|
||||
request_body = CreateTriggerRequest,
|
||||
responses(
|
||||
(status = 201, description = "Trigger created successfully", body = ApiResponse<TriggerResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 404, description = "Pack not found"),
|
||||
(status = 409, description = "Trigger with same ref already exists"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(request): Json<CreateTriggerRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if trigger with same ref already exists
|
||||
if let Some(_) = TriggerRepository::find_by_ref(&state.db, &request.r#ref).await? {
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Trigger with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
// If pack_ref is provided, verify pack exists and get its ID
|
||||
let (pack_id, pack_ref) = if let Some(ref pack_ref_str) = request.pack_ref {
|
||||
let pack = PackRepository::find_by_ref(&state.db, pack_ref_str)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref_str)))?;
|
||||
(Some(pack.id), Some(pack.r#ref.clone()))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Create trigger input
|
||||
let trigger_input = CreateTriggerInput {
|
||||
r#ref: request.r#ref,
|
||||
pack: pack_id,
|
||||
pack_ref,
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
enabled: request.enabled,
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
is_adhoc: true, // Triggers created via API are ad-hoc (not from pack installation)
|
||||
};
|
||||
|
||||
let trigger = TriggerRepository::create(&state.db, trigger_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
TriggerResponse::from(trigger),
|
||||
"Trigger created successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Update an existing trigger
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/triggers/{ref}",
|
||||
tag = "triggers",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference")
|
||||
),
|
||||
request_body = UpdateTriggerRequest,
|
||||
responses(
|
||||
(status = 200, description = "Trigger updated successfully", body = ApiResponse<TriggerResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn update_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
Json(request): Json<UpdateTriggerRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if trigger exists
|
||||
let existing_trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Create update input
|
||||
let update_input = UpdateTriggerInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
enabled: request.enabled,
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
};
|
||||
|
||||
let trigger = TriggerRepository::update(&state.db, existing_trigger.id, update_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
TriggerResponse::from(trigger),
|
||||
"Trigger updated successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Delete a trigger
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/triggers/{ref}",
|
||||
tag = "triggers",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Trigger deleted successfully", body = SuccessResponse),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn delete_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if trigger exists
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Delete the trigger
|
||||
let deleted = TriggerRepository::delete(&state.db, trigger.id).await?;
|
||||
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!(
|
||||
"Trigger '{}' not found",
|
||||
trigger_ref
|
||||
)));
|
||||
}
|
||||
|
||||
let response = SuccessResponse::new(format!("Trigger '{}' deleted successfully", trigger_ref));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Enable a trigger
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/triggers/{ref}/enable",
|
||||
tag = "triggers",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Trigger enabled successfully", body = ApiResponse<TriggerResponse>),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn enable_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if trigger exists
|
||||
let existing_trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Update trigger to enabled
|
||||
let update_input = UpdateTriggerInput {
|
||||
label: None,
|
||||
description: None,
|
||||
enabled: Some(true),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
};
|
||||
|
||||
let trigger = TriggerRepository::update(&state.db, existing_trigger.id, update_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
TriggerResponse::from(trigger),
|
||||
"Trigger enabled successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Disable a trigger
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/triggers/{ref}/disable",
|
||||
tag = "triggers",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Trigger disabled successfully", body = ApiResponse<TriggerResponse>),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn disable_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if trigger exists
|
||||
let existing_trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Update trigger to disabled
|
||||
let update_input = UpdateTriggerInput {
|
||||
label: None,
|
||||
description: None,
|
||||
enabled: Some(false),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
};
|
||||
|
||||
let trigger = TriggerRepository::update(&state.db, existing_trigger.id, update_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
TriggerResponse::from(trigger),
|
||||
"Trigger disabled successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SENSOR ENDPOINTS
|
||||
// ============================================================================
|
||||
|
||||
/// List all sensors with pagination
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/sensors",
|
||||
tag = "sensors",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of sensors", body = PaginatedResponse<SensorSummary>),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_sensors(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all sensors
|
||||
let sensors = SensorRepository::list(&state.db).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List enabled sensors
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/sensors/enabled",
|
||||
tag = "sensors",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of enabled sensors", body = PaginatedResponse<SensorSummary>),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_enabled_sensors(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enabled sensors
|
||||
let sensors = SensorRepository::find_enabled(&state.db).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List sensors by pack reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/packs/{pack_ref}/sensors",
|
||||
tag = "sensors",
|
||||
params(
|
||||
("pack_ref" = String, Path, description = "Pack reference"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of sensors in pack", body = PaginatedResponse<SensorSummary>),
|
||||
(status = 404, description = "Pack not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_sensors_by_pack(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(pack_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify pack exists
|
||||
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get sensors for this pack
|
||||
let sensors = SensorRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List sensors by trigger reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/triggers/{trigger_ref}/sensors",
|
||||
tag = "sensors",
|
||||
params(
|
||||
("trigger_ref" = String, Path, description = "Trigger reference"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of sensors for trigger", body = PaginatedResponse<SensorSummary>),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn list_sensors_by_trigger(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify trigger exists
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Get sensors for this trigger
|
||||
let sensors = SensorRepository::find_by_trigger(&state.db, trigger.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single sensor by reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/sensors/{ref}",
|
||||
tag = "sensors",
|
||||
params(
|
||||
("ref" = String, Path, description = "Sensor reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Sensor details", body = ApiResponse<SensorResponse>),
|
||||
(status = 404, description = "Sensor not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn get_sensor(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(sensor_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let sensor = SensorRepository::find_by_ref(&state.db, &sensor_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Sensor '{}' not found", sensor_ref)))?;
|
||||
|
||||
let response = ApiResponse::new(SensorResponse::from(sensor));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create a new sensor
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/sensors",
|
||||
tag = "sensors",
|
||||
request_body = CreateSensorRequest,
|
||||
responses(
|
||||
(status = 201, description = "Sensor created successfully", body = ApiResponse<SensorResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 404, description = "Pack, runtime, or trigger not found"),
|
||||
(status = 409, description = "Sensor with same ref already exists"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_sensor(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(request): Json<CreateSensorRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if sensor with same ref already exists
|
||||
if let Some(_) = SensorRepository::find_by_ref(&state.db, &request.r#ref).await? {
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Sensor with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
// Verify pack exists and get its ID
|
||||
let pack = PackRepository::find_by_ref(&state.db, &request.pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", request.pack_ref)))?;
|
||||
|
||||
// Verify runtime exists and get its ID
|
||||
let runtime = RuntimeRepository::find_by_ref(&state.db, &request.runtime_ref)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Runtime '{}' not found", request.runtime_ref))
|
||||
})?;
|
||||
|
||||
// Verify trigger exists and get its ID
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &request.trigger_ref)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Trigger '{}' not found", request.trigger_ref))
|
||||
})?;
|
||||
|
||||
// Create sensor input
|
||||
let sensor_input = CreateSensorInput {
|
||||
r#ref: request.r#ref,
|
||||
pack: Some(pack.id),
|
||||
pack_ref: Some(pack.r#ref.clone()),
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
entrypoint: request.entrypoint,
|
||||
runtime: runtime.id,
|
||||
runtime_ref: runtime.r#ref.clone(),
|
||||
trigger: trigger.id,
|
||||
trigger_ref: trigger.r#ref.clone(),
|
||||
enabled: request.enabled,
|
||||
param_schema: request.param_schema,
|
||||
config: request.config,
|
||||
};
|
||||
|
||||
let sensor = SensorRepository::create(&state.db, sensor_input).await?;
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(SensorResponse::from(sensor), "Sensor created successfully");
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Update an existing sensor
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/sensors/{ref}",
|
||||
tag = "sensors",
|
||||
params(
|
||||
("ref" = String, Path, description = "Sensor reference")
|
||||
),
|
||||
request_body = UpdateSensorRequest,
|
||||
responses(
|
||||
(status = 200, description = "Sensor updated successfully", body = ApiResponse<SensorResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 404, description = "Sensor not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn update_sensor(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(sensor_ref): Path<String>,
|
||||
Json(request): Json<UpdateSensorRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if sensor exists
|
||||
let existing_sensor = SensorRepository::find_by_ref(&state.db, &sensor_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Sensor '{}' not found", sensor_ref)))?;
|
||||
|
||||
// Create update input
|
||||
let update_input = UpdateSensorInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
entrypoint: request.entrypoint,
|
||||
enabled: request.enabled,
|
||||
param_schema: request.param_schema,
|
||||
};
|
||||
|
||||
let sensor = SensorRepository::update(&state.db, existing_sensor.id, update_input).await?;
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(SensorResponse::from(sensor), "Sensor updated successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Delete a sensor
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/sensors/{ref}",
|
||||
tag = "sensors",
|
||||
params(
|
||||
("ref" = String, Path, description = "Sensor reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Sensor deleted successfully", body = SuccessResponse),
|
||||
(status = 404, description = "Sensor not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn delete_sensor(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(sensor_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if sensor exists
|
||||
let sensor = SensorRepository::find_by_ref(&state.db, &sensor_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Sensor '{}' not found", sensor_ref)))?;
|
||||
|
||||
// Delete the sensor
|
||||
let deleted = SensorRepository::delete(&state.db, sensor.id).await?;
|
||||
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!(
|
||||
"Sensor '{}' not found",
|
||||
sensor_ref
|
||||
)));
|
||||
}
|
||||
|
||||
let response = SuccessResponse::new(format!("Sensor '{}' deleted successfully", sensor_ref));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Enable a sensor
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/sensors/{ref}/enable",
|
||||
tag = "sensors",
|
||||
params(
|
||||
("ref" = String, Path, description = "Sensor reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Sensor enabled successfully", body = ApiResponse<SensorResponse>),
|
||||
(status = 404, description = "Sensor not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn enable_sensor(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(sensor_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if sensor exists
|
||||
let existing_sensor = SensorRepository::find_by_ref(&state.db, &sensor_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Sensor '{}' not found", sensor_ref)))?;
|
||||
|
||||
// Update sensor to enabled
|
||||
let update_input = UpdateSensorInput {
|
||||
label: None,
|
||||
description: None,
|
||||
entrypoint: None,
|
||||
enabled: Some(true),
|
||||
param_schema: None,
|
||||
};
|
||||
|
||||
let sensor = SensorRepository::update(&state.db, existing_sensor.id, update_input).await?;
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(SensorResponse::from(sensor), "Sensor enabled successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Disable a sensor
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/sensors/{ref}/disable",
|
||||
tag = "sensors",
|
||||
params(
|
||||
("ref" = String, Path, description = "Sensor reference")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Sensor disabled successfully", body = ApiResponse<SensorResponse>),
|
||||
(status = 404, description = "Sensor not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn disable_sensor(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(sensor_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if sensor exists
|
||||
let existing_sensor = SensorRepository::find_by_ref(&state.db, &sensor_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Sensor '{}' not found", sensor_ref)))?;
|
||||
|
||||
// Update sensor to disabled
|
||||
let update_input = UpdateSensorInput {
|
||||
label: None,
|
||||
description: None,
|
||||
entrypoint: None,
|
||||
enabled: Some(false),
|
||||
param_schema: None,
|
||||
};
|
||||
|
||||
let sensor = SensorRepository::update(&state.db, existing_sensor.id, update_input).await?;
|
||||
|
||||
let response =
|
||||
ApiResponse::with_message(SensorResponse::from(sensor), "Sensor disabled successfully");
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create trigger and sensor routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
// Trigger routes
|
||||
.route("/triggers", get(list_triggers).post(create_trigger))
|
||||
.route("/triggers/enabled", get(list_enabled_triggers))
|
||||
.route(
|
||||
"/triggers/{ref}",
|
||||
get(get_trigger).put(update_trigger).delete(delete_trigger),
|
||||
)
|
||||
.route("/triggers/{ref}/enable", post(enable_trigger))
|
||||
.route("/triggers/{ref}/disable", post(disable_trigger))
|
||||
.route("/packs/{pack_ref}/triggers", get(list_triggers_by_pack))
|
||||
// Sensor routes
|
||||
.route("/sensors", get(list_sensors).post(create_sensor))
|
||||
.route("/sensors/enabled", get(list_enabled_sensors))
|
||||
.route(
|
||||
"/sensors/{ref}",
|
||||
get(get_sensor).put(update_sensor).delete(delete_sensor),
|
||||
)
|
||||
.route("/sensors/{ref}/enable", post(enable_sensor))
|
||||
.route("/sensors/{ref}/disable", post(disable_sensor))
|
||||
.route("/packs/{pack_ref}/sensors", get(list_sensors_by_pack))
|
||||
.route(
|
||||
"/triggers/{trigger_ref}/sensors",
|
||||
get(list_sensors_by_trigger),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_trigger_sensor_routes_structure() {
|
||||
// Just verify the router can be constructed
|
||||
let _router = routes();
|
||||
}
|
||||
}
|
||||
808
crates/api/src/routes/webhooks.rs
Normal file
808
crates/api/src/routes/webhooks.rs
Normal file
@@ -0,0 +1,808 @@
|
||||
//! Webhook management and receiver API routes
|
||||
|
||||
use axum::{
|
||||
body::Bytes,
|
||||
extract::{Path, State},
|
||||
http::HeaderMap,
|
||||
response::IntoResponse,
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use attune_common::{
|
||||
mq::{EventCreatedPayload, MessageEnvelope, MessageType},
|
||||
repositories::{
|
||||
event::{CreateEventInput, EventRepository},
|
||||
trigger::{TriggerRepository, WebhookEventLogInput},
|
||||
Create, FindById, FindByRef,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
trigger::TriggerResponse,
|
||||
webhook::{WebhookReceiverRequest, WebhookReceiverResponse},
|
||||
ApiResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
webhook_security,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// WEBHOOK CONFIG HELPERS
|
||||
// ============================================================================
|
||||
|
||||
/// Helper to extract boolean value from webhook_config JSON using path notation
|
||||
fn get_webhook_config_bool(
|
||||
trigger: &attune_common::models::trigger::Trigger,
|
||||
path: &str,
|
||||
default: bool,
|
||||
) -> bool {
|
||||
let config = match &trigger.webhook_config {
|
||||
Some(c) => c,
|
||||
None => return default,
|
||||
};
|
||||
|
||||
let parts: Vec<&str> = path.split('/').collect();
|
||||
let mut current = config;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if i == parts.len() - 1 {
|
||||
// Last part - extract value
|
||||
return current
|
||||
.get(part)
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(default);
|
||||
} else {
|
||||
// Intermediate part - navigate deeper
|
||||
current = match current.get(part) {
|
||||
Some(v) => v,
|
||||
None => return default,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
default
|
||||
}
|
||||
|
||||
/// Helper to extract string value from webhook_config JSON using path notation
|
||||
fn get_webhook_config_str(
|
||||
trigger: &attune_common::models::trigger::Trigger,
|
||||
path: &str,
|
||||
) -> Option<String> {
|
||||
let config = trigger.webhook_config.as_ref()?;
|
||||
|
||||
let parts: Vec<&str> = path.split('/').collect();
|
||||
let mut current = config;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if i == parts.len() - 1 {
|
||||
// Last part - extract value
|
||||
return current
|
||||
.get(part)
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
} else {
|
||||
// Intermediate part - navigate deeper
|
||||
current = current.get(part)?;
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Helper to extract i64 value from webhook_config JSON using path notation
|
||||
fn get_webhook_config_i64(
|
||||
trigger: &attune_common::models::trigger::Trigger,
|
||||
path: &str,
|
||||
) -> Option<i64> {
|
||||
let config = trigger.webhook_config.as_ref()?;
|
||||
|
||||
let parts: Vec<&str> = path.split('/').collect();
|
||||
let mut current = config;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if i == parts.len() - 1 {
|
||||
// Last part - extract value
|
||||
return current.get(part).and_then(|v| v.as_i64());
|
||||
} else {
|
||||
// Intermediate part - navigate deeper
|
||||
current = current.get(part)?;
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Helper to extract array of strings from webhook_config JSON using path notation
|
||||
fn get_webhook_config_array(
|
||||
trigger: &attune_common::models::trigger::Trigger,
|
||||
path: &str,
|
||||
) -> Option<Vec<String>> {
|
||||
let config = trigger.webhook_config.as_ref()?;
|
||||
|
||||
let parts: Vec<&str> = path.split('/').collect();
|
||||
let mut current = config;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if i == parts.len() - 1 {
|
||||
// Last part - extract array
|
||||
return current.get(part).and_then(|v| {
|
||||
v.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|item| item.as_str().map(|s| s.to_string()))
|
||||
.collect()
|
||||
})
|
||||
});
|
||||
} else {
|
||||
// Intermediate part - navigate deeper
|
||||
current = current.get(part)?;
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WEBHOOK MANAGEMENT ENDPOINTS
|
||||
// ============================================================================
|
||||
|
||||
/// Enable webhooks for a trigger
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/triggers/{ref}/webhooks/enable",
|
||||
tag = "webhooks",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference (pack.name)")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Webhooks enabled", body = TriggerResponse),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(
|
||||
("jwt" = [])
|
||||
)
|
||||
)]
|
||||
pub async fn enable_webhook(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// First, find the trigger by ref to get its ID
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Enable webhooks for this trigger
|
||||
let _webhook_info = TriggerRepository::enable_webhook(&state.db, trigger.id)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?;
|
||||
|
||||
// Fetch the updated trigger to return
|
||||
let updated_trigger = TriggerRepository::find_by_id(&state.db, trigger.id)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound("Trigger not found after update".to_string()))?;
|
||||
|
||||
let response = TriggerResponse::from(updated_trigger);
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Disable webhooks for a trigger
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/triggers/{ref}/webhooks/disable",
|
||||
tag = "webhooks",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference (pack.name)")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Webhooks disabled", body = TriggerResponse),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(
|
||||
("jwt" = [])
|
||||
)
|
||||
)]
|
||||
pub async fn disable_webhook(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// First, find the trigger by ref to get its ID
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Disable webhooks for this trigger
|
||||
TriggerRepository::disable_webhook(&state.db, trigger.id)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?;
|
||||
|
||||
// Fetch the updated trigger to return
|
||||
let updated_trigger = TriggerRepository::find_by_id(&state.db, trigger.id)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound("Trigger not found after update".to_string()))?;
|
||||
|
||||
let response = TriggerResponse::from(updated_trigger);
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Regenerate webhook key for a trigger
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/triggers/{ref}/webhooks/regenerate",
|
||||
tag = "webhooks",
|
||||
params(
|
||||
("ref" = String, Path, description = "Trigger reference (pack.name)")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Webhook key regenerated", body = TriggerResponse),
|
||||
(status = 400, description = "Webhooks not enabled for this trigger"),
|
||||
(status = 404, description = "Trigger not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(
|
||||
("jwt" = [])
|
||||
)
|
||||
)]
|
||||
pub async fn regenerate_webhook_key(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// First, find the trigger by ref to get its ID
|
||||
let trigger = TriggerRepository::find_by_ref(&state.db, &trigger_ref)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Check if webhooks are enabled
|
||||
if !trigger.webhook_enabled {
|
||||
return Err(ApiError::BadRequest(
|
||||
"Webhooks are not enabled for this trigger. Enable webhooks first.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Regenerate the webhook key
|
||||
let _regenerate_result = TriggerRepository::regenerate_webhook_key(&state.db, trigger.id)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?;
|
||||
|
||||
// Fetch the updated trigger to return
|
||||
let updated_trigger = TriggerRepository::find_by_id(&state.db, trigger.id)
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound("Trigger not found after update".to_string()))?;
|
||||
|
||||
let response = TriggerResponse::from(updated_trigger);
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WEBHOOK RECEIVER ENDPOINT
|
||||
// ============================================================================
|
||||
|
||||
/// Webhook receiver endpoint - receives webhook events and creates events
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/webhooks/{webhook_key}",
|
||||
tag = "webhooks",
|
||||
params(
|
||||
("webhook_key" = String, Path, description = "Webhook key")
|
||||
),
|
||||
request_body = WebhookReceiverRequest,
|
||||
responses(
|
||||
(status = 200, description = "Webhook received and event created", body = WebhookReceiverResponse),
|
||||
(status = 404, description = "Invalid webhook key"),
|
||||
(status = 429, description = "Rate limit exceeded"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn receive_webhook(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(webhook_key): Path<String>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Extract metadata from headers
|
||||
let source_ip = headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.or_else(|| headers.get("x-real-ip").and_then(|v| v.to_str().ok()))
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let signature = headers
|
||||
.get("x-webhook-signature")
|
||||
.or_else(|| headers.get("x-hub-signature-256"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Parse JSON payload
|
||||
let payload: WebhookReceiverRequest = serde_json::from_slice(&body)
|
||||
.map_err(|e| ApiError::BadRequest(format!("Invalid JSON payload: {}", e)))?;
|
||||
|
||||
let payload_size_bytes = body.len() as i32;
|
||||
|
||||
// Look up trigger by webhook key
|
||||
let trigger = match TriggerRepository::find_by_webhook_key(&state.db, &webhook_key).await {
|
||||
Ok(Some(t)) => t,
|
||||
Ok(None) => {
|
||||
// Log failed attempt
|
||||
let _ = log_webhook_failure(
|
||||
&state,
|
||||
webhook_key.clone(),
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
404,
|
||||
"Invalid webhook key".to_string(),
|
||||
start_time,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::NotFound("Invalid webhook key".to_string()));
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = log_webhook_failure(
|
||||
&state,
|
||||
webhook_key.clone(),
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
500,
|
||||
e.to_string(),
|
||||
start_time,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::InternalServerError(e.to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
// Verify webhooks are enabled
|
||||
if !trigger.webhook_enabled {
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
400,
|
||||
Some("Webhooks not enabled for this trigger".to_string()),
|
||||
start_time,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::BadRequest(
|
||||
"Webhooks are not enabled for this trigger".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Phase 3: Check payload size limit
|
||||
if let Some(limit_kb) = get_webhook_config_i64(&trigger, "payload_size_limit_kb") {
|
||||
let limit_bytes = limit_kb * 1024;
|
||||
if i64::from(payload_size_bytes) > limit_bytes {
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
413,
|
||||
Some(format!(
|
||||
"Payload too large: {} bytes (limit: {} bytes)",
|
||||
payload_size_bytes, limit_bytes
|
||||
)),
|
||||
start_time,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"Payload too large. Maximum size: {} KB",
|
||||
limit_kb
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Check IP whitelist
|
||||
let ip_whitelist_enabled = get_webhook_config_bool(&trigger, "ip_whitelist/enabled", false);
|
||||
let ip_allowed = if ip_whitelist_enabled {
|
||||
if let Some(ref ip) = source_ip {
|
||||
if let Some(whitelist) = get_webhook_config_array(&trigger, "ip_whitelist/ips") {
|
||||
match webhook_security::check_ip_in_whitelist(ip, &whitelist) {
|
||||
Ok(allowed) => {
|
||||
if !allowed {
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
403,
|
||||
Some("IP address not in whitelist".to_string()),
|
||||
start_time,
|
||||
None,
|
||||
false,
|
||||
Some(false),
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::Forbidden("IP address not allowed".to_string()));
|
||||
}
|
||||
Some(true)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("IP whitelist check error: {}", e);
|
||||
Some(false)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(false)
|
||||
}
|
||||
} else {
|
||||
Some(false)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Phase 3: Check rate limit
|
||||
let rate_limit_enabled = get_webhook_config_bool(&trigger, "rate_limit/enabled", false);
|
||||
if rate_limit_enabled {
|
||||
if let (Some(max_requests), Some(window_seconds)) = (
|
||||
get_webhook_config_i64(&trigger, "rate_limit/requests"),
|
||||
get_webhook_config_i64(&trigger, "rate_limit/window_seconds"),
|
||||
) {
|
||||
// Note: Rate limit checking would need to be implemented with a time-series approach
|
||||
// For now, we skip this check as the repository function was removed
|
||||
let allowed = true; // TODO: Implement proper rate limiting
|
||||
|
||||
if !allowed {
|
||||
{
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
429,
|
||||
Some("Rate limit exceeded".to_string()),
|
||||
start_time,
|
||||
None,
|
||||
true,
|
||||
ip_allowed,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::TooManyRequests(format!(
|
||||
"Rate limit exceeded. Maximum {} requests per {} seconds",
|
||||
max_requests, window_seconds
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Verify HMAC signature
|
||||
let hmac_enabled = get_webhook_config_bool(&trigger, "hmac/enabled", false);
|
||||
let hmac_verified = if hmac_enabled {
|
||||
if let (Some(secret), Some(algorithm)) = (
|
||||
get_webhook_config_str(&trigger, "hmac/secret"),
|
||||
get_webhook_config_str(&trigger, "hmac/algorithm"),
|
||||
) {
|
||||
if let Some(sig) = signature {
|
||||
match webhook_security::verify_hmac_signature(&body, &sig, &secret, &algorithm) {
|
||||
Ok(valid) => {
|
||||
if !valid {
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
401,
|
||||
Some("Invalid HMAC signature".to_string()),
|
||||
start_time,
|
||||
Some(false),
|
||||
false,
|
||||
ip_allowed,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid webhook signature".to_string(),
|
||||
));
|
||||
}
|
||||
Some(true)
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
401,
|
||||
Some(format!("HMAC verification error: {}", e)),
|
||||
start_time,
|
||||
Some(false),
|
||||
false,
|
||||
ip_allowed,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::Unauthorized(format!(
|
||||
"Signature verification failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
401,
|
||||
Some("HMAC signature required but not provided".to_string()),
|
||||
start_time,
|
||||
Some(false),
|
||||
false,
|
||||
ip_allowed,
|
||||
)
|
||||
.await;
|
||||
return Err(ApiError::Unauthorized("Signature required".to_string()));
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build config with webhook context metadata
|
||||
let mut config = serde_json::json!({
|
||||
"source": "webhook",
|
||||
"webhook_key": webhook_key,
|
||||
"received_at": chrono::Utc::now().to_rfc3339(),
|
||||
});
|
||||
|
||||
// Add optional metadata
|
||||
if let Some(headers) = payload.headers {
|
||||
config["headers"] = headers;
|
||||
}
|
||||
if let Some(ref ip) = source_ip {
|
||||
config["source_ip"] = serde_json::Value::String(ip.clone());
|
||||
}
|
||||
if let Some(ref ua) = user_agent {
|
||||
config["user_agent"] = serde_json::Value::String(ua.clone());
|
||||
}
|
||||
let hmac_enabled = get_webhook_config_bool(&trigger, "hmac/enabled", false);
|
||||
if hmac_enabled {
|
||||
config["hmac_verified"] = serde_json::Value::Bool(hmac_verified.unwrap_or(false));
|
||||
}
|
||||
|
||||
// Create event
|
||||
let event_input = CreateEventInput {
|
||||
trigger: Some(trigger.id),
|
||||
trigger_ref: trigger.r#ref.clone(),
|
||||
config: Some(config),
|
||||
payload: Some(payload.payload),
|
||||
source: None,
|
||||
source_ref: Some("webhook".to_string()),
|
||||
rule: None,
|
||||
rule_ref: None,
|
||||
};
|
||||
|
||||
let event = EventRepository::create(&state.db, event_input)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let _ = futures::executor::block_on(log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
None,
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
500,
|
||||
Some(format!("Failed to create event: {}", e)),
|
||||
start_time,
|
||||
hmac_verified,
|
||||
false,
|
||||
ip_allowed,
|
||||
));
|
||||
ApiError::InternalServerError(e.to_string())
|
||||
})?;
|
||||
|
||||
// Publish EventCreated message to message queue if publisher is available
|
||||
tracing::info!(
|
||||
"Webhook event {} created, attempting to publish EventCreated message",
|
||||
event.id
|
||||
);
|
||||
if let Some(ref publisher) = state.publisher {
|
||||
let message_payload = EventCreatedPayload {
|
||||
event_id: event.id,
|
||||
trigger_id: event.trigger,
|
||||
trigger_ref: event.trigger_ref.clone(),
|
||||
sensor_id: event.source,
|
||||
sensor_ref: event.source_ref.clone(),
|
||||
payload: event.payload.clone().unwrap_or(serde_json::json!({})),
|
||||
config: event.config.clone(),
|
||||
};
|
||||
|
||||
let envelope = MessageEnvelope::new(MessageType::EventCreated, message_payload)
|
||||
.with_source("api-webhook-receiver");
|
||||
|
||||
if let Err(e) = publisher.publish_envelope(&envelope).await {
|
||||
tracing::warn!(
|
||||
"Failed to publish EventCreated message for event {}: {}",
|
||||
event.id,
|
||||
e
|
||||
);
|
||||
// Continue even if message publishing fails - event is already recorded
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Published EventCreated message for event {} (trigger: {})",
|
||||
event.id,
|
||||
event.trigger_ref
|
||||
);
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Publisher not available, cannot publish EventCreated message for event {}",
|
||||
event.id
|
||||
);
|
||||
}
|
||||
|
||||
// Log successful webhook
|
||||
let _ = log_webhook_event(
|
||||
&state,
|
||||
&trigger,
|
||||
&webhook_key,
|
||||
Some(event.id),
|
||||
source_ip.clone(),
|
||||
user_agent.clone(),
|
||||
payload_size_bytes,
|
||||
200,
|
||||
None,
|
||||
start_time,
|
||||
hmac_verified,
|
||||
false,
|
||||
ip_allowed,
|
||||
)
|
||||
.await;
|
||||
|
||||
let response = WebhookReceiverResponse {
|
||||
event_id: event.id,
|
||||
trigger_ref: trigger.r#ref.clone(),
|
||||
received_at: event.created,
|
||||
message: "Webhook received successfully".to_string(),
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
// Helper function to log webhook events
|
||||
async fn log_webhook_event(
|
||||
state: &AppState,
|
||||
trigger: &attune_common::models::trigger::Trigger,
|
||||
webhook_key: &str,
|
||||
event_id: Option<i64>,
|
||||
source_ip: Option<String>,
|
||||
user_agent: Option<String>,
|
||||
payload_size_bytes: i32,
|
||||
status_code: i32,
|
||||
error_message: Option<String>,
|
||||
start_time: Instant,
|
||||
hmac_verified: Option<bool>,
|
||||
rate_limited: bool,
|
||||
ip_allowed: Option<bool>,
|
||||
) -> Result<(), attune_common::error::Error> {
|
||||
let processing_time_ms = start_time.elapsed().as_millis() as i32;
|
||||
|
||||
let log_input = WebhookEventLogInput {
|
||||
trigger_id: trigger.id,
|
||||
trigger_ref: trigger.r#ref.clone(),
|
||||
webhook_key: webhook_key.to_string(),
|
||||
event_id,
|
||||
source_ip,
|
||||
user_agent,
|
||||
payload_size_bytes: Some(payload_size_bytes),
|
||||
headers: None, // Could be added if needed
|
||||
status_code,
|
||||
error_message,
|
||||
processing_time_ms: Some(processing_time_ms),
|
||||
hmac_verified,
|
||||
rate_limited,
|
||||
ip_allowed,
|
||||
};
|
||||
|
||||
TriggerRepository::log_webhook_event(&state.db, log_input).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper function to log failures when trigger is not found
|
||||
async fn log_webhook_failure(
|
||||
_state: &AppState,
|
||||
webhook_key: String,
|
||||
source_ip: Option<String>,
|
||||
user_agent: Option<String>,
|
||||
payload_size_bytes: i32,
|
||||
status_code: i32,
|
||||
error_message: String,
|
||||
start_time: Instant,
|
||||
) -> Result<(), attune_common::error::Error> {
|
||||
let processing_time_ms = start_time.elapsed().as_millis() as i32;
|
||||
|
||||
// We can't log to webhook_event_log without a trigger_id, so just log to tracing
|
||||
tracing::warn!(
|
||||
webhook_key = %webhook_key,
|
||||
source_ip = ?source_ip,
|
||||
user_agent = ?user_agent,
|
||||
payload_size_bytes = payload_size_bytes,
|
||||
status_code = status_code,
|
||||
error_message = %error_message,
|
||||
processing_time_ms = processing_time_ms,
|
||||
"Webhook request failed"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ROUTER
|
||||
// ============================================================================
|
||||
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
// Webhook management routes (protected)
|
||||
.route("/triggers/{ref}/webhooks/enable", post(enable_webhook))
|
||||
.route("/triggers/{ref}/webhooks/disable", post(disable_webhook))
|
||||
.route(
|
||||
"/triggers/{ref}/webhooks/regenerate",
|
||||
post(regenerate_webhook_key),
|
||||
)
|
||||
// TODO: Add Phase 3 management endpoints for HMAC, rate limiting, IP whitelist
|
||||
// Webhook receiver route (public - no auth required)
|
||||
.route("/webhooks/{webhook_key}", post(receive_webhook))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_webhook_routes_structure() {
|
||||
let _router = routes();
|
||||
}
|
||||
}
|
||||
365
crates/api/src/routes/workflows.rs
Normal file
365
crates/api/src/routes/workflows.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
//! Workflow management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::repositories::{
|
||||
pack::PackRepository,
|
||||
workflow::{
|
||||
CreateWorkflowDefinitionInput, UpdateWorkflowDefinitionInput, WorkflowDefinitionRepository,
|
||||
},
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
workflow::{
|
||||
CreateWorkflowRequest, UpdateWorkflowRequest, WorkflowResponse, WorkflowSearchParams,
|
||||
WorkflowSummary,
|
||||
},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// List all workflows with pagination and filtering
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/workflows",
|
||||
tag = "workflows",
|
||||
params(PaginationParams, WorkflowSearchParams),
|
||||
responses(
|
||||
(status = 200, description = "List of workflows", body = PaginatedResponse<WorkflowSummary>),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_workflows(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
Query(search_params): Query<WorkflowSearchParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate search params
|
||||
search_params.validate()?;
|
||||
|
||||
// Get workflows based on filters
|
||||
let mut workflows = if let Some(tags_str) = &search_params.tags {
|
||||
// Filter by tags
|
||||
let tags: Vec<&str> = tags_str.split(',').map(|s| s.trim()).collect();
|
||||
let mut results = Vec::new();
|
||||
for tag in tags {
|
||||
let mut tag_results = WorkflowDefinitionRepository::find_by_tag(&state.db, tag).await?;
|
||||
results.append(&mut tag_results);
|
||||
}
|
||||
// Remove duplicates by ID
|
||||
results.sort_by_key(|w| w.id);
|
||||
results.dedup_by_key(|w| w.id);
|
||||
results
|
||||
} else if search_params.enabled == Some(true) {
|
||||
// Filter by enabled status (only return enabled workflows)
|
||||
WorkflowDefinitionRepository::find_enabled(&state.db).await?
|
||||
} else {
|
||||
// Get all workflows
|
||||
WorkflowDefinitionRepository::list(&state.db).await?
|
||||
};
|
||||
|
||||
// Apply enabled filter if specified and not already filtered by it
|
||||
if let Some(enabled) = search_params.enabled {
|
||||
if search_params.tags.is_some() {
|
||||
// If we filtered by tags, also apply enabled filter
|
||||
workflows.retain(|w| w.enabled == enabled);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply search filter if provided
|
||||
if let Some(search_term) = &search_params.search {
|
||||
let search_lower = search_term.to_lowercase();
|
||||
workflows.retain(|w| {
|
||||
w.label.to_lowercase().contains(&search_lower)
|
||||
|| w.description
|
||||
.as_ref()
|
||||
.map(|d| d.to_lowercase().contains(&search_lower))
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Apply pack_ref filter if provided
|
||||
if let Some(pack_ref) = &search_params.pack_ref {
|
||||
workflows.retain(|w| w.pack_ref == *pack_ref);
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = workflows.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(workflows.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_workflows: Vec<WorkflowSummary> = workflows[start..end]
|
||||
.iter()
|
||||
.map(|w| WorkflowSummary::from(w.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_workflows, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// List workflows by pack reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/packs/{pack_ref}/workflows",
|
||||
tag = "workflows",
|
||||
params(
|
||||
("pack_ref" = String, Path, description = "Pack reference identifier"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of workflows for pack", body = PaginatedResponse<WorkflowSummary>),
|
||||
(status = 404, description = "Pack not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_workflows_by_pack(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(pack_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify pack exists
|
||||
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get workflows for this pack
|
||||
let workflows = WorkflowDefinitionRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
|
||||
// Calculate pagination
|
||||
let total = workflows.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(workflows.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_workflows: Vec<WorkflowSummary> = workflows[start..end]
|
||||
.iter()
|
||||
.map(|w| WorkflowSummary::from(w.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_workflows, &pagination, total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single workflow by reference
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/workflows/{ref}",
|
||||
tag = "workflows",
|
||||
params(
|
||||
("ref" = String, Path, description = "Workflow reference identifier")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Workflow details", body = inline(ApiResponse<WorkflowResponse>)),
|
||||
(status = 404, description = "Workflow not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_workflow(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(workflow_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let workflow = WorkflowDefinitionRepository::find_by_ref(&state.db, &workflow_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Workflow '{}' not found", workflow_ref)))?;
|
||||
|
||||
let response = ApiResponse::new(WorkflowResponse::from(workflow));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create a new workflow
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/workflows",
|
||||
tag = "workflows",
|
||||
request_body = CreateWorkflowRequest,
|
||||
responses(
|
||||
(status = 201, description = "Workflow created successfully", body = inline(ApiResponse<WorkflowResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Pack not found"),
|
||||
(status = 409, description = "Workflow with same ref already exists")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_workflow(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(request): Json<CreateWorkflowRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if workflow with same ref already exists
|
||||
if let Some(_) = WorkflowDefinitionRepository::find_by_ref(&state.db, &request.r#ref).await? {
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Workflow with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
// Verify pack exists and get its ID
|
||||
let pack = PackRepository::find_by_ref(&state.db, &request.pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", request.pack_ref)))?;
|
||||
|
||||
// Create workflow input
|
||||
let workflow_input = CreateWorkflowDefinitionInput {
|
||||
r#ref: request.r#ref,
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
version: request.version,
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
definition: request.definition,
|
||||
tags: request.tags.unwrap_or_default(),
|
||||
enabled: request.enabled.unwrap_or(true),
|
||||
};
|
||||
|
||||
let workflow = WorkflowDefinitionRepository::create(&state.db, workflow_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
WorkflowResponse::from(workflow),
|
||||
"Workflow created successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Update an existing workflow
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/workflows/{ref}",
|
||||
tag = "workflows",
|
||||
params(
|
||||
("ref" = String, Path, description = "Workflow reference identifier")
|
||||
),
|
||||
request_body = UpdateWorkflowRequest,
|
||||
responses(
|
||||
(status = 200, description = "Workflow updated successfully", body = inline(ApiResponse<WorkflowResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Workflow not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn update_workflow(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(workflow_ref): Path<String>,
|
||||
Json(request): Json<UpdateWorkflowRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate request
|
||||
request.validate()?;
|
||||
|
||||
// Check if workflow exists
|
||||
let existing_workflow = WorkflowDefinitionRepository::find_by_ref(&state.db, &workflow_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Workflow '{}' not found", workflow_ref)))?;
|
||||
|
||||
// Create update input
|
||||
let update_input = UpdateWorkflowDefinitionInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
version: request.version,
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
definition: request.definition,
|
||||
tags: request.tags,
|
||||
enabled: request.enabled,
|
||||
};
|
||||
|
||||
let workflow =
|
||||
WorkflowDefinitionRepository::update(&state.db, existing_workflow.id, update_input).await?;
|
||||
|
||||
let response = ApiResponse::with_message(
|
||||
WorkflowResponse::from(workflow),
|
||||
"Workflow updated successfully",
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Delete a workflow
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/workflows/{ref}",
|
||||
tag = "workflows",
|
||||
params(
|
||||
("ref" = String, Path, description = "Workflow reference identifier")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Workflow deleted successfully", body = SuccessResponse),
|
||||
(status = 404, description = "Workflow not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_workflow(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(workflow_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Check if workflow exists
|
||||
let workflow = WorkflowDefinitionRepository::find_by_ref(&state.db, &workflow_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Workflow '{}' not found", workflow_ref)))?;
|
||||
|
||||
// Delete the workflow
|
||||
let deleted = WorkflowDefinitionRepository::delete(&state.db, workflow.id).await?;
|
||||
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!(
|
||||
"Workflow '{}' not found",
|
||||
workflow_ref
|
||||
)));
|
||||
}
|
||||
|
||||
let response =
|
||||
SuccessResponse::new(format!("Workflow '{}' deleted successfully", workflow_ref));
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Create workflow routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/workflows", get(list_workflows).post(create_workflow))
|
||||
.route(
|
||||
"/workflows/{ref}",
|
||||
get(get_workflow)
|
||||
.put(update_workflow)
|
||||
.delete(delete_workflow),
|
||||
)
|
||||
.route("/packs/{pack_ref}/workflows", get(list_workflows_by_pack))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_workflow_routes_structure() {
|
||||
// Just verify the router can be constructed
|
||||
let _router = routes();
|
||||
}
|
||||
}
|
||||
125
crates/api/src/server.rs
Normal file
125
crates/api/src/server.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
//! Server setup and lifecycle management
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{middleware, Router};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::info;
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
use crate::{
|
||||
middleware::{create_cors_layer, log_request},
|
||||
openapi::ApiDoc,
|
||||
routes,
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// Server configuration and lifecycle manager
|
||||
pub struct Server {
|
||||
/// Application state
|
||||
state: Arc<AppState>,
|
||||
/// Server host address
|
||||
host: String,
|
||||
/// Server port
|
||||
port: u16,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
/// Create a new server instance
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
let host = state.config.server.host.clone();
|
||||
let port = state.config.server.port;
|
||||
|
||||
Self { state, host, port }
|
||||
}
|
||||
|
||||
/// Get the router for testing purposes
|
||||
pub fn router(&self) -> Router {
|
||||
self.build_router()
|
||||
}
|
||||
|
||||
/// Build the application router with all routes and middleware
|
||||
fn build_router(&self) -> Router {
|
||||
// API v1 routes (versioned endpoints)
|
||||
let api_v1 = Router::new()
|
||||
.merge(routes::pack_routes())
|
||||
.merge(routes::action_routes())
|
||||
.merge(routes::rule_routes())
|
||||
.merge(routes::execution_routes())
|
||||
.merge(routes::trigger_routes())
|
||||
.merge(routes::inquiry_routes())
|
||||
.merge(routes::event_routes())
|
||||
.merge(routes::key_routes())
|
||||
.merge(routes::workflow_routes())
|
||||
.merge(routes::webhook_routes())
|
||||
// TODO: Add more route modules here
|
||||
// etc.
|
||||
.with_state(self.state.clone());
|
||||
|
||||
// Auth routes at root level (not versioned for frontend compatibility)
|
||||
let auth_routes = routes::auth_routes().with_state(self.state.clone());
|
||||
|
||||
// Health endpoint at root level (operational endpoint, not versioned)
|
||||
let health_routes = routes::health_routes().with_state(self.state.clone());
|
||||
|
||||
// Root router with versioning and documentation
|
||||
Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-spec/openapi.json", ApiDoc::openapi()))
|
||||
.merge(health_routes)
|
||||
.nest("/auth", auth_routes)
|
||||
.nest("/api/v1", api_v1)
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
// Add tracing for all requests
|
||||
.layer(TraceLayer::new_for_http())
|
||||
// Add CORS support with configured origins
|
||||
.layer(create_cors_layer(self.state.cors_origins.clone()))
|
||||
// Add custom request logging
|
||||
.layer(middleware::from_fn(log_request)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Start the server and listen for requests
|
||||
pub async fn run(self) -> Result<()> {
|
||||
let router = self.build_router();
|
||||
let addr = format!("{}:{}", self.host, self.port);
|
||||
|
||||
info!("Starting server on {}", addr);
|
||||
info!("API documentation available at http://{}/docs", addr);
|
||||
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
info!("Server listening on {}", addr);
|
||||
|
||||
axum::serve(listener, router).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Graceful shutdown handler
|
||||
pub async fn shutdown(&self) {
|
||||
info!("Shutting down server...");
|
||||
// Perform any cleanup here
|
||||
// - Close database connections
|
||||
// - Flush logs
|
||||
// - Wait for in-flight requests
|
||||
info!("Server shutdown complete");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore until we have test database setup
|
||||
async fn test_server_creation() {
|
||||
// This test is ignored because it requires a test database pool
|
||||
// When implemented, create a test pool and verify server creation
|
||||
// let pool = PgPool::connect(&test_db_url).await.unwrap();
|
||||
// let state = AppState::new(pool);
|
||||
// let server = Server::new(state, "127.0.0.1".to_string(), 8080);
|
||||
// assert_eq!(server.host, "127.0.0.1");
|
||||
// assert_eq!(server.port, 8080);
|
||||
}
|
||||
}
|
||||
67
crates/api/src/state.rs
Normal file
67
crates/api/src/state.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Application state shared across request handlers
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::auth::jwt::JwtConfig;
|
||||
use attune_common::{config::Config, mq::Publisher};
|
||||
|
||||
/// Shared application state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// Database connection pool
|
||||
pub db: PgPool,
|
||||
/// JWT configuration
|
||||
pub jwt_config: Arc<JwtConfig>,
|
||||
/// CORS allowed origins
|
||||
pub cors_origins: Vec<String>,
|
||||
/// Application configuration
|
||||
pub config: Arc<Config>,
|
||||
/// Optional message queue publisher
|
||||
pub publisher: Option<Arc<Publisher>>,
|
||||
/// Broadcast channel for SSE notifications
|
||||
pub broadcast_tx: broadcast::Sender<String>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
/// Create new application state
|
||||
pub fn new(db: PgPool, config: Config) -> Self {
|
||||
let jwt_secret = config.security.jwt_secret.clone().unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
"JWT_SECRET not set in config, using default (INSECURE for production!)"
|
||||
);
|
||||
"insecure_default_secret_change_in_production".to_string()
|
||||
});
|
||||
|
||||
let jwt_config = JwtConfig {
|
||||
secret: jwt_secret,
|
||||
access_token_expiration: config.security.jwt_access_expiration as i64,
|
||||
refresh_token_expiration: config.security.jwt_refresh_expiration as i64,
|
||||
};
|
||||
|
||||
let cors_origins = config.server.cors_origins.clone();
|
||||
|
||||
// Create broadcast channel for SSE notifications (capacity 1000)
|
||||
let (broadcast_tx, _) = broadcast::channel(1000);
|
||||
|
||||
Self {
|
||||
db,
|
||||
jwt_config: Arc::new(jwt_config),
|
||||
cors_origins,
|
||||
config: Arc::new(config),
|
||||
publisher: None,
|
||||
broadcast_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the message queue publisher
|
||||
pub fn with_publisher(mut self, publisher: Arc<Publisher>) -> Self {
|
||||
self.publisher = Some(publisher);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for Arc-wrapped application state
|
||||
/// Used by Axum handlers
|
||||
pub type SharedState = Arc<AppState>;
|
||||
7
crates/api/src/validation/mod.rs
Normal file
7
crates/api/src/validation/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Validation module
|
||||
//!
|
||||
//! Contains validation utilities for API requests and parameters.
|
||||
|
||||
pub mod params;
|
||||
|
||||
pub use params::{validate_action_params, validate_trigger_params};
|
||||
259
crates/api/src/validation/params.rs
Normal file
259
crates/api/src/validation/params.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
//! Parameter validation module
|
||||
//!
|
||||
//! Validates trigger and action parameters against their declared JSON schemas.
|
||||
|
||||
use attune_common::models::{action::Action, trigger::Trigger};
|
||||
use jsonschema::Validator;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::middleware::ApiError;
|
||||
|
||||
/// Validate trigger parameters against the trigger's parameter schema
|
||||
pub fn validate_trigger_params(trigger: &Trigger, params: &Value) -> Result<(), ApiError> {
|
||||
// If no schema is defined, accept any parameters
|
||||
let Some(schema) = &trigger.param_schema else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// If parameters are empty object and schema exists, validate against schema
|
||||
// (schema might allow empty object or have defaults)
|
||||
|
||||
// Compile the JSON schema
|
||||
let compiled_schema = Validator::new(schema).map_err(|e| {
|
||||
ApiError::InternalServerError(format!(
|
||||
"Invalid parameter schema for trigger '{}': {}",
|
||||
trigger.r#ref, e
|
||||
))
|
||||
})?;
|
||||
|
||||
// Validate the parameters
|
||||
let errors: Vec<String> = compiled_schema
|
||||
.iter_errors(params)
|
||||
.map(|e| {
|
||||
let path = e.instance_path().to_string();
|
||||
if path.is_empty() {
|
||||
e.to_string()
|
||||
} else {
|
||||
format!("{} at {}", e, path)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !errors.is_empty() {
|
||||
return Err(ApiError::ValidationError(format!(
|
||||
"Invalid parameters for trigger '{}': {}",
|
||||
trigger.r#ref,
|
||||
errors.join(", ")
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate action parameters against the action's parameter schema
|
||||
pub fn validate_action_params(action: &Action, params: &Value) -> Result<(), ApiError> {
|
||||
// If no schema is defined, accept any parameters
|
||||
let Some(schema) = &action.param_schema else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Compile the JSON schema
|
||||
let compiled_schema = Validator::new(schema).map_err(|e| {
|
||||
ApiError::InternalServerError(format!(
|
||||
"Invalid parameter schema for action '{}': {}",
|
||||
action.r#ref, e
|
||||
))
|
||||
})?;
|
||||
|
||||
// Validate the parameters
|
||||
let errors: Vec<String> = compiled_schema
|
||||
.iter_errors(params)
|
||||
.map(|e| {
|
||||
let path = e.instance_path().to_string();
|
||||
if path.is_empty() {
|
||||
e.to_string()
|
||||
} else {
|
||||
format!("{} at {}", e, path)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !errors.is_empty() {
|
||||
return Err(ApiError::ValidationError(format!(
|
||||
"Invalid parameters for action '{}': {}",
|
||||
action.r#ref,
|
||||
errors.join(", ")
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_validate_trigger_params_with_no_schema() {
|
||||
let trigger = Trigger {
|
||||
id: 1,
|
||||
r#ref: "test.trigger".to_string(),
|
||||
pack: Some(1),
|
||||
pack_ref: Some("test".to_string()),
|
||||
label: "Test Trigger".to_string(),
|
||||
description: None,
|
||||
enabled: true,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
webhook_enabled: false,
|
||||
webhook_key: None,
|
||||
webhook_config: None,
|
||||
is_adhoc: false,
|
||||
created: chrono::Utc::now(),
|
||||
updated: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let params = json!({ "any": "value" });
|
||||
assert!(validate_trigger_params(&trigger, ¶ms).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_trigger_params_with_valid_params() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"unit": { "type": "string", "enum": ["seconds", "minutes", "hours"] },
|
||||
"delta": { "type": "integer", "minimum": 1 }
|
||||
},
|
||||
"required": ["unit", "delta"]
|
||||
});
|
||||
|
||||
let trigger = Trigger {
|
||||
id: 1,
|
||||
r#ref: "test.trigger".to_string(),
|
||||
pack: Some(1),
|
||||
pack_ref: Some("test".to_string()),
|
||||
label: "Test Trigger".to_string(),
|
||||
description: None,
|
||||
enabled: true,
|
||||
param_schema: Some(schema),
|
||||
out_schema: None,
|
||||
webhook_enabled: false,
|
||||
webhook_key: None,
|
||||
webhook_config: None,
|
||||
is_adhoc: false,
|
||||
created: chrono::Utc::now(),
|
||||
updated: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let params = json!({ "unit": "seconds", "delta": 10 });
|
||||
assert!(validate_trigger_params(&trigger, ¶ms).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_trigger_params_with_invalid_params() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"unit": { "type": "string", "enum": ["seconds", "minutes", "hours"] },
|
||||
"delta": { "type": "integer", "minimum": 1 }
|
||||
},
|
||||
"required": ["unit", "delta"]
|
||||
});
|
||||
|
||||
let trigger = Trigger {
|
||||
id: 1,
|
||||
r#ref: "test.trigger".to_string(),
|
||||
pack: Some(1),
|
||||
pack_ref: Some("test".to_string()),
|
||||
label: "Test Trigger".to_string(),
|
||||
description: None,
|
||||
enabled: true,
|
||||
param_schema: Some(schema),
|
||||
out_schema: None,
|
||||
webhook_enabled: false,
|
||||
webhook_key: None,
|
||||
webhook_config: None,
|
||||
is_adhoc: false,
|
||||
created: chrono::Utc::now(),
|
||||
updated: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
// Missing required field 'delta'
|
||||
let params = json!({ "unit": "seconds" });
|
||||
assert!(validate_trigger_params(&trigger, ¶ms).is_err());
|
||||
|
||||
// Invalid enum value for 'unit'
|
||||
let params = json!({ "unit": "days", "delta": 10 });
|
||||
assert!(validate_trigger_params(&trigger, ¶ms).is_err());
|
||||
|
||||
// Invalid type for 'delta'
|
||||
let params = json!({ "unit": "seconds", "delta": "10" });
|
||||
assert!(validate_trigger_params(&trigger, ¶ms).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_action_params_with_valid_params() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" }
|
||||
},
|
||||
"required": ["message"]
|
||||
});
|
||||
|
||||
let action = Action {
|
||||
id: 1,
|
||||
r#ref: "test.action".to_string(),
|
||||
pack: 1,
|
||||
pack_ref: "test".to_string(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test action".to_string(),
|
||||
entrypoint: "test.sh".to_string(),
|
||||
runtime: Some(1),
|
||||
param_schema: Some(schema),
|
||||
out_schema: None,
|
||||
is_workflow: false,
|
||||
workflow_def: None,
|
||||
is_adhoc: false,
|
||||
created: chrono::Utc::now(),
|
||||
updated: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let params = json!({ "message": "Hello, world!" });
|
||||
assert!(validate_action_params(&action, ¶ms).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_action_params_with_empty_params_but_required_fields() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" }
|
||||
},
|
||||
"required": ["message"]
|
||||
});
|
||||
|
||||
let action = Action {
|
||||
id: 1,
|
||||
r#ref: "test.action".to_string(),
|
||||
pack: 1,
|
||||
pack_ref: "test".to_string(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test action".to_string(),
|
||||
entrypoint: "test.sh".to_string(),
|
||||
runtime: Some(1),
|
||||
param_schema: Some(schema),
|
||||
out_schema: None,
|
||||
is_workflow: false,
|
||||
workflow_def: None,
|
||||
is_adhoc: false,
|
||||
created: chrono::Utc::now(),
|
||||
updated: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let params = json!({});
|
||||
assert!(validate_action_params(&action, ¶ms).is_err());
|
||||
}
|
||||
}
|
||||
274
crates/api/src/webhook_security.rs
Normal file
274
crates/api/src/webhook_security.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
//! Webhook security helpers for HMAC verification and validation
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Sha256, Sha512};
|
||||
use sha1::Sha1;
|
||||
|
||||
/// Verify HMAC signature for webhook payload
|
||||
pub fn verify_hmac_signature(
|
||||
payload: &[u8],
|
||||
signature: &str,
|
||||
secret: &str,
|
||||
algorithm: &str,
|
||||
) -> Result<bool, String> {
|
||||
// Parse signature format (e.g., "sha256=abc123..." or just "abc123...")
|
||||
let (algo_from_sig, hex_signature) = if signature.contains('=') {
|
||||
let parts: Vec<&str> = signature.splitn(2, '=').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err("Invalid signature format".to_string());
|
||||
}
|
||||
(Some(parts[0]), parts[1])
|
||||
} else {
|
||||
(None, signature)
|
||||
};
|
||||
|
||||
// Verify algorithm matches if specified in signature
|
||||
if let Some(sig_algo) = algo_from_sig {
|
||||
if sig_algo != algorithm {
|
||||
return Err(format!(
|
||||
"Algorithm mismatch: expected {}, got {}",
|
||||
algorithm, sig_algo
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Decode hex signature
|
||||
let expected_signature = hex::decode(hex_signature)
|
||||
.map_err(|e| format!("Invalid hex signature: {}", e))?;
|
||||
|
||||
// Compute HMAC based on algorithm
|
||||
let is_valid = match algorithm {
|
||||
"sha256" => verify_hmac_sha256(payload, &expected_signature, secret),
|
||||
"sha512" => verify_hmac_sha512(payload, &expected_signature, secret),
|
||||
"sha1" => verify_hmac_sha1(payload, &expected_signature, secret),
|
||||
_ => return Err(format!("Unsupported algorithm: {}", algorithm)),
|
||||
};
|
||||
|
||||
Ok(is_valid)
|
||||
}
|
||||
|
||||
/// Verify HMAC-SHA256 signature
|
||||
fn verify_hmac_sha256(payload: &[u8], expected: &[u8], secret: &str) -> bool {
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
mac.update(payload);
|
||||
|
||||
// Use constant-time comparison
|
||||
mac.verify_slice(expected).is_ok()
|
||||
}
|
||||
|
||||
/// Verify HMAC-SHA512 signature
|
||||
fn verify_hmac_sha512(payload: &[u8], expected: &[u8], secret: &str) -> bool {
|
||||
type HmacSha512 = Hmac<Sha512>;
|
||||
|
||||
let mut mac = match HmacSha512::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
mac.update(payload);
|
||||
|
||||
mac.verify_slice(expected).is_ok()
|
||||
}
|
||||
|
||||
/// Verify HMAC-SHA1 signature (legacy, not recommended)
|
||||
fn verify_hmac_sha1(payload: &[u8], expected: &[u8], secret: &str) -> bool {
|
||||
type HmacSha1 = Hmac<Sha1>;
|
||||
|
||||
let mut mac = match HmacSha1::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
mac.update(payload);
|
||||
|
||||
mac.verify_slice(expected).is_ok()
|
||||
}
|
||||
|
||||
/// Generate HMAC signature for testing
|
||||
pub fn generate_hmac_signature(payload: &[u8], secret: &str, algorithm: &str) -> Result<String, String> {
|
||||
let signature = match algorithm {
|
||||
"sha256" => {
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
|
||||
.map_err(|e| format!("Invalid key length: {}", e))?;
|
||||
mac.update(payload);
|
||||
let result = mac.finalize();
|
||||
hex::encode(result.into_bytes())
|
||||
}
|
||||
"sha512" => {
|
||||
type HmacSha512 = Hmac<Sha512>;
|
||||
let mut mac = HmacSha512::new_from_slice(secret.as_bytes())
|
||||
.map_err(|e| format!("Invalid key length: {}", e))?;
|
||||
mac.update(payload);
|
||||
let result = mac.finalize();
|
||||
hex::encode(result.into_bytes())
|
||||
}
|
||||
"sha1" => {
|
||||
type HmacSha1 = Hmac<Sha1>;
|
||||
let mut mac = HmacSha1::new_from_slice(secret.as_bytes())
|
||||
.map_err(|e| format!("Invalid key length: {}", e))?;
|
||||
mac.update(payload);
|
||||
let result = mac.finalize();
|
||||
hex::encode(result.into_bytes())
|
||||
}
|
||||
_ => return Err(format!("Unsupported algorithm: {}", algorithm)),
|
||||
};
|
||||
|
||||
Ok(format!("{}={}", algorithm, signature))
|
||||
}
|
||||
|
||||
/// Check if IP address matches a CIDR block
|
||||
pub fn check_ip_in_cidr(ip: &str, cidr: &str) -> Result<bool, String> {
|
||||
use std::net::IpAddr;
|
||||
|
||||
let ip_addr: IpAddr = ip.parse()
|
||||
.map_err(|e| format!("Invalid IP address: {}", e))?;
|
||||
|
||||
// If CIDR doesn't contain '/', treat it as a single IP
|
||||
if !cidr.contains('/') {
|
||||
let cidr_addr: IpAddr = cidr.parse()
|
||||
.map_err(|e| format!("Invalid CIDR notation: {}", e))?;
|
||||
return Ok(ip_addr == cidr_addr);
|
||||
}
|
||||
|
||||
// Parse CIDR notation
|
||||
let parts: Vec<&str> = cidr.split('/').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err("Invalid CIDR format".to_string());
|
||||
}
|
||||
|
||||
let network_addr: IpAddr = parts[0].parse()
|
||||
.map_err(|e| format!("Invalid network address: {}", e))?;
|
||||
let prefix_len: u8 = parts[1].parse()
|
||||
.map_err(|e| format!("Invalid prefix length: {}", e))?;
|
||||
|
||||
// Convert to bytes for comparison
|
||||
match (ip_addr, network_addr) {
|
||||
(IpAddr::V4(ip), IpAddr::V4(network)) => {
|
||||
if prefix_len > 32 {
|
||||
return Err("IPv4 prefix length must be <= 32".to_string());
|
||||
}
|
||||
let ip_bits = u32::from(ip);
|
||||
let network_bits = u32::from(network);
|
||||
let mask = if prefix_len == 0 { 0 } else { !0u32 << (32 - prefix_len) };
|
||||
Ok((ip_bits & mask) == (network_bits & mask))
|
||||
}
|
||||
(IpAddr::V6(ip), IpAddr::V6(network)) => {
|
||||
if prefix_len > 128 {
|
||||
return Err("IPv6 prefix length must be <= 128".to_string());
|
||||
}
|
||||
let ip_bits = u128::from(ip);
|
||||
let network_bits = u128::from(network);
|
||||
let mask = if prefix_len == 0 { 0 } else { !0u128 << (128 - prefix_len) };
|
||||
Ok((ip_bits & mask) == (network_bits & mask))
|
||||
}
|
||||
_ => Err("IP address and CIDR must be same version (IPv4 or IPv6)".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if IP is in any of the CIDR blocks in the whitelist
|
||||
pub fn check_ip_in_whitelist(ip: &str, whitelist: &[String]) -> Result<bool, String> {
|
||||
for cidr in whitelist {
|
||||
match check_ip_in_cidr(ip, cidr) {
|
||||
Ok(true) => return Ok(true),
|
||||
Ok(false) => continue,
|
||||
Err(e) => return Err(format!("Error checking CIDR {}: {}", cidr, e)),
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_verify_hmac_sha256() {
|
||||
let payload = b"test payload";
|
||||
let secret = "my-secret-key";
|
||||
let signature = generate_hmac_signature(payload, secret, "sha256").unwrap();
|
||||
|
||||
assert!(verify_hmac_signature(payload, &signature, secret, "sha256").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_hmac_wrong_secret() {
|
||||
let payload = b"test payload";
|
||||
let secret = "my-secret-key";
|
||||
let wrong_secret = "wrong-key";
|
||||
let signature = generate_hmac_signature(payload, secret, "sha256").unwrap();
|
||||
|
||||
assert!(!verify_hmac_signature(payload, &signature, wrong_secret, "sha256").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_hmac_wrong_payload() {
|
||||
let payload = b"test payload";
|
||||
let wrong_payload = b"wrong payload";
|
||||
let secret = "my-secret-key";
|
||||
let signature = generate_hmac_signature(payload, secret, "sha256").unwrap();
|
||||
|
||||
assert!(!verify_hmac_signature(wrong_payload, &signature, secret, "sha256").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_hmac_sha512() {
|
||||
let payload = b"test payload";
|
||||
let secret = "my-secret-key";
|
||||
let signature = generate_hmac_signature(payload, secret, "sha512").unwrap();
|
||||
|
||||
assert!(verify_hmac_signature(payload, &signature, secret, "sha512").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_hmac_without_algorithm_prefix() {
|
||||
let payload = b"test payload";
|
||||
let secret = "my-secret-key";
|
||||
let signature = generate_hmac_signature(payload, secret, "sha256").unwrap();
|
||||
|
||||
// Remove the "sha256=" prefix
|
||||
let hex_only = signature.split('=').nth(1).unwrap();
|
||||
|
||||
assert!(verify_hmac_signature(payload, hex_only, secret, "sha256").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_ip_in_cidr_single_ip() {
|
||||
assert!(check_ip_in_cidr("192.168.1.1", "192.168.1.1").unwrap());
|
||||
assert!(!check_ip_in_cidr("192.168.1.2", "192.168.1.1").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_ip_in_cidr_block() {
|
||||
assert!(check_ip_in_cidr("192.168.1.100", "192.168.1.0/24").unwrap());
|
||||
assert!(check_ip_in_cidr("192.168.1.1", "192.168.1.0/24").unwrap());
|
||||
assert!(check_ip_in_cidr("192.168.1.254", "192.168.1.0/24").unwrap());
|
||||
assert!(!check_ip_in_cidr("192.168.2.1", "192.168.1.0/24").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_ip_in_cidr_ipv6() {
|
||||
assert!(check_ip_in_cidr("2001:db8::1", "2001:db8::/32").unwrap());
|
||||
assert!(!check_ip_in_cidr("2001:db9::1", "2001:db8::/32").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_ip_in_whitelist() {
|
||||
let whitelist = vec![
|
||||
"192.168.1.0/24".to_string(),
|
||||
"10.0.0.0/8".to_string(),
|
||||
"172.16.5.10".to_string(),
|
||||
];
|
||||
|
||||
assert!(check_ip_in_whitelist("192.168.1.100", &whitelist).unwrap());
|
||||
assert!(check_ip_in_whitelist("10.20.30.40", &whitelist).unwrap());
|
||||
assert!(check_ip_in_whitelist("172.16.5.10", &whitelist).unwrap());
|
||||
assert!(!check_ip_in_whitelist("8.8.8.8", &whitelist).unwrap());
|
||||
}
|
||||
}
|
||||
145
crates/api/tests/README.md
Normal file
145
crates/api/tests/README.md
Normal file
@@ -0,0 +1,145 @@
|
||||
# API Integration Tests
|
||||
|
||||
This directory contains integration tests for the Attune API service.
|
||||
|
||||
## Test Files
|
||||
|
||||
- `webhook_api_tests.rs` - Basic webhook management and receiver endpoint tests (8 tests)
|
||||
- `webhook_security_tests.rs` - Comprehensive webhook security feature tests (17 tests)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before running tests, ensure:
|
||||
|
||||
1. **PostgreSQL is running** on `localhost:5432` (or set `DATABASE_URL`)
|
||||
2. **Database migrations are applied**: `sqlx migrate run`
|
||||
3. **Test user exists** (username: `test_user`, password: `test_password`)
|
||||
|
||||
### Quick Setup
|
||||
|
||||
```bash
|
||||
# Set database URL
|
||||
export DATABASE_URL="postgresql://postgres:postgres@localhost:5432/attune"
|
||||
|
||||
# Run migrations
|
||||
sqlx migrate run
|
||||
|
||||
# Create test user (run from psql or create via API)
|
||||
# The test user is created automatically when you run the API for the first time
|
||||
# Or create manually:
|
||||
psql $DATABASE_URL -c "
|
||||
INSERT INTO attune.identity (username, email, password_hash, enabled)
|
||||
VALUES ('test_user', 'test@example.com',
|
||||
crypt('test_password', gen_salt('bf')), true)
|
||||
ON CONFLICT (username) DO NOTHING;
|
||||
"
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
All tests are marked with `#[ignore]` because they require a database connection.
|
||||
|
||||
### Run all API integration tests
|
||||
```bash
|
||||
cargo test -p attune-api --test '*' -- --ignored
|
||||
```
|
||||
|
||||
### Run webhook API tests only
|
||||
```bash
|
||||
cargo test -p attune-api --test webhook_api_tests -- --ignored
|
||||
```
|
||||
|
||||
### Run webhook security tests only
|
||||
```bash
|
||||
cargo test -p attune-api --test webhook_security_tests -- --ignored
|
||||
```
|
||||
|
||||
### Run a specific test
|
||||
```bash
|
||||
cargo test -p attune-api --test webhook_security_tests test_webhook_hmac_sha256_valid -- --ignored --nocapture
|
||||
```
|
||||
|
||||
### Run tests with output
|
||||
```bash
|
||||
cargo test -p attune-api --test webhook_security_tests -- --ignored --nocapture
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Basic Webhook Tests (`webhook_api_tests.rs`)
|
||||
- Webhook enable/disable/regenerate operations
|
||||
- Webhook receiver with valid/invalid keys
|
||||
- Authentication enforcement
|
||||
- Disabled webhook handling
|
||||
|
||||
### Security Feature Tests (`webhook_security_tests.rs`)
|
||||
|
||||
#### HMAC Signature Tests
|
||||
- `test_webhook_hmac_sha256_valid` - SHA256 signature validation
|
||||
- `test_webhook_hmac_sha512_valid` - SHA512 signature validation
|
||||
- `test_webhook_hmac_invalid_signature` - Invalid signature rejection
|
||||
- `test_webhook_hmac_missing_signature` - Missing signature rejection
|
||||
- `test_webhook_hmac_wrong_secret` - Wrong secret rejection
|
||||
|
||||
#### Rate Limiting Tests
|
||||
- `test_webhook_rate_limit_enforced` - Rate limit enforcement
|
||||
- `test_webhook_rate_limit_disabled` - No rate limit when disabled
|
||||
|
||||
#### IP Whitelisting Tests
|
||||
- `test_webhook_ip_whitelist_allowed` - Allowed IPs pass
|
||||
- `test_webhook_ip_whitelist_blocked` - Blocked IPs rejected
|
||||
|
||||
#### Payload Size Tests
|
||||
- `test_webhook_payload_size_limit_enforced` - Size limit enforcement
|
||||
- `test_webhook_payload_size_within_limit` - Valid size acceptance
|
||||
|
||||
#### Event Logging Tests
|
||||
- `test_webhook_event_logging_success` - Success logging
|
||||
- `test_webhook_event_logging_failure` - Failure logging
|
||||
|
||||
#### Combined Security Tests
|
||||
- `test_webhook_all_security_features_pass` - All features enabled
|
||||
- `test_webhook_multiple_security_failures` - Multiple failures
|
||||
|
||||
#### Error Scenarios
|
||||
- `test_webhook_malformed_json` - Invalid JSON handling
|
||||
- `test_webhook_empty_payload` - Empty payload handling
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Failed to connect to database"
|
||||
- Ensure PostgreSQL is running: `pg_isready -h localhost -p 5432`
|
||||
- Check `DATABASE_URL` is set correctly
|
||||
- Test connection: `psql $DATABASE_URL -c "SELECT 1"`
|
||||
|
||||
### "Trigger not found" or table errors
|
||||
- Run migrations: `sqlx migrate run`
|
||||
- Check schema exists: `psql $DATABASE_URL -c "\dn"`
|
||||
|
||||
### "Authentication required" errors
|
||||
- Ensure test user exists with correct credentials
|
||||
- Check `JWT_SECRET` environment variable is set
|
||||
|
||||
### Tests timeout
|
||||
- Increase timeout with: `cargo test -- --ignored --test-threads=1`
|
||||
- Check database performance
|
||||
- Reduce concurrent test execution
|
||||
|
||||
### Rate limit tests fail
|
||||
- Clear webhook event logs between runs
|
||||
- Ensure tests run in isolation: `cargo test -- --ignored --test-threads=1`
|
||||
|
||||
## Documentation
|
||||
|
||||
For comprehensive test documentation, see:
|
||||
- `docs/webhook-testing.md` - Full test suite documentation
|
||||
- `docs/webhook-manual-testing.md` - Manual testing guide
|
||||
- `docs/webhook-system-architecture.md` - Webhook system architecture
|
||||
|
||||
## CI/CD
|
||||
|
||||
These tests are designed to run in CI with:
|
||||
- PostgreSQL service container
|
||||
- Automatic migration application
|
||||
- Test user creation script
|
||||
- Parallel test execution (where safe)
|
||||
241
crates/api/tests/SSE_TESTS_README.md
Normal file
241
crates/api/tests/SSE_TESTS_README.md
Normal file
@@ -0,0 +1,241 @@
|
||||
# SSE Integration Tests
|
||||
|
||||
This directory contains integration tests for the Server-Sent Events (SSE) execution streaming functionality.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run CI-friendly tests (no server required)
|
||||
cargo test -p attune-api --test sse_execution_stream_tests
|
||||
|
||||
# Expected output:
|
||||
# test result: ok. 2 passed; 0 failed; 3 ignored
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
The SSE tests verify the complete real-time update pipeline:
|
||||
1. PostgreSQL NOTIFY triggers fire on execution changes
|
||||
2. API service listener receives notifications via LISTEN
|
||||
3. Notifications are broadcast to SSE clients
|
||||
4. Web UI receives real-time updates
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Database-Level Tests (No Server Required) ✅ CI-Friendly
|
||||
|
||||
These tests run automatically and do NOT require the API server:
|
||||
|
||||
```bash
|
||||
# Run all non-ignored tests (CI/CD safe)
|
||||
cargo test -p attune-api --test sse_execution_stream_tests
|
||||
|
||||
# Or specifically test PostgreSQL NOTIFY
|
||||
cargo test -p attune-api test_postgresql_notify_trigger_fires -- --nocapture
|
||||
```
|
||||
|
||||
**What they test:**
|
||||
- ✅ PostgreSQL trigger fires on execution INSERT/UPDATE
|
||||
- ✅ Notification payload structure is correct
|
||||
- ✅ LISTEN/NOTIFY mechanism works
|
||||
- ✅ Database-level integration is working
|
||||
|
||||
**Status**: These tests pass automatically in CI/CD
|
||||
|
||||
### 2. End-to-End SSE Tests (Server Required) 🚧 Manual Testing
|
||||
|
||||
These tests are **marked as `#[ignore]`** and require a running API service.
|
||||
They are not run by default in CI/CD.
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start API service
|
||||
cargo run -p attune-api -- -c config.test.yaml
|
||||
|
||||
# Terminal 2: Run ignored SSE tests
|
||||
cargo test -p attune-api --test sse_execution_stream_tests -- --ignored --nocapture --test-threads=1
|
||||
|
||||
# Or run a specific test
|
||||
cargo test -p attune-api test_sse_stream_receives_execution_updates -- --ignored --nocapture
|
||||
```
|
||||
|
||||
**What they test:**
|
||||
- 🔍 SSE endpoint receives notifications from PostgreSQL listener
|
||||
- 🔍 Filtering by execution_id works correctly
|
||||
- 🔍 Authentication is enforced
|
||||
- 🔍 Multiple concurrent SSE connections work
|
||||
- 🔍 Real-time updates are delivered instantly
|
||||
|
||||
**Status**: Manual verification only (marked `#[ignore]`)
|
||||
|
||||
## Test Files
|
||||
|
||||
- `sse_execution_stream_tests.rs` - Main SSE integration tests (539 lines)
|
||||
- 5 comprehensive test cases covering the full SSE pipeline
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Database Setup
|
||||
Each test:
|
||||
1. Creates a clean test database state
|
||||
2. Sets up test pack and action
|
||||
3. Creates test executions
|
||||
|
||||
### SSE Connection
|
||||
Tests use `eventsource-client` crate to:
|
||||
1. Connect to `/api/v1/executions/stream` endpoint
|
||||
2. Authenticate with JWT token
|
||||
3. Subscribe to execution updates
|
||||
4. Verify received events
|
||||
|
||||
### Assertions
|
||||
Tests verify:
|
||||
- Correct event structure
|
||||
- Proper filtering behavior
|
||||
- Authentication requirements
|
||||
- Real-time delivery (no polling delay)
|
||||
|
||||
## Running All Tests
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start API service
|
||||
cargo run -p attune-api -- -c config.test.yaml
|
||||
|
||||
# Terminal 2: Run all SSE tests
|
||||
cargo test -p attune-api --test sse_execution_stream_tests -- --test-threads=1 --nocapture
|
||||
|
||||
# Or run specific test
|
||||
cargo test -p attune-api test_sse_stream_receives_execution_updates -- --nocapture
|
||||
```
|
||||
|
||||
## Expected Output
|
||||
|
||||
### Default Test Run (CI/CD)
|
||||
|
||||
```
|
||||
running 5 tests
|
||||
test test_postgresql_notify_trigger_fires ... ok
|
||||
test test_sse_stream_receives_execution_updates ... ignored
|
||||
test test_sse_stream_filters_by_execution_id ... ignored
|
||||
test test_sse_stream_all_executions ... ignored
|
||||
test test_sse_stream_requires_authentication ... ok
|
||||
|
||||
test result: ok. 2 passed; 0 failed; 3 ignored
|
||||
```
|
||||
|
||||
### Full Test Run (With Server Running)
|
||||
|
||||
```
|
||||
running 5 tests
|
||||
test test_postgresql_notify_trigger_fires ... ok
|
||||
test test_sse_stream_receives_execution_updates ... ok
|
||||
test test_sse_stream_filters_by_execution_id ... ok
|
||||
test test_sse_stream_requires_authentication ... ok
|
||||
test test_sse_stream_all_executions ... ok
|
||||
|
||||
test result: ok. 5 passed; 0 failed; 0 ignored
|
||||
```
|
||||
|
||||
### PostgreSQL Notification Example
|
||||
|
||||
```json
|
||||
{
|
||||
"entity_type": "execution",
|
||||
"entity_id": 123,
|
||||
"timestamp": "2026-01-19T05:02:14.188288+00:00",
|
||||
"data": {
|
||||
"id": 123,
|
||||
"status": "running",
|
||||
"action_id": 42,
|
||||
"action_ref": "test_sse_pack.test_action",
|
||||
"result": null,
|
||||
"created": "2026-01-19T05:02:13.982769+00:00",
|
||||
"updated": "2026-01-19T05:02:14.188288+00:00"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Refused Error
|
||||
|
||||
```
|
||||
error trying to connect: tcp connect error: Connection refused
|
||||
```
|
||||
|
||||
**Solution**: Make sure the API service is running on port 8080:
|
||||
```bash
|
||||
cargo run -p attune-api -- -c config.test.yaml
|
||||
```
|
||||
|
||||
### Test Database Not Found
|
||||
|
||||
**Solution**: Create the test database:
|
||||
```bash
|
||||
createdb attune_test
|
||||
sqlx migrate run --database-url postgresql://postgres:postgres@localhost:5432/attune_test
|
||||
```
|
||||
|
||||
### Missing Migration
|
||||
|
||||
**Solution**: Apply the execution notify trigger migration:
|
||||
```bash
|
||||
psql postgresql://postgres:postgres@localhost:5432/attune_test < migrations/20260119000001_add_execution_notify_trigger.sql
|
||||
```
|
||||
|
||||
### Tests Hang
|
||||
|
||||
**Cause**: Tests are waiting for SSE events that never arrive
|
||||
|
||||
**Debug steps:**
|
||||
1. Check API service logs for PostgreSQL listener errors
|
||||
2. Verify trigger exists: `\d+ attune.execution` in psql
|
||||
3. Manually update execution and check notifications:
|
||||
```sql
|
||||
UPDATE attune.execution SET status = 'running' WHERE id = 1;
|
||||
LISTEN attune_notifications;
|
||||
```
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
### Recommended Approach (Default)
|
||||
|
||||
Run only the database-level tests in CI/CD:
|
||||
|
||||
```bash
|
||||
# CI-friendly tests (no server required) ✅
|
||||
cargo test -p attune-api --test sse_execution_stream_tests
|
||||
```
|
||||
|
||||
This will:
|
||||
- ✅ Run `test_postgresql_notify_trigger_fires` (database trigger verification)
|
||||
- ✅ Run `test_sse_stream_requires_authentication` (auth logic verification)
|
||||
- ⏭️ Skip 3 tests marked `#[ignore]` (require running server)
|
||||
|
||||
### Full Testing (Optional)
|
||||
|
||||
For complete end-to-end verification in CI/CD:
|
||||
|
||||
```bash
|
||||
# Start API in background
|
||||
cargo run -p attune-api -- -c config.test.yaml &
|
||||
API_PID=$!
|
||||
|
||||
# Wait for server to start
|
||||
sleep 3
|
||||
|
||||
# Run ALL tests including ignored ones
|
||||
cargo test -p attune-api --test sse_execution_stream_tests -- --ignored --test-threads=1
|
||||
|
||||
# Cleanup
|
||||
kill $API_PID
|
||||
```
|
||||
|
||||
**Note**: Full testing adds complexity and time. The database-level tests provide
|
||||
sufficient coverage for the notification pipeline. The ignored tests are for
|
||||
manual verification during development.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [SSE Architecture](../../docs/sse-architecture.md)
|
||||
- [Web UI Integration](../../web/src/hooks/useExecutionStream.ts)
|
||||
- [Session Summary](../../work-summary/session-09-web-ui-detail-pages.md)
|
||||
416
crates/api/tests/health_and_auth_tests.rs
Normal file
416
crates/api/tests/health_and_auth_tests.rs
Normal file
@@ -0,0 +1,416 @@
|
||||
//! Integration tests for health check and authentication endpoints
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use helpers::*;
|
||||
use serde_json::json;
|
||||
|
||||
mod helpers;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_debug() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "debuguser",
|
||||
"password": "TestPassword123!",
|
||||
"display_name": "Debug User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
let status = response.status();
|
||||
println!("Status: {}", status);
|
||||
|
||||
let body_text = response.text().await.expect("Failed to get body");
|
||||
println!("Body: {}", body_text);
|
||||
|
||||
// This test is just for debugging - will fail if not 201
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_check() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/health", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
|
||||
assert_eq!(body["status"], "ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_detailed() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/health/detailed", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
|
||||
assert_eq!(body["status"], "ok");
|
||||
assert_eq!(body["database"], "connected");
|
||||
assert!(body["version"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_ready() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/health/ready", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Readiness endpoint returns empty body with 200 status
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_live() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/health/live", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Liveness endpoint returns empty body with 200 status
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_user() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "newuser",
|
||||
"password": "SecurePassword123!",
|
||||
"display_name": "New User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
|
||||
assert!(body["data"].is_object());
|
||||
assert!(body["data"]["access_token"].is_string());
|
||||
assert!(body["data"]["refresh_token"].is_string());
|
||||
assert!(body["data"]["user"].is_object());
|
||||
assert_eq!(body["data"]["user"]["login"], "newuser");
|
||||
assert_eq!(body["data"]["user"]["display_name"], "New User");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_duplicate_user() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
// Register first user
|
||||
let _ = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "duplicate",
|
||||
"password": "SecurePassword123!",
|
||||
"display_name": "Duplicate User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// Try to register same user again
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "duplicate",
|
||||
"password": "SecurePassword123!",
|
||||
"display_name": "Duplicate User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_invalid_password() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "testuser",
|
||||
"password": "weak",
|
||||
"display_name": "Test User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_login_success() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
// Register a user first
|
||||
let _ = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "loginuser",
|
||||
"password": "SecurePassword123!",
|
||||
"display_name": "Login User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to register user");
|
||||
|
||||
// Now try to login
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/login",
|
||||
json!({
|
||||
"login": "loginuser",
|
||||
"password": "SecurePassword123!"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
|
||||
assert!(body["data"]["access_token"].is_string());
|
||||
assert!(body["data"]["refresh_token"].is_string());
|
||||
assert_eq!(body["data"]["user"]["login"], "loginuser");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_login_wrong_password() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
// Register a user first
|
||||
let _ = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "wrongpassuser",
|
||||
"password": "SecurePassword123!",
|
||||
"display_name": "Wrong Pass User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to register user");
|
||||
|
||||
// Try to login with wrong password
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/login",
|
||||
json!({
|
||||
"login": "wrongpassuser",
|
||||
"password": "WrongPassword123!"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_login_nonexistent_user() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/login",
|
||||
json!({
|
||||
"login": "nonexistent",
|
||||
"password": "SomePassword123!"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_user() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context")
|
||||
.with_auth()
|
||||
.await
|
||||
.expect("Failed to authenticate");
|
||||
|
||||
let response = ctx
|
||||
.get("/auth/me", ctx.token())
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
|
||||
assert!(body["data"].is_object());
|
||||
assert!(body["data"]["id"].is_number());
|
||||
assert!(body["data"]["login"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_user_unauthorized() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/auth/me", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_user_invalid_token() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/auth/me", Some("invalid-token"))
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_token() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
// Register a user first
|
||||
let register_response = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": "refreshuser",
|
||||
"email": "refresh@example.com",
|
||||
"password": "SecurePassword123!",
|
||||
"display_name": "Refresh User"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to register user");
|
||||
|
||||
let register_body: serde_json::Value = register_response
|
||||
.json()
|
||||
.await
|
||||
.expect("Failed to parse JSON");
|
||||
|
||||
let refresh_token = register_body["data"]["refresh_token"]
|
||||
.as_str()
|
||||
.expect("Missing refresh token");
|
||||
|
||||
// Use refresh token to get new access token
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/refresh",
|
||||
json!({
|
||||
"refresh_token": refresh_token
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
|
||||
assert!(body["data"]["access_token"].is_string());
|
||||
assert!(body["data"]["refresh_token"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_with_invalid_token() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/refresh",
|
||||
json!({
|
||||
"refresh_token": "invalid-refresh-token"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
525
crates/api/tests/helpers.rs
Normal file
525
crates/api/tests/helpers.rs
Normal file
@@ -0,0 +1,525 @@
|
||||
//! Test helpers and utilities for API integration tests
|
||||
//!
|
||||
//! This module provides common test fixtures, server setup/teardown,
|
||||
//! and utility functions for testing API endpoints.
|
||||
|
||||
use attune_common::{
|
||||
config::Config,
|
||||
db::Database,
|
||||
models::*,
|
||||
repositories::{
|
||||
action::{ActionRepository, CreateActionInput},
|
||||
pack::{CreatePackInput, PackRepository},
|
||||
trigger::{CreateTriggerInput, TriggerRepository},
|
||||
workflow::{CreateWorkflowDefinitionInput, WorkflowDefinitionRepository},
|
||||
Create,
|
||||
},
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{header, Method, Request, StatusCode},
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::PgPool;
|
||||
use std::sync::{Arc, Once};
|
||||
use tower::Service;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
/// Initialize test environment (run once)
|
||||
pub fn init_test_env() {
|
||||
INIT.call_once(|| {
|
||||
// Clear any existing ATTUNE environment variables
|
||||
for (key, _) in std::env::vars() {
|
||||
if key.starts_with("ATTUNE") {
|
||||
std::env::remove_var(&key);
|
||||
}
|
||||
}
|
||||
|
||||
// Don't set environment via env var - let config load from file
|
||||
// The test config file already specifies environment: test
|
||||
|
||||
// Initialize tracing for tests
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::from_default_env()
|
||||
.add_directive(tracing::Level::WARN.into()),
|
||||
)
|
||||
.try_init()
|
||||
.ok();
|
||||
});
|
||||
}
|
||||
|
||||
/// Create a base database pool (connected to attune_test database)
|
||||
async fn create_base_pool() -> Result<PgPool> {
|
||||
init_test_env();
|
||||
|
||||
// Load config from project root (crates/api is 2 levels deep)
|
||||
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
|
||||
let config_path = format!("{}/../../config.test.yaml", manifest_dir);
|
||||
|
||||
let config = Config::load_from_file(&config_path)
|
||||
.map_err(|e| format!("Failed to load config from {}: {}", config_path, e))?;
|
||||
|
||||
// Create base pool without setting search_path (for creating schemas)
|
||||
// Don't use Database::new as it sets search_path - we just need a plain connection
|
||||
let pool = sqlx::PgPool::connect(&config.database.url).await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// Create a test database pool with a unique schema for this test
|
||||
async fn create_schema_pool(schema_name: &str) -> Result<PgPool> {
|
||||
let base_pool = create_base_pool().await?;
|
||||
|
||||
// Create the test schema
|
||||
tracing::debug!("Creating test schema: {}", schema_name);
|
||||
let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS {}", schema_name);
|
||||
sqlx::query(&create_schema_sql).execute(&base_pool).await?;
|
||||
tracing::debug!("Test schema created successfully: {}", schema_name);
|
||||
|
||||
// Run migrations in the new schema
|
||||
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
|
||||
let migrations_path = format!("{}/../../migrations", manifest_dir);
|
||||
|
||||
// Create a config with our test schema and add search_path to the URL
|
||||
let config_path = format!("{}/../../config.test.yaml", manifest_dir);
|
||||
let mut config = Config::load_from_file(&config_path)?;
|
||||
config.database.schema = Some(schema_name.to_string());
|
||||
|
||||
// Add search_path parameter to the database URL for the migrator
|
||||
// PostgreSQL supports setting options in the connection URL
|
||||
let separator = if config.database.url.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
"?"
|
||||
};
|
||||
|
||||
// Use proper URL encoding for search_path option
|
||||
let _url_with_schema = format!(
|
||||
"{}{}options=--search_path%3D{}",
|
||||
config.database.url, separator, schema_name
|
||||
);
|
||||
|
||||
// Create a pool directly with the modified URL for migrations
|
||||
// Also set after_connect hook to ensure all connections from pool have search_path
|
||||
let migration_pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.after_connect({
|
||||
let schema = schema_name.to_string();
|
||||
move |conn, _meta| {
|
||||
let schema = schema.clone();
|
||||
Box::pin(async move {
|
||||
sqlx::query(&format!("SET search_path TO {}", schema))
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
})
|
||||
.connect(&config.database.url)
|
||||
.await?;
|
||||
|
||||
// Manually run migration SQL files instead of using SQLx migrator
|
||||
// This is necessary because SQLx migrator has issues with per-schema search_path
|
||||
let migration_files = std::fs::read_dir(&migrations_path)?;
|
||||
let mut migrations: Vec<_> = migration_files
|
||||
.filter_map(|entry| entry.ok())
|
||||
.filter(|entry| entry.path().extension().and_then(|s| s.to_str()) == Some("sql"))
|
||||
.collect();
|
||||
|
||||
// Sort by filename to ensure migrations run in version order
|
||||
migrations.sort_by_key(|entry| entry.path().clone());
|
||||
|
||||
for migration_file in migrations {
|
||||
let migration_path = migration_file.path();
|
||||
let sql = std::fs::read_to_string(&migration_path)?;
|
||||
|
||||
// Execute search_path setting and migration in sequence
|
||||
// First set the search_path
|
||||
sqlx::query(&format!("SET search_path TO {}", schema_name))
|
||||
.execute(&migration_pool)
|
||||
.await?;
|
||||
|
||||
// Then execute the migration SQL
|
||||
// This preserves DO blocks, CREATE TYPE statements, etc.
|
||||
if let Err(e) = sqlx::raw_sql(&sql).execute(&migration_pool).await {
|
||||
// Ignore "already exists" errors since enums may be global
|
||||
let error_msg = format!("{:?}", e);
|
||||
if !error_msg.contains("already exists") && !error_msg.contains("duplicate") {
|
||||
eprintln!(
|
||||
"Migration error in {}: {}",
|
||||
migration_file.path().display(),
|
||||
e
|
||||
);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now create the proper Database instance for use in tests
|
||||
let database = Database::new(&config.database).await?;
|
||||
let pool = database.pool().clone();
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// Cleanup a test schema (drop it)
|
||||
pub async fn cleanup_test_schema(schema_name: &str) -> Result<()> {
|
||||
let base_pool = create_base_pool().await?;
|
||||
|
||||
// Drop the schema and all its contents
|
||||
tracing::debug!("Dropping test schema: {}", schema_name);
|
||||
let drop_schema_sql = format!("DROP SCHEMA IF EXISTS {} CASCADE", schema_name);
|
||||
sqlx::query(&drop_schema_sql).execute(&base_pool).await?;
|
||||
tracing::debug!("Test schema dropped successfully: {}", schema_name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create unique test packs directory for this test
|
||||
pub fn create_test_packs_dir(schema: &str) -> Result<std::path::PathBuf> {
|
||||
let test_packs_dir = std::path::PathBuf::from(format!("/tmp/attune-test-packs-{}", schema));
|
||||
if test_packs_dir.exists() {
|
||||
std::fs::remove_dir_all(&test_packs_dir)?;
|
||||
}
|
||||
std::fs::create_dir_all(&test_packs_dir)?;
|
||||
Ok(test_packs_dir)
|
||||
}
|
||||
|
||||
/// Test context with server and authentication
|
||||
pub struct TestContext {
|
||||
#[allow(dead_code)]
|
||||
pub pool: PgPool,
|
||||
pub app: axum::Router,
|
||||
pub token: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
pub user: Option<Identity>,
|
||||
pub schema: String,
|
||||
pub test_packs_dir: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl TestContext {
|
||||
/// Create a new test context with a unique schema
|
||||
pub async fn new() -> Result<Self> {
|
||||
// Generate a unique schema name for this test
|
||||
let schema = format!("test_{}", uuid::Uuid::new_v4().to_string().replace("-", ""));
|
||||
|
||||
tracing::info!("Initializing test context with schema: {}", schema);
|
||||
|
||||
// Create unique test packs directory for this test
|
||||
let test_packs_dir = create_test_packs_dir(&schema)?;
|
||||
|
||||
// Create pool with the test schema
|
||||
let pool = create_schema_pool(&schema).await?;
|
||||
|
||||
// Load config from project root
|
||||
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
|
||||
let config_path = format!("{}/../../config.test.yaml", manifest_dir);
|
||||
let mut config = Config::load_from_file(&config_path)?;
|
||||
config.database.schema = Some(schema.clone());
|
||||
|
||||
let state = attune_api::state::AppState::new(pool.clone(), config.clone());
|
||||
let server = attune_api::server::Server::new(Arc::new(state));
|
||||
let app = server.router();
|
||||
|
||||
Ok(Self {
|
||||
pool,
|
||||
app,
|
||||
token: None,
|
||||
user: None,
|
||||
schema,
|
||||
test_packs_dir,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create and authenticate a test user
|
||||
pub async fn with_auth(mut self) -> Result<Self> {
|
||||
// Generate unique username to avoid conflicts in parallel tests
|
||||
let unique_id = uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string();
|
||||
let login = format!("testuser_{}", unique_id);
|
||||
let token = self.create_test_user(&login).await?;
|
||||
self.token = Some(token);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Create a test user and return access token
|
||||
async fn create_test_user(&self, login: &str) -> Result<String> {
|
||||
// Register via API to get real token
|
||||
let response = self
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": login,
|
||||
"password": "TestPassword123!",
|
||||
"display_name": format!("Test User {}", login)
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let body: Value = response.json().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(
|
||||
format!("Failed to register user: status={}, body={}", status, body).into(),
|
||||
);
|
||||
}
|
||||
|
||||
let token = body["data"]["access_token"]
|
||||
.as_str()
|
||||
.ok_or_else(|| format!("No access token in response: {}", body))?
|
||||
.to_string();
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Make a GET request
|
||||
#[allow(dead_code)]
|
||||
pub async fn get(&self, path: &str, token: Option<&str>) -> Result<TestResponse> {
|
||||
self.request(Method::GET, path, None::<Value>, token).await
|
||||
}
|
||||
|
||||
/// Make a POST request
|
||||
pub async fn post<T: serde::Serialize>(
|
||||
&self,
|
||||
path: &str,
|
||||
body: T,
|
||||
token: Option<&str>,
|
||||
) -> Result<TestResponse> {
|
||||
self.request(Method::POST, path, Some(body), token).await
|
||||
}
|
||||
|
||||
/// Make a PUT request
|
||||
#[allow(dead_code)]
|
||||
pub async fn put<T: serde::Serialize>(
|
||||
&self,
|
||||
path: &str,
|
||||
body: T,
|
||||
token: Option<&str>,
|
||||
) -> Result<TestResponse> {
|
||||
self.request(Method::PUT, path, Some(body), token).await
|
||||
}
|
||||
|
||||
/// Make a DELETE request
|
||||
#[allow(dead_code)]
|
||||
pub async fn delete(&self, path: &str, token: Option<&str>) -> Result<TestResponse> {
|
||||
self.request(Method::DELETE, path, None::<Value>, token)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Make a generic HTTP request
|
||||
async fn request<T: serde::Serialize>(
|
||||
&self,
|
||||
method: Method,
|
||||
path: &str,
|
||||
body: Option<T>,
|
||||
token: Option<&str>,
|
||||
) -> Result<TestResponse> {
|
||||
let mut request = Request::builder()
|
||||
.method(method)
|
||||
.uri(path)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
|
||||
// Add authorization header if token provided
|
||||
if let Some(token) = token.or(self.token.as_deref()) {
|
||||
request = request.header(header::AUTHORIZATION, format!("Bearer {}", token));
|
||||
}
|
||||
|
||||
let request = if let Some(body) = body {
|
||||
request.body(Body::from(serde_json::to_string(&body).unwrap()))
|
||||
} else {
|
||||
request.body(Body::empty())
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
let response = self
|
||||
.app
|
||||
.clone()
|
||||
.call(request)
|
||||
.await
|
||||
.expect("Failed to execute request");
|
||||
|
||||
Ok(TestResponse::new(response))
|
||||
}
|
||||
|
||||
/// Get authenticated token
|
||||
pub fn token(&self) -> Option<&str> {
|
||||
self.token.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestContext {
|
||||
fn drop(&mut self) {
|
||||
// Cleanup the test schema when the context is dropped
|
||||
// Best-effort async cleanup - schema will be dropped shortly after test completes
|
||||
// If tests are interrupted, run ./scripts/cleanup-test-schemas.sh
|
||||
let schema = self.schema.clone();
|
||||
let test_packs_dir = self.test_packs_dir.clone();
|
||||
|
||||
// Spawn cleanup task in background
|
||||
let _ = tokio::spawn(async move {
|
||||
if let Err(e) = cleanup_test_schema(&schema).await {
|
||||
eprintln!("Failed to cleanup test schema {}: {}", schema, e);
|
||||
}
|
||||
});
|
||||
|
||||
// Cleanup the test packs directory synchronously
|
||||
let _ = std::fs::remove_dir_all(&test_packs_dir);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test response wrapper
|
||||
pub struct TestResponse {
|
||||
response: axum::response::Response,
|
||||
}
|
||||
|
||||
impl TestResponse {
|
||||
pub fn new(response: axum::response::Response) -> Self {
|
||||
Self { response }
|
||||
}
|
||||
|
||||
/// Get response status code
|
||||
pub fn status(&self) -> StatusCode {
|
||||
self.response.status()
|
||||
}
|
||||
|
||||
/// Deserialize response body as JSON
|
||||
pub async fn json<T: DeserializeOwned>(self) -> Result<T> {
|
||||
let body = self.response.into_body();
|
||||
let bytes = axum::body::to_bytes(body, usize::MAX).await?;
|
||||
Ok(serde_json::from_slice(&bytes)?)
|
||||
}
|
||||
|
||||
/// Get response body as text
|
||||
#[allow(dead_code)]
|
||||
pub async fn text(self) -> Result<String> {
|
||||
let body = self.response.into_body();
|
||||
let bytes = axum::body::to_bytes(body, usize::MAX).await?;
|
||||
Ok(String::from_utf8(bytes.to_vec())?)
|
||||
}
|
||||
|
||||
/// Assert status code
|
||||
#[allow(dead_code)]
|
||||
pub fn assert_status(self, expected: StatusCode) -> Self {
|
||||
assert_eq!(
|
||||
self.response.status(),
|
||||
expected,
|
||||
"Expected status {}, got {}",
|
||||
expected,
|
||||
self.response.status()
|
||||
);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Fixture for creating test packs
|
||||
#[allow(dead_code)]
|
||||
pub async fn create_test_pack(pool: &PgPool, ref_name: &str) -> Result<Pack> {
|
||||
let input = CreatePackInput {
|
||||
r#ref: ref_name.to_string(),
|
||||
label: format!("Test Pack {}", ref_name),
|
||||
description: Some(format!("Test pack for {}", ref_name)),
|
||||
version: "1.0.0".to_string(),
|
||||
conf_schema: json!({}),
|
||||
config: json!({}),
|
||||
meta: json!({
|
||||
"author": "test",
|
||||
"keywords": ["test"]
|
||||
}),
|
||||
tags: vec!["test".to_string()],
|
||||
runtime_deps: vec![],
|
||||
is_standard: false,
|
||||
};
|
||||
|
||||
Ok(PackRepository::create(pool, input).await?)
|
||||
}
|
||||
|
||||
/// Fixture for creating test actions
|
||||
#[allow(dead_code)]
|
||||
pub async fn create_test_action(pool: &PgPool, pack_id: i64, ref_name: &str) -> Result<Action> {
|
||||
let input = CreateActionInput {
|
||||
r#ref: ref_name.to_string(),
|
||||
pack: pack_id,
|
||||
pack_ref: format!("pack_{}", pack_id),
|
||||
label: format!("Test Action {}", ref_name),
|
||||
description: format!("Test action for {}", ref_name),
|
||||
entrypoint: "main.py".to_string(),
|
||||
runtime: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
is_adhoc: false,
|
||||
};
|
||||
|
||||
Ok(ActionRepository::create(pool, input).await?)
|
||||
}
|
||||
|
||||
/// Fixture for creating test triggers
|
||||
#[allow(dead_code)]
|
||||
pub async fn create_test_trigger(pool: &PgPool, pack_id: i64, ref_name: &str) -> Result<Trigger> {
|
||||
let input = CreateTriggerInput {
|
||||
r#ref: ref_name.to_string(),
|
||||
pack: Some(pack_id),
|
||||
pack_ref: Some(format!("pack_{}", pack_id)),
|
||||
label: format!("Test Trigger {}", ref_name),
|
||||
description: Some(format!("Test trigger for {}", ref_name)),
|
||||
enabled: true,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
is_adhoc: false,
|
||||
};
|
||||
|
||||
Ok(TriggerRepository::create(pool, input).await?)
|
||||
}
|
||||
|
||||
/// Fixture for creating test workflows
|
||||
#[allow(dead_code)]
|
||||
pub async fn create_test_workflow(
|
||||
pool: &PgPool,
|
||||
pack_id: i64,
|
||||
pack_ref: &str,
|
||||
ref_name: &str,
|
||||
) -> Result<attune_common::models::workflow::WorkflowDefinition> {
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: ref_name.to_string(),
|
||||
pack: pack_id,
|
||||
pack_ref: pack_ref.to_string(),
|
||||
label: format!("Test Workflow {}", ref_name),
|
||||
description: Some(format!("Test workflow for {}", ref_name)),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({
|
||||
"tasks": [
|
||||
{
|
||||
"name": "test_task",
|
||||
"action": "core.echo",
|
||||
"input": {"message": "test"}
|
||||
}
|
||||
]
|
||||
}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
Ok(WorkflowDefinitionRepository::create(pool, input).await?)
|
||||
}
|
||||
|
||||
/// Assert that a value matches expected JSON structure
|
||||
#[macro_export]
|
||||
macro_rules! assert_json_contains {
|
||||
($actual:expr, $expected:expr) => {
|
||||
let actual: serde_json::Value = $actual;
|
||||
let expected: serde_json::Value = $expected;
|
||||
|
||||
// This is a simple implementation - you might want more sophisticated matching
|
||||
assert!(
|
||||
actual.get("data").is_some(),
|
||||
"Response should have 'data' field"
|
||||
);
|
||||
};
|
||||
}
|
||||
686
crates/api/tests/pack_registry_tests.rs
Normal file
686
crates/api/tests/pack_registry_tests.rs
Normal file
@@ -0,0 +1,686 @@
|
||||
//! Integration tests for pack registry system
|
||||
//!
|
||||
//! This module tests:
|
||||
//! - End-to-end pack installation from all sources (git, archive, local, registry)
|
||||
//! - Dependency validation during installation
|
||||
//! - Installation metadata tracking
|
||||
//! - Checksum verification
|
||||
//! - Error handling and edge cases
|
||||
|
||||
mod helpers;
|
||||
|
||||
use attune_common::{
|
||||
models::Pack,
|
||||
pack_registry::calculate_directory_checksum,
|
||||
repositories::{pack::PackRepository, pack_installation::PackInstallationRepository, List},
|
||||
};
|
||||
use helpers::{Result, TestContext};
|
||||
use serde_json::json;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Helper to create a test pack directory with pack.yaml
|
||||
fn create_test_pack_dir(name: &str, version: &str) -> Result<TempDir> {
|
||||
let temp_dir = TempDir::new()?;
|
||||
let pack_yaml = format!(
|
||||
r#"
|
||||
ref: {}
|
||||
name: Test Pack {}
|
||||
version: {}
|
||||
description: Test pack for integration tests
|
||||
author: Test Author
|
||||
email: test@example.com
|
||||
keywords:
|
||||
- test
|
||||
- integration
|
||||
dependencies: []
|
||||
python: "3.8"
|
||||
actions:
|
||||
test_action:
|
||||
entry_point: test.py
|
||||
runner_type: python-script
|
||||
"#,
|
||||
name, name, version
|
||||
);
|
||||
|
||||
fs::write(temp_dir.path().join("pack.yaml"), pack_yaml)?;
|
||||
|
||||
// Create a simple action file
|
||||
let action_content = r#"
|
||||
#!/usr/bin/env python3
|
||||
print("Test action executed")
|
||||
"#;
|
||||
fs::write(temp_dir.path().join("test.py"), action_content)?;
|
||||
|
||||
Ok(temp_dir)
|
||||
}
|
||||
|
||||
/// Helper to create a pack with dependencies
|
||||
fn create_pack_with_deps(name: &str, deps: &[&str]) -> Result<TempDir> {
|
||||
let temp_dir = TempDir::new()?;
|
||||
let deps_yaml = deps
|
||||
.iter()
|
||||
.map(|d| format!(" - {}", d))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let pack_yaml = format!(
|
||||
r#"
|
||||
ref: {}
|
||||
name: Test Pack {}
|
||||
version: 1.0.0
|
||||
description: Test pack with dependencies
|
||||
author: Test Author
|
||||
dependencies:
|
||||
{}
|
||||
python: "3.8"
|
||||
actions:
|
||||
test_action:
|
||||
entry_point: test.py
|
||||
runner_type: python-script
|
||||
"#,
|
||||
name, name, deps_yaml
|
||||
);
|
||||
|
||||
fs::write(temp_dir.path().join("pack.yaml"), pack_yaml)?;
|
||||
fs::write(temp_dir.path().join("test.py"), "print('test')")?;
|
||||
|
||||
Ok(temp_dir)
|
||||
}
|
||||
|
||||
/// Helper to create a pack with specific runtime requirements
|
||||
fn create_pack_with_runtime(
|
||||
name: &str,
|
||||
python: Option<&str>,
|
||||
nodejs: Option<&str>,
|
||||
) -> Result<TempDir> {
|
||||
let temp_dir = TempDir::new()?;
|
||||
|
||||
let python_line = python
|
||||
.map(|v| format!("python: \"{}\"", v))
|
||||
.unwrap_or_default();
|
||||
let nodejs_line = nodejs
|
||||
.map(|v| format!("nodejs: \"{}\"", v))
|
||||
.unwrap_or_default();
|
||||
|
||||
let pack_yaml = format!(
|
||||
r#"
|
||||
ref: {}
|
||||
name: Test Pack {}
|
||||
version: 1.0.0
|
||||
description: Test pack with runtime requirements
|
||||
author: Test Author
|
||||
{}
|
||||
{}
|
||||
actions:
|
||||
test_action:
|
||||
entry_point: test.py
|
||||
runner_type: python-script
|
||||
"#,
|
||||
name, name, python_line, nodejs_line
|
||||
);
|
||||
|
||||
fs::write(temp_dir.path().join("pack.yaml"), pack_yaml)?;
|
||||
fs::write(temp_dir.path().join("test.py"), "print('test')")?;
|
||||
|
||||
Ok(temp_dir)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_from_local_directory() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create a test pack directory
|
||||
let pack_dir = create_test_pack_dir("local-test", "1.0.0")?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
// Install pack from local directory
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let body_text = response.text().await?;
|
||||
|
||||
if status != 200 {
|
||||
eprintln!("Error response (status {}): {}", status, body_text);
|
||||
}
|
||||
assert_eq!(status, 200, "Installation should succeed");
|
||||
|
||||
let body: serde_json::Value = serde_json::from_str(&body_text)?;
|
||||
assert_eq!(body["data"]["pack"]["ref"], "local-test");
|
||||
assert_eq!(body["data"]["pack"]["version"], "1.0.0");
|
||||
assert_eq!(body["data"]["tests_skipped"], true);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_with_dependency_validation_success() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// First, install a dependency pack
|
||||
let dep_pack_dir = create_test_pack_dir("core", "1.0.0")?;
|
||||
let dep_path = dep_pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
ctx.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": dep_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Now install a pack that depends on it
|
||||
let pack_dir = create_pack_with_deps("dependent-pack", &["core"])?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": false // Enable dependency validation
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
200,
|
||||
"Installation should succeed when dependencies are met"
|
||||
);
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
assert_eq!(body["data"]["pack"]["ref"], "dependent-pack");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_with_missing_dependency_fails() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create a pack with an unmet dependency
|
||||
let pack_dir = create_pack_with_deps("dependent-pack", &["missing-pack"])?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": false // Enable dependency validation
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Should fail with 400 Bad Request
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
400,
|
||||
"Installation should fail when dependencies are missing"
|
||||
);
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let error_msg = body["error"].as_str().unwrap();
|
||||
assert!(
|
||||
error_msg.contains("dependency validation failed") || error_msg.contains("missing-pack"),
|
||||
"Error should mention dependency validation failure"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_skip_deps_bypasses_validation() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create a pack with an unmet dependency
|
||||
let pack_dir = create_pack_with_deps("dependent-pack", &["missing-pack"])?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true // Skip dependency validation
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Should succeed because validation is skipped
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
200,
|
||||
"Installation should succeed when validation is skipped"
|
||||
);
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
assert_eq!(body["data"]["pack"]["ref"], "dependent-pack");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_with_runtime_validation() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create a pack with reasonable runtime requirements
|
||||
let pack_dir = create_pack_with_runtime("runtime-test", Some("3.8"), None)?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": false // Enable validation
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Result depends on whether Python 3.8+ is available in test environment
|
||||
// We just verify the response is well-formed
|
||||
let status = response.status();
|
||||
assert!(
|
||||
status == 200 || status == 400,
|
||||
"Should either succeed or fail gracefully"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_metadata_tracking() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Install a pack
|
||||
let pack_dir = create_test_pack_dir("metadata-test", "1.0.0")?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
let original_checksum = calculate_directory_checksum(pack_dir.path())?;
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let pack_id = body["data"]["pack"]["id"].as_i64().unwrap();
|
||||
|
||||
// Verify installation metadata was created
|
||||
let installation_repo = PackInstallationRepository::new(ctx.pool.clone());
|
||||
let installation = installation_repo
|
||||
.get_by_pack_id(pack_id)
|
||||
.await?
|
||||
.expect("Should have installation record");
|
||||
|
||||
assert_eq!(installation.pack_id, pack_id);
|
||||
assert_eq!(installation.source_type, "local_directory");
|
||||
assert!(installation.source_url.is_some());
|
||||
assert!(installation.checksum.is_some());
|
||||
|
||||
// Verify checksum matches
|
||||
let stored_checksum = installation.checksum.as_ref().unwrap();
|
||||
assert_eq!(
|
||||
stored_checksum, &original_checksum,
|
||||
"Stored checksum should match calculated checksum"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_force_reinstall() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
let pack_dir = create_test_pack_dir("force-test", "1.0.0")?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
// Install once
|
||||
let response1 = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": &pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response1.status(), 200);
|
||||
|
||||
// Try to install again without force - should work but might replace
|
||||
let response2 = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": &pack_path,
|
||||
"force": true,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response2.status(), 200, "Force reinstall should succeed");
|
||||
|
||||
// Verify pack exists
|
||||
let packs = PackRepository::list(&ctx.pool).await?;
|
||||
let force_test_packs: Vec<&Pack> = packs.iter().filter(|p| p.r#ref == "force-test").collect();
|
||||
assert_eq!(
|
||||
force_test_packs.len(),
|
||||
1,
|
||||
"Should have exactly one force-test pack"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_storage_path_created() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
let pack_dir = create_test_pack_dir("storage-test", "2.3.4")?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let pack_id = body["data"]["pack"]["id"].as_i64().unwrap();
|
||||
|
||||
// Verify installation metadata has storage path
|
||||
let installation_repo = PackInstallationRepository::new(ctx.pool.clone());
|
||||
let installation = installation_repo
|
||||
.get_by_pack_id(pack_id)
|
||||
.await?
|
||||
.expect("Should have installation record");
|
||||
|
||||
let storage_path = &installation.storage_path;
|
||||
assert!(
|
||||
storage_path.contains("storage-test"),
|
||||
"Storage path should contain pack ref"
|
||||
);
|
||||
assert!(
|
||||
storage_path.contains("2.3.4"),
|
||||
"Storage path should contain version"
|
||||
);
|
||||
|
||||
// Note: We can't verify the actual filesystem without knowing the config path
|
||||
// but we verify the path structure is correct
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_invalid_source() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": "/nonexistent/path/to/pack",
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
404,
|
||||
"Should fail with not found status for nonexistent path"
|
||||
);
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
assert!(body["error"].is_string(), "Should have error message");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_missing_pack_yaml() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create directory without pack.yaml
|
||||
let temp_dir = TempDir::new()?;
|
||||
fs::write(temp_dir.path().join("readme.txt"), "No pack.yaml here")?;
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": temp_dir.path().to_string_lossy(),
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), 400, "Should fail with bad request");
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let error = body["error"].as_str().unwrap();
|
||||
assert!(
|
||||
error.contains("pack.yaml"),
|
||||
"Error should mention pack.yaml"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_invalid_pack_yaml() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create pack.yaml with invalid content
|
||||
let temp_dir = TempDir::new()?;
|
||||
fs::write(temp_dir.path().join("pack.yaml"), "invalid: yaml: content:")?;
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": temp_dir.path().to_string_lossy(),
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Should fail with error status
|
||||
assert!(response.status().is_client_error() || response.status().is_server_error());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_without_auth_fails() -> Result<()> {
|
||||
let ctx = TestContext::new().await?; // No auth
|
||||
|
||||
let pack_dir = create_test_pack_dir("auth-test", "1.0.0")?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
None, // No token
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), 401, "Should require authentication");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_pack_installations() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Install multiple packs
|
||||
for i in 1..=3 {
|
||||
let pack_dir = create_test_pack_dir(&format!("multi-pack-{}", i), "1.0.0")?;
|
||||
let pack_path = pack_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_path,
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
200,
|
||||
"Pack {} installation should succeed",
|
||||
i
|
||||
);
|
||||
}
|
||||
|
||||
// Verify all packs are installed
|
||||
let packs = <PackRepository as List>::list(&ctx.pool).await?;
|
||||
let multi_packs: Vec<&Pack> = packs
|
||||
.iter()
|
||||
.filter(|p| p.r#ref.starts_with("multi-pack-"))
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
multi_packs.len(),
|
||||
3,
|
||||
"Should have 3 multi-pack installations"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_install_pack_version_upgrade() -> Result<()> {
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Install version 1.0.0
|
||||
let pack_dir_v1 = create_test_pack_dir("version-test", "1.0.0")?;
|
||||
let response1 = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_dir_v1.path().to_string_lossy(),
|
||||
"force": false,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response1.status(), 200);
|
||||
|
||||
// Install version 2.0.0 with force
|
||||
let pack_dir_v2 = create_test_pack_dir("version-test", "2.0.0")?;
|
||||
let response2 = ctx
|
||||
.post(
|
||||
"/api/v1/packs/install",
|
||||
json!({
|
||||
"source": pack_dir_v2.path().to_string_lossy(),
|
||||
"force": true,
|
||||
"skip_tests": true,
|
||||
"skip_deps": true
|
||||
}),
|
||||
Some(token),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response2.status(), 200);
|
||||
|
||||
let body: serde_json::Value = response2.json().await?;
|
||||
assert_eq!(
|
||||
body["data"]["pack"]["version"], "2.0.0",
|
||||
"Should be upgraded to version 2.0.0"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
261
crates/api/tests/pack_workflow_tests.rs
Normal file
261
crates/api/tests/pack_workflow_tests.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
//! Integration tests for pack workflow sync and validation
|
||||
|
||||
mod helpers;
|
||||
|
||||
use helpers::{create_test_pack, TestContext};
|
||||
use serde_json::json;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Create test pack structure with workflows on filesystem
|
||||
fn create_pack_with_workflows(base_dir: &std::path::Path, pack_name: &str) {
|
||||
let pack_dir = base_dir.join(pack_name);
|
||||
let workflows_dir = pack_dir.join("workflows");
|
||||
|
||||
// Create directory structure
|
||||
fs::create_dir_all(&workflows_dir).unwrap();
|
||||
|
||||
// Create a valid workflow YAML
|
||||
let workflow_yaml = format!(
|
||||
r#"
|
||||
ref: {}.example_workflow
|
||||
label: Example Workflow
|
||||
description: A test workflow for integration testing
|
||||
version: "1.0.0"
|
||||
enabled: true
|
||||
parameters:
|
||||
message:
|
||||
type: string
|
||||
required: true
|
||||
description: "Message to display"
|
||||
tasks:
|
||||
- name: display_message
|
||||
action: core.echo
|
||||
input:
|
||||
message: "{{{{ parameters.message }}}}"
|
||||
"#,
|
||||
pack_name
|
||||
);
|
||||
|
||||
fs::write(workflows_dir.join("example_workflow.yaml"), workflow_yaml).unwrap();
|
||||
|
||||
// Create another workflow
|
||||
let workflow2_yaml = format!(
|
||||
r#"
|
||||
ref: {}.another_workflow
|
||||
label: Another Workflow
|
||||
description: Second test workflow
|
||||
version: "1.0.0"
|
||||
enabled: false
|
||||
tasks:
|
||||
- name: task1
|
||||
action: core.noop
|
||||
"#,
|
||||
pack_name
|
||||
);
|
||||
|
||||
fs::write(workflows_dir.join("another_workflow.yaml"), workflow2_yaml).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sync_pack_workflows_endpoint() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Use unique pack name to avoid conflicts in parallel tests
|
||||
let pack_name = format!(
|
||||
"test_pack_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string()
|
||||
);
|
||||
|
||||
// Create temporary directory for pack workflows
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
create_pack_with_workflows(temp_dir.path(), &pack_name);
|
||||
|
||||
// Create pack in database
|
||||
create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
// Note: This test will fail in CI without proper packs_base_dir configuration
|
||||
// The sync endpoint expects workflows to be in /opt/attune/packs by default
|
||||
// In a real integration test environment, we would need to:
|
||||
// 1. Configure packs_base_dir to point to temp_dir
|
||||
// 2. Or mount temp_dir to /opt/attune/packs
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
&format!("/api/v1/packs/{}/workflows/sync", pack_name),
|
||||
json!({}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// This might return 200 with 0 workflows if pack dir doesn't exist in configured location
|
||||
assert!(response.status().is_success() || response.status().is_client_error());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_pack_workflows_endpoint() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Use unique pack name to avoid conflicts in parallel tests
|
||||
let pack_name = format!(
|
||||
"test_pack_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string()
|
||||
);
|
||||
|
||||
// Create pack in database
|
||||
create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
&format!("/api/v1/packs/{}/workflows/validate", pack_name),
|
||||
json!({}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should succeed even if no workflows exist
|
||||
assert!(response.status().is_success() || response.status().is_client_error());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sync_nonexistent_pack_returns_404() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/nonexistent_pack/workflows/sync",
|
||||
json!({}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 404);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_nonexistent_pack_returns_404() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs/nonexistent_pack/workflows/validate",
|
||||
json!({}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 404);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sync_workflows_requires_authentication() {
|
||||
let ctx = TestContext::new().await.unwrap();
|
||||
|
||||
// Use unique pack name to avoid conflicts in parallel tests
|
||||
let pack_name = format!(
|
||||
"test_pack_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string()
|
||||
);
|
||||
|
||||
// Create pack in database
|
||||
create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
&format!("/api/v1/packs/{}/workflows/sync", pack_name),
|
||||
json!({}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// TODO: API endpoints don't currently enforce authentication
|
||||
// This should be 401 once auth middleware is implemented
|
||||
assert!(response.status().is_success() || response.status().is_client_error());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_workflows_requires_authentication() {
|
||||
let ctx = TestContext::new().await.unwrap();
|
||||
|
||||
// Use unique pack name to avoid conflicts in parallel tests
|
||||
let pack_name = format!(
|
||||
"test_pack_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string()
|
||||
);
|
||||
|
||||
// Create pack in database
|
||||
create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
&format!("/api/v1/packs/{}/workflows/validate", pack_name),
|
||||
json!({}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// TODO: API endpoints don't currently enforce authentication
|
||||
// This should be 401 once auth middleware is implemented
|
||||
assert!(response.status().is_success() || response.status().is_client_error());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pack_creation_with_auto_sync() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create pack via API (should auto-sync workflows if they exist on filesystem)
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/packs",
|
||||
json!({
|
||||
"ref": "auto_sync_pack",
|
||||
"label": "Auto Sync Pack",
|
||||
"version": "1.0.0",
|
||||
"description": "A test pack with auto-sync"
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 201);
|
||||
|
||||
// Verify pack was created
|
||||
let get_response = ctx
|
||||
.get("/api/v1/packs/auto_sync_pack", ctx.token())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(get_response.status(), 200);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pack_update_with_auto_resync() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create pack first
|
||||
create_test_pack(&ctx.pool, "update_test_pack")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Update pack (should trigger workflow resync)
|
||||
let response = ctx
|
||||
.put(
|
||||
"/api/v1/packs/update_test_pack",
|
||||
json!({
|
||||
"label": "Updated Test Pack",
|
||||
"version": "1.1.0"
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
}
|
||||
537
crates/api/tests/sse_execution_stream_tests.rs
Normal file
537
crates/api/tests/sse_execution_stream_tests.rs
Normal file
@@ -0,0 +1,537 @@
|
||||
//! Integration tests for SSE execution stream endpoint
|
||||
//!
|
||||
//! These tests verify that:
|
||||
//! 1. PostgreSQL LISTEN/NOTIFY correctly triggers notifications
|
||||
//! 2. The SSE endpoint streams execution updates in real-time
|
||||
//! 3. Filtering by execution_id works correctly
|
||||
//! 4. Authentication is properly enforced
|
||||
//! 5. Reconnection and error handling work as expected
|
||||
|
||||
use attune_common::{
|
||||
models::*,
|
||||
repositories::{
|
||||
action::{ActionRepository, CreateActionInput},
|
||||
execution::{CreateExecutionInput, ExecutionRepository},
|
||||
pack::{CreatePackInput, PackRepository},
|
||||
Create,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::StreamExt;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::PgPool;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
mod helpers;
|
||||
use helpers::TestContext;
|
||||
|
||||
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
|
||||
|
||||
/// Helper to set up test pack and action
|
||||
async fn setup_test_pack_and_action(pool: &PgPool) -> Result<(Pack, Action)> {
|
||||
let pack_input = CreatePackInput {
|
||||
r#ref: "test_sse_pack".to_string(),
|
||||
label: "Test SSE Pack".to_string(),
|
||||
description: Some("Pack for SSE testing".to_string()),
|
||||
version: "1.0.0".to_string(),
|
||||
conf_schema: json!({}),
|
||||
config: json!({}),
|
||||
meta: json!({"author": "test"}),
|
||||
tags: vec!["test".to_string()],
|
||||
runtime_deps: vec![],
|
||||
is_standard: false,
|
||||
};
|
||||
let pack = PackRepository::create(pool, pack_input).await?;
|
||||
|
||||
let action_input = CreateActionInput {
|
||||
r#ref: format!("{}.test_action", pack.r#ref),
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test action for SSE tests".to_string(),
|
||||
entrypoint: "test.sh".to_string(),
|
||||
runtime: None,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
is_adhoc: false,
|
||||
};
|
||||
let action = ActionRepository::create(pool, action_input).await?;
|
||||
|
||||
Ok((pack, action))
|
||||
}
|
||||
|
||||
/// Helper to create a test execution
|
||||
async fn create_test_execution(pool: &PgPool, action_id: i64) -> Result<Execution> {
|
||||
let input = CreateExecutionInput {
|
||||
action: Some(action_id),
|
||||
action_ref: format!("action_{}", action_id),
|
||||
config: None,
|
||||
parent: None,
|
||||
enforcement: None,
|
||||
executor: None,
|
||||
status: ExecutionStatus::Scheduled,
|
||||
result: None,
|
||||
workflow_task: None,
|
||||
};
|
||||
Ok(ExecutionRepository::create(pool, input).await?)
|
||||
}
|
||||
|
||||
/// This test requires a running API server on port 8080
|
||||
/// Run with: cargo test test_sse_stream_receives_execution_updates -- --ignored --nocapture
|
||||
/// After starting: cargo run -p attune-api -- -c config.test.yaml
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_sse_stream_receives_execution_updates() -> Result<()> {
|
||||
// Set up test context with auth
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create test pack, action, and execution
|
||||
let (_pack, action) = setup_test_pack_and_action(&ctx.pool).await?;
|
||||
let execution = create_test_execution(&ctx.pool, action.id).await?;
|
||||
|
||||
println!(
|
||||
"Created execution: id={}, status={:?}",
|
||||
execution.id, execution.status
|
||||
);
|
||||
|
||||
// Build SSE URL with authentication
|
||||
let sse_url = format!(
|
||||
"http://localhost:8080/api/v1/executions/stream?execution_id={}&token={}",
|
||||
execution.id, token
|
||||
);
|
||||
|
||||
// Create SSE stream
|
||||
let mut stream = EventSource::get(&sse_url);
|
||||
|
||||
// Spawn a task to update the execution status after a short delay
|
||||
let pool_clone = ctx.pool.clone();
|
||||
let execution_id = execution.id;
|
||||
tokio::spawn(async move {
|
||||
// Wait a bit to ensure SSE connection is established
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
println!("Updating execution {} to 'running' status", execution_id);
|
||||
|
||||
// Update execution status - this should trigger PostgreSQL NOTIFY
|
||||
let _ = sqlx::query(
|
||||
"UPDATE execution SET status = 'running', start_time = NOW() WHERE id = $1",
|
||||
)
|
||||
.bind(execution_id)
|
||||
.execute(&pool_clone)
|
||||
.await;
|
||||
|
||||
println!("Update executed, waiting before setting to succeeded");
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Update to succeeded
|
||||
let _ = sqlx::query(
|
||||
"UPDATE execution SET status = 'succeeded', end_time = NOW() WHERE id = $1",
|
||||
)
|
||||
.bind(execution_id)
|
||||
.execute(&pool_clone)
|
||||
.await;
|
||||
|
||||
println!("Execution {} updated to 'succeeded'", execution_id);
|
||||
});
|
||||
|
||||
// Wait for SSE events with timeout
|
||||
let mut received_running = false;
|
||||
let mut received_succeeded = false;
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 20; // 10 seconds total
|
||||
|
||||
while attempts < max_attempts && (!received_running || !received_succeeded) {
|
||||
match timeout(Duration::from_millis(500), stream.next()).await {
|
||||
Ok(Some(Ok(event))) => {
|
||||
println!("Received SSE event: {:?}", event);
|
||||
|
||||
match event {
|
||||
Event::Open => {
|
||||
println!("SSE connection established");
|
||||
}
|
||||
Event::Message(msg) => {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&msg.data) {
|
||||
println!(
|
||||
"Parsed event data: {}",
|
||||
serde_json::to_string_pretty(&data)?
|
||||
);
|
||||
|
||||
if let Some(entity_type) =
|
||||
data.get("entity_type").and_then(|v| v.as_str())
|
||||
{
|
||||
if entity_type == "execution" {
|
||||
if let Some(event_data) = data.get("data") {
|
||||
if let Some(status) =
|
||||
event_data.get("status").and_then(|v| v.as_str())
|
||||
{
|
||||
println!(
|
||||
"Received execution update with status: {}",
|
||||
status
|
||||
);
|
||||
|
||||
if status == "running" {
|
||||
received_running = true;
|
||||
println!("✓ Received 'running' status");
|
||||
} else if status == "succeeded" {
|
||||
received_succeeded = true;
|
||||
println!("✓ Received 'succeeded' status");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Some(Err(e))) => {
|
||||
eprintln!("SSE stream error: {}", e);
|
||||
break;
|
||||
}
|
||||
Ok(None) => {
|
||||
println!("SSE stream ended");
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout waiting for next event
|
||||
attempts += 1;
|
||||
println!(
|
||||
"Timeout waiting for event (attempt {}/{})",
|
||||
attempts, max_attempts
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we received both updates
|
||||
assert!(
|
||||
received_running,
|
||||
"Should have received execution update with status 'running'"
|
||||
);
|
||||
assert!(
|
||||
received_succeeded,
|
||||
"Should have received execution update with status 'succeeded'"
|
||||
);
|
||||
|
||||
println!("✓ Test passed: SSE stream received all expected updates");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test that SSE stream correctly filters by execution_id
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_sse_stream_filters_by_execution_id() -> Result<()> {
|
||||
// Set up test context with auth
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create test pack, action, and TWO executions
|
||||
let (_pack, action) = setup_test_pack_and_action(&ctx.pool).await?;
|
||||
let execution1 = create_test_execution(&ctx.pool, action.id).await?;
|
||||
let execution2 = create_test_execution(&ctx.pool, action.id).await?;
|
||||
|
||||
println!(
|
||||
"Created executions: id1={}, id2={}",
|
||||
execution1.id, execution2.id
|
||||
);
|
||||
|
||||
// Subscribe to updates for execution1 only
|
||||
let sse_url = format!(
|
||||
"http://localhost:8080/api/v1/executions/stream?execution_id={}&token={}",
|
||||
execution1.id, token
|
||||
);
|
||||
|
||||
let mut stream = EventSource::get(&sse_url);
|
||||
|
||||
// Update both executions
|
||||
let pool_clone = ctx.pool.clone();
|
||||
let exec1_id = execution1.id;
|
||||
let exec2_id = execution2.id;
|
||||
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Update execution2 (should NOT appear in filtered stream)
|
||||
let _ = sqlx::query("UPDATE execution SET status = 'completed' WHERE id = $1")
|
||||
.bind(exec2_id)
|
||||
.execute(&pool_clone)
|
||||
.await;
|
||||
|
||||
println!("Updated execution2 {} to 'completed'", exec2_id);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
// Update execution1 (SHOULD appear in filtered stream)
|
||||
let _ = sqlx::query("UPDATE execution SET status = 'running' WHERE id = $1")
|
||||
.bind(exec1_id)
|
||||
.execute(&pool_clone)
|
||||
.await;
|
||||
|
||||
println!("Updated execution1 {} to 'running'", exec1_id);
|
||||
});
|
||||
|
||||
// Wait for events
|
||||
let mut received_exec1_update = false;
|
||||
let mut received_exec2_update = false;
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 20;
|
||||
|
||||
while attempts < max_attempts && !received_exec1_update {
|
||||
match timeout(Duration::from_millis(500), stream.next()).await {
|
||||
Ok(Some(Ok(event))) => match event {
|
||||
Event::Open => {}
|
||||
Event::Message(msg) => {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&msg.data) {
|
||||
if let Some(entity_id) = data.get("entity_id").and_then(|v| v.as_i64()) {
|
||||
println!("Received update for execution: {}", entity_id);
|
||||
|
||||
if entity_id == execution1.id {
|
||||
received_exec1_update = true;
|
||||
println!("✓ Received update for execution1 (correct)");
|
||||
} else if entity_id == execution2.id {
|
||||
received_exec2_update = true;
|
||||
println!(
|
||||
"✗ Received update for execution2 (should be filtered out)"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Ok(Some(Err(_))) | Ok(None) => break,
|
||||
Err(_) => {
|
||||
attempts += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive execution1 update but NOT execution2
|
||||
assert!(
|
||||
received_exec1_update,
|
||||
"Should have received update for execution1"
|
||||
);
|
||||
assert!(
|
||||
!received_exec2_update,
|
||||
"Should NOT have received update for execution2 (filtered out)"
|
||||
);
|
||||
|
||||
println!("✓ Test passed: SSE stream correctly filters by execution_id");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_sse_stream_requires_authentication() -> Result<()> {
|
||||
// Try to connect without token
|
||||
let sse_url = "http://localhost:8080/api/v1/executions/stream";
|
||||
|
||||
let mut stream = EventSource::get(sse_url);
|
||||
|
||||
// Should receive an error due to missing authentication
|
||||
let mut received_error = false;
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 5;
|
||||
|
||||
while attempts < max_attempts && !received_error {
|
||||
match timeout(Duration::from_millis(500), stream.next()).await {
|
||||
Ok(Some(Ok(_))) => {
|
||||
// Should not receive successful events without auth
|
||||
panic!("Received SSE event without authentication - this should not happen");
|
||||
}
|
||||
Ok(Some(Err(e))) => {
|
||||
println!("Correctly received error without auth: {}", e);
|
||||
received_error = true;
|
||||
}
|
||||
Ok(None) => {
|
||||
println!("Stream ended (expected behavior for unauthorized)");
|
||||
received_error = true;
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
attempts += 1;
|
||||
println!("Timeout waiting for response (attempt {})", attempts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
received_error,
|
||||
"Should have received error or stream closure due to missing authentication"
|
||||
);
|
||||
|
||||
println!("✓ Test passed: SSE stream requires authentication");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test streaming all executions (no filter)
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_sse_stream_all_executions() -> Result<()> {
|
||||
// Set up test context with auth
|
||||
let ctx = TestContext::new().await?.with_auth().await?;
|
||||
let token = ctx.token().unwrap();
|
||||
|
||||
// Create test pack, action, and multiple executions
|
||||
let (_pack, action) = setup_test_pack_and_action(&ctx.pool).await?;
|
||||
let execution1 = create_test_execution(&ctx.pool, action.id).await?;
|
||||
let execution2 = create_test_execution(&ctx.pool, action.id).await?;
|
||||
|
||||
println!(
|
||||
"Created executions: id1={}, id2={}",
|
||||
execution1.id, execution2.id
|
||||
);
|
||||
|
||||
// Subscribe to ALL execution updates (no execution_id filter)
|
||||
let sse_url = format!(
|
||||
"http://localhost:8080/api/v1/executions/stream?token={}",
|
||||
token
|
||||
);
|
||||
|
||||
let mut stream = EventSource::get(&sse_url);
|
||||
|
||||
// Update both executions
|
||||
let pool_clone = ctx.pool.clone();
|
||||
let exec1_id = execution1.id;
|
||||
let exec2_id = execution2.id;
|
||||
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Update execution1
|
||||
let _ = sqlx::query("UPDATE execution SET status = 'running' WHERE id = $1")
|
||||
.bind(exec1_id)
|
||||
.execute(&pool_clone)
|
||||
.await;
|
||||
|
||||
println!("Updated execution1 {} to 'running'", exec1_id);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
// Update execution2
|
||||
let _ = sqlx::query("UPDATE execution SET status = 'running' WHERE id = $1")
|
||||
.bind(exec2_id)
|
||||
.execute(&pool_clone)
|
||||
.await;
|
||||
|
||||
println!("Updated execution2 {} to 'running'", exec2_id);
|
||||
});
|
||||
|
||||
// Wait for events from BOTH executions
|
||||
let mut received_updates = std::collections::HashSet::new();
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 20;
|
||||
|
||||
while attempts < max_attempts && received_updates.len() < 2 {
|
||||
match timeout(Duration::from_millis(500), stream.next()).await {
|
||||
Ok(Some(Ok(event))) => match event {
|
||||
Event::Open => {}
|
||||
Event::Message(msg) => {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&msg.data) {
|
||||
if let Some(entity_id) = data.get("entity_id").and_then(|v| v.as_i64()) {
|
||||
println!("Received update for execution: {}", entity_id);
|
||||
received_updates.insert(entity_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Ok(Some(Err(_))) | Ok(None) => break,
|
||||
Err(_) => {
|
||||
attempts += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should have received updates for BOTH executions
|
||||
assert!(
|
||||
received_updates.contains(&execution1.id),
|
||||
"Should have received update for execution1"
|
||||
);
|
||||
assert!(
|
||||
received_updates.contains(&execution2.id),
|
||||
"Should have received update for execution2"
|
||||
);
|
||||
|
||||
println!("✓ Test passed: SSE stream received updates for all executions (no filter)");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test that PostgreSQL NOTIFY triggers actually fire
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_postgresql_notify_trigger_fires() -> Result<()> {
|
||||
let ctx = TestContext::new().await?;
|
||||
|
||||
// Create test pack, action, and execution
|
||||
let (_pack, action) = setup_test_pack_and_action(&ctx.pool).await?;
|
||||
let execution = create_test_execution(&ctx.pool, action.id).await?;
|
||||
|
||||
println!("Created execution: id={}", execution.id);
|
||||
|
||||
// Set up a listener on the PostgreSQL channel
|
||||
let mut listener = sqlx::postgres::PgListener::connect_with(&ctx.pool).await?;
|
||||
listener.listen("execution_events").await?;
|
||||
|
||||
println!("Listening on channel 'execution_events'");
|
||||
|
||||
// Update the execution in another task
|
||||
let pool_clone = ctx.pool.clone();
|
||||
let execution_id = execution.id;
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
println!("Updating execution {} to trigger NOTIFY", execution_id);
|
||||
|
||||
let _ = sqlx::query("UPDATE execution SET status = 'running' WHERE id = $1")
|
||||
.bind(execution_id)
|
||||
.execute(&pool_clone)
|
||||
.await;
|
||||
});
|
||||
|
||||
// Wait for the NOTIFY with a timeout
|
||||
let mut received_notification = false;
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 10;
|
||||
|
||||
while attempts < max_attempts && !received_notification {
|
||||
match timeout(Duration::from_millis(1000), listener.recv()).await {
|
||||
Ok(Ok(notification)) => {
|
||||
println!("Received NOTIFY: channel={}", notification.channel());
|
||||
println!("Payload: {}", notification.payload());
|
||||
|
||||
// Parse the payload
|
||||
if let Ok(data) = serde_json::from_str::<Value>(notification.payload()) {
|
||||
if let Some(entity_id) = data.get("entity_id").and_then(|v| v.as_i64()) {
|
||||
if entity_id == execution.id {
|
||||
println!("✓ Received NOTIFY for our execution");
|
||||
received_notification = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
eprintln!("Error receiving notification: {}", e);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
attempts += 1;
|
||||
println!("Timeout waiting for NOTIFY (attempt {})", attempts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
received_notification,
|
||||
"Should have received PostgreSQL NOTIFY when execution was updated"
|
||||
);
|
||||
|
||||
println!("✓ Test passed: PostgreSQL NOTIFY trigger fires correctly");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
518
crates/api/tests/webhook_api_tests.rs
Normal file
518
crates/api/tests/webhook_api_tests.rs
Normal file
@@ -0,0 +1,518 @@
|
||||
//! Integration tests for webhook API endpoints
|
||||
|
||||
use attune_api::{AppState, Server};
|
||||
use attune_common::{
|
||||
config::Config,
|
||||
db::Database,
|
||||
repositories::{
|
||||
pack::{CreatePackInput, PackRepository},
|
||||
trigger::{CreateTriggerInput, TriggerRepository},
|
||||
Create,
|
||||
},
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use serde_json::json;
|
||||
use tower::ServiceExt;
|
||||
|
||||
/// Helper to create test database and state
|
||||
async fn setup_test_state() -> AppState {
|
||||
let config = Config::load().expect("Failed to load config");
|
||||
let database = Database::new(&config.database)
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
AppState::new(database.pool().clone(), config)
|
||||
}
|
||||
|
||||
/// Helper to create a test pack
|
||||
async fn create_test_pack(state: &AppState, name: &str) -> i64 {
|
||||
let input = CreatePackInput {
|
||||
r#ref: name.to_string(),
|
||||
label: format!("{} Pack", name),
|
||||
description: Some(format!("Test pack for {}", name)),
|
||||
version: "1.0.0".to_string(),
|
||||
conf_schema: serde_json::json!({}),
|
||||
config: serde_json::json!({}),
|
||||
meta: serde_json::json!({}),
|
||||
tags: vec![],
|
||||
runtime_deps: vec![],
|
||||
is_standard: false,
|
||||
};
|
||||
|
||||
let pack = PackRepository::create(&state.db, input)
|
||||
.await
|
||||
.expect("Failed to create pack");
|
||||
|
||||
pack.id
|
||||
}
|
||||
|
||||
/// Helper to create a test trigger
|
||||
async fn create_test_trigger(
|
||||
state: &AppState,
|
||||
pack_id: i64,
|
||||
pack_ref: &str,
|
||||
trigger_ref: &str,
|
||||
) -> i64 {
|
||||
let input = CreateTriggerInput {
|
||||
r#ref: trigger_ref.to_string(),
|
||||
pack: Some(pack_id),
|
||||
pack_ref: Some(pack_ref.to_string()),
|
||||
label: format!("{} Trigger", trigger_ref),
|
||||
description: Some(format!("Test trigger {}", trigger_ref)),
|
||||
enabled: true,
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
is_adhoc: false,
|
||||
};
|
||||
|
||||
let trigger = TriggerRepository::create(&state.db, input)
|
||||
.await
|
||||
.expect("Failed to create trigger");
|
||||
|
||||
trigger.id
|
||||
}
|
||||
|
||||
/// Helper to get JWT token for authenticated requests
|
||||
async fn get_auth_token(app: &axum::Router, username: &str, password: &str) -> String {
|
||||
let login_request = json!({
|
||||
"username": username,
|
||||
"password": password
|
||||
});
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/auth/login")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&login_request).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
json["data"]["access_token"].as_str().unwrap().to_string()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Run with --ignored flag when database is available
|
||||
async fn test_enable_webhook() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data
|
||||
let pack_id = create_test_pack(&state, "webhook_test").await;
|
||||
let _trigger_id =
|
||||
create_test_trigger(&state, pack_id, "webhook_test", "webhook_test.trigger").await;
|
||||
|
||||
// Get auth token (assumes a test user exists)
|
||||
let token = get_auth_token(&app, "test_user", "test_password").await;
|
||||
|
||||
// Enable webhooks
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/triggers/webhook_test.trigger/webhooks/enable")
|
||||
.header("authorization", format!("Bearer {}", token))
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify response structure
|
||||
assert!(json["data"]["webhook_enabled"].as_bool().unwrap());
|
||||
assert!(json["data"]["webhook_key"].is_string());
|
||||
let webhook_key = json["data"]["webhook_key"].as_str().unwrap();
|
||||
assert!(webhook_key.starts_with("wh_"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_disable_webhook() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data
|
||||
let pack_id = create_test_pack(&state, "webhook_disable_test").await;
|
||||
let trigger_id = create_test_trigger(
|
||||
&state,
|
||||
pack_id,
|
||||
"webhook_disable_test",
|
||||
"webhook_disable_test.trigger",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Enable webhooks first
|
||||
let _ = TriggerRepository::enable_webhook(&state.db, trigger_id)
|
||||
.await
|
||||
.expect("Failed to enable webhook");
|
||||
|
||||
// Get auth token
|
||||
let token = get_auth_token(&app, "test_user", "test_password").await;
|
||||
|
||||
// Disable webhooks
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/triggers/webhook_disable_test.trigger/webhooks/disable")
|
||||
.header("authorization", format!("Bearer {}", token))
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify webhooks are disabled
|
||||
assert!(!json["data"]["webhook_enabled"].as_bool().unwrap());
|
||||
assert!(json["data"]["webhook_key"].is_null());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_regenerate_webhook_key() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data
|
||||
let pack_id = create_test_pack(&state, "webhook_regen_test").await;
|
||||
let trigger_id = create_test_trigger(
|
||||
&state,
|
||||
pack_id,
|
||||
"webhook_regen_test",
|
||||
"webhook_regen_test.trigger",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Enable webhooks first
|
||||
let original_info = TriggerRepository::enable_webhook(&state.db, trigger_id)
|
||||
.await
|
||||
.expect("Failed to enable webhook");
|
||||
|
||||
// Get auth token
|
||||
let token = get_auth_token(&app, "test_user", "test_password").await;
|
||||
|
||||
// Regenerate webhook key
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/triggers/webhook_regen_test.trigger/webhooks/regenerate")
|
||||
.header("authorization", format!("Bearer {}", token))
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify new key is different from original
|
||||
let new_key = json["data"]["webhook_key"].as_str().unwrap();
|
||||
assert_ne!(new_key, original_info.webhook_key);
|
||||
assert!(new_key.starts_with("wh_"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_regenerate_webhook_key_not_enabled() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data without enabling webhooks
|
||||
let pack_id = create_test_pack(&state, "webhook_not_enabled_test").await;
|
||||
let _trigger_id = create_test_trigger(
|
||||
&state,
|
||||
pack_id,
|
||||
"webhook_not_enabled_test",
|
||||
"webhook_not_enabled_test.trigger",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Get auth token
|
||||
let token = get_auth_token(&app, "test_user", "test_password").await;
|
||||
|
||||
// Try to regenerate without enabling first
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/triggers/webhook_not_enabled_test.trigger/webhooks/regenerate")
|
||||
.header("authorization", format!("Bearer {}", token))
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_receive_webhook() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data
|
||||
let pack_id = create_test_pack(&state, "webhook_receive_test").await;
|
||||
let trigger_id = create_test_trigger(
|
||||
&state,
|
||||
pack_id,
|
||||
"webhook_receive_test",
|
||||
"webhook_receive_test.trigger",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Enable webhooks
|
||||
let webhook_info = TriggerRepository::enable_webhook(&state.db, trigger_id)
|
||||
.await
|
||||
.expect("Failed to enable webhook");
|
||||
|
||||
// Send webhook
|
||||
let webhook_payload = json!({
|
||||
"payload": {
|
||||
"event": "test_event",
|
||||
"data": {
|
||||
"foo": "bar",
|
||||
"number": 42
|
||||
}
|
||||
},
|
||||
"headers": {
|
||||
"X-Test-Header": "test-value"
|
||||
},
|
||||
"source_ip": "192.168.1.1",
|
||||
"user_agent": "Test Agent/1.0"
|
||||
});
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri(format!("/api/v1/webhooks/{}", webhook_info.webhook_key))
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&webhook_payload).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify response
|
||||
assert!(json["data"]["event_id"].is_number());
|
||||
assert_eq!(
|
||||
json["data"]["trigger_ref"].as_str().unwrap(),
|
||||
"webhook_receive_test.trigger"
|
||||
);
|
||||
assert!(json["data"]["received_at"].is_string());
|
||||
assert_eq!(
|
||||
json["data"]["message"].as_str().unwrap(),
|
||||
"Webhook received successfully"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_receive_webhook_invalid_key() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state));
|
||||
let app = server.router();
|
||||
|
||||
// Try to send webhook with invalid key
|
||||
let webhook_payload = json!({
|
||||
"payload": {
|
||||
"event": "test_event"
|
||||
}
|
||||
});
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/webhooks/wh_invalid_key_12345")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&webhook_payload).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_receive_webhook_disabled() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data
|
||||
let pack_id = create_test_pack(&state, "webhook_disabled_test").await;
|
||||
let trigger_id = create_test_trigger(
|
||||
&state,
|
||||
pack_id,
|
||||
"webhook_disabled_test",
|
||||
"webhook_disabled_test.trigger",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Enable then disable webhooks
|
||||
let webhook_info = TriggerRepository::enable_webhook(&state.db, trigger_id)
|
||||
.await
|
||||
.expect("Failed to enable webhook");
|
||||
|
||||
TriggerRepository::disable_webhook(&state.db, trigger_id)
|
||||
.await
|
||||
.expect("Failed to disable webhook");
|
||||
|
||||
// Try to send webhook with disabled key
|
||||
let webhook_payload = json!({
|
||||
"payload": {
|
||||
"event": "test_event"
|
||||
}
|
||||
});
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri(format!("/api/v1/webhooks/{}", webhook_info.webhook_key))
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&webhook_payload).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should return 404 because disabled webhook keys are not found
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_webhook_requires_auth_for_management() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data
|
||||
let pack_id = create_test_pack(&state, "webhook_auth_test").await;
|
||||
let _trigger_id = create_test_trigger(
|
||||
&state,
|
||||
pack_id,
|
||||
"webhook_auth_test",
|
||||
"webhook_auth_test.trigger",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Try to enable without auth
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/triggers/webhook_auth_test.trigger/webhooks/enable")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_receive_webhook_minimal_payload() {
|
||||
let state = setup_test_state().await;
|
||||
let server = Server::new(std::sync::Arc::new(state.clone()));
|
||||
let app = server.router();
|
||||
|
||||
// Create test data
|
||||
let pack_id = create_test_pack(&state, "webhook_minimal_test").await;
|
||||
let trigger_id = create_test_trigger(
|
||||
&state,
|
||||
pack_id,
|
||||
"webhook_minimal_test",
|
||||
"webhook_minimal_test.trigger",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Enable webhooks
|
||||
let webhook_info = TriggerRepository::enable_webhook(&state.db, trigger_id)
|
||||
.await
|
||||
.expect("Failed to enable webhook");
|
||||
|
||||
// Send webhook with minimal payload (only required fields)
|
||||
let webhook_payload = json!({
|
||||
"payload": {
|
||||
"message": "minimal test"
|
||||
}
|
||||
});
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri(format!("/api/v1/webhooks/{}", webhook_info.webhook_key))
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&webhook_payload).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
1119
crates/api/tests/webhook_security_tests.rs
Normal file
1119
crates/api/tests/webhook_security_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
547
crates/api/tests/workflow_tests.rs
Normal file
547
crates/api/tests/workflow_tests.rs
Normal file
@@ -0,0 +1,547 @@
|
||||
//! Integration tests for workflow API endpoints
|
||||
|
||||
use attune_common::repositories::{
|
||||
workflow::{CreateWorkflowDefinitionInput, WorkflowDefinitionRepository},
|
||||
Create,
|
||||
};
|
||||
use axum::http::StatusCode;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
mod helpers;
|
||||
use helpers::*;
|
||||
|
||||
/// Generate a unique pack name for testing to avoid conflicts
|
||||
fn unique_pack_name() -> String {
|
||||
format!(
|
||||
"test_pack_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string()
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_workflow_success() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create a pack first
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
// Create workflow via API
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/workflows",
|
||||
json!({
|
||||
"ref": "test-pack.test_workflow",
|
||||
"pack_ref": pack.r#ref,
|
||||
"label": "Test Workflow",
|
||||
"description": "A test workflow",
|
||||
"version": "1.0.0",
|
||||
"definition": {
|
||||
"tasks": [
|
||||
{
|
||||
"name": "task1",
|
||||
"action": "core.echo",
|
||||
"input": {"message": "Hello"}
|
||||
}
|
||||
]
|
||||
},
|
||||
"tags": ["test", "automation"],
|
||||
"enabled": true
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::CREATED);
|
||||
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"]["ref"], "test-pack.test_workflow");
|
||||
assert_eq!(body["data"]["label"], "Test Workflow");
|
||||
assert_eq!(body["data"]["version"], "1.0.0");
|
||||
assert_eq!(body["data"]["enabled"], true);
|
||||
assert!(body["data"]["tags"].as_array().unwrap().len() == 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_workflow_duplicate_ref() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create a pack first
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
// Create workflow directly in DB
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: "test-pack.existing_workflow".to_string(),
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: "Existing Workflow".to_string(),
|
||||
description: Some("An existing workflow".to_string()),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Try to create workflow with same ref via API
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/workflows",
|
||||
json!({
|
||||
"ref": "test-pack.existing_workflow",
|
||||
"pack_ref": pack.r#ref,
|
||||
"label": "Duplicate Workflow",
|
||||
"version": "1.0.0",
|
||||
"definition": {"tasks": []}
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_workflow_pack_not_found() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/workflows",
|
||||
json!({
|
||||
"ref": "nonexistent.workflow",
|
||||
"pack_ref": "nonexistent-pack",
|
||||
"label": "Test Workflow",
|
||||
"version": "1.0.0",
|
||||
"definition": {"tasks": []}
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_workflow_by_ref() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create a pack and workflow
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: "test-pack.my_workflow".to_string(),
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: "My Workflow".to_string(),
|
||||
description: Some("A workflow".to_string()),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": [{"name": "task1"}]}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get workflow via API
|
||||
let response = ctx
|
||||
.get("/api/v1/workflows/test-pack.my_workflow", ctx.token())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"]["ref"], "test-pack.my_workflow");
|
||||
assert_eq!(body["data"]["label"], "My Workflow");
|
||||
assert_eq!(body["data"]["version"], "1.0.0");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_workflow_not_found() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.get("/api/v1/workflows/nonexistent.workflow", ctx.token())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_workflows() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create a pack and multiple workflows
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
for i in 1..=3 {
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: format!("test-pack.workflow_{}", i),
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: format!("Workflow {}", i),
|
||||
description: Some(format!("Workflow number {}", i)),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: i % 2 == 1, // Odd ones enabled
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// List all workflows (filtered by pack_ref for test isolation)
|
||||
let response = ctx
|
||||
.get(
|
||||
&format!(
|
||||
"/api/v1/workflows?page=1&per_page=10&pack_ref={}",
|
||||
pack_name
|
||||
),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"].as_array().unwrap().len(), 3);
|
||||
assert_eq!(body["pagination"]["total_items"], 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_workflows_by_pack() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create two packs
|
||||
let pack1_name = unique_pack_name();
|
||||
let pack2_name = unique_pack_name();
|
||||
let pack1 = create_test_pack(&ctx.pool, &pack1_name).await.unwrap();
|
||||
let pack2 = create_test_pack(&ctx.pool, &pack2_name).await.unwrap();
|
||||
|
||||
// Create workflows for pack1
|
||||
for i in 1..=2 {
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: format!("pack1.workflow_{}", i),
|
||||
pack: pack1.id,
|
||||
pack_ref: pack1.r#ref.clone(),
|
||||
label: format!("Pack1 Workflow {}", i),
|
||||
description: None,
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Create workflows for pack2
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: "pack2.workflow_1".to_string(),
|
||||
pack: pack2.id,
|
||||
pack_ref: pack2.r#ref.clone(),
|
||||
label: "Pack2 Workflow".to_string(),
|
||||
description: None,
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// List workflows for pack1
|
||||
let response = ctx
|
||||
.get(
|
||||
&format!("/api/v1/packs/{}/workflows", pack1_name),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = response.json().await.unwrap();
|
||||
let workflows = body["data"].as_array().unwrap();
|
||||
assert_eq!(workflows.len(), 2);
|
||||
assert!(workflows
|
||||
.iter()
|
||||
.all(|w| w["pack_ref"] == pack1.r#ref.as_str()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_workflows_with_filters() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
// Create workflows with different tags and enabled status
|
||||
let workflows = vec![
|
||||
("workflow1", vec!["incident", "approval"], true),
|
||||
("workflow2", vec!["incident"], false),
|
||||
("workflow3", vec!["automation"], true),
|
||||
];
|
||||
|
||||
for (ref_name, tags, enabled) in workflows {
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: format!("test-pack.{}", ref_name),
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: format!("Workflow {}", ref_name),
|
||||
description: Some(format!("Description for {}", ref_name)),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: tags.iter().map(|s| s.to_string()).collect(),
|
||||
enabled,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Filter by enabled (and pack_ref for isolation)
|
||||
let response = ctx
|
||||
.get(
|
||||
&format!("/api/v1/workflows?enabled=true&pack_ref={}", pack_name),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"].as_array().unwrap().len(), 2);
|
||||
|
||||
// Filter by tag (and pack_ref for isolation)
|
||||
let response = ctx
|
||||
.get(
|
||||
&format!("/api/v1/workflows?tags=incident&pack_ref={}", pack_name),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"].as_array().unwrap().len(), 2);
|
||||
|
||||
// Search by label (and pack_ref for isolation)
|
||||
let response = ctx
|
||||
.get(
|
||||
&format!("/api/v1/workflows?search=workflow1&pack_ref={}", pack_name),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"].as_array().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_workflow() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create a pack and workflow
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: "test-pack.update_test".to_string(),
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: "Original Label".to_string(),
|
||||
description: Some("Original description".to_string()),
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Update workflow via API
|
||||
let response = ctx
|
||||
.put(
|
||||
"/api/v1/workflows/test-pack.update_test",
|
||||
json!({
|
||||
"label": "Updated Label",
|
||||
"description": "Updated description",
|
||||
"version": "1.1.0",
|
||||
"enabled": false
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"]["label"], "Updated Label");
|
||||
assert_eq!(body["data"]["description"], "Updated description");
|
||||
assert_eq!(body["data"]["version"], "1.1.0");
|
||||
assert_eq!(body["data"]["enabled"], false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_workflow_not_found() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.put(
|
||||
"/api/v1/workflows/nonexistent.workflow",
|
||||
json!({
|
||||
"label": "Updated Label"
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_workflow() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Create a pack and workflow
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: "test-pack.delete_test".to_string(),
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: "To Be Deleted".to_string(),
|
||||
description: None,
|
||||
version: "1.0.0".to_string(),
|
||||
param_schema: None,
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Delete workflow via API
|
||||
let response = ctx
|
||||
.delete("/api/v1/workflows/test-pack.delete_test", ctx.token())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Verify it's deleted
|
||||
let response = ctx
|
||||
.get("/api/v1/workflows/test-pack.delete_test", ctx.token())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_workflow_not_found() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.delete("/api/v1/workflows/nonexistent.workflow", ctx.token())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_workflow_requires_auth() {
|
||||
let ctx = TestContext::new().await.unwrap();
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/workflows",
|
||||
json!({
|
||||
"ref": "test.workflow",
|
||||
"pack_ref": "test",
|
||||
"label": "Test",
|
||||
"version": "1.0.0",
|
||||
"definition": {"tasks": []}
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// TODO: API endpoints don't currently enforce authentication
|
||||
// This should be 401 once auth middleware is implemented
|
||||
assert!(response.status().is_success() || response.status().is_client_error());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_workflow_validation() {
|
||||
let ctx = TestContext::new().await.unwrap().with_auth().await.unwrap();
|
||||
|
||||
// Test empty ref
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/workflows",
|
||||
json!({
|
||||
"ref": "",
|
||||
"pack_ref": "test",
|
||||
"label": "Test",
|
||||
"version": "1.0.0",
|
||||
"definition": {"tasks": []}
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// API returns 422 (Unprocessable Entity) for validation errors
|
||||
assert!(response.status().is_client_error());
|
||||
|
||||
// Test empty label
|
||||
let response = ctx
|
||||
.post(
|
||||
"/api/v1/workflows",
|
||||
json!({
|
||||
"ref": "test.workflow",
|
||||
"pack_ref": "test",
|
||||
"label": "",
|
||||
"version": "1.0.0",
|
||||
"definition": {"tasks": []}
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// API returns 422 (Unprocessable Entity) for validation errors
|
||||
assert!(response.status().is_client_error());
|
||||
}
|
||||
Reference in New Issue
Block a user