WIP
This commit is contained in:
@@ -26,7 +26,7 @@ async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# Web framework
|
||||
axum = { workspace = true }
|
||||
axum = { workspace = true, features = ["multipart"] }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
|
||||
@@ -69,7 +69,6 @@ jsonschema = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
|
||||
# Authentication
|
||||
jsonwebtoken = { version = "10.2", features = ["rust_crypto"] }
|
||||
argon2 = { workspace = true }
|
||||
rand = "0.9"
|
||||
|
||||
|
||||
@@ -1,389 +1,11 @@
|
||||
//! JWT token generation and validation
|
||||
//!
|
||||
//! This module re-exports all JWT functionality from `attune_common::auth::jwt`.
|
||||
//! The canonical implementation lives in the common crate so that all services
|
||||
//! (API, worker, sensor) share the same token types and signing logic.
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum JwtError {
|
||||
#[error("Failed to encode JWT: {0}")]
|
||||
EncodeError(String),
|
||||
#[error("Failed to decode JWT: {0}")]
|
||||
DecodeError(String),
|
||||
#[error("Token has expired")]
|
||||
Expired,
|
||||
#[error("Invalid token")]
|
||||
Invalid,
|
||||
}
|
||||
|
||||
/// JWT Claims structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
/// Subject (identity ID)
|
||||
pub sub: String,
|
||||
/// Identity login
|
||||
pub login: String,
|
||||
/// Issued at (Unix timestamp)
|
||||
pub iat: i64,
|
||||
/// Expiration time (Unix timestamp)
|
||||
pub exp: i64,
|
||||
/// Token type (access or refresh)
|
||||
#[serde(default)]
|
||||
pub token_type: TokenType,
|
||||
/// Optional scope (e.g., "sensor", "service")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scope: Option<String>,
|
||||
/// Optional metadata (e.g., trigger_types for sensors)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TokenType {
|
||||
Access,
|
||||
Refresh,
|
||||
Sensor,
|
||||
}
|
||||
|
||||
impl Default for TokenType {
|
||||
fn default() -> Self {
|
||||
Self::Access
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for JWT tokens
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JwtConfig {
|
||||
/// Secret key for signing tokens
|
||||
pub secret: String,
|
||||
/// Access token expiration duration (in seconds)
|
||||
pub access_token_expiration: i64,
|
||||
/// Refresh token expiration duration (in seconds)
|
||||
pub refresh_token_expiration: i64,
|
||||
}
|
||||
|
||||
impl Default for JwtConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
secret: "insecure_default_secret_change_in_production".to_string(),
|
||||
access_token_expiration: 3600, // 1 hour
|
||||
refresh_token_expiration: 604800, // 7 days
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a JWT access token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID
|
||||
/// * `login` - The identity login
|
||||
/// * `config` - JWT configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_access_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
) -> Result<String, JwtError> {
|
||||
generate_token(identity_id, login, config, TokenType::Access)
|
||||
}
|
||||
|
||||
/// Generate a JWT refresh token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID
|
||||
/// * `login` - The identity login
|
||||
/// * `config` - JWT configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_refresh_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
) -> Result<String, JwtError> {
|
||||
generate_token(identity_id, login, config, TokenType::Refresh)
|
||||
}
|
||||
|
||||
/// Generate a JWT token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID
|
||||
/// * `login` - The identity login
|
||||
/// * `config` - JWT configuration
|
||||
/// * `token_type` - Type of token to generate
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
token_type: TokenType,
|
||||
) -> Result<String, JwtError> {
|
||||
let now = Utc::now();
|
||||
let expiration = match token_type {
|
||||
TokenType::Access => config.access_token_expiration,
|
||||
TokenType::Refresh => config.refresh_token_expiration,
|
||||
TokenType::Sensor => 86400, // Sensor tokens handled separately via generate_sensor_token()
|
||||
};
|
||||
|
||||
let exp = (now + Duration::seconds(expiration)).timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
sub: identity_id.to_string(),
|
||||
login: login.to_string(),
|
||||
iat: now.timestamp(),
|
||||
exp,
|
||||
token_type,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| JwtError::EncodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Generate a sensor token with specific trigger types
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID for the sensor
|
||||
/// * `sensor_ref` - The sensor reference (e.g., "sensor:core.timer")
|
||||
/// * `trigger_types` - List of trigger types this sensor can create events for
|
||||
/// * `config` - JWT configuration
|
||||
/// * `ttl_seconds` - Time to live in seconds (default: 24 hours)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, JwtError>` - The encoded JWT token
|
||||
pub fn generate_sensor_token(
|
||||
identity_id: i64,
|
||||
sensor_ref: &str,
|
||||
trigger_types: Vec<String>,
|
||||
config: &JwtConfig,
|
||||
ttl_seconds: Option<i64>,
|
||||
) -> Result<String, JwtError> {
|
||||
let now = Utc::now();
|
||||
let expiration = ttl_seconds.unwrap_or(86400); // Default: 24 hours
|
||||
let exp = (now + Duration::seconds(expiration)).timestamp();
|
||||
|
||||
let metadata = serde_json::json!({
|
||||
"trigger_types": trigger_types,
|
||||
});
|
||||
|
||||
let claims = Claims {
|
||||
sub: identity_id.to_string(),
|
||||
login: sensor_ref.to_string(),
|
||||
iat: now.timestamp(),
|
||||
exp,
|
||||
token_type: TokenType::Sensor,
|
||||
scope: Some("sensor".to_string()),
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| JwtError::EncodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Validate and decode a JWT token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - The JWT token string
|
||||
/// * `config` - JWT configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Claims, JwtError>` - The decoded claims if valid
|
||||
pub fn validate_token(token: &str, config: &JwtConfig) -> Result<Claims, JwtError> {
|
||||
let validation = Validation::default();
|
||||
|
||||
decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(config.secret.as_bytes()),
|
||||
&validation,
|
||||
)
|
||||
.map(|data| data.claims)
|
||||
.map_err(|e| {
|
||||
if e.to_string().contains("ExpiredSignature") {
|
||||
JwtError::Expired
|
||||
} else {
|
||||
JwtError::DecodeError(e.to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract token from Authorization header
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `auth_header` - The Authorization header value
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Option<&str>` - The token if present and valid format
|
||||
pub fn extract_token_from_header(auth_header: &str) -> Option<&str> {
|
||||
if auth_header.starts_with("Bearer ") {
|
||||
Some(&auth_header[7..])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> JwtConfig {
|
||||
JwtConfig {
|
||||
secret: "test_secret_key_for_testing".to_string(),
|
||||
access_token_expiration: 3600,
|
||||
refresh_token_expiration: 604800,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_validate_access_token() {
|
||||
let config = test_config();
|
||||
let token =
|
||||
generate_access_token(123, "testuser", &config).expect("Failed to generate token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "123");
|
||||
assert_eq!(claims.login, "testuser");
|
||||
assert_eq!(claims.token_type, TokenType::Access);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_validate_refresh_token() {
|
||||
let config = test_config();
|
||||
let token =
|
||||
generate_refresh_token(456, "anotheruser", &config).expect("Failed to generate token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "456");
|
||||
assert_eq!(claims.login, "anotheruser");
|
||||
assert_eq!(claims.token_type, TokenType::Refresh);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_token() {
|
||||
let config = test_config();
|
||||
let result = validate_token("invalid.token.here", &config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_with_wrong_secret() {
|
||||
let config = test_config();
|
||||
let token = generate_access_token(789, "user", &config).expect("Failed to generate token");
|
||||
|
||||
let wrong_config = JwtConfig {
|
||||
secret: "different_secret".to_string(),
|
||||
..config
|
||||
};
|
||||
|
||||
let result = validate_token(&token, &wrong_config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expired_token() {
|
||||
// Create a token that's already expired by setting exp in the past
|
||||
let now = Utc::now().timestamp();
|
||||
let expired_claims = Claims {
|
||||
sub: "999".to_string(),
|
||||
login: "expireduser".to_string(),
|
||||
iat: now - 3600,
|
||||
exp: now - 1800, // Expired 30 minutes ago
|
||||
token_type: TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let config = test_config();
|
||||
|
||||
let expired_token = encode(
|
||||
&Header::default(),
|
||||
&expired_claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.expect("Failed to encode token");
|
||||
|
||||
// Validate the expired token
|
||||
let result = validate_token(&expired_token, &config);
|
||||
assert!(matches!(result, Err(JwtError::Expired)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_from_header() {
|
||||
let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";
|
||||
let token = extract_token_from_header(header);
|
||||
assert_eq!(token, Some("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"));
|
||||
|
||||
let invalid_header = "Token abc123";
|
||||
let token = extract_token_from_header(invalid_header);
|
||||
assert_eq!(token, None);
|
||||
|
||||
let no_token = "Bearer ";
|
||||
let token = extract_token_from_header(no_token);
|
||||
assert_eq!(token, Some(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claims_serialization() {
|
||||
let claims = Claims {
|
||||
sub: "123".to_string(),
|
||||
login: "testuser".to_string(),
|
||||
iat: 1234567890,
|
||||
exp: 1234571490,
|
||||
token_type: TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&claims).expect("Failed to serialize");
|
||||
let deserialized: Claims = serde_json::from_str(&json).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(claims.sub, deserialized.sub);
|
||||
assert_eq!(claims.login, deserialized.login);
|
||||
assert_eq!(claims.token_type, deserialized.token_type);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_sensor_token() {
|
||||
let config = test_config();
|
||||
let trigger_types = vec!["core.timer".to_string(), "core.webhook".to_string()];
|
||||
|
||||
let token = generate_sensor_token(
|
||||
999,
|
||||
"sensor:core.timer",
|
||||
trigger_types.clone(),
|
||||
&config,
|
||||
Some(86400),
|
||||
)
|
||||
.expect("Failed to generate sensor token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "999");
|
||||
assert_eq!(claims.login, "sensor:core.timer");
|
||||
assert_eq!(claims.token_type, TokenType::Sensor);
|
||||
assert_eq!(claims.scope, Some("sensor".to_string()));
|
||||
|
||||
let metadata = claims.metadata.expect("Metadata should be present");
|
||||
let trigger_types_from_token = metadata["trigger_types"]
|
||||
.as_array()
|
||||
.expect("trigger_types should be an array");
|
||||
|
||||
assert_eq!(trigger_types_from_token.len(), 2);
|
||||
}
|
||||
}
|
||||
pub use attune_common::auth::jwt::{
|
||||
extract_token_from_header, generate_access_token, generate_execution_token,
|
||||
generate_refresh_token, generate_sensor_token, generate_token, validate_token, Claims,
|
||||
JwtConfig, JwtError, TokenType,
|
||||
};
|
||||
|
||||
@@ -10,7 +10,9 @@ use axum::{
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::jwt::{extract_token_from_header, validate_token, Claims, JwtConfig, TokenType};
|
||||
use attune_common::auth::jwt::{
|
||||
extract_token_from_header, validate_token, Claims, JwtConfig, TokenType,
|
||||
};
|
||||
|
||||
/// Authentication middleware state
|
||||
#[derive(Clone)]
|
||||
@@ -105,8 +107,11 @@ impl axum::extract::FromRequestParts<crate::state::SharedState> for RequireAuth
|
||||
_ => AuthError::InvalidToken,
|
||||
})?;
|
||||
|
||||
// Allow both access tokens and sensor tokens
|
||||
if claims.token_type != TokenType::Access && claims.token_type != TokenType::Sensor {
|
||||
// Allow access, sensor, and execution-scoped tokens
|
||||
if claims.token_type != TokenType::Access
|
||||
&& claims.token_type != TokenType::Sensor
|
||||
&& claims.token_type != TokenType::Execution
|
||||
{
|
||||
return Err(AuthError::InvalidToken);
|
||||
}
|
||||
|
||||
@@ -154,7 +159,7 @@ mod tests {
|
||||
login: "testuser".to_string(),
|
||||
iat: 1234567890,
|
||||
exp: 1234571490,
|
||||
token_type: super::super::jwt::TokenType::Access,
|
||||
token_type: TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
471
crates/api/src/dto/artifact.rs
Normal file
471
crates/api/src/dto/artifact.rs
Normal file
@@ -0,0 +1,471 @@
|
||||
//! Artifact DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
use attune_common::models::enums::{ArtifactType, OwnerType, RetentionPolicyType};
|
||||
|
||||
// ============================================================================
|
||||
// Artifact DTOs
|
||||
// ============================================================================
|
||||
|
||||
/// Request DTO for creating a new artifact
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
pub struct CreateArtifactRequest {
|
||||
/// Artifact reference (unique identifier, e.g. "build.log", "test.results")
|
||||
#[schema(example = "mypack.build_log")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Owner scope type
|
||||
#[schema(example = "action")]
|
||||
pub scope: OwnerType,
|
||||
|
||||
/// Owner identifier (ref string of the owning entity)
|
||||
#[schema(example = "mypack.deploy")]
|
||||
pub owner: String,
|
||||
|
||||
/// Artifact type
|
||||
#[schema(example = "file_text")]
|
||||
pub r#type: ArtifactType,
|
||||
|
||||
/// Retention policy type
|
||||
#[serde(default = "default_retention_policy")]
|
||||
#[schema(example = "versions")]
|
||||
pub retention_policy: RetentionPolicyType,
|
||||
|
||||
/// Retention limit (number of versions, days, hours, or minutes depending on policy)
|
||||
#[serde(default = "default_retention_limit")]
|
||||
#[schema(example = 5)]
|
||||
pub retention_limit: i32,
|
||||
|
||||
/// Human-readable name
|
||||
#[schema(example = "Build Log")]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Optional description
|
||||
#[schema(example = "Output log from the build action")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// MIME content type (e.g. "text/plain", "application/json")
|
||||
#[schema(example = "text/plain")]
|
||||
pub content_type: Option<String>,
|
||||
|
||||
/// Execution ID that produced this artifact
|
||||
#[schema(example = 42)]
|
||||
pub execution: Option<i64>,
|
||||
|
||||
/// Initial structured data (for progress-type artifacts or metadata)
|
||||
#[schema(value_type = Option<Object>)]
|
||||
pub data: Option<JsonValue>,
|
||||
}
|
||||
|
||||
fn default_retention_policy() -> RetentionPolicyType {
|
||||
RetentionPolicyType::Versions
|
||||
}
|
||||
|
||||
fn default_retention_limit() -> i32 {
|
||||
5
|
||||
}
|
||||
|
||||
/// Request DTO for updating an existing artifact
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
pub struct UpdateArtifactRequest {
|
||||
/// Updated owner scope
|
||||
pub scope: Option<OwnerType>,
|
||||
|
||||
/// Updated owner identifier
|
||||
pub owner: Option<String>,
|
||||
|
||||
/// Updated artifact type
|
||||
pub r#type: Option<ArtifactType>,
|
||||
|
||||
/// Updated retention policy
|
||||
pub retention_policy: Option<RetentionPolicyType>,
|
||||
|
||||
/// Updated retention limit
|
||||
pub retention_limit: Option<i32>,
|
||||
|
||||
/// Updated name
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Updated description
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Updated content type
|
||||
pub content_type: Option<String>,
|
||||
|
||||
/// Updated structured data (replaces existing data entirely)
|
||||
pub data: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Request DTO for appending to a progress-type artifact
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
pub struct AppendProgressRequest {
|
||||
/// The entry to append to the progress data array.
|
||||
/// Can be any JSON value (string, object, number, etc.)
|
||||
#[schema(value_type = Object, example = json!({"step": "compile", "status": "done", "duration_ms": 1234}))]
|
||||
pub entry: JsonValue,
|
||||
}
|
||||
|
||||
/// Request DTO for setting the full data payload on an artifact
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
pub struct SetDataRequest {
|
||||
/// The data to set (replaces existing data entirely)
|
||||
#[schema(value_type = Object)]
|
||||
pub data: JsonValue,
|
||||
}
|
||||
|
||||
/// Response DTO for artifact information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ArtifactResponse {
|
||||
/// Artifact ID
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
/// Artifact reference
|
||||
#[schema(example = "mypack.build_log")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Owner scope type
|
||||
pub scope: OwnerType,
|
||||
|
||||
/// Owner identifier
|
||||
#[schema(example = "mypack.deploy")]
|
||||
pub owner: String,
|
||||
|
||||
/// Artifact type
|
||||
pub r#type: ArtifactType,
|
||||
|
||||
/// Retention policy
|
||||
pub retention_policy: RetentionPolicyType,
|
||||
|
||||
/// Retention limit
|
||||
#[schema(example = 5)]
|
||||
pub retention_limit: i32,
|
||||
|
||||
/// Human-readable name
|
||||
#[schema(example = "Build Log")]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Description
|
||||
pub description: Option<String>,
|
||||
|
||||
/// MIME content type
|
||||
#[schema(example = "text/plain")]
|
||||
pub content_type: Option<String>,
|
||||
|
||||
/// Size of the latest version in bytes
|
||||
pub size_bytes: Option<i64>,
|
||||
|
||||
/// Execution that produced this artifact
|
||||
pub execution: Option<i64>,
|
||||
|
||||
/// Structured data (progress entries, metadata, etc.)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<JsonValue>,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified artifact for list endpoints
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ArtifactSummary {
|
||||
/// Artifact ID
|
||||
pub id: i64,
|
||||
|
||||
/// Artifact reference
|
||||
pub r#ref: String,
|
||||
|
||||
/// Artifact type
|
||||
pub r#type: ArtifactType,
|
||||
|
||||
/// Human-readable name
|
||||
pub name: Option<String>,
|
||||
|
||||
/// MIME content type
|
||||
pub content_type: Option<String>,
|
||||
|
||||
/// Size of latest version in bytes
|
||||
pub size_bytes: Option<i64>,
|
||||
|
||||
/// Execution that produced this artifact
|
||||
pub execution: Option<i64>,
|
||||
|
||||
/// Owner scope
|
||||
pub scope: OwnerType,
|
||||
|
||||
/// Owner identifier
|
||||
pub owner: String,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Query parameters for filtering artifacts
|
||||
#[derive(Debug, Clone, Deserialize, IntoParams)]
|
||||
pub struct ArtifactQueryParams {
|
||||
/// Filter by owner scope type
|
||||
pub scope: Option<OwnerType>,
|
||||
|
||||
/// Filter by owner identifier
|
||||
pub owner: Option<String>,
|
||||
|
||||
/// Filter by artifact type
|
||||
pub r#type: Option<ArtifactType>,
|
||||
|
||||
/// Filter by execution ID
|
||||
pub execution: Option<i64>,
|
||||
|
||||
/// Search by name (case-insensitive substring match)
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Page number (1-based)
|
||||
#[serde(default = "default_page")]
|
||||
#[param(example = 1, minimum = 1)]
|
||||
pub page: u32,
|
||||
|
||||
/// Items per page
|
||||
#[serde(default = "default_per_page")]
|
||||
#[param(example = 20, minimum = 1, maximum = 100)]
|
||||
pub per_page: u32,
|
||||
}
|
||||
|
||||
impl ArtifactQueryParams {
|
||||
pub fn offset(&self) -> u32 {
|
||||
(self.page.saturating_sub(1)) * self.per_page
|
||||
}
|
||||
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.per_page.min(100)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_page() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_per_page() -> u32 {
|
||||
20
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ArtifactVersion DTOs
|
||||
// ============================================================================
|
||||
|
||||
/// Request DTO for creating a new artifact version with JSON content
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
pub struct CreateVersionJsonRequest {
|
||||
/// Structured JSON content for this version
|
||||
#[schema(value_type = Object)]
|
||||
pub content: JsonValue,
|
||||
|
||||
/// MIME content type override (defaults to "application/json")
|
||||
pub content_type: Option<String>,
|
||||
|
||||
/// Free-form metadata about this version
|
||||
#[schema(value_type = Option<Object>)]
|
||||
pub meta: Option<JsonValue>,
|
||||
|
||||
/// Who created this version (e.g. action ref, identity, "system")
|
||||
pub created_by: Option<String>,
|
||||
}
|
||||
|
||||
/// Response DTO for an artifact version (without binary content)
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ArtifactVersionResponse {
|
||||
/// Version ID
|
||||
pub id: i64,
|
||||
|
||||
/// Parent artifact ID
|
||||
pub artifact: i64,
|
||||
|
||||
/// Version number (1-based)
|
||||
pub version: i32,
|
||||
|
||||
/// MIME content type
|
||||
pub content_type: Option<String>,
|
||||
|
||||
/// Size of content in bytes
|
||||
pub size_bytes: Option<i64>,
|
||||
|
||||
/// Structured JSON content (if this version has JSON data)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content_json: Option<JsonValue>,
|
||||
|
||||
/// Free-form metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<JsonValue>,
|
||||
|
||||
/// Who created this version
|
||||
pub created_by: Option<String>,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Simplified version for list endpoints
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ArtifactVersionSummary {
|
||||
/// Version ID
|
||||
pub id: i64,
|
||||
|
||||
/// Version number
|
||||
pub version: i32,
|
||||
|
||||
/// MIME content type
|
||||
pub content_type: Option<String>,
|
||||
|
||||
/// Size of content in bytes
|
||||
pub size_bytes: Option<i64>,
|
||||
|
||||
/// Who created this version
|
||||
pub created_by: Option<String>,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<attune_common::models::artifact::Artifact> for ArtifactResponse {
|
||||
fn from(a: attune_common::models::artifact::Artifact) -> Self {
|
||||
Self {
|
||||
id: a.id,
|
||||
r#ref: a.r#ref,
|
||||
scope: a.scope,
|
||||
owner: a.owner,
|
||||
r#type: a.r#type,
|
||||
retention_policy: a.retention_policy,
|
||||
retention_limit: a.retention_limit,
|
||||
name: a.name,
|
||||
description: a.description,
|
||||
content_type: a.content_type,
|
||||
size_bytes: a.size_bytes,
|
||||
execution: a.execution,
|
||||
data: a.data,
|
||||
created: a.created,
|
||||
updated: a.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<attune_common::models::artifact::Artifact> for ArtifactSummary {
|
||||
fn from(a: attune_common::models::artifact::Artifact) -> Self {
|
||||
Self {
|
||||
id: a.id,
|
||||
r#ref: a.r#ref,
|
||||
r#type: a.r#type,
|
||||
name: a.name,
|
||||
content_type: a.content_type,
|
||||
size_bytes: a.size_bytes,
|
||||
execution: a.execution,
|
||||
scope: a.scope,
|
||||
owner: a.owner,
|
||||
created: a.created,
|
||||
updated: a.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<attune_common::models::artifact_version::ArtifactVersion> for ArtifactVersionResponse {
|
||||
fn from(v: attune_common::models::artifact_version::ArtifactVersion) -> Self {
|
||||
Self {
|
||||
id: v.id,
|
||||
artifact: v.artifact,
|
||||
version: v.version,
|
||||
content_type: v.content_type,
|
||||
size_bytes: v.size_bytes,
|
||||
content_json: v.content_json,
|
||||
meta: v.meta,
|
||||
created_by: v.created_by,
|
||||
created: v.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<attune_common::models::artifact_version::ArtifactVersion> for ArtifactVersionSummary {
|
||||
fn from(v: attune_common::models::artifact_version::ArtifactVersion) -> Self {
|
||||
Self {
|
||||
id: v.id,
|
||||
version: v.version,
|
||||
content_type: v.content_type,
|
||||
size_bytes: v.size_bytes,
|
||||
created_by: v.created_by,
|
||||
created: v.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_query_params_defaults() {
|
||||
let json = r#"{}"#;
|
||||
let params: ArtifactQueryParams = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(params.page, 1);
|
||||
assert_eq!(params.per_page, 20);
|
||||
assert!(params.scope.is_none());
|
||||
assert!(params.r#type.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_params_offset() {
|
||||
let params = ArtifactQueryParams {
|
||||
scope: None,
|
||||
owner: None,
|
||||
r#type: None,
|
||||
execution: None,
|
||||
name: None,
|
||||
page: 3,
|
||||
per_page: 20,
|
||||
};
|
||||
assert_eq!(params.offset(), 40);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_params_limit_cap() {
|
||||
let params = ArtifactQueryParams {
|
||||
scope: None,
|
||||
owner: None,
|
||||
r#type: None,
|
||||
execution: None,
|
||||
name: None,
|
||||
page: 1,
|
||||
per_page: 200,
|
||||
};
|
||||
assert_eq!(params.limit(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_defaults() {
|
||||
let json = r#"{
|
||||
"ref": "test.artifact",
|
||||
"scope": "system",
|
||||
"owner": "",
|
||||
"type": "file_text"
|
||||
}"#;
|
||||
let req: CreateArtifactRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(req.retention_policy, RetentionPolicyType::Versions);
|
||||
assert_eq!(req.retention_limit, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_progress_request() {
|
||||
let json = r#"{"entry": {"step": "build", "status": "done"}}"#;
|
||||
let req: AppendProgressRequest = serde_json::from_str(json).unwrap();
|
||||
assert!(req.entry.is_object());
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
pub mod action;
|
||||
pub mod analytics;
|
||||
pub mod artifact;
|
||||
pub mod auth;
|
||||
pub mod common;
|
||||
pub mod event;
|
||||
@@ -21,6 +22,11 @@ pub use analytics::{
|
||||
ExecutionStatusTimeSeriesResponse, ExecutionThroughputResponse, FailureRateResponse,
|
||||
TimeSeriesPoint,
|
||||
};
|
||||
pub use artifact::{
|
||||
AppendProgressRequest, ArtifactQueryParams, ArtifactResponse, ArtifactSummary,
|
||||
ArtifactVersionResponse, ArtifactVersionSummary, CreateArtifactRequest,
|
||||
CreateVersionJsonRequest, SetDataRequest, UpdateArtifactRequest,
|
||||
};
|
||||
pub use auth::{
|
||||
ChangePasswordRequest, CurrentUserResponse, LoginRequest, RefreshTokenRequest, RegisterRequest,
|
||||
TokenResponse,
|
||||
|
||||
978
crates/api/src/routes/artifacts.rs
Normal file
978
crates/api/src/routes/artifacts.rs
Normal file
@@ -0,0 +1,978 @@
|
||||
//! Artifact management API routes
|
||||
//!
|
||||
//! Provides endpoints for:
|
||||
//! - CRUD operations on artifacts (metadata + data)
|
||||
//! - File upload (binary) and download for file-type artifacts
|
||||
//! - JSON content versioning for structured artifacts
|
||||
//! - Progress append for progress-type artifacts (streaming updates)
|
||||
//! - Listing artifacts by execution
|
||||
//! - Version history and retrieval
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Multipart, Path, Query, State},
|
||||
http::{header, StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use attune_common::models::enums::ArtifactType;
|
||||
use attune_common::repositories::{
|
||||
artifact::{
|
||||
ArtifactRepository, ArtifactSearchFilters, ArtifactVersionRepository, CreateArtifactInput,
|
||||
CreateArtifactVersionInput, UpdateArtifactInput,
|
||||
},
|
||||
Create, Delete, FindById, FindByRef, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
artifact::{
|
||||
AppendProgressRequest, ArtifactQueryParams, ArtifactResponse, ArtifactSummary,
|
||||
ArtifactVersionResponse, ArtifactVersionSummary, CreateArtifactRequest,
|
||||
CreateVersionJsonRequest, SetDataRequest, UpdateArtifactRequest,
|
||||
},
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Artifact CRUD
|
||||
// ============================================================================
|
||||
|
||||
/// List artifacts with pagination and optional filters
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts",
|
||||
tag = "artifacts",
|
||||
params(ArtifactQueryParams),
|
||||
responses(
|
||||
(status = 200, description = "List of artifacts", body = PaginatedResponse<ArtifactSummary>),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_artifacts(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<ArtifactQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let filters = ArtifactSearchFilters {
|
||||
scope: query.scope,
|
||||
owner: query.owner.clone(),
|
||||
r#type: query.r#type,
|
||||
execution: query.execution,
|
||||
name_contains: query.name.clone(),
|
||||
limit: query.limit(),
|
||||
offset: query.offset(),
|
||||
};
|
||||
|
||||
let result = ArtifactRepository::search(&state.db, &filters).await?;
|
||||
|
||||
let items: Vec<ArtifactSummary> = result.rows.into_iter().map(ArtifactSummary::from).collect();
|
||||
|
||||
let pagination = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(items, &pagination, result.total as u64);
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a single artifact by ID
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts/{id}",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
responses(
|
||||
(status = 200, description = "Artifact details", body = inline(ApiResponse<ArtifactResponse>)),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(ArtifactResponse::from(artifact))),
|
||||
))
|
||||
}
|
||||
|
||||
/// Get a single artifact by ref
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts/ref/{ref}",
|
||||
tag = "artifacts",
|
||||
params(("ref" = String, Path, description = "Artifact reference")),
|
||||
responses(
|
||||
(status = 200, description = "Artifact details", body = inline(ApiResponse<ArtifactResponse>)),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_artifact_by_ref(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(artifact_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let artifact = ArtifactRepository::find_by_ref(&state.db, &artifact_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact '{}' not found", artifact_ref)))?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(ArtifactResponse::from(artifact))),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new artifact
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/artifacts",
|
||||
tag = "artifacts",
|
||||
request_body = CreateArtifactRequest,
|
||||
responses(
|
||||
(status = 201, description = "Artifact created", body = inline(ApiResponse<ArtifactResponse>)),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 409, description = "Artifact with same ref already exists"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(request): Json<CreateArtifactRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Validate ref is not empty
|
||||
if request.r#ref.trim().is_empty() {
|
||||
return Err(ApiError::BadRequest(
|
||||
"Artifact ref must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check for duplicate ref
|
||||
if ArtifactRepository::find_by_ref(&state.db, &request.r#ref)
|
||||
.await?
|
||||
.is_some()
|
||||
{
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Artifact with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
let input = CreateArtifactInput {
|
||||
r#ref: request.r#ref,
|
||||
scope: request.scope,
|
||||
owner: request.owner,
|
||||
r#type: request.r#type,
|
||||
retention_policy: request.retention_policy,
|
||||
retention_limit: request.retention_limit,
|
||||
name: request.name,
|
||||
description: request.description,
|
||||
content_type: request.content_type,
|
||||
execution: request.execution,
|
||||
data: request.data,
|
||||
};
|
||||
|
||||
let artifact = ArtifactRepository::create(&state.db, input).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(ApiResponse::with_message(
|
||||
ArtifactResponse::from(artifact),
|
||||
"Artifact created successfully",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
/// Update an existing artifact
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/artifacts/{id}",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
request_body = UpdateArtifactRequest,
|
||||
responses(
|
||||
(status = 200, description = "Artifact updated", body = inline(ApiResponse<ArtifactResponse>)),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn update_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<UpdateArtifactRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify artifact exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let input = UpdateArtifactInput {
|
||||
r#ref: None, // Ref is immutable after creation
|
||||
scope: request.scope,
|
||||
owner: request.owner,
|
||||
r#type: request.r#type,
|
||||
retention_policy: request.retention_policy,
|
||||
retention_limit: request.retention_limit,
|
||||
name: request.name,
|
||||
description: request.description,
|
||||
content_type: request.content_type,
|
||||
size_bytes: None, // Managed by version creation trigger
|
||||
data: request.data,
|
||||
};
|
||||
|
||||
let updated = ArtifactRepository::update(&state.db, id, input).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::with_message(
|
||||
ArtifactResponse::from(updated),
|
||||
"Artifact updated successfully",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
/// Delete an artifact (cascades to all versions)
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/artifacts/{id}",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
responses(
|
||||
(status = 200, description = "Artifact deleted", body = SuccessResponse),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let deleted = ArtifactRepository::delete(&state.db, id).await?;
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!(
|
||||
"Artifact with ID {} not found",
|
||||
id
|
||||
)));
|
||||
}
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(SuccessResponse::new("Artifact deleted successfully")),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Artifacts by Execution
|
||||
// ============================================================================
|
||||
|
||||
/// List all artifacts for a given execution
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/{execution_id}/artifacts",
|
||||
tag = "artifacts",
|
||||
params(("execution_id" = i64, Path, description = "Execution ID")),
|
||||
responses(
|
||||
(status = 200, description = "List of artifacts for execution", body = inline(ApiResponse<Vec<ArtifactSummary>>)),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_artifacts_by_execution(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(execution_id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let artifacts = ArtifactRepository::find_by_execution(&state.db, execution_id).await?;
|
||||
let items: Vec<ArtifactSummary> = artifacts.into_iter().map(ArtifactSummary::from).collect();
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::new(items))))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Progress Artifacts
|
||||
// ============================================================================
|
||||
|
||||
/// Append an entry to a progress-type artifact's data array.
|
||||
///
|
||||
/// The entry is atomically appended to `artifact.data` (initialized as `[]` if null).
|
||||
/// This is the primary mechanism for actions to stream progress updates.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/artifacts/{id}/progress",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID (must be progress type)")),
|
||||
request_body = AppendProgressRequest,
|
||||
responses(
|
||||
(status = 200, description = "Entry appended", body = inline(ApiResponse<ArtifactResponse>)),
|
||||
(status = 400, description = "Artifact is not a progress type"),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn append_progress(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<AppendProgressRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
if artifact.r#type != ArtifactType::Progress {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"Artifact '{}' is type {:?}, not progress. Use version endpoints for file artifacts.",
|
||||
artifact.r#ref, artifact.r#type
|
||||
)));
|
||||
}
|
||||
|
||||
let updated = ArtifactRepository::append_progress(&state.db, id, &request.entry).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::with_message(
|
||||
ArtifactResponse::from(updated),
|
||||
"Progress entry appended",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
/// Set the full data payload on an artifact (replaces existing data).
|
||||
///
|
||||
/// Useful for resetting progress, updating metadata, or setting structured content.
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/artifacts/{id}/data",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
request_body = SetDataRequest,
|
||||
responses(
|
||||
(status = 200, description = "Data set", body = inline(ApiResponse<ArtifactResponse>)),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn set_artifact_data(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<SetDataRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let updated = ArtifactRepository::set_data(&state.db, id, &request.data).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::with_message(
|
||||
ArtifactResponse::from(updated),
|
||||
"Artifact data updated",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Version Management
|
||||
// ============================================================================
|
||||
|
||||
/// List all versions for an artifact (without binary content)
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts/{id}/versions",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
responses(
|
||||
(status = 200, description = "List of versions", body = inline(ApiResponse<Vec<ArtifactVersionSummary>>)),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_versions(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify artifact exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let versions = ArtifactVersionRepository::list_by_artifact(&state.db, id).await?;
|
||||
let items: Vec<ArtifactVersionSummary> = versions
|
||||
.into_iter()
|
||||
.map(ArtifactVersionSummary::from)
|
||||
.collect();
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::new(items))))
|
||||
}
|
||||
|
||||
/// Get a specific version's metadata and JSON content (no binary)
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts/{id}/versions/{version}",
|
||||
tag = "artifacts",
|
||||
params(
|
||||
("id" = i64, Path, description = "Artifact ID"),
|
||||
("version" = i32, Path, description = "Version number"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Version details", body = inline(ApiResponse<ArtifactVersionResponse>)),
|
||||
(status = 404, description = "Artifact or version not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path((id, version)): Path<(i64, i32)>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify artifact exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let ver = ArtifactVersionRepository::find_by_version(&state.db, id, version)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Version {} not found for artifact {}", version, id))
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(ArtifactVersionResponse::from(ver))),
|
||||
))
|
||||
}
|
||||
|
||||
/// Get the latest version's metadata and JSON content
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts/{id}/versions/latest",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
responses(
|
||||
(status = 200, description = "Latest version", body = inline(ApiResponse<ArtifactVersionResponse>)),
|
||||
(status = 404, description = "Artifact not found or no versions"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_latest_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let ver = ArtifactVersionRepository::find_latest(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("No versions found for artifact {}", id)))?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(ArtifactVersionResponse::from(ver))),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new version with JSON content
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/artifacts/{id}/versions",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
request_body = CreateVersionJsonRequest,
|
||||
responses(
|
||||
(status = 201, description = "Version created", body = inline(ApiResponse<ArtifactVersionResponse>)),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_version_json(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<CreateVersionJsonRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let input = CreateArtifactVersionInput {
|
||||
artifact: id,
|
||||
content_type: Some(
|
||||
request
|
||||
.content_type
|
||||
.unwrap_or_else(|| "application/json".to_string()),
|
||||
),
|
||||
content: None,
|
||||
content_json: Some(request.content),
|
||||
meta: request.meta,
|
||||
created_by: request.created_by,
|
||||
};
|
||||
|
||||
let version = ArtifactVersionRepository::create(&state.db, input).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(ApiResponse::with_message(
|
||||
ArtifactVersionResponse::from(version),
|
||||
"Version created successfully",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
/// Upload a binary file as a new version (multipart/form-data)
|
||||
///
|
||||
/// The file is sent as a multipart form field named `file`. Optional fields:
|
||||
/// - `content_type`: MIME type override (auto-detected from filename if omitted)
|
||||
/// - `meta`: JSON metadata string
|
||||
/// - `created_by`: Creator identifier
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/artifacts/{id}/versions/upload",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
request_body(content = String, content_type = "multipart/form-data"),
|
||||
responses(
|
||||
(status = 201, description = "File version created", body = inline(ApiResponse<ArtifactVersionResponse>)),
|
||||
(status = 400, description = "Missing file field"),
|
||||
(status = 404, description = "Artifact not found"),
|
||||
(status = 413, description = "File too large"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn upload_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
mut multipart: Multipart,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let mut file_data: Option<Vec<u8>> = None;
|
||||
let mut content_type: Option<String> = None;
|
||||
let mut meta: Option<serde_json::Value> = None;
|
||||
let mut created_by: Option<String> = None;
|
||||
let mut file_content_type: Option<String> = None;
|
||||
|
||||
// 50 MB limit
|
||||
const MAX_FILE_SIZE: usize = 50 * 1024 * 1024;
|
||||
|
||||
while let Some(field) = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|e| ApiError::BadRequest(format!("Multipart error: {}", e)))?
|
||||
{
|
||||
let name = field.name().unwrap_or("").to_string();
|
||||
match name.as_str() {
|
||||
"file" => {
|
||||
// Capture content type from the multipart field itself
|
||||
file_content_type = field.content_type().map(|s| s.to_string());
|
||||
|
||||
let bytes = field
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| ApiError::BadRequest(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
if bytes.len() > MAX_FILE_SIZE {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"File exceeds maximum size of {} bytes",
|
||||
MAX_FILE_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
file_data = Some(bytes.to_vec());
|
||||
}
|
||||
"content_type" => {
|
||||
let text = field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| ApiError::BadRequest(format!("Failed to read field: {}", e)))?;
|
||||
if !text.is_empty() {
|
||||
content_type = Some(text);
|
||||
}
|
||||
}
|
||||
"meta" => {
|
||||
let text = field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| ApiError::BadRequest(format!("Failed to read field: {}", e)))?;
|
||||
if !text.is_empty() {
|
||||
meta =
|
||||
Some(serde_json::from_str(&text).map_err(|e| {
|
||||
ApiError::BadRequest(format!("Invalid meta JSON: {}", e))
|
||||
})?);
|
||||
}
|
||||
}
|
||||
"created_by" => {
|
||||
let text = field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| ApiError::BadRequest(format!("Failed to read field: {}", e)))?;
|
||||
if !text.is_empty() {
|
||||
created_by = Some(text);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Skip unknown fields
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let file_bytes = file_data.ok_or_else(|| {
|
||||
ApiError::BadRequest("Missing required 'file' field in multipart upload".to_string())
|
||||
})?;
|
||||
|
||||
// Resolve content type: explicit > multipart header > fallback
|
||||
let resolved_ct = content_type
|
||||
.or(file_content_type)
|
||||
.unwrap_or_else(|| "application/octet-stream".to_string());
|
||||
|
||||
let input = CreateArtifactVersionInput {
|
||||
artifact: id,
|
||||
content_type: Some(resolved_ct),
|
||||
content: Some(file_bytes),
|
||||
content_json: None,
|
||||
meta,
|
||||
created_by,
|
||||
};
|
||||
|
||||
let version = ArtifactVersionRepository::create(&state.db, input).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(ApiResponse::with_message(
|
||||
ArtifactVersionResponse::from(version),
|
||||
"File version uploaded successfully",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
/// Download the binary content of a specific version
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts/{id}/versions/{version}/download",
|
||||
tag = "artifacts",
|
||||
params(
|
||||
("id" = i64, Path, description = "Artifact ID"),
|
||||
("version" = i32, Path, description = "Version number"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Binary file content", content_type = "application/octet-stream"),
|
||||
(status = 404, description = "Artifact, version, or content not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn download_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path((id, version)): Path<(i64, i32)>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let ver = ArtifactVersionRepository::find_by_version_with_content(&state.db, id, version)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Version {} not found for artifact {}", version, id))
|
||||
})?;
|
||||
|
||||
// For binary content
|
||||
if let Some(bytes) = ver.content {
|
||||
let ct = ver
|
||||
.content_type
|
||||
.unwrap_or_else(|| "application/octet-stream".to_string());
|
||||
|
||||
let filename = format!(
|
||||
"{}_v{}.{}",
|
||||
artifact.r#ref.replace('.', "_"),
|
||||
version,
|
||||
extension_from_content_type(&ct)
|
||||
);
|
||||
|
||||
return Ok((
|
||||
StatusCode::OK,
|
||||
[
|
||||
(header::CONTENT_TYPE, ct),
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
),
|
||||
],
|
||||
Body::from(bytes),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// For JSON content, serialize and return
|
||||
if let Some(json) = ver.content_json {
|
||||
let bytes = serde_json::to_vec_pretty(&json).map_err(|e| {
|
||||
ApiError::InternalServerError(format!("Failed to serialize JSON: {}", e))
|
||||
})?;
|
||||
|
||||
let ct = ver
|
||||
.content_type
|
||||
.unwrap_or_else(|| "application/json".to_string());
|
||||
|
||||
let filename = format!("{}_v{}.json", artifact.r#ref.replace('.', "_"), version,);
|
||||
|
||||
return Ok((
|
||||
StatusCode::OK,
|
||||
[
|
||||
(header::CONTENT_TYPE, ct),
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
),
|
||||
],
|
||||
Body::from(bytes),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
Err(ApiError::NotFound(format!(
|
||||
"Version {} of artifact {} has no downloadable content",
|
||||
version, id
|
||||
)))
|
||||
}
|
||||
|
||||
/// Download the latest version's content
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/artifacts/{id}/download",
|
||||
tag = "artifacts",
|
||||
params(("id" = i64, Path, description = "Artifact ID")),
|
||||
responses(
|
||||
(status = 200, description = "Binary file content of latest version", content_type = "application/octet-stream"),
|
||||
(status = 404, description = "Artifact not found or no versions"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn download_latest(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let ver = ArtifactVersionRepository::find_latest_with_content(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("No versions found for artifact {}", id)))?;
|
||||
|
||||
let version = ver.version;
|
||||
|
||||
// For binary content
|
||||
if let Some(bytes) = ver.content {
|
||||
let ct = ver
|
||||
.content_type
|
||||
.unwrap_or_else(|| "application/octet-stream".to_string());
|
||||
|
||||
let filename = format!(
|
||||
"{}_v{}.{}",
|
||||
artifact.r#ref.replace('.', "_"),
|
||||
version,
|
||||
extension_from_content_type(&ct)
|
||||
);
|
||||
|
||||
return Ok((
|
||||
StatusCode::OK,
|
||||
[
|
||||
(header::CONTENT_TYPE, ct),
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
),
|
||||
],
|
||||
Body::from(bytes),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// For JSON content
|
||||
if let Some(json) = ver.content_json {
|
||||
let bytes = serde_json::to_vec_pretty(&json).map_err(|e| {
|
||||
ApiError::InternalServerError(format!("Failed to serialize JSON: {}", e))
|
||||
})?;
|
||||
|
||||
let ct = ver
|
||||
.content_type
|
||||
.unwrap_or_else(|| "application/json".to_string());
|
||||
|
||||
let filename = format!("{}_v{}.json", artifact.r#ref.replace('.', "_"), version,);
|
||||
|
||||
return Ok((
|
||||
StatusCode::OK,
|
||||
[
|
||||
(header::CONTENT_TYPE, ct),
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
),
|
||||
],
|
||||
Body::from(bytes),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
Err(ApiError::NotFound(format!(
|
||||
"Latest version of artifact {} has no downloadable content",
|
||||
id
|
||||
)))
|
||||
}
|
||||
|
||||
/// Delete a specific version by version number
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/artifacts/{id}/versions/{version}",
|
||||
tag = "artifacts",
|
||||
params(
|
||||
("id" = i64, Path, description = "Artifact ID"),
|
||||
("version" = i32, Path, description = "Version number"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Version deleted", body = SuccessResponse),
|
||||
(status = 404, description = "Artifact or version not found"),
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path((id, version)): Path<(i64, i32)>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify artifact exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
// Find the version by artifact + version number
|
||||
let ver = ArtifactVersionRepository::find_by_version(&state.db, id, version)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Version {} not found for artifact {}", version, id))
|
||||
})?;
|
||||
|
||||
ArtifactVersionRepository::delete(&state.db, ver.id).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(SuccessResponse::new("Version deleted successfully")),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Derive a simple file extension from a MIME content type
|
||||
fn extension_from_content_type(ct: &str) -> &str {
|
||||
match ct {
|
||||
"text/plain" => "txt",
|
||||
"text/html" => "html",
|
||||
"text/css" => "css",
|
||||
"text/csv" => "csv",
|
||||
"text/xml" => "xml",
|
||||
"application/json" => "json",
|
||||
"application/xml" => "xml",
|
||||
"application/pdf" => "pdf",
|
||||
"application/zip" => "zip",
|
||||
"application/gzip" => "gz",
|
||||
"application/octet-stream" => "bin",
|
||||
"image/png" => "png",
|
||||
"image/jpeg" => "jpg",
|
||||
"image/gif" => "gif",
|
||||
"image/svg+xml" => "svg",
|
||||
"image/webp" => "webp",
|
||||
_ => "bin",
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Router
|
||||
// ============================================================================
|
||||
|
||||
/// Register artifact routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
// Artifact CRUD
|
||||
.route("/artifacts", get(list_artifacts).post(create_artifact))
|
||||
.route(
|
||||
"/artifacts/{id}",
|
||||
get(get_artifact)
|
||||
.put(update_artifact)
|
||||
.delete(delete_artifact),
|
||||
)
|
||||
.route("/artifacts/ref/{ref}", get(get_artifact_by_ref))
|
||||
// Progress / data
|
||||
.route("/artifacts/{id}/progress", post(append_progress))
|
||||
.route(
|
||||
"/artifacts/{id}/data",
|
||||
axum::routing::put(set_artifact_data),
|
||||
)
|
||||
// Download (latest)
|
||||
.route("/artifacts/{id}/download", get(download_latest))
|
||||
// Version management
|
||||
.route(
|
||||
"/artifacts/{id}/versions",
|
||||
get(list_versions).post(create_version_json),
|
||||
)
|
||||
.route("/artifacts/{id}/versions/latest", get(get_latest_version))
|
||||
.route("/artifacts/{id}/versions/upload", post(upload_version))
|
||||
.route(
|
||||
"/artifacts/{id}/versions/{version}",
|
||||
get(get_version).delete(delete_version),
|
||||
)
|
||||
.route(
|
||||
"/artifacts/{id}/versions/{version}/download",
|
||||
get(download_version),
|
||||
)
|
||||
// By execution
|
||||
.route(
|
||||
"/executions/{execution_id}/artifacts",
|
||||
get(list_artifacts_by_execution),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_artifact_routes_structure() {
|
||||
let _router = routes();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extension_from_content_type() {
|
||||
assert_eq!(extension_from_content_type("text/plain"), "txt");
|
||||
assert_eq!(extension_from_content_type("application/json"), "json");
|
||||
assert_eq!(extension_from_content_type("image/png"), "png");
|
||||
assert_eq!(extension_from_content_type("unknown/type"), "bin");
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
pub mod actions;
|
||||
pub mod analytics;
|
||||
pub mod artifacts;
|
||||
pub mod auth;
|
||||
pub mod events;
|
||||
pub mod executions;
|
||||
@@ -17,6 +18,7 @@ pub mod workflows;
|
||||
|
||||
pub use actions::routes as action_routes;
|
||||
pub use analytics::routes as analytics_routes;
|
||||
pub use artifacts::routes as artifact_routes;
|
||||
pub use auth::routes as auth_routes;
|
||||
pub use events::routes as event_routes;
|
||||
pub use executions::routes as execution_routes;
|
||||
|
||||
@@ -57,8 +57,7 @@ impl Server {
|
||||
.merge(routes::webhook_routes())
|
||||
.merge(routes::history_routes())
|
||||
.merge(routes::analytics_routes())
|
||||
// TODO: Add more route modules here
|
||||
// etc.
|
||||
.merge(routes::artifact_routes())
|
||||
.with_state(self.state.clone());
|
||||
|
||||
// Auth routes at root level (not versioned for frontend compatibility)
|
||||
|
||||
@@ -53,6 +53,9 @@ jsonschema = { workspace = true }
|
||||
# OpenAPI
|
||||
utoipa = { workspace = true }
|
||||
|
||||
# JWT
|
||||
jsonwebtoken = { workspace = true }
|
||||
|
||||
# Encryption
|
||||
argon2 = { workspace = true }
|
||||
ring = { workspace = true }
|
||||
|
||||
460
crates/common/src/auth/jwt.rs
Normal file
460
crates/common/src/auth/jwt.rs
Normal file
@@ -0,0 +1,460 @@
|
||||
//! JWT token generation and validation
|
||||
//!
|
||||
//! Shared across all Attune services. Token types:
|
||||
//! - **Access**: Standard user login tokens (1h default)
|
||||
//! - **Refresh**: Long-lived refresh tokens (7d default)
|
||||
//! - **Sensor**: Sensor service tokens with trigger type metadata (24h default)
|
||||
//! - **Execution**: Short-lived tokens scoped to a single execution (matching execution timeout)
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum JwtError {
|
||||
#[error("Failed to encode JWT: {0}")]
|
||||
EncodeError(String),
|
||||
#[error("Failed to decode JWT: {0}")]
|
||||
DecodeError(String),
|
||||
#[error("Token has expired")]
|
||||
Expired,
|
||||
#[error("Invalid token")]
|
||||
Invalid,
|
||||
}
|
||||
|
||||
/// JWT Claims structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
/// Subject (identity ID)
|
||||
pub sub: String,
|
||||
/// Identity login (or descriptor like "execution:123")
|
||||
pub login: String,
|
||||
/// Issued at (Unix timestamp)
|
||||
pub iat: i64,
|
||||
/// Expiration time (Unix timestamp)
|
||||
pub exp: i64,
|
||||
/// Token type (access, refresh, sensor, or execution)
|
||||
#[serde(default)]
|
||||
pub token_type: TokenType,
|
||||
/// Optional scope (e.g., "sensor", "execution")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scope: Option<String>,
|
||||
/// Optional metadata (e.g., trigger_types for sensors, execution_id for execution tokens)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TokenType {
|
||||
Access,
|
||||
Refresh,
|
||||
Sensor,
|
||||
Execution,
|
||||
}
|
||||
|
||||
impl Default for TokenType {
|
||||
fn default() -> Self {
|
||||
Self::Access
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for JWT tokens
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JwtConfig {
|
||||
/// Secret key for signing tokens
|
||||
pub secret: String,
|
||||
/// Access token expiration duration (in seconds)
|
||||
pub access_token_expiration: i64,
|
||||
/// Refresh token expiration duration (in seconds)
|
||||
pub refresh_token_expiration: i64,
|
||||
}
|
||||
|
||||
impl Default for JwtConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
secret: "insecure_default_secret_change_in_production".to_string(),
|
||||
access_token_expiration: 3600, // 1 hour
|
||||
refresh_token_expiration: 604800, // 7 days
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a JWT access token
|
||||
pub fn generate_access_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
) -> Result<String, JwtError> {
|
||||
generate_token(identity_id, login, config, TokenType::Access)
|
||||
}
|
||||
|
||||
/// Generate a JWT refresh token
|
||||
pub fn generate_refresh_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
) -> Result<String, JwtError> {
|
||||
generate_token(identity_id, login, config, TokenType::Refresh)
|
||||
}
|
||||
|
||||
/// Generate a JWT token with a specific type
|
||||
pub fn generate_token(
|
||||
identity_id: i64,
|
||||
login: &str,
|
||||
config: &JwtConfig,
|
||||
token_type: TokenType,
|
||||
) -> Result<String, JwtError> {
|
||||
let now = Utc::now();
|
||||
let expiration = match token_type {
|
||||
TokenType::Access => config.access_token_expiration,
|
||||
TokenType::Refresh => config.refresh_token_expiration,
|
||||
// Sensor and Execution tokens are generated via their own dedicated functions
|
||||
// with explicit TTLs; this fallback should not normally be reached.
|
||||
TokenType::Sensor => 86400,
|
||||
TokenType::Execution => 300,
|
||||
};
|
||||
|
||||
let exp = (now + Duration::seconds(expiration)).timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
sub: identity_id.to_string(),
|
||||
login: login.to_string(),
|
||||
iat: now.timestamp(),
|
||||
exp,
|
||||
token_type,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| JwtError::EncodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Generate a sensor token with specific trigger types
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID for the sensor
|
||||
/// * `sensor_ref` - The sensor reference (e.g., "sensor:core.timer")
|
||||
/// * `trigger_types` - List of trigger types this sensor can create events for
|
||||
/// * `config` - JWT configuration
|
||||
/// * `ttl_seconds` - Time to live in seconds (default: 24 hours)
|
||||
pub fn generate_sensor_token(
|
||||
identity_id: i64,
|
||||
sensor_ref: &str,
|
||||
trigger_types: Vec<String>,
|
||||
config: &JwtConfig,
|
||||
ttl_seconds: Option<i64>,
|
||||
) -> Result<String, JwtError> {
|
||||
let now = Utc::now();
|
||||
let expiration = ttl_seconds.unwrap_or(86400); // Default: 24 hours
|
||||
let exp = (now + Duration::seconds(expiration)).timestamp();
|
||||
|
||||
let metadata = serde_json::json!({
|
||||
"trigger_types": trigger_types,
|
||||
});
|
||||
|
||||
let claims = Claims {
|
||||
sub: identity_id.to_string(),
|
||||
login: sensor_ref.to_string(),
|
||||
iat: now.timestamp(),
|
||||
exp,
|
||||
token_type: TokenType::Sensor,
|
||||
scope: Some("sensor".to_string()),
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| JwtError::EncodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Generate an execution-scoped token.
|
||||
///
|
||||
/// These tokens are short-lived (matching the execution timeout) and scoped
|
||||
/// to a single execution. They allow actions to call back into the Attune API
|
||||
/// (e.g., to create artifacts, update progress) without full user credentials.
|
||||
///
|
||||
/// The token is automatically invalidated when it expires. The TTL defaults to
|
||||
/// the execution timeout plus a 60-second grace period to account for cleanup.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identity_id` - The identity ID that triggered the execution
|
||||
/// * `execution_id` - The execution ID this token is scoped to
|
||||
/// * `action_ref` - The action reference for audit/logging
|
||||
/// * `config` - JWT configuration (uses the same signing secret as all tokens)
|
||||
/// * `ttl_seconds` - Time to live in seconds (defaults to 360 = 5 min timeout + 60s grace)
|
||||
pub fn generate_execution_token(
|
||||
identity_id: i64,
|
||||
execution_id: i64,
|
||||
action_ref: &str,
|
||||
config: &JwtConfig,
|
||||
ttl_seconds: Option<i64>,
|
||||
) -> Result<String, JwtError> {
|
||||
let now = Utc::now();
|
||||
let expiration = ttl_seconds.unwrap_or(360); // Default: 6 minutes (5 min timeout + grace)
|
||||
let exp = (now + Duration::seconds(expiration)).timestamp();
|
||||
|
||||
let metadata = serde_json::json!({
|
||||
"execution_id": execution_id,
|
||||
"action_ref": action_ref,
|
||||
});
|
||||
|
||||
let claims = Claims {
|
||||
sub: identity_id.to_string(),
|
||||
login: format!("execution:{}", execution_id),
|
||||
iat: now.timestamp(),
|
||||
exp,
|
||||
token_type: TokenType::Execution,
|
||||
scope: Some("execution".to_string()),
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| JwtError::EncodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Validate and decode a JWT token
|
||||
pub fn validate_token(token: &str, config: &JwtConfig) -> Result<Claims, JwtError> {
|
||||
let validation = Validation::default();
|
||||
|
||||
decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(config.secret.as_bytes()),
|
||||
&validation,
|
||||
)
|
||||
.map(|data| data.claims)
|
||||
.map_err(|e| {
|
||||
if e.to_string().contains("ExpiredSignature") {
|
||||
JwtError::Expired
|
||||
} else {
|
||||
JwtError::DecodeError(e.to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract token from Authorization header
|
||||
pub fn extract_token_from_header(auth_header: &str) -> Option<&str> {
|
||||
if auth_header.starts_with("Bearer ") {
|
||||
Some(&auth_header[7..])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> JwtConfig {
|
||||
JwtConfig {
|
||||
secret: "test_secret_key_for_testing".to_string(),
|
||||
access_token_expiration: 3600,
|
||||
refresh_token_expiration: 604800,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_validate_access_token() {
|
||||
let config = test_config();
|
||||
let token =
|
||||
generate_access_token(123, "testuser", &config).expect("Failed to generate token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "123");
|
||||
assert_eq!(claims.login, "testuser");
|
||||
assert_eq!(claims.token_type, TokenType::Access);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_validate_refresh_token() {
|
||||
let config = test_config();
|
||||
let token =
|
||||
generate_refresh_token(456, "anotheruser", &config).expect("Failed to generate token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "456");
|
||||
assert_eq!(claims.login, "anotheruser");
|
||||
assert_eq!(claims.token_type, TokenType::Refresh);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_token() {
|
||||
let config = test_config();
|
||||
let result = validate_token("invalid.token.here", &config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_with_wrong_secret() {
|
||||
let config = test_config();
|
||||
let token = generate_access_token(789, "user", &config).expect("Failed to generate token");
|
||||
|
||||
let wrong_config = JwtConfig {
|
||||
secret: "different_secret".to_string(),
|
||||
..config
|
||||
};
|
||||
|
||||
let result = validate_token(&token, &wrong_config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expired_token() {
|
||||
let now = Utc::now().timestamp();
|
||||
let expired_claims = Claims {
|
||||
sub: "999".to_string(),
|
||||
login: "expireduser".to_string(),
|
||||
iat: now - 3600,
|
||||
exp: now - 1800,
|
||||
token_type: TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let config = test_config();
|
||||
|
||||
let expired_token = encode(
|
||||
&Header::default(),
|
||||
&expired_claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.expect("Failed to encode token");
|
||||
|
||||
let result = validate_token(&expired_token, &config);
|
||||
assert!(matches!(result, Err(JwtError::Expired)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_from_header() {
|
||||
let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";
|
||||
let token = extract_token_from_header(header);
|
||||
assert_eq!(token, Some("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"));
|
||||
|
||||
let invalid_header = "Token abc123";
|
||||
let token = extract_token_from_header(invalid_header);
|
||||
assert_eq!(token, None);
|
||||
|
||||
let no_token = "Bearer ";
|
||||
let token = extract_token_from_header(no_token);
|
||||
assert_eq!(token, Some(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claims_serialization() {
|
||||
let claims = Claims {
|
||||
sub: "123".to_string(),
|
||||
login: "testuser".to_string(),
|
||||
iat: 1234567890,
|
||||
exp: 1234571490,
|
||||
token_type: TokenType::Access,
|
||||
scope: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&claims).expect("Failed to serialize");
|
||||
let deserialized: Claims = serde_json::from_str(&json).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(claims.sub, deserialized.sub);
|
||||
assert_eq!(claims.login, deserialized.login);
|
||||
assert_eq!(claims.token_type, deserialized.token_type);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_sensor_token() {
|
||||
let config = test_config();
|
||||
let trigger_types = vec!["core.timer".to_string(), "core.webhook".to_string()];
|
||||
|
||||
let token = generate_sensor_token(
|
||||
999,
|
||||
"sensor:core.timer",
|
||||
trigger_types.clone(),
|
||||
&config,
|
||||
Some(86400),
|
||||
)
|
||||
.expect("Failed to generate sensor token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "999");
|
||||
assert_eq!(claims.login, "sensor:core.timer");
|
||||
assert_eq!(claims.token_type, TokenType::Sensor);
|
||||
assert_eq!(claims.scope, Some("sensor".to_string()));
|
||||
|
||||
let metadata = claims.metadata.expect("Metadata should be present");
|
||||
let trigger_types_from_token = metadata["trigger_types"]
|
||||
.as_array()
|
||||
.expect("trigger_types should be an array");
|
||||
|
||||
assert_eq!(trigger_types_from_token.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_execution_token() {
|
||||
let config = test_config();
|
||||
|
||||
let token =
|
||||
generate_execution_token(42, 12345, "python_example.artifact_demo", &config, None)
|
||||
.expect("Failed to generate execution token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
assert_eq!(claims.sub, "42");
|
||||
assert_eq!(claims.login, "execution:12345");
|
||||
assert_eq!(claims.token_type, TokenType::Execution);
|
||||
assert_eq!(claims.scope, Some("execution".to_string()));
|
||||
|
||||
let metadata = claims.metadata.expect("Metadata should be present");
|
||||
assert_eq!(metadata["execution_id"], 12345);
|
||||
assert_eq!(metadata["action_ref"], "python_example.artifact_demo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execution_token_custom_ttl() {
|
||||
let config = test_config();
|
||||
|
||||
let token = generate_execution_token(1, 100, "core.echo", &config, Some(600))
|
||||
.expect("Failed to generate execution token");
|
||||
|
||||
let claims = validate_token(&token, &config).expect("Failed to validate token");
|
||||
|
||||
// Should expire roughly 600 seconds from now
|
||||
let now = Utc::now().timestamp();
|
||||
let diff = claims.exp - now;
|
||||
assert!(
|
||||
diff > 590 && diff <= 600,
|
||||
"TTL should be ~600s, got {}s",
|
||||
diff
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_type_serialization() {
|
||||
// Ensure all token types round-trip through JSON correctly
|
||||
for tt in [
|
||||
TokenType::Access,
|
||||
TokenType::Refresh,
|
||||
TokenType::Sensor,
|
||||
TokenType::Execution,
|
||||
] {
|
||||
let json = serde_json::to_string(&tt).expect("Failed to serialize");
|
||||
let deserialized: TokenType =
|
||||
serde_json::from_str(&json).expect("Failed to deserialize");
|
||||
assert_eq!(tt, deserialized);
|
||||
}
|
||||
}
|
||||
}
|
||||
13
crates/common/src/auth/mod.rs
Normal file
13
crates/common/src/auth/mod.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! Authentication primitives shared across Attune services.
|
||||
//!
|
||||
//! This module provides JWT token types, generation, and validation functions
|
||||
//! that are used by the API (for all token types), the worker (for execution-scoped
|
||||
//! tokens), and the sensor service (for sensor tokens).
|
||||
|
||||
pub mod jwt;
|
||||
|
||||
pub use jwt::{
|
||||
extract_token_from_header, generate_access_token, generate_execution_token,
|
||||
generate_refresh_token, generate_sensor_token, generate_token, validate_token, Claims,
|
||||
JwtConfig, JwtError, TokenType,
|
||||
};
|
||||
@@ -6,6 +6,7 @@
|
||||
//! - Configuration
|
||||
//! - Utilities
|
||||
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod crypto;
|
||||
pub mod db;
|
||||
|
||||
@@ -10,6 +10,8 @@ use sqlx::FromRow;
|
||||
|
||||
// Re-export common types
|
||||
pub use action::*;
|
||||
pub use artifact::Artifact;
|
||||
pub use artifact_version::ArtifactVersion;
|
||||
pub use entity_history::*;
|
||||
pub use enums::*;
|
||||
pub use event::*;
|
||||
@@ -355,7 +357,7 @@ pub mod enums {
|
||||
Url,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type, ToSchema)]
|
||||
#[sqlx(type_name = "artifact_retention_enum", rename_all = "lowercase")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum RetentionPolicyType {
|
||||
@@ -1268,9 +1270,66 @@ pub mod artifact {
|
||||
pub r#type: ArtifactType,
|
||||
pub retention_policy: RetentionPolicyType,
|
||||
pub retention_limit: i32,
|
||||
/// Human-readable name (e.g. "Build Log", "Test Results")
|
||||
pub name: Option<String>,
|
||||
/// Optional longer description
|
||||
pub description: Option<String>,
|
||||
/// MIME content type (e.g. "application/json", "text/plain")
|
||||
pub content_type: Option<String>,
|
||||
/// Size of the latest version's content in bytes
|
||||
pub size_bytes: Option<i64>,
|
||||
/// Execution that produced this artifact (no FK — execution is a hypertable)
|
||||
pub execution: Option<Id>,
|
||||
/// Structured JSONB data for progress artifacts or metadata
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub created: DateTime<Utc>,
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Select columns for Artifact queries (excludes DB-only columns if any arise).
|
||||
/// Must be kept in sync with the Artifact struct field order.
|
||||
pub const SELECT_COLUMNS: &str =
|
||||
"id, ref, scope, owner, type, retention_policy, retention_limit, \
|
||||
name, description, content_type, size_bytes, execution, data, \
|
||||
created, updated";
|
||||
}
|
||||
|
||||
/// Artifact version model — immutable content snapshots
|
||||
pub mod artifact_version {
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct ArtifactVersion {
|
||||
pub id: Id,
|
||||
/// Parent artifact
|
||||
pub artifact: Id,
|
||||
/// Version number (1-based, monotonically increasing per artifact)
|
||||
pub version: i32,
|
||||
/// MIME content type for this version
|
||||
pub content_type: Option<String>,
|
||||
/// Size of content in bytes
|
||||
pub size_bytes: Option<i64>,
|
||||
/// Binary content (file data) — not included in default queries for performance
|
||||
#[serde(skip_serializing)]
|
||||
pub content: Option<Vec<u8>>,
|
||||
/// Structured JSON content
|
||||
pub content_json: Option<serde_json::Value>,
|
||||
/// Free-form metadata about this version
|
||||
pub meta: Option<serde_json::Value>,
|
||||
/// Who created this version
|
||||
pub created_by: Option<String>,
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Select columns WITHOUT the potentially large `content` BYTEA column.
|
||||
/// Use `SELECT_COLUMNS_WITH_CONTENT` when you need the binary payload.
|
||||
pub const SELECT_COLUMNS: &str = "id, artifact, version, content_type, size_bytes, \
|
||||
NULL::bytea AS content, content_json, meta, created_by, created";
|
||||
|
||||
/// Select columns INCLUDING the binary `content` column.
|
||||
pub const SELECT_COLUMNS_WITH_CONTENT: &str =
|
||||
"id, artifact, version, content_type, size_bytes, \
|
||||
content, content_json, meta, created_by, created";
|
||||
}
|
||||
|
||||
/// Workflow orchestration models
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
//! Artifact repository for database operations
|
||||
//! Artifact and ArtifactVersion repositories for database operations
|
||||
|
||||
use crate::models::{
|
||||
artifact::*,
|
||||
artifact_version::ArtifactVersion,
|
||||
enums::{ArtifactType, OwnerType, RetentionPolicyType},
|
||||
};
|
||||
use crate::Result;
|
||||
@@ -9,6 +10,10 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
|
||||
|
||||
// ============================================================================
|
||||
// ArtifactRepository
|
||||
// ============================================================================
|
||||
|
||||
pub struct ArtifactRepository;
|
||||
|
||||
impl Repository for ArtifactRepository {
|
||||
@@ -26,6 +31,11 @@ pub struct CreateArtifactInput {
|
||||
pub r#type: ArtifactType,
|
||||
pub retention_policy: RetentionPolicyType,
|
||||
pub retention_limit: i32,
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub execution: Option<i64>,
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -36,6 +46,29 @@ pub struct UpdateArtifactInput {
|
||||
pub r#type: Option<ArtifactType>,
|
||||
pub retention_policy: Option<RetentionPolicyType>,
|
||||
pub retention_limit: Option<i32>,
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub size_bytes: Option<i64>,
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Filters for searching artifacts
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ArtifactSearchFilters {
|
||||
pub scope: Option<OwnerType>,
|
||||
pub owner: Option<String>,
|
||||
pub r#type: Option<ArtifactType>,
|
||||
pub execution: Option<i64>,
|
||||
pub name_contains: Option<String>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Search result with total count
|
||||
pub struct ArtifactSearchResult {
|
||||
pub rows: Vec<Artifact>,
|
||||
pub total: i64,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -44,15 +77,12 @@ impl FindById for ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!("SELECT {} FROM artifact WHERE id = $1", SELECT_COLUMNS);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,15 +92,12 @@ impl FindByRef for ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
WHERE ref = $1",
|
||||
)
|
||||
.bind(ref_str)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!("SELECT {} FROM artifact WHERE ref = $1", SELECT_COLUMNS);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(ref_str)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,15 +107,14 @@ impl List for ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
ORDER BY created DESC
|
||||
LIMIT 1000",
|
||||
)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact ORDER BY created DESC LIMIT 1000",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,20 +126,28 @@ impl Create for ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"INSERT INTO artifact (ref, scope, owner, type, retention_policy, retention_limit)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, ref, scope, owner, type, retention_policy, retention_limit, created, updated",
|
||||
)
|
||||
.bind(&input.r#ref)
|
||||
.bind(input.scope)
|
||||
.bind(&input.owner)
|
||||
.bind(input.r#type)
|
||||
.bind(input.retention_policy)
|
||||
.bind(input.retention_limit)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!(
|
||||
"INSERT INTO artifact (ref, scope, owner, type, retention_policy, retention_limit, \
|
||||
name, description, content_type, execution, data) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) \
|
||||
RETURNING {}",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(&input.r#ref)
|
||||
.bind(input.scope)
|
||||
.bind(&input.owner)
|
||||
.bind(input.r#type)
|
||||
.bind(input.retention_policy)
|
||||
.bind(input.retention_limit)
|
||||
.bind(&input.name)
|
||||
.bind(&input.description)
|
||||
.bind(&input.content_type)
|
||||
.bind(input.execution)
|
||||
.bind(&input.data)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,59 +159,40 @@ impl Update for ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
// Build update query dynamically
|
||||
let mut query = QueryBuilder::new("UPDATE artifact SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(ref_value) = &input.r#ref {
|
||||
query.push("ref = ").push_bind(ref_value);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(scope) = input.scope {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("scope = ").push_bind(scope);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(owner) = &input.owner {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("owner = ").push_bind(owner);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(artifact_type) = input.r#type {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("type = ").push_bind(artifact_type);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(retention_policy) = input.retention_policy {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query
|
||||
.push("retention_policy = ")
|
||||
.push_bind(retention_policy);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(retention_limit) = input.retention_limit {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("retention_limit = ").push_bind(retention_limit);
|
||||
has_updates = true;
|
||||
macro_rules! push_field {
|
||||
($field:expr, $col:expr) => {
|
||||
if let Some(val) = $field {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push(concat!($col, " = ")).push_bind(val);
|
||||
has_updates = true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
push_field!(&input.r#ref, "ref");
|
||||
push_field!(input.scope, "scope");
|
||||
push_field!(&input.owner, "owner");
|
||||
push_field!(input.r#type, "type");
|
||||
push_field!(input.retention_policy, "retention_policy");
|
||||
push_field!(input.retention_limit, "retention_limit");
|
||||
push_field!(&input.name, "name");
|
||||
push_field!(&input.description, "description");
|
||||
push_field!(&input.content_type, "content_type");
|
||||
push_field!(input.size_bytes, "size_bytes");
|
||||
push_field!(&input.data, "data");
|
||||
|
||||
if !has_updates {
|
||||
// No updates requested, fetch and return existing entity
|
||||
return Self::get_by_id(executor, id).await;
|
||||
}
|
||||
|
||||
query.push(", updated = NOW() WHERE id = ").push_bind(id);
|
||||
query.push(" RETURNING id, ref, scope, owner, type, retention_policy, retention_limit, created, updated");
|
||||
query.push(" RETURNING ");
|
||||
query.push(SELECT_COLUMNS);
|
||||
|
||||
query
|
||||
.build_query_as::<Artifact>()
|
||||
@@ -202,21 +217,113 @@ impl Delete for ArtifactRepository {
|
||||
}
|
||||
|
||||
impl ArtifactRepository {
|
||||
/// Search artifacts with filters and pagination
|
||||
pub async fn search<'e, E>(
|
||||
executor: E,
|
||||
filters: &ArtifactSearchFilters,
|
||||
) -> Result<ArtifactSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
// Build WHERE clauses
|
||||
let mut conditions: Vec<String> = Vec::new();
|
||||
let mut param_idx: usize = 0;
|
||||
|
||||
if filters.scope.is_some() {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("scope = ${}", param_idx));
|
||||
}
|
||||
if filters.owner.is_some() {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("owner = ${}", param_idx));
|
||||
}
|
||||
if filters.r#type.is_some() {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("type = ${}", param_idx));
|
||||
}
|
||||
if filters.execution.is_some() {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("execution = ${}", param_idx));
|
||||
}
|
||||
if filters.name_contains.is_some() {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("name ILIKE '%' || ${} || '%'", param_idx));
|
||||
}
|
||||
|
||||
let where_clause = if conditions.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("WHERE {}", conditions.join(" AND "))
|
||||
};
|
||||
|
||||
// Count query
|
||||
let count_sql = format!("SELECT COUNT(*) AS cnt FROM artifact {}", where_clause);
|
||||
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||
|
||||
// Bind params for count
|
||||
if let Some(scope) = filters.scope {
|
||||
count_query = count_query.bind(scope);
|
||||
}
|
||||
if let Some(ref owner) = filters.owner {
|
||||
count_query = count_query.bind(owner.clone());
|
||||
}
|
||||
if let Some(r#type) = filters.r#type {
|
||||
count_query = count_query.bind(r#type);
|
||||
}
|
||||
if let Some(execution) = filters.execution {
|
||||
count_query = count_query.bind(execution);
|
||||
}
|
||||
if let Some(ref name) = filters.name_contains {
|
||||
count_query = count_query.bind(name.clone());
|
||||
}
|
||||
|
||||
let total = count_query.fetch_one(executor).await?;
|
||||
|
||||
// Data query
|
||||
let limit = filters.limit.min(1000);
|
||||
let offset = filters.offset;
|
||||
let data_sql = format!(
|
||||
"SELECT {} FROM artifact {} ORDER BY created DESC LIMIT {} OFFSET {}",
|
||||
SELECT_COLUMNS, where_clause, limit, offset
|
||||
);
|
||||
|
||||
let mut data_query = sqlx::query_as::<_, Artifact>(&data_sql);
|
||||
|
||||
if let Some(scope) = filters.scope {
|
||||
data_query = data_query.bind(scope);
|
||||
}
|
||||
if let Some(ref owner) = filters.owner {
|
||||
data_query = data_query.bind(owner.clone());
|
||||
}
|
||||
if let Some(r#type) = filters.r#type {
|
||||
data_query = data_query.bind(r#type);
|
||||
}
|
||||
if let Some(execution) = filters.execution {
|
||||
data_query = data_query.bind(execution);
|
||||
}
|
||||
if let Some(ref name) = filters.name_contains {
|
||||
data_query = data_query.bind(name.clone());
|
||||
}
|
||||
|
||||
let rows = data_query.fetch_all(executor).await?;
|
||||
|
||||
Ok(ArtifactSearchResult { rows, total })
|
||||
}
|
||||
|
||||
/// Find artifacts by scope
|
||||
pub async fn find_by_scope<'e, E>(executor: E, scope: OwnerType) -> Result<Vec<Artifact>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
WHERE scope = $1
|
||||
ORDER BY created DESC",
|
||||
)
|
||||
.bind(scope)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact WHERE scope = $1 ORDER BY created DESC",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(scope)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Find artifacts by owner
|
||||
@@ -224,16 +331,15 @@ impl ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
WHERE owner = $1
|
||||
ORDER BY created DESC",
|
||||
)
|
||||
.bind(owner)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact WHERE owner = $1 ORDER BY created DESC",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(owner)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Find artifacts by type
|
||||
@@ -244,19 +350,18 @@ impl ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
WHERE type = $1
|
||||
ORDER BY created DESC",
|
||||
)
|
||||
.bind(artifact_type)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact WHERE type = $1 ORDER BY created DESC",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(artifact_type)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Find artifacts by scope and owner (common query pattern)
|
||||
/// Find artifacts by scope and owner
|
||||
pub async fn find_by_scope_and_owner<'e, E>(
|
||||
executor: E,
|
||||
scope: OwnerType,
|
||||
@@ -265,17 +370,32 @@ impl ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
WHERE scope = $1 AND owner = $2
|
||||
ORDER BY created DESC",
|
||||
)
|
||||
.bind(scope)
|
||||
.bind(owner)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact WHERE scope = $1 AND owner = $2 ORDER BY created DESC",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(scope)
|
||||
.bind(owner)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Find artifacts by execution ID
|
||||
pub async fn find_by_execution<'e, E>(executor: E, execution_id: i64) -> Result<Vec<Artifact>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact WHERE execution = $1 ORDER BY created DESC",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(execution_id)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Find artifacts by retention policy
|
||||
@@ -286,15 +406,297 @@ impl ArtifactRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Artifact>(
|
||||
"SELECT id, ref, scope, owner, type, retention_policy, retention_limit, created, updated
|
||||
FROM artifact
|
||||
WHERE retention_policy = $1
|
||||
ORDER BY created DESC",
|
||||
)
|
||||
.bind(retention_policy)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact WHERE retention_policy = $1 ORDER BY created DESC",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(retention_policy)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Append data to a progress-type artifact.
|
||||
///
|
||||
/// If `artifact.data` is currently NULL, it is initialized as a JSON array
|
||||
/// containing the new entry. Otherwise the entry is appended to the existing
|
||||
/// array. This is done atomically in a single SQL statement.
|
||||
pub async fn append_progress<'e, E>(
|
||||
executor: E,
|
||||
id: i64,
|
||||
entry: &serde_json::Value,
|
||||
) -> Result<Artifact>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"UPDATE artifact \
|
||||
SET data = CASE \
|
||||
WHEN data IS NULL THEN jsonb_build_array($2::jsonb) \
|
||||
ELSE data || jsonb_build_array($2::jsonb) \
|
||||
END, \
|
||||
updated = NOW() \
|
||||
WHERE id = $1 AND type = 'progress' \
|
||||
RETURNING {}",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(id)
|
||||
.bind(entry)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Replace the full data payload on a progress-type artifact (for "set" semantics).
|
||||
pub async fn set_data<'e, E>(executor: E, id: i64, data: &serde_json::Value) -> Result<Artifact>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"UPDATE artifact SET data = $2, updated = NOW() \
|
||||
WHERE id = $1 RETURNING {}",
|
||||
SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Artifact>(&query)
|
||||
.bind(id)
|
||||
.bind(data)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ArtifactVersionRepository
|
||||
// ============================================================================
|
||||
|
||||
use crate::models::artifact_version;
|
||||
|
||||
pub struct ArtifactVersionRepository;
|
||||
|
||||
impl Repository for ArtifactVersionRepository {
|
||||
type Entity = ArtifactVersion;
|
||||
fn table_name() -> &'static str {
|
||||
"artifact_version"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreateArtifactVersionInput {
|
||||
pub artifact: i64,
|
||||
pub content_type: Option<String>,
|
||||
pub content: Option<Vec<u8>>,
|
||||
pub content_json: Option<serde_json::Value>,
|
||||
pub meta: Option<serde_json::Value>,
|
||||
pub created_by: Option<String>,
|
||||
}
|
||||
|
||||
impl ArtifactVersionRepository {
|
||||
/// Find a version by ID (without binary content for performance)
|
||||
pub async fn find_by_id<'e, E>(executor: E, id: i64) -> Result<Option<ArtifactVersion>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact_version WHERE id = $1",
|
||||
artifact_version::SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Find a version by ID including binary content
|
||||
pub async fn find_by_id_with_content<'e, E>(
|
||||
executor: E,
|
||||
id: i64,
|
||||
) -> Result<Option<ArtifactVersion>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact_version WHERE id = $1",
|
||||
artifact_version::SELECT_COLUMNS_WITH_CONTENT
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// List all versions for an artifact (without binary content), newest first
|
||||
pub async fn list_by_artifact<'e, E>(
|
||||
executor: E,
|
||||
artifact_id: i64,
|
||||
) -> Result<Vec<ArtifactVersion>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact_version WHERE artifact = $1 ORDER BY version DESC",
|
||||
artifact_version::SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(artifact_id)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Get the latest version for an artifact (without binary content)
|
||||
pub async fn find_latest<'e, E>(
|
||||
executor: E,
|
||||
artifact_id: i64,
|
||||
) -> Result<Option<ArtifactVersion>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact_version WHERE artifact = $1 ORDER BY version DESC LIMIT 1",
|
||||
artifact_version::SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(artifact_id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Get the latest version for an artifact (with binary content)
|
||||
pub async fn find_latest_with_content<'e, E>(
|
||||
executor: E,
|
||||
artifact_id: i64,
|
||||
) -> Result<Option<ArtifactVersion>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact_version WHERE artifact = $1 ORDER BY version DESC LIMIT 1",
|
||||
artifact_version::SELECT_COLUMNS_WITH_CONTENT
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(artifact_id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Get a specific version by artifact and version number (without binary content)
|
||||
pub async fn find_by_version<'e, E>(
|
||||
executor: E,
|
||||
artifact_id: i64,
|
||||
version: i32,
|
||||
) -> Result<Option<ArtifactVersion>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact_version WHERE artifact = $1 AND version = $2",
|
||||
artifact_version::SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(artifact_id)
|
||||
.bind(version)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Get a specific version by artifact and version number (with binary content)
|
||||
pub async fn find_by_version_with_content<'e, E>(
|
||||
executor: E,
|
||||
artifact_id: i64,
|
||||
version: i32,
|
||||
) -> Result<Option<ArtifactVersion>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT {} FROM artifact_version WHERE artifact = $1 AND version = $2",
|
||||
artifact_version::SELECT_COLUMNS_WITH_CONTENT
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(artifact_id)
|
||||
.bind(version)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Create a new artifact version. The version number is auto-assigned
|
||||
/// (MAX(version) + 1) and the retention trigger fires after insert.
|
||||
pub async fn create<'e, E>(
|
||||
executor: E,
|
||||
input: CreateArtifactVersionInput,
|
||||
) -> Result<ArtifactVersion>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let size_bytes = input.content.as_ref().map(|c| c.len() as i64).or_else(|| {
|
||||
input
|
||||
.content_json
|
||||
.as_ref()
|
||||
.map(|j| serde_json::to_string(j).unwrap_or_default().len() as i64)
|
||||
});
|
||||
|
||||
let query = format!(
|
||||
"INSERT INTO artifact_version \
|
||||
(artifact, version, content_type, size_bytes, content, content_json, meta, created_by) \
|
||||
VALUES ($1, next_artifact_version($1), $2, $3, $4, $5, $6, $7) \
|
||||
RETURNING {}",
|
||||
artifact_version::SELECT_COLUMNS_WITH_CONTENT
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(input.artifact)
|
||||
.bind(&input.content_type)
|
||||
.bind(size_bytes)
|
||||
.bind(&input.content)
|
||||
.bind(&input.content_json)
|
||||
.bind(&input.meta)
|
||||
.bind(&input.created_by)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Delete a specific version by ID
|
||||
pub async fn delete<'e, E>(executor: E, id: i64) -> Result<bool>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM artifact_version WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
/// Delete all versions for an artifact
|
||||
pub async fn delete_all_for_artifact<'e, E>(executor: E, artifact_id: i64) -> Result<u64>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM artifact_version WHERE artifact = $1")
|
||||
.bind(artifact_id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Count versions for an artifact
|
||||
pub async fn count_by_artifact<'e, E>(executor: E, artifact_id: i64) -> Result<i64>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM artifact_version WHERE artifact = $1")
|
||||
.bind(artifact_id)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ pub mod workflow;
|
||||
// Re-export repository types
|
||||
pub use action::{ActionRepository, PolicyRepository};
|
||||
pub use analytics::AnalyticsRepository;
|
||||
pub use artifact::ArtifactRepository;
|
||||
pub use artifact::{ArtifactRepository, ArtifactVersionRepository};
|
||||
pub use entity_history::EntityHistoryRepository;
|
||||
pub use event::{EnforcementRepository, EventRepository};
|
||||
pub use execution::ExecutionRepository;
|
||||
|
||||
@@ -67,6 +67,11 @@ impl ArtifactFixture {
|
||||
r#type: ArtifactType::FileText,
|
||||
retention_policy: RetentionPolicyType::Versions,
|
||||
retention_limit: 5,
|
||||
name: None,
|
||||
description: None,
|
||||
content_type: None,
|
||||
execution: None,
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,6 +254,11 @@ async fn test_update_artifact_all_fields() {
|
||||
r#type: Some(ArtifactType::FileImage),
|
||||
retention_policy: Some(RetentionPolicyType::Days),
|
||||
retention_limit: Some(30),
|
||||
name: Some("Updated Name".to_string()),
|
||||
description: Some("Updated description".to_string()),
|
||||
content_type: Some("image/png".to_string()),
|
||||
size_bytes: Some(12345),
|
||||
data: Some(serde_json::json!({"key": "value"})),
|
||||
};
|
||||
|
||||
let updated = ArtifactRepository::update(&pool, created.id, update_input.clone())
|
||||
|
||||
@@ -31,7 +31,6 @@ clap = { workspace = true }
|
||||
lapin = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
tera = "1.19"
|
||||
serde_yaml_ng = { workspace = true }
|
||||
validator = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
@@ -28,7 +28,3 @@ pub use queue_manager::{ExecutionQueueManager, QueueConfig, QueueStats};
|
||||
pub use retry_manager::{RetryAnalysis, RetryConfig, RetryManager, RetryReason};
|
||||
pub use timeout_monitor::{ExecutionTimeoutMonitor, TimeoutMonitorConfig};
|
||||
pub use worker_health::{HealthMetrics, HealthProbeConfig, HealthStatus, WorkerHealthProbe};
|
||||
pub use workflow::{
|
||||
parse_workflow_yaml, BackoffStrategy, ParseError, TemplateEngine, VariableContext,
|
||||
WorkflowDefinition, WorkflowValidator,
|
||||
};
|
||||
|
||||
@@ -61,9 +61,6 @@ pub type ContextResult<T> = Result<T, ContextError>;
|
||||
/// Errors that can occur during context operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ContextError {
|
||||
#[error("Template rendering error: {0}")]
|
||||
TemplateError(String),
|
||||
|
||||
#[error("Variable not found: {0}")]
|
||||
VariableNotFound(String),
|
||||
|
||||
@@ -200,16 +197,19 @@ impl WorkflowContext {
|
||||
}
|
||||
|
||||
/// Get a workflow-scoped variable by name.
|
||||
#[allow(dead_code)] // Part of complete context API; used in tests
|
||||
pub fn get_var(&self, name: &str) -> Option<JsonValue> {
|
||||
self.variables.get(name).map(|entry| entry.value().clone())
|
||||
}
|
||||
|
||||
/// Store a completed task's result (accessible as `task.<name>.*`).
|
||||
#[allow(dead_code)] // Part of complete context API; used in tests
|
||||
pub fn set_task_result(&mut self, task_name: &str, result: JsonValue) {
|
||||
self.task_results.insert(task_name.to_string(), result);
|
||||
}
|
||||
|
||||
/// Get a task result by task name.
|
||||
#[allow(dead_code)] // Part of complete context API; used in tests
|
||||
pub fn get_task_result(&self, task_name: &str) -> Option<JsonValue> {
|
||||
self.task_results
|
||||
.get(task_name)
|
||||
@@ -217,11 +217,13 @@ impl WorkflowContext {
|
||||
}
|
||||
|
||||
/// Set the pack configuration (accessible as `config.<key>`).
|
||||
#[allow(dead_code)] // Part of complete context API; used in tests
|
||||
pub fn set_pack_config(&mut self, config: JsonValue) {
|
||||
self.pack_config = Arc::new(config);
|
||||
}
|
||||
|
||||
/// Set the keystore secrets (accessible as `keystore.<key>`).
|
||||
#[allow(dead_code)] // Part of complete context API; used in tests
|
||||
pub fn set_keystore(&mut self, secrets: JsonValue) {
|
||||
self.keystore = Arc::new(secrets);
|
||||
}
|
||||
@@ -233,6 +235,7 @@ impl WorkflowContext {
|
||||
}
|
||||
|
||||
/// Clear current item
|
||||
#[allow(dead_code)] // Part of complete context API; symmetric with set_current_item
|
||||
pub fn clear_current_item(&mut self) {
|
||||
self.current_item = None;
|
||||
self.current_index = None;
|
||||
@@ -440,6 +443,7 @@ impl WorkflowContext {
|
||||
}
|
||||
|
||||
/// Export context for storage
|
||||
#[allow(dead_code)] // Part of complete context API; used in tests
|
||||
pub fn export(&self) -> JsonValue {
|
||||
let variables: HashMap<String, JsonValue> = self
|
||||
.variables
|
||||
@@ -470,6 +474,7 @@ impl WorkflowContext {
|
||||
}
|
||||
|
||||
/// Import context from stored data
|
||||
#[allow(dead_code)] // Part of complete context API; used in tests
|
||||
pub fn import(data: JsonValue) -> ContextResult<Self> {
|
||||
let variables = DashMap::new();
|
||||
if let Some(obj) = data["variables"].as_object() {
|
||||
@@ -677,7 +682,9 @@ mod tests {
|
||||
ctx.set_var("greeting", json!("Hello"));
|
||||
|
||||
// Canonical: workflow.<name>
|
||||
let result = ctx.render_template("{{ workflow.greeting }} World").unwrap();
|
||||
let result = ctx
|
||||
.render_template("{{ workflow.greeting }} World")
|
||||
.unwrap();
|
||||
assert_eq!(result, "Hello World");
|
||||
}
|
||||
|
||||
@@ -699,7 +706,9 @@ mod tests {
|
||||
let ctx = WorkflowContext::new(json!({}), vars);
|
||||
|
||||
// Backward-compat alias: variables.<name>
|
||||
let result = ctx.render_template("{{ variables.greeting }} World").unwrap();
|
||||
let result = ctx
|
||||
.render_template("{{ variables.greeting }} World")
|
||||
.unwrap();
|
||||
assert_eq!(result, "Hello World");
|
||||
}
|
||||
|
||||
@@ -735,7 +744,9 @@ mod tests {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_task_result("fetch", json!({"result": {"data": {"id": 42}}}));
|
||||
|
||||
let val = ctx.evaluate_expression("task.fetch.result.data.id").unwrap();
|
||||
let val = ctx
|
||||
.evaluate_expression("task.fetch.result.data.id")
|
||||
.unwrap();
|
||||
assert_eq!(val, json!(42));
|
||||
}
|
||||
|
||||
@@ -744,7 +755,9 @@ mod tests {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_task_result("run_cmd", json!({"result": {"stdout": "hello world"}}));
|
||||
|
||||
let val = ctx.evaluate_expression("task.run_cmd.result.stdout").unwrap();
|
||||
let val = ctx
|
||||
.evaluate_expression("task.run_cmd.result.stdout")
|
||||
.unwrap();
|
||||
assert_eq!(val, json!("hello world"));
|
||||
}
|
||||
|
||||
@@ -755,14 +768,14 @@ mod tests {
|
||||
#[test]
|
||||
fn test_config_namespace() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_pack_config(json!({"api_token": "tok_abc123", "base_url": "https://api.example.com"}));
|
||||
ctx.set_pack_config(
|
||||
json!({"api_token": "tok_abc123", "base_url": "https://api.example.com"}),
|
||||
);
|
||||
|
||||
let val = ctx.evaluate_expression("config.api_token").unwrap();
|
||||
assert_eq!(val, json!("tok_abc123"));
|
||||
|
||||
let result = ctx
|
||||
.render_template("URL: {{ config.base_url }}")
|
||||
.unwrap();
|
||||
let result = ctx.render_template("URL: {{ config.base_url }}").unwrap();
|
||||
assert_eq!(result, "URL: https://api.example.com");
|
||||
}
|
||||
|
||||
@@ -796,7 +809,9 @@ mod tests {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_keystore(json!({"My Secret Key": "value123"}));
|
||||
|
||||
let val = ctx.evaluate_expression("keystore[\"My Secret Key\"]").unwrap();
|
||||
let val = ctx
|
||||
.evaluate_expression("keystore[\"My Secret Key\"]")
|
||||
.unwrap();
|
||||
assert_eq!(val, json!("value123"));
|
||||
}
|
||||
|
||||
@@ -850,9 +865,7 @@ mod tests {
|
||||
assert!(ctx
|
||||
.evaluate_condition("parameters.x > 50 or parameters.y > 15")
|
||||
.unwrap());
|
||||
assert!(ctx
|
||||
.evaluate_condition("not parameters.x > 50")
|
||||
.unwrap());
|
||||
assert!(ctx.evaluate_condition("not parameters.x > 50").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -863,16 +876,15 @@ mod tests {
|
||||
assert!(ctx.evaluate_condition("\"admin\" in roles").unwrap());
|
||||
assert!(!ctx.evaluate_condition("\"root\" in roles").unwrap());
|
||||
// Via canonical workflow namespace
|
||||
assert!(ctx.evaluate_condition("\"admin\" in workflow.roles").unwrap());
|
||||
assert!(ctx
|
||||
.evaluate_condition("\"admin\" in workflow.roles")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_condition_with_function_calls() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_last_task_outcome(
|
||||
json!({"status": "ok", "code": 200}),
|
||||
TaskOutcome::Succeeded,
|
||||
);
|
||||
ctx.set_last_task_outcome(json!({"status": "ok", "code": 200}), TaskOutcome::Succeeded);
|
||||
assert!(ctx.evaluate_condition("succeeded()").unwrap());
|
||||
assert!(!ctx.evaluate_condition("failed()").unwrap());
|
||||
assert!(ctx
|
||||
@@ -889,9 +901,7 @@ mod tests {
|
||||
ctx.set_var("items", json!([1, 2, 3, 4, 5]));
|
||||
assert!(ctx.evaluate_condition("length(items) > 3").unwrap());
|
||||
assert!(!ctx.evaluate_condition("length(items) > 10").unwrap());
|
||||
assert!(ctx
|
||||
.evaluate_condition("length(items) == 5")
|
||||
.unwrap());
|
||||
assert!(ctx.evaluate_condition("length(items) == 5").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -916,10 +926,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_expression_string_concat() {
|
||||
let ctx = WorkflowContext::new(
|
||||
json!({"first": "Hello", "second": "World"}),
|
||||
HashMap::new(),
|
||||
);
|
||||
let ctx =
|
||||
WorkflowContext::new(json!({"first": "Hello", "second": "World"}), HashMap::new());
|
||||
let input = json!({"msg": "{{ parameters.first + \" \" + parameters.second }}"});
|
||||
let result = ctx.render_json(&input).unwrap();
|
||||
assert_eq!(result["msg"], json!("Hello World"));
|
||||
|
||||
@@ -1,776 +0,0 @@
|
||||
//! Workflow Execution Coordinator
|
||||
//!
|
||||
//! This module orchestrates workflow execution, managing task dependencies,
|
||||
//! parallel execution, state transitions, and error handling.
|
||||
|
||||
use crate::workflow::context::WorkflowContext;
|
||||
use crate::workflow::graph::{TaskGraph, TaskNode};
|
||||
use crate::workflow::task_executor::{TaskExecutionResult, TaskExecutionStatus, TaskExecutor};
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::{
|
||||
execution::{Execution, WorkflowTaskMetadata},
|
||||
ExecutionStatus, Id, WorkflowExecution,
|
||||
};
|
||||
use attune_common::mq::MessageQueue;
|
||||
use attune_common::workflow::WorkflowDefinition;
|
||||
use chrono::Utc;
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use sqlx::PgPool;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Workflow execution coordinator
|
||||
pub struct WorkflowCoordinator {
|
||||
db_pool: PgPool,
|
||||
mq: MessageQueue,
|
||||
task_executor: TaskExecutor,
|
||||
}
|
||||
|
||||
impl WorkflowCoordinator {
|
||||
/// Create a new workflow coordinator
|
||||
pub fn new(db_pool: PgPool, mq: MessageQueue) -> Self {
|
||||
let task_executor = TaskExecutor::new(db_pool.clone(), mq.clone());
|
||||
|
||||
Self {
|
||||
db_pool,
|
||||
mq,
|
||||
task_executor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a new workflow execution
|
||||
pub async fn start_workflow(
|
||||
&self,
|
||||
workflow_ref: &str,
|
||||
parameters: JsonValue,
|
||||
parent_execution_id: Option<Id>,
|
||||
) -> Result<WorkflowExecutionHandle> {
|
||||
info!(
|
||||
"Starting workflow: {} with params: {:?}",
|
||||
workflow_ref, parameters
|
||||
);
|
||||
|
||||
// Load workflow definition
|
||||
let workflow_def = sqlx::query_as::<_, attune_common::models::WorkflowDefinition>(
|
||||
"SELECT * FROM attune.workflow_definition WHERE ref = $1",
|
||||
)
|
||||
.bind(workflow_ref)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await?
|
||||
.ok_or_else(|| Error::not_found("workflow_definition", "ref", workflow_ref))?;
|
||||
|
||||
if !workflow_def.enabled {
|
||||
return Err(Error::validation("Workflow is disabled"));
|
||||
}
|
||||
|
||||
// Parse workflow definition
|
||||
let definition: WorkflowDefinition = serde_json::from_value(workflow_def.definition)
|
||||
.map_err(|e| Error::validation(format!("Invalid workflow definition: {}", e)))?;
|
||||
|
||||
// Build task graph
|
||||
let graph = TaskGraph::from_workflow(&definition)
|
||||
.map_err(|e| Error::validation(format!("Failed to build task graph: {}", e)))?;
|
||||
|
||||
// Create parent execution record
|
||||
// TODO: Implement proper execution creation
|
||||
let _parent_execution_id_temp = parent_execution_id.unwrap_or(1); // Placeholder
|
||||
|
||||
let parent_execution = sqlx::query_as::<_, attune_common::models::Execution>(
|
||||
r#"
|
||||
INSERT INTO attune.execution (action_ref, pack, input, parent, status)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(workflow_ref)
|
||||
.bind(workflow_def.pack)
|
||||
.bind(¶meters)
|
||||
.bind(parent_execution_id)
|
||||
.bind(ExecutionStatus::Running)
|
||||
.fetch_one(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
// Initialize workflow context
|
||||
let initial_vars: HashMap<String, JsonValue> = definition
|
||||
.vars
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect();
|
||||
let context = WorkflowContext::new(parameters, initial_vars);
|
||||
|
||||
// Create workflow execution record
|
||||
let workflow_execution = self
|
||||
.create_workflow_execution_record(
|
||||
parent_execution.id,
|
||||
workflow_def.id,
|
||||
&graph,
|
||||
&context,
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"Created workflow execution {} for workflow {}",
|
||||
workflow_execution.id, workflow_ref
|
||||
);
|
||||
|
||||
// Create execution handle
|
||||
let handle = WorkflowExecutionHandle {
|
||||
coordinator: Arc::new(self.clone_ref()),
|
||||
execution_id: workflow_execution.id,
|
||||
parent_execution_id: parent_execution.id,
|
||||
workflow_def_id: workflow_def.id,
|
||||
graph,
|
||||
state: Arc::new(Mutex::new(WorkflowExecutionState {
|
||||
context,
|
||||
status: ExecutionStatus::Running,
|
||||
completed_tasks: HashSet::new(),
|
||||
failed_tasks: HashSet::new(),
|
||||
skipped_tasks: HashSet::new(),
|
||||
executing_tasks: HashSet::new(),
|
||||
scheduled_tasks: HashSet::new(),
|
||||
join_state: HashMap::new(),
|
||||
task_executions: HashMap::new(),
|
||||
paused: false,
|
||||
pause_reason: None,
|
||||
error_message: None,
|
||||
})),
|
||||
};
|
||||
|
||||
// Update execution status to running
|
||||
self.update_workflow_execution_status(workflow_execution.id, ExecutionStatus::Running)
|
||||
.await?;
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Create workflow execution record in database
|
||||
async fn create_workflow_execution_record(
|
||||
&self,
|
||||
execution_id: Id,
|
||||
workflow_def_id: Id,
|
||||
graph: &TaskGraph,
|
||||
context: &WorkflowContext,
|
||||
) -> Result<WorkflowExecution> {
|
||||
let task_graph_json = serde_json::to_value(graph)
|
||||
.map_err(|e| Error::internal(format!("Failed to serialize task graph: {}", e)))?;
|
||||
|
||||
let variables = context.export();
|
||||
|
||||
sqlx::query_as::<_, WorkflowExecution>(
|
||||
r#"
|
||||
INSERT INTO attune.workflow_execution (
|
||||
execution, workflow_def, current_tasks, completed_tasks,
|
||||
failed_tasks, skipped_tasks, variables, task_graph,
|
||||
status, paused
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(execution_id)
|
||||
.bind(workflow_def_id)
|
||||
.bind(&[] as &[String])
|
||||
.bind(&[] as &[String])
|
||||
.bind(&[] as &[String])
|
||||
.bind(&[] as &[String])
|
||||
.bind(variables)
|
||||
.bind(task_graph_json)
|
||||
.bind(ExecutionStatus::Running)
|
||||
.bind(false)
|
||||
.fetch_one(&self.db_pool)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Update workflow execution status
|
||||
async fn update_workflow_execution_status(
|
||||
&self,
|
||||
workflow_execution_id: Id,
|
||||
status: ExecutionStatus,
|
||||
) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE attune.workflow_execution
|
||||
SET status = $1, updated = NOW()
|
||||
WHERE id = $2
|
||||
"#,
|
||||
)
|
||||
.bind(status)
|
||||
.bind(workflow_execution_id)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update workflow execution state
|
||||
async fn update_workflow_execution_state(
|
||||
&self,
|
||||
workflow_execution_id: Id,
|
||||
state: &WorkflowExecutionState,
|
||||
) -> Result<()> {
|
||||
let current_tasks: Vec<String> = state.executing_tasks.iter().cloned().collect();
|
||||
let completed_tasks: Vec<String> = state.completed_tasks.iter().cloned().collect();
|
||||
let failed_tasks: Vec<String> = state.failed_tasks.iter().cloned().collect();
|
||||
let skipped_tasks: Vec<String> = state.skipped_tasks.iter().cloned().collect();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE attune.workflow_execution
|
||||
SET
|
||||
current_tasks = $1,
|
||||
completed_tasks = $2,
|
||||
failed_tasks = $3,
|
||||
skipped_tasks = $4,
|
||||
variables = $5,
|
||||
status = $6,
|
||||
paused = $7,
|
||||
pause_reason = $8,
|
||||
error_message = $9,
|
||||
updated = NOW()
|
||||
WHERE id = $10
|
||||
"#,
|
||||
)
|
||||
.bind(¤t_tasks)
|
||||
.bind(&completed_tasks)
|
||||
.bind(&failed_tasks)
|
||||
.bind(&skipped_tasks)
|
||||
.bind(state.context.export())
|
||||
.bind(state.status)
|
||||
.bind(state.paused)
|
||||
.bind(&state.pause_reason)
|
||||
.bind(&state.error_message)
|
||||
.bind(workflow_execution_id)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a task execution record
|
||||
async fn create_task_execution_record(
|
||||
&self,
|
||||
workflow_execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
task: &TaskNode,
|
||||
task_index: Option<i32>,
|
||||
task_batch: Option<i32>,
|
||||
) -> Result<Execution> {
|
||||
let max_retries = task.retry.as_ref().map(|r| r.count as i32).unwrap_or(0);
|
||||
let timeout = task.timeout.map(|t| t as i32);
|
||||
|
||||
// Create workflow task metadata
|
||||
let workflow_task = WorkflowTaskMetadata {
|
||||
workflow_execution: workflow_execution_id,
|
||||
task_name: task.name.clone(),
|
||||
task_index,
|
||||
task_batch,
|
||||
retry_count: 0,
|
||||
max_retries,
|
||||
next_retry_at: None,
|
||||
timeout_seconds: timeout,
|
||||
timed_out: false,
|
||||
duration_ms: None,
|
||||
started_at: Some(Utc::now()),
|
||||
completed_at: None,
|
||||
};
|
||||
|
||||
sqlx::query_as::<_, Execution>(
|
||||
r#"
|
||||
INSERT INTO attune.execution (
|
||||
action_ref, parent, status, workflow_task
|
||||
)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(&task.name)
|
||||
.bind(parent_execution_id)
|
||||
.bind(ExecutionStatus::Running)
|
||||
.bind(sqlx::types::Json(&workflow_task))
|
||||
.fetch_one(&self.db_pool)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Update task execution record
|
||||
async fn update_task_execution_record(
|
||||
&self,
|
||||
task_execution_id: Id,
|
||||
result: &TaskExecutionResult,
|
||||
) -> Result<()> {
|
||||
let status = match result.status {
|
||||
TaskExecutionStatus::Success => ExecutionStatus::Completed,
|
||||
TaskExecutionStatus::Failed => ExecutionStatus::Failed,
|
||||
TaskExecutionStatus::Timeout => ExecutionStatus::Timeout,
|
||||
TaskExecutionStatus::Skipped => ExecutionStatus::Cancelled,
|
||||
};
|
||||
|
||||
// Fetch current execution to get workflow_task metadata
|
||||
let execution =
|
||||
sqlx::query_as::<_, Execution>("SELECT * FROM attune.execution WHERE id = $1")
|
||||
.bind(task_execution_id)
|
||||
.fetch_one(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
// Update workflow_task metadata
|
||||
if let Some(mut workflow_task) = execution.workflow_task {
|
||||
workflow_task.completed_at = if result.status == TaskExecutionStatus::Success {
|
||||
Some(Utc::now())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
workflow_task.duration_ms = Some(result.duration_ms);
|
||||
workflow_task.retry_count = result.retry_count;
|
||||
workflow_task.next_retry_at = result.next_retry_at;
|
||||
workflow_task.timed_out = result.status == TaskExecutionStatus::Timeout;
|
||||
|
||||
let _error_json = result.error.as_ref().map(|e| {
|
||||
json!({
|
||||
"message": e.message,
|
||||
"type": e.error_type,
|
||||
"details": e.details
|
||||
})
|
||||
});
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE attune.execution
|
||||
SET
|
||||
status = $1,
|
||||
result = $2,
|
||||
workflow_task = $3,
|
||||
updated = NOW()
|
||||
WHERE id = $4
|
||||
"#,
|
||||
)
|
||||
.bind(status)
|
||||
.bind(&result.output)
|
||||
.bind(sqlx::types::Json(&workflow_task))
|
||||
.bind(task_execution_id)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clone reference for Arc sharing
|
||||
fn clone_ref(&self) -> Self {
|
||||
Self {
|
||||
db_pool: self.db_pool.clone(),
|
||||
mq: self.mq.clone(),
|
||||
task_executor: TaskExecutor::new(self.db_pool.clone(), self.mq.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Workflow execution state
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkflowExecutionState {
|
||||
pub context: WorkflowContext,
|
||||
pub status: ExecutionStatus,
|
||||
pub completed_tasks: HashSet<String>,
|
||||
pub failed_tasks: HashSet<String>,
|
||||
pub skipped_tasks: HashSet<String>,
|
||||
/// Tasks currently executing
|
||||
pub executing_tasks: HashSet<String>,
|
||||
/// Tasks scheduled but not yet executing
|
||||
pub scheduled_tasks: HashSet<String>,
|
||||
/// Join state tracking: task_name -> set of completed predecessor tasks
|
||||
pub join_state: HashMap<String, HashSet<String>>,
|
||||
pub task_executions: HashMap<String, Vec<Id>>,
|
||||
pub paused: bool,
|
||||
pub pause_reason: Option<String>,
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
|
||||
/// Handle for managing a workflow execution
|
||||
pub struct WorkflowExecutionHandle {
|
||||
coordinator: Arc<WorkflowCoordinator>,
|
||||
execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
#[allow(dead_code)]
|
||||
workflow_def_id: Id,
|
||||
graph: TaskGraph,
|
||||
state: Arc<Mutex<WorkflowExecutionState>>,
|
||||
}
|
||||
|
||||
impl WorkflowExecutionHandle {
|
||||
/// Execute the workflow to completion
|
||||
pub async fn execute(&self) -> Result<WorkflowExecutionResult> {
|
||||
info!("Executing workflow {}", self.execution_id);
|
||||
|
||||
// Start with entry point tasks
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
for task_name in &self.graph.entry_points {
|
||||
info!("Scheduling entry point task: {}", task_name);
|
||||
state.scheduled_tasks.insert(task_name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tasks to complete
|
||||
loop {
|
||||
// Check for and spawn scheduled tasks
|
||||
let tasks_to_spawn = {
|
||||
let mut state = self.state.lock().await;
|
||||
let mut to_spawn = Vec::new();
|
||||
for task_name in state.scheduled_tasks.iter() {
|
||||
to_spawn.push(task_name.clone());
|
||||
}
|
||||
// Clear scheduled tasks as we're about to spawn them
|
||||
state.scheduled_tasks.clear();
|
||||
to_spawn
|
||||
};
|
||||
|
||||
// Spawn scheduled tasks
|
||||
for task_name in tasks_to_spawn {
|
||||
self.spawn_task_execution(task_name).await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let state = self.state.lock().await;
|
||||
|
||||
// Check if workflow is paused
|
||||
if state.paused {
|
||||
info!("Workflow {} is paused", self.execution_id);
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if workflow is complete (nothing executing and nothing scheduled)
|
||||
if state.executing_tasks.is_empty() && state.scheduled_tasks.is_empty() {
|
||||
info!("Workflow {} completed", self.execution_id);
|
||||
drop(state);
|
||||
|
||||
let mut state = self.state.lock().await;
|
||||
if state.failed_tasks.is_empty() {
|
||||
state.status = ExecutionStatus::Completed;
|
||||
} else {
|
||||
state.status = ExecutionStatus::Failed;
|
||||
state.error_message = Some(format!(
|
||||
"Workflow failed: {} tasks failed",
|
||||
state.failed_tasks.len()
|
||||
));
|
||||
}
|
||||
self.coordinator
|
||||
.update_workflow_execution_state(self.execution_id, &state)
|
||||
.await?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let state = self.state.lock().await;
|
||||
Ok(WorkflowExecutionResult {
|
||||
status: state.status,
|
||||
output: state.context.export(),
|
||||
completed_tasks: state.completed_tasks.len(),
|
||||
failed_tasks: state.failed_tasks.len(),
|
||||
skipped_tasks: state.skipped_tasks.len(),
|
||||
error_message: state.error_message.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Spawn a task execution in a new tokio task
|
||||
async fn spawn_task_execution(&self, task_name: String) {
|
||||
let coordinator = self.coordinator.clone();
|
||||
let state_arc = self.state.clone();
|
||||
let workflow_execution_id = self.execution_id;
|
||||
let parent_execution_id = self.parent_execution_id;
|
||||
let graph = self.graph.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = Self::execute_task_async(
|
||||
coordinator,
|
||||
state_arc,
|
||||
workflow_execution_id,
|
||||
parent_execution_id,
|
||||
graph,
|
||||
task_name,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Task execution failed: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Execute a single task asynchronously
|
||||
async fn execute_task_async(
|
||||
coordinator: Arc<WorkflowCoordinator>,
|
||||
state: Arc<Mutex<WorkflowExecutionState>>,
|
||||
workflow_execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
graph: TaskGraph,
|
||||
task_name: String,
|
||||
) -> Result<()> {
|
||||
// Move task from scheduled to executing
|
||||
let task = {
|
||||
let mut state = state.lock().await;
|
||||
state.scheduled_tasks.remove(&task_name);
|
||||
state.executing_tasks.insert(task_name.clone());
|
||||
|
||||
// Get the task node
|
||||
match graph.get_task(&task_name) {
|
||||
Some(task) => task.clone(),
|
||||
None => {
|
||||
error!("Task {} not found in graph", task_name);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
info!("Executing task: {}", task.name);
|
||||
|
||||
// Create task execution record
|
||||
let task_execution = coordinator
|
||||
.create_task_execution_record(
|
||||
workflow_execution_id,
|
||||
parent_execution_id,
|
||||
&task,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Get context for execution
|
||||
let mut context = {
|
||||
let state = state.lock().await;
|
||||
state.context.clone()
|
||||
};
|
||||
|
||||
// Execute task
|
||||
let result = coordinator
|
||||
.task_executor
|
||||
.execute_task(
|
||||
&task,
|
||||
&mut context,
|
||||
workflow_execution_id,
|
||||
parent_execution_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Update task execution record
|
||||
coordinator
|
||||
.update_task_execution_record(task_execution.id, &result)
|
||||
.await?;
|
||||
|
||||
// Update workflow state based on result
|
||||
let success = matches!(result.status, TaskExecutionStatus::Success);
|
||||
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.executing_tasks.remove(&task.name);
|
||||
|
||||
match result.status {
|
||||
TaskExecutionStatus::Success => {
|
||||
state.completed_tasks.insert(task.name.clone());
|
||||
// Update context with task result
|
||||
if let Some(output) = result.output {
|
||||
state.context.set_task_result(&task.name, output);
|
||||
}
|
||||
}
|
||||
TaskExecutionStatus::Failed => {
|
||||
if result.should_retry {
|
||||
// Task will be retried, keep it in scheduled
|
||||
info!("Task {} will be retried", task.name);
|
||||
state.scheduled_tasks.insert(task.name.clone());
|
||||
// TODO: Schedule retry with delay
|
||||
} else {
|
||||
state.failed_tasks.insert(task.name.clone());
|
||||
if let Some(ref error) = result.error {
|
||||
warn!("Task {} failed: {}", task.name, error.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
TaskExecutionStatus::Timeout => {
|
||||
state.failed_tasks.insert(task.name.clone());
|
||||
warn!("Task {} timed out", task.name);
|
||||
}
|
||||
TaskExecutionStatus::Skipped => {
|
||||
state.skipped_tasks.insert(task.name.clone());
|
||||
debug!("Task {} skipped", task.name);
|
||||
}
|
||||
}
|
||||
|
||||
// Persist state
|
||||
coordinator
|
||||
.update_workflow_execution_state(workflow_execution_id, &state)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Evaluate transitions and schedule next tasks
|
||||
Self::on_task_completion(state.clone(), graph.clone(), task.name.clone(), success).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle task completion by evaluating transitions and scheduling next tasks
|
||||
async fn on_task_completion(
|
||||
state: Arc<Mutex<WorkflowExecutionState>>,
|
||||
graph: TaskGraph,
|
||||
completed_task: String,
|
||||
success: bool,
|
||||
) -> Result<()> {
|
||||
// Get next tasks based on transitions
|
||||
let next_tasks = graph.next_tasks(&completed_task, success);
|
||||
|
||||
info!(
|
||||
"Task {} completed (success={}), next tasks: {:?}",
|
||||
completed_task, success, next_tasks
|
||||
);
|
||||
|
||||
// Collect tasks to schedule
|
||||
let mut tasks_to_schedule = Vec::new();
|
||||
|
||||
for next_task_name in next_tasks {
|
||||
let mut state = state.lock().await;
|
||||
|
||||
// Check if task already scheduled or executing
|
||||
if state.scheduled_tasks.contains(&next_task_name)
|
||||
|| state.executing_tasks.contains(&next_task_name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(task_node) = graph.get_task(&next_task_name) {
|
||||
// Check join conditions
|
||||
if let Some(join_count) = task_node.join {
|
||||
// Update join state
|
||||
let join_completions = state
|
||||
.join_state
|
||||
.entry(next_task_name.clone())
|
||||
.or_insert_with(HashSet::new);
|
||||
join_completions.insert(completed_task.clone());
|
||||
|
||||
// Check if join is satisfied
|
||||
if join_completions.len() >= join_count {
|
||||
info!(
|
||||
"Join condition satisfied for task {}: {}/{} completed",
|
||||
next_task_name,
|
||||
join_completions.len(),
|
||||
join_count
|
||||
);
|
||||
state.scheduled_tasks.insert(next_task_name.clone());
|
||||
tasks_to_schedule.push(next_task_name);
|
||||
} else {
|
||||
info!(
|
||||
"Join condition not yet satisfied for task {}: {}/{} completed",
|
||||
next_task_name,
|
||||
join_completions.len(),
|
||||
join_count
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// No join, schedule immediately
|
||||
state.scheduled_tasks.insert(next_task_name.clone());
|
||||
tasks_to_schedule.push(next_task_name);
|
||||
}
|
||||
} else {
|
||||
error!("Next task {} not found in graph", next_task_name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Pause workflow execution
|
||||
pub async fn pause(&self, reason: Option<String>) -> Result<()> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.paused = true;
|
||||
state.pause_reason = reason;
|
||||
|
||||
self.coordinator
|
||||
.update_workflow_execution_state(self.execution_id, &state)
|
||||
.await?;
|
||||
|
||||
info!("Workflow {} paused", self.execution_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resume workflow execution
|
||||
pub async fn resume(&self) -> Result<()> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.paused = false;
|
||||
state.pause_reason = None;
|
||||
|
||||
self.coordinator
|
||||
.update_workflow_execution_state(self.execution_id, &state)
|
||||
.await?;
|
||||
|
||||
info!("Workflow {} resumed", self.execution_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cancel workflow execution
|
||||
pub async fn cancel(&self) -> Result<()> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.status = ExecutionStatus::Cancelled;
|
||||
|
||||
self.coordinator
|
||||
.update_workflow_execution_state(self.execution_id, &state)
|
||||
.await?;
|
||||
|
||||
info!("Workflow {} cancelled", self.execution_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current execution status
|
||||
pub async fn status(&self) -> WorkflowExecutionStatus {
|
||||
let state = self.state.lock().await;
|
||||
WorkflowExecutionStatus {
|
||||
execution_id: self.execution_id,
|
||||
status: state.status,
|
||||
completed_tasks: state.completed_tasks.len(),
|
||||
failed_tasks: state.failed_tasks.len(),
|
||||
skipped_tasks: state.skipped_tasks.len(),
|
||||
executing_tasks: state.executing_tasks.iter().cloned().collect(),
|
||||
scheduled_tasks: state.scheduled_tasks.iter().cloned().collect(),
|
||||
total_tasks: self.graph.nodes.len(),
|
||||
paused: state.paused,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of workflow execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkflowExecutionResult {
|
||||
pub status: ExecutionStatus,
|
||||
pub output: JsonValue,
|
||||
pub completed_tasks: usize,
|
||||
pub failed_tasks: usize,
|
||||
pub skipped_tasks: usize,
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
|
||||
/// Current status of workflow execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkflowExecutionStatus {
|
||||
pub execution_id: Id,
|
||||
pub status: ExecutionStatus,
|
||||
pub completed_tasks: usize,
|
||||
pub failed_tasks: usize,
|
||||
pub skipped_tasks: usize,
|
||||
pub executing_tasks: Vec<String>,
|
||||
pub scheduled_tasks: Vec<String>,
|
||||
pub total_tasks: usize,
|
||||
pub paused: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
// Note: These tests require a database connection and are integration tests
|
||||
// They should be run with `cargo test --features integration-tests`
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires database
|
||||
async fn test_workflow_coordinator_creation() {
|
||||
// This is a placeholder test
|
||||
// Actual tests would require database setup
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
@@ -21,9 +21,6 @@ pub type GraphResult<T> = Result<T, GraphError>;
|
||||
pub enum GraphError {
|
||||
#[error("Invalid task reference: {0}")]
|
||||
InvalidTaskReference(String),
|
||||
|
||||
#[error("Graph building error: {0}")]
|
||||
BuildError(String),
|
||||
}
|
||||
|
||||
/// Executable task graph
|
||||
@@ -197,6 +194,7 @@ impl TaskGraph {
|
||||
}
|
||||
|
||||
/// Get all tasks that can transition into the given task (inbound edges)
|
||||
#[allow(dead_code)] // Part of complete graph API; used in tests
|
||||
pub fn get_inbound_tasks(&self, task_name: &str) -> Vec<String> {
|
||||
self.inbound_edges
|
||||
.get(task_name)
|
||||
@@ -221,7 +219,8 @@ impl TaskGraph {
|
||||
/// * `success` - Whether the task succeeded
|
||||
///
|
||||
/// # Returns
|
||||
/// A vector of (task_name, publish_vars) tuples to schedule next
|
||||
/// A vector of task names to schedule next
|
||||
#[allow(dead_code)] // Part of complete graph API; used in tests
|
||||
pub fn next_tasks(&self, task_name: &str, success: bool) -> Vec<String> {
|
||||
let mut next = Vec::new();
|
||||
|
||||
@@ -251,7 +250,8 @@ impl TaskGraph {
|
||||
/// Get the next tasks with full transition information.
|
||||
///
|
||||
/// Returns matching transitions with their publish directives and targets,
|
||||
/// giving the coordinator full context for variable publishing.
|
||||
/// giving the caller full context for variable publishing.
|
||||
#[allow(dead_code)] // Part of complete graph API; used in tests
|
||||
pub fn matching_transitions(&self, task_name: &str, success: bool) -> Vec<&GraphTransition> {
|
||||
let mut matching = Vec::new();
|
||||
|
||||
@@ -275,6 +275,7 @@ impl TaskGraph {
|
||||
}
|
||||
|
||||
/// Collect all unique target task names from all transitions of a given task.
|
||||
#[allow(dead_code)] // Part of complete graph API; used in tests
|
||||
pub fn all_transition_targets(&self, task_name: &str) -> HashSet<String> {
|
||||
let mut targets = HashSet::new();
|
||||
if let Some(node) = self.nodes.get(task_name) {
|
||||
|
||||
@@ -1,60 +1,12 @@
|
||||
//! Workflow orchestration module
|
||||
//!
|
||||
//! This module provides workflow execution, orchestration, parsing, validation,
|
||||
//! and template rendering capabilities for the Attune workflow orchestration system.
|
||||
//! This module provides workflow execution context, graph building, and
|
||||
//! orchestration capabilities for the Attune workflow engine.
|
||||
//!
|
||||
//! # Modules
|
||||
//!
|
||||
//! - `parser`: Parse YAML workflow definitions into structured types
|
||||
//! - `graph`: Build executable task graphs from workflow definitions
|
||||
//! - `context`: Manage workflow execution context and variables
|
||||
//! - `task_executor`: Execute individual workflow tasks
|
||||
//! - `coordinator`: Orchestrate workflow execution with state management
|
||||
//! - `template`: Template engine for variable interpolation (Jinja2-like syntax)
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use attune_executor::workflow::{parse_workflow_yaml, WorkflowCoordinator};
|
||||
//!
|
||||
//! // Parse a workflow YAML file
|
||||
//! let yaml = r#"
|
||||
//! ref: my_pack.my_workflow
|
||||
//! label: My Workflow
|
||||
//! version: 1.0.0
|
||||
//! tasks:
|
||||
//! - name: hello
|
||||
//! action: core.echo
|
||||
//! input:
|
||||
//! message: "{{ parameters.name }}"
|
||||
//! "#;
|
||||
//!
|
||||
//! let workflow = parse_workflow_yaml(yaml).expect("Failed to parse workflow");
|
||||
//! ```
|
||||
//! - `graph`: Build executable task graphs from workflow definitions
|
||||
|
||||
// Phase 2: Workflow Execution Engine
|
||||
pub mod context;
|
||||
pub mod coordinator;
|
||||
pub mod graph;
|
||||
pub mod task_executor;
|
||||
pub mod template;
|
||||
|
||||
// Re-export workflow utilities from common crate
|
||||
pub use attune_common::workflow::{
|
||||
parse_workflow_file, parse_workflow_yaml, workflow_to_json, BackoffStrategy, DecisionBranch,
|
||||
LoadedWorkflow, LoaderConfig, ParseError, ParseResult, PublishDirective, RegistrationOptions,
|
||||
RegistrationResult, RetryConfig, Task, TaskType, ValidationError, ValidationResult,
|
||||
WorkflowDefinition, WorkflowFile, WorkflowLoader, WorkflowRegistrar, WorkflowValidator,
|
||||
};
|
||||
|
||||
// Re-export Phase 2 components
|
||||
pub use context::{ContextError, ContextResult, WorkflowContext};
|
||||
pub use coordinator::{
|
||||
WorkflowCoordinator, WorkflowExecutionHandle, WorkflowExecutionResult, WorkflowExecutionState,
|
||||
WorkflowExecutionStatus,
|
||||
};
|
||||
pub use graph::{GraphError, GraphResult, GraphTransition, TaskGraph, TaskNode};
|
||||
pub use task_executor::{
|
||||
TaskExecutionError, TaskExecutionResult, TaskExecutionStatus, TaskExecutor,
|
||||
};
|
||||
pub use template::{TemplateEngine, TemplateError, TemplateResult, VariableContext, VariableScope};
|
||||
|
||||
@@ -1,871 +0,0 @@
|
||||
//! Task Executor
|
||||
//!
|
||||
//! This module handles the execution of individual workflow tasks,
|
||||
//! including action invocation, retries, timeouts, and with-items iteration.
|
||||
|
||||
use crate::workflow::context::WorkflowContext;
|
||||
use crate::workflow::graph::{BackoffStrategy, RetryConfig, TaskNode};
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::Id;
|
||||
use attune_common::mq::MessageQueue;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use sqlx::PgPool;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Task execution result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskExecutionResult {
|
||||
/// Execution status
|
||||
pub status: TaskExecutionStatus,
|
||||
|
||||
/// Task output/result
|
||||
pub output: Option<JsonValue>,
|
||||
|
||||
/// Error information
|
||||
pub error: Option<TaskExecutionError>,
|
||||
|
||||
/// Execution duration in milliseconds
|
||||
pub duration_ms: i64,
|
||||
|
||||
/// Whether the task should be retried
|
||||
pub should_retry: bool,
|
||||
|
||||
/// Next retry time (if applicable)
|
||||
pub next_retry_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Number of retries performed
|
||||
pub retry_count: i32,
|
||||
}
|
||||
|
||||
/// Task execution status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TaskExecutionStatus {
|
||||
Success,
|
||||
Failed,
|
||||
Timeout,
|
||||
Skipped,
|
||||
}
|
||||
|
||||
/// Task execution error
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskExecutionError {
|
||||
pub message: String,
|
||||
pub error_type: String,
|
||||
pub details: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Task executor
|
||||
pub struct TaskExecutor {
|
||||
db_pool: PgPool,
|
||||
mq: MessageQueue,
|
||||
}
|
||||
|
||||
impl TaskExecutor {
|
||||
/// Create a new task executor
|
||||
pub fn new(db_pool: PgPool, mq: MessageQueue) -> Self {
|
||||
Self { db_pool, mq }
|
||||
}
|
||||
|
||||
/// Execute a task
|
||||
pub async fn execute_task(
|
||||
&self,
|
||||
task: &TaskNode,
|
||||
context: &mut WorkflowContext,
|
||||
workflow_execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
) -> Result<TaskExecutionResult> {
|
||||
info!("Executing task: {}", task.name);
|
||||
|
||||
let start_time = Utc::now();
|
||||
|
||||
// Check if task should be skipped (when condition)
|
||||
if let Some(ref condition) = task.when {
|
||||
match context.evaluate_condition(condition) {
|
||||
Ok(should_run) => {
|
||||
if !should_run {
|
||||
info!("Task {} skipped due to when condition", task.name);
|
||||
return Ok(TaskExecutionResult {
|
||||
status: TaskExecutionStatus::Skipped,
|
||||
output: None,
|
||||
error: None,
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to evaluate when condition for task {}: {}",
|
||||
task.name, e
|
||||
);
|
||||
// Continue execution if condition evaluation fails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a with-items task
|
||||
if let Some(ref with_items_expr) = task.with_items {
|
||||
return self
|
||||
.execute_with_items(
|
||||
task,
|
||||
context,
|
||||
workflow_execution_id,
|
||||
parent_execution_id,
|
||||
with_items_expr,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Execute single task
|
||||
let result = self
|
||||
.execute_single_task(task, context, workflow_execution_id, parent_execution_id, 0)
|
||||
.await?;
|
||||
|
||||
let duration_ms = (Utc::now() - start_time).num_milliseconds();
|
||||
|
||||
// Store task result in context
|
||||
if let Some(ref output) = result.output {
|
||||
context.set_task_result(&task.name, output.clone());
|
||||
|
||||
// Publish variables from matching transitions
|
||||
let success = matches!(result.status, TaskExecutionStatus::Success);
|
||||
for transition in &task.transitions {
|
||||
let should_fire = match transition.kind() {
|
||||
super::graph::TransitionKind::Succeeded => success,
|
||||
super::graph::TransitionKind::Failed => !success,
|
||||
super::graph::TransitionKind::TimedOut => !success,
|
||||
super::graph::TransitionKind::Always => true,
|
||||
super::graph::TransitionKind::Custom => true,
|
||||
};
|
||||
if should_fire && !transition.publish.is_empty() {
|
||||
let var_names: Vec<String> =
|
||||
transition.publish.iter().map(|p| p.name.clone()).collect();
|
||||
if let Err(e) = context.publish_from_result(output, &var_names, None) {
|
||||
warn!("Failed to publish variables for task {}: {}", task.name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TaskExecutionResult {
|
||||
duration_ms,
|
||||
..result
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a single task (without with-items iteration)
|
||||
async fn execute_single_task(
|
||||
&self,
|
||||
task: &TaskNode,
|
||||
context: &WorkflowContext,
|
||||
workflow_execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
retry_count: i32,
|
||||
) -> Result<TaskExecutionResult> {
|
||||
let start_time = Utc::now();
|
||||
|
||||
// Render task input
|
||||
let input = match context.render_json(&task.input) {
|
||||
Ok(rendered) => rendered,
|
||||
Err(e) => {
|
||||
error!("Failed to render task input for {}: {}", task.name, e);
|
||||
return Ok(TaskExecutionResult {
|
||||
status: TaskExecutionStatus::Failed,
|
||||
output: None,
|
||||
error: Some(TaskExecutionError {
|
||||
message: format!("Failed to render task input: {}", e),
|
||||
error_type: "template_error".to_string(),
|
||||
details: None,
|
||||
}),
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Execute based on task type
|
||||
let result = match task.task_type {
|
||||
attune_common::workflow::TaskType::Action => {
|
||||
self.execute_action(task, input, workflow_execution_id, parent_execution_id)
|
||||
.await
|
||||
}
|
||||
attune_common::workflow::TaskType::Parallel => {
|
||||
self.execute_parallel(task, context, workflow_execution_id, parent_execution_id)
|
||||
.await
|
||||
}
|
||||
attune_common::workflow::TaskType::Workflow => {
|
||||
self.execute_workflow(task, input, workflow_execution_id, parent_execution_id)
|
||||
.await
|
||||
}
|
||||
};
|
||||
|
||||
let duration_ms = (Utc::now() - start_time).num_milliseconds();
|
||||
|
||||
// Apply timeout if specified
|
||||
let result = if let Some(timeout_secs) = task.timeout {
|
||||
self.apply_timeout(result, timeout_secs).await
|
||||
} else {
|
||||
result
|
||||
};
|
||||
|
||||
// Handle retries
|
||||
let mut result = result?;
|
||||
result.retry_count = retry_count;
|
||||
|
||||
if result.status == TaskExecutionStatus::Failed {
|
||||
if let Some(ref retry_config) = task.retry {
|
||||
if retry_count < retry_config.count as i32 {
|
||||
// Check if we should retry based on error condition
|
||||
let should_retry = if let Some(ref _on_error) = retry_config.on_error {
|
||||
// TODO: Evaluate error condition
|
||||
true
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if should_retry {
|
||||
result.should_retry = true;
|
||||
result.next_retry_at =
|
||||
Some(calculate_retry_time(retry_config, retry_count));
|
||||
info!(
|
||||
"Task {} failed, will retry (attempt {}/{})",
|
||||
task.name,
|
||||
retry_count + 1,
|
||||
retry_config.count
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.duration_ms = duration_ms;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Execute an action task
|
||||
async fn execute_action(
|
||||
&self,
|
||||
task: &TaskNode,
|
||||
input: JsonValue,
|
||||
_workflow_execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
) -> Result<TaskExecutionResult> {
|
||||
let action_ref = match &task.action {
|
||||
Some(action) => action,
|
||||
None => {
|
||||
return Ok(TaskExecutionResult {
|
||||
status: TaskExecutionStatus::Failed,
|
||||
output: None,
|
||||
error: Some(TaskExecutionError {
|
||||
message: "Action task missing action reference".to_string(),
|
||||
error_type: "configuration_error".to_string(),
|
||||
details: None,
|
||||
}),
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Executing action: {} with input: {:?}", action_ref, input);
|
||||
|
||||
// Create execution record in database
|
||||
let execution = sqlx::query_as::<_, attune_common::models::Execution>(
|
||||
r#"
|
||||
INSERT INTO attune.execution (action_ref, input, parent, status)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(action_ref)
|
||||
.bind(&input)
|
||||
.bind(parent_execution_id)
|
||||
.bind(attune_common::models::ExecutionStatus::Scheduled)
|
||||
.fetch_one(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
// Queue action for execution by worker
|
||||
// TODO: Implement proper message queue publishing
|
||||
info!(
|
||||
"Created action execution {} for task {} (queuing not yet implemented)",
|
||||
execution.id, task.name
|
||||
);
|
||||
|
||||
// For now, return pending status
|
||||
// In a real implementation, we would wait for completion via message queue
|
||||
Ok(TaskExecutionResult {
|
||||
status: TaskExecutionStatus::Success,
|
||||
output: Some(json!({
|
||||
"execution_id": execution.id,
|
||||
"status": "queued"
|
||||
})),
|
||||
error: None,
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute parallel tasks
|
||||
async fn execute_parallel(
|
||||
&self,
|
||||
task: &TaskNode,
|
||||
context: &WorkflowContext,
|
||||
workflow_execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
) -> Result<TaskExecutionResult> {
|
||||
let sub_tasks = match &task.sub_tasks {
|
||||
Some(tasks) => tasks,
|
||||
None => {
|
||||
return Ok(TaskExecutionResult {
|
||||
status: TaskExecutionStatus::Failed,
|
||||
output: None,
|
||||
error: Some(TaskExecutionError {
|
||||
message: "Parallel task missing sub-tasks".to_string(),
|
||||
error_type: "configuration_error".to_string(),
|
||||
details: None,
|
||||
}),
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
info!("Executing {} parallel tasks", sub_tasks.len());
|
||||
|
||||
// Execute all sub-tasks in parallel
|
||||
let mut futures = Vec::new();
|
||||
|
||||
for subtask in sub_tasks {
|
||||
let subtask_clone = subtask.clone();
|
||||
let subtask_name = subtask.name.clone();
|
||||
let context = context.clone();
|
||||
let db_pool = self.db_pool.clone();
|
||||
let mq = self.mq.clone();
|
||||
|
||||
let future = async move {
|
||||
let executor = TaskExecutor::new(db_pool, mq);
|
||||
let result = executor
|
||||
.execute_single_task(
|
||||
&subtask_clone,
|
||||
&context,
|
||||
workflow_execution_id,
|
||||
parent_execution_id,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
(subtask_name, result)
|
||||
};
|
||||
|
||||
futures.push(future);
|
||||
}
|
||||
|
||||
// Wait for all tasks to complete
|
||||
let task_results = futures::future::join_all(futures).await;
|
||||
|
||||
let mut results = Vec::new();
|
||||
let mut all_succeeded = true;
|
||||
let mut errors = Vec::new();
|
||||
|
||||
for (task_name, result) in task_results {
|
||||
match result {
|
||||
Ok(result) => {
|
||||
if result.status != TaskExecutionStatus::Success {
|
||||
all_succeeded = false;
|
||||
if let Some(error) = &result.error {
|
||||
errors.push(json!({
|
||||
"task": task_name,
|
||||
"error": error.message
|
||||
}));
|
||||
}
|
||||
}
|
||||
results.push(json!({
|
||||
"task": task_name,
|
||||
"status": format!("{:?}", result.status),
|
||||
"output": result.output
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
all_succeeded = false;
|
||||
errors.push(json!({
|
||||
"task": task_name,
|
||||
"error": e.to_string()
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let status = if all_succeeded {
|
||||
TaskExecutionStatus::Success
|
||||
} else {
|
||||
TaskExecutionStatus::Failed
|
||||
};
|
||||
|
||||
Ok(TaskExecutionResult {
|
||||
status,
|
||||
output: Some(json!({
|
||||
"results": results
|
||||
})),
|
||||
error: if errors.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(TaskExecutionError {
|
||||
message: format!("{} parallel tasks failed", errors.len()),
|
||||
error_type: "parallel_execution_error".to_string(),
|
||||
details: Some(json!({"errors": errors})),
|
||||
})
|
||||
},
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a workflow task (nested workflow)
|
||||
async fn execute_workflow(
|
||||
&self,
|
||||
_task: &TaskNode,
|
||||
_input: JsonValue,
|
||||
_workflow_execution_id: Id,
|
||||
_parent_execution_id: Id,
|
||||
) -> Result<TaskExecutionResult> {
|
||||
// TODO: Implement nested workflow execution
|
||||
// For now, return not implemented
|
||||
warn!("Workflow task execution not yet implemented");
|
||||
|
||||
Ok(TaskExecutionResult {
|
||||
status: TaskExecutionStatus::Failed,
|
||||
output: None,
|
||||
error: Some(TaskExecutionError {
|
||||
message: "Nested workflow execution not yet implemented".to_string(),
|
||||
error_type: "not_implemented".to_string(),
|
||||
details: None,
|
||||
}),
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute task with with-items iteration
|
||||
async fn execute_with_items(
|
||||
&self,
|
||||
task: &TaskNode,
|
||||
context: &mut WorkflowContext,
|
||||
workflow_execution_id: Id,
|
||||
parent_execution_id: Id,
|
||||
items_expr: &str,
|
||||
) -> Result<TaskExecutionResult> {
|
||||
// Render items expression
|
||||
let items_str = context.render_template(items_expr).map_err(|e| {
|
||||
Error::validation(format!("Failed to render with-items expression: {}", e))
|
||||
})?;
|
||||
|
||||
// Parse items (should be a JSON array)
|
||||
let items: Vec<JsonValue> = serde_json::from_str(&items_str).map_err(|e| {
|
||||
Error::validation(format!(
|
||||
"with-items expression did not produce valid JSON array: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
info!("Executing task {} with {} items", task.name, items.len());
|
||||
|
||||
let items_len = items.len(); // Store length before consuming items
|
||||
let concurrency = task.concurrency.unwrap_or(10);
|
||||
|
||||
let mut all_results = Vec::new();
|
||||
let mut all_succeeded = true;
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Check if batch processing is enabled
|
||||
if let Some(batch_size) = task.batch_size {
|
||||
// Batch mode: split items into batches and pass as arrays
|
||||
debug!(
|
||||
"Processing {} items in batches of {} (batch mode)",
|
||||
items.len(),
|
||||
batch_size
|
||||
);
|
||||
|
||||
let batches: Vec<Vec<JsonValue>> = items
|
||||
.chunks(batch_size)
|
||||
.map(|chunk| chunk.to_vec())
|
||||
.collect();
|
||||
|
||||
debug!("Created {} batches", batches.len());
|
||||
|
||||
// Execute batches with concurrency limit
|
||||
let mut handles = Vec::new();
|
||||
let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(concurrency));
|
||||
|
||||
for (batch_idx, batch) in batches.into_iter().enumerate() {
|
||||
let permit = semaphore.clone().acquire_owned().await.unwrap();
|
||||
|
||||
let executor = TaskExecutor::new(self.db_pool.clone(), self.mq.clone());
|
||||
let task = task.clone();
|
||||
let mut batch_context = context.clone();
|
||||
|
||||
// Set current_item to the batch array
|
||||
batch_context.set_current_item(json!(batch), batch_idx);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = executor
|
||||
.execute_single_task(
|
||||
&task,
|
||||
&batch_context,
|
||||
workflow_execution_id,
|
||||
parent_execution_id,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
drop(permit);
|
||||
(batch_idx, result)
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all batches to complete
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok((batch_idx, Ok(result))) => {
|
||||
if result.status != TaskExecutionStatus::Success {
|
||||
all_succeeded = false;
|
||||
if let Some(error) = &result.error {
|
||||
errors.push(json!({
|
||||
"batch": batch_idx,
|
||||
"error": error.message
|
||||
}));
|
||||
}
|
||||
}
|
||||
all_results.push(json!({
|
||||
"batch": batch_idx,
|
||||
"status": format!("{:?}", result.status),
|
||||
"output": result.output
|
||||
}));
|
||||
}
|
||||
Ok((batch_idx, Err(e))) => {
|
||||
all_succeeded = false;
|
||||
errors.push(json!({
|
||||
"batch": batch_idx,
|
||||
"error": e.to_string()
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
all_succeeded = false;
|
||||
errors.push(json!({
|
||||
"error": format!("Task panicked: {}", e)
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Individual mode: process each item separately
|
||||
debug!(
|
||||
"Processing {} items individually (no batch_size specified)",
|
||||
items.len()
|
||||
);
|
||||
|
||||
// Execute items with concurrency limit
|
||||
let mut handles = Vec::new();
|
||||
let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(concurrency));
|
||||
|
||||
for (item_idx, item) in items.into_iter().enumerate() {
|
||||
let permit = semaphore.clone().acquire_owned().await.unwrap();
|
||||
|
||||
let executor = TaskExecutor::new(self.db_pool.clone(), self.mq.clone());
|
||||
let task = task.clone();
|
||||
let mut item_context = context.clone();
|
||||
|
||||
// Set current_item to the individual item
|
||||
item_context.set_current_item(item, item_idx);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = executor
|
||||
.execute_single_task(
|
||||
&task,
|
||||
&item_context,
|
||||
workflow_execution_id,
|
||||
parent_execution_id,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
drop(permit);
|
||||
(item_idx, result)
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all items to complete
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok((idx, Ok(result))) => {
|
||||
if result.status != TaskExecutionStatus::Success {
|
||||
all_succeeded = false;
|
||||
if let Some(error) = &result.error {
|
||||
errors.push(json!({
|
||||
"index": idx,
|
||||
"error": error.message
|
||||
}));
|
||||
}
|
||||
}
|
||||
all_results.push(json!({
|
||||
"index": idx,
|
||||
"status": format!("{:?}", result.status),
|
||||
"output": result.output
|
||||
}));
|
||||
}
|
||||
Ok((idx, Err(e))) => {
|
||||
all_succeeded = false;
|
||||
errors.push(json!({
|
||||
"index": idx,
|
||||
"error": e.to_string()
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
all_succeeded = false;
|
||||
errors.push(json!({
|
||||
"error": format!("Task panicked: {}", e)
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
context.clear_current_item();
|
||||
|
||||
let status = if all_succeeded {
|
||||
TaskExecutionStatus::Success
|
||||
} else {
|
||||
TaskExecutionStatus::Failed
|
||||
};
|
||||
|
||||
Ok(TaskExecutionResult {
|
||||
status,
|
||||
output: Some(json!({
|
||||
"results": all_results,
|
||||
"total": items_len
|
||||
})),
|
||||
error: if errors.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(TaskExecutionError {
|
||||
message: format!("{} items failed", errors.len()),
|
||||
error_type: "with_items_error".to_string(),
|
||||
details: Some(json!({"errors": errors})),
|
||||
})
|
||||
},
|
||||
duration_ms: 0,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply timeout to task execution
|
||||
async fn apply_timeout(
|
||||
&self,
|
||||
result_future: Result<TaskExecutionResult>,
|
||||
timeout_secs: u32,
|
||||
) -> Result<TaskExecutionResult> {
|
||||
match timeout(Duration::from_secs(timeout_secs as u64), async {
|
||||
result_future
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
warn!("Task execution timed out after {} seconds", timeout_secs);
|
||||
Ok(TaskExecutionResult {
|
||||
status: TaskExecutionStatus::Timeout,
|
||||
output: None,
|
||||
error: Some(TaskExecutionError {
|
||||
message: format!("Task timed out after {} seconds", timeout_secs),
|
||||
error_type: "timeout".to_string(),
|
||||
details: None,
|
||||
}),
|
||||
duration_ms: (timeout_secs * 1000) as i64,
|
||||
should_retry: false,
|
||||
next_retry_at: None,
|
||||
retry_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate next retry time based on retry configuration
|
||||
fn calculate_retry_time(config: &RetryConfig, retry_count: i32) -> DateTime<Utc> {
|
||||
let delay_secs = match config.backoff {
|
||||
BackoffStrategy::Constant => config.delay,
|
||||
BackoffStrategy::Linear => config.delay * (retry_count as u32 + 1),
|
||||
BackoffStrategy::Exponential => {
|
||||
let exp_delay = config.delay * 2_u32.pow(retry_count as u32);
|
||||
if let Some(max_delay) = config.max_delay {
|
||||
exp_delay.min(max_delay)
|
||||
} else {
|
||||
exp_delay
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Utc::now() + chrono::Duration::seconds(delay_secs as i64)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_calculate_retry_time_constant() {
|
||||
let config = RetryConfig {
|
||||
count: 3,
|
||||
delay: 10,
|
||||
backoff: BackoffStrategy::Constant,
|
||||
max_delay: None,
|
||||
on_error: None,
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
let retry_time = calculate_retry_time(&config, 0);
|
||||
let diff = (retry_time - now).num_seconds();
|
||||
|
||||
assert!(diff >= 9 && diff <= 11); // Allow 1 second tolerance
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_retry_time_exponential() {
|
||||
let config = RetryConfig {
|
||||
count: 3,
|
||||
delay: 10,
|
||||
backoff: BackoffStrategy::Exponential,
|
||||
max_delay: Some(100),
|
||||
on_error: None,
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
// First retry: 10 * 2^0 = 10
|
||||
let retry1 = calculate_retry_time(&config, 0);
|
||||
assert!((retry1 - now).num_seconds() >= 9 && (retry1 - now).num_seconds() <= 11);
|
||||
|
||||
// Second retry: 10 * 2^1 = 20
|
||||
let retry2 = calculate_retry_time(&config, 1);
|
||||
assert!((retry2 - now).num_seconds() >= 19 && (retry2 - now).num_seconds() <= 21);
|
||||
|
||||
// Third retry: 10 * 2^2 = 40
|
||||
let retry3 = calculate_retry_time(&config, 2);
|
||||
assert!((retry3 - now).num_seconds() >= 39 && (retry3 - now).num_seconds() <= 41);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_retry_time_exponential_with_max() {
|
||||
let config = RetryConfig {
|
||||
count: 10,
|
||||
delay: 10,
|
||||
backoff: BackoffStrategy::Exponential,
|
||||
max_delay: Some(100),
|
||||
on_error: None,
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
// Retry with high count should be capped at max_delay
|
||||
let retry = calculate_retry_time(&config, 10);
|
||||
assert!((retry - now).num_seconds() >= 99 && (retry - now).num_seconds() <= 101);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_items_batch_creation() {
|
||||
use serde_json::json;
|
||||
|
||||
// Test batch_size=3 with 7 items
|
||||
let items = vec![
|
||||
json!({"id": 1}),
|
||||
json!({"id": 2}),
|
||||
json!({"id": 3}),
|
||||
json!({"id": 4}),
|
||||
json!({"id": 5}),
|
||||
json!({"id": 6}),
|
||||
json!({"id": 7}),
|
||||
];
|
||||
|
||||
let batch_size = 3;
|
||||
let batches: Vec<Vec<JsonValue>> = items
|
||||
.chunks(batch_size)
|
||||
.map(|chunk| chunk.to_vec())
|
||||
.collect();
|
||||
|
||||
// Should create 3 batches: [1,2,3], [4,5,6], [7]
|
||||
assert_eq!(batches.len(), 3);
|
||||
assert_eq!(batches[0].len(), 3);
|
||||
assert_eq!(batches[1].len(), 3);
|
||||
assert_eq!(batches[2].len(), 1); // Last batch can be smaller
|
||||
|
||||
// Verify content - batches are arrays
|
||||
assert_eq!(batches[0][0], json!({"id": 1}));
|
||||
assert_eq!(batches[2][0], json!({"id": 7}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_items_no_batch_size_individual_processing() {
|
||||
use serde_json::json;
|
||||
|
||||
// Without batch_size, items are processed individually
|
||||
let items = vec![json!({"id": 1}), json!({"id": 2}), json!({"id": 3})];
|
||||
|
||||
// Each item should be processed separately (not as batches)
|
||||
assert_eq!(items.len(), 3);
|
||||
|
||||
// Verify individual items
|
||||
assert_eq!(items[0], json!({"id": 1}));
|
||||
assert_eq!(items[1], json!({"id": 2}));
|
||||
assert_eq!(items[2], json!({"id": 3}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_items_batch_vs_individual() {
|
||||
use serde_json::json;
|
||||
|
||||
let items = vec![json!({"id": 1}), json!({"id": 2}), json!({"id": 3})];
|
||||
|
||||
// With batch_size: items are grouped into batches (arrays)
|
||||
let batch_size = Some(2);
|
||||
if let Some(bs) = batch_size {
|
||||
let batches: Vec<Vec<JsonValue>> = items
|
||||
.clone()
|
||||
.chunks(bs)
|
||||
.map(|chunk| chunk.to_vec())
|
||||
.collect();
|
||||
|
||||
// 2 batches: [1,2], [3]
|
||||
assert_eq!(batches.len(), 2);
|
||||
assert_eq!(batches[0], vec![json!({"id": 1}), json!({"id": 2})]);
|
||||
assert_eq!(batches[1], vec![json!({"id": 3})]);
|
||||
}
|
||||
|
||||
// Without batch_size: items processed individually
|
||||
let batch_size: Option<usize> = None;
|
||||
if batch_size.is_none() {
|
||||
// Each item is a single value, not wrapped in array
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
assert_eq!(item["id"], idx + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,360 +0,0 @@
|
||||
//! Template engine for workflow variable interpolation
|
||||
//!
|
||||
//! This module provides template rendering using Tera (Jinja2-like syntax)
|
||||
//! with support for multi-scope variable contexts.
|
||||
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::HashMap;
|
||||
use tera::{Context, Tera};
|
||||
|
||||
/// Result type for template operations
|
||||
pub type TemplateResult<T> = Result<T, TemplateError>;
|
||||
|
||||
/// Errors that can occur during template rendering
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TemplateError {
|
||||
#[error("Template rendering error: {0}")]
|
||||
RenderError(#[from] tera::Error),
|
||||
|
||||
#[error("Invalid template syntax: {0}")]
|
||||
SyntaxError(String),
|
||||
|
||||
#[error("Variable not found: {0}")]
|
||||
VariableNotFound(String),
|
||||
|
||||
#[error("JSON serialization error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Invalid scope: {0}")]
|
||||
InvalidScope(String),
|
||||
}
|
||||
|
||||
/// Variable scope priority (higher number = higher priority)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum VariableScope {
|
||||
/// System-level variables (lowest priority)
|
||||
System = 1,
|
||||
/// Key-value store variables
|
||||
KeyValue = 2,
|
||||
/// Pack configuration
|
||||
PackConfig = 3,
|
||||
/// Workflow parameters (input)
|
||||
Parameters = 4,
|
||||
/// Workflow vars (defined in workflow)
|
||||
Vars = 5,
|
||||
/// Task-specific variables (highest priority)
|
||||
Task = 6,
|
||||
}
|
||||
|
||||
/// Template engine with multi-scope variable context
|
||||
pub struct TemplateEngine {
|
||||
// Note: We can't use custom filters with Tera::one_off, so we need to keep tera instance
|
||||
// But Tera doesn't expose a way to register templates without files in the new() constructor
|
||||
// So we'll just use one_off for now and skip custom filters in basic rendering
|
||||
}
|
||||
|
||||
impl Default for TemplateEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TemplateEngine {
|
||||
/// Create a new template engine
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
/// Render a template string with the given context
|
||||
pub fn render(&self, template: &str, context: &VariableContext) -> TemplateResult<String> {
|
||||
let tera_context = context.to_tera_context()?;
|
||||
|
||||
// Use one-off template rendering
|
||||
// Note: Custom filters are not supported with one_off rendering
|
||||
Tera::one_off(template, &tera_context, true).map_err(TemplateError::from)
|
||||
}
|
||||
|
||||
/// Render a template and parse result as JSON
|
||||
pub fn render_json(
|
||||
&self,
|
||||
template: &str,
|
||||
context: &VariableContext,
|
||||
) -> TemplateResult<JsonValue> {
|
||||
let rendered = self.render(template, context)?;
|
||||
serde_json::from_str(&rendered).map_err(TemplateError::from)
|
||||
}
|
||||
|
||||
/// Check if a template string contains valid syntax
|
||||
pub fn validate_template(&self, template: &str) -> TemplateResult<()> {
|
||||
Tera::one_off(template, &Context::new(), true)
|
||||
.map(|_| ())
|
||||
.map_err(TemplateError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-scope variable context for template rendering
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VariableContext {
|
||||
/// System-level variables
|
||||
system: HashMap<String, JsonValue>,
|
||||
/// Key-value store variables
|
||||
kv: HashMap<String, JsonValue>,
|
||||
/// Pack configuration
|
||||
pack_config: HashMap<String, JsonValue>,
|
||||
/// Workflow parameters (input)
|
||||
parameters: HashMap<String, JsonValue>,
|
||||
/// Workflow vars
|
||||
vars: HashMap<String, JsonValue>,
|
||||
/// Task results and metadata
|
||||
task: HashMap<String, JsonValue>,
|
||||
}
|
||||
|
||||
impl Default for VariableContext {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl VariableContext {
|
||||
/// Create a new empty variable context
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
system: HashMap::new(),
|
||||
kv: HashMap::new(),
|
||||
pack_config: HashMap::new(),
|
||||
parameters: HashMap::new(),
|
||||
vars: HashMap::new(),
|
||||
task: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set system variables
|
||||
pub fn with_system(mut self, vars: HashMap<String, JsonValue>) -> Self {
|
||||
self.system = vars;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set key-value store variables
|
||||
pub fn with_kv(mut self, vars: HashMap<String, JsonValue>) -> Self {
|
||||
self.kv = vars;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pack configuration
|
||||
pub fn with_pack_config(mut self, config: HashMap<String, JsonValue>) -> Self {
|
||||
self.pack_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set workflow parameters
|
||||
pub fn with_parameters(mut self, params: HashMap<String, JsonValue>) -> Self {
|
||||
self.parameters = params;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set workflow vars
|
||||
pub fn with_vars(mut self, vars: HashMap<String, JsonValue>) -> Self {
|
||||
self.vars = vars;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set task variables
|
||||
pub fn with_task(mut self, task_vars: HashMap<String, JsonValue>) -> Self {
|
||||
self.task = task_vars;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a single variable to a scope
|
||||
pub fn set(&mut self, scope: VariableScope, key: String, value: JsonValue) {
|
||||
match scope {
|
||||
VariableScope::System => self.system.insert(key, value),
|
||||
VariableScope::KeyValue => self.kv.insert(key, value),
|
||||
VariableScope::PackConfig => self.pack_config.insert(key, value),
|
||||
VariableScope::Parameters => self.parameters.insert(key, value),
|
||||
VariableScope::Vars => self.vars.insert(key, value),
|
||||
VariableScope::Task => self.task.insert(key, value),
|
||||
};
|
||||
}
|
||||
|
||||
/// Get a variable from any scope (respects priority)
|
||||
pub fn get(&self, key: &str) -> Option<&JsonValue> {
|
||||
// Check scopes in priority order (highest to lowest)
|
||||
self.task
|
||||
.get(key)
|
||||
.or_else(|| self.vars.get(key))
|
||||
.or_else(|| self.parameters.get(key))
|
||||
.or_else(|| self.pack_config.get(key))
|
||||
.or_else(|| self.kv.get(key))
|
||||
.or_else(|| self.system.get(key))
|
||||
}
|
||||
|
||||
/// Convert to Tera context for rendering
|
||||
pub fn to_tera_context(&self) -> TemplateResult<Context> {
|
||||
let mut context = Context::new();
|
||||
|
||||
// Insert scopes as nested objects
|
||||
context.insert("system", &self.system);
|
||||
context.insert("kv", &self.kv);
|
||||
context.insert("pack", &serde_json::json!({ "config": self.pack_config }));
|
||||
context.insert("parameters", &self.parameters);
|
||||
context.insert("vars", &self.vars);
|
||||
context.insert("task", &self.task);
|
||||
|
||||
Ok(context)
|
||||
}
|
||||
|
||||
/// Merge another context into this one (preserves priority)
|
||||
pub fn merge(&mut self, other: &VariableContext) {
|
||||
self.system.extend(other.system.clone());
|
||||
self.kv.extend(other.kv.clone());
|
||||
self.pack_config.extend(other.pack_config.clone());
|
||||
self.parameters.extend(other.parameters.clone());
|
||||
self.vars.extend(other.vars.clone());
|
||||
self.task.extend(other.task.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_basic_template_rendering() {
|
||||
let engine = TemplateEngine::new();
|
||||
let mut context = VariableContext::new();
|
||||
context.set(
|
||||
VariableScope::Parameters,
|
||||
"name".to_string(),
|
||||
json!("World"),
|
||||
);
|
||||
|
||||
let result = engine.render("Hello {{ parameters.name }}!", &context);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "Hello World!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scope_priority() {
|
||||
let engine = TemplateEngine::new();
|
||||
let mut context = VariableContext::new();
|
||||
|
||||
// Set same variable in multiple scopes
|
||||
context.set(VariableScope::System, "value".to_string(), json!("system"));
|
||||
context.set(VariableScope::Vars, "value".to_string(), json!("vars"));
|
||||
context.set(VariableScope::Task, "value".to_string(), json!("task"));
|
||||
|
||||
// Task scope should win (highest priority)
|
||||
let result = engine.render("{{ task.value }}", &context);
|
||||
assert_eq!(result.unwrap(), "task");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_variables() {
|
||||
let engine = TemplateEngine::new();
|
||||
let mut context = VariableContext::new();
|
||||
context.set(
|
||||
VariableScope::Parameters,
|
||||
"config".to_string(),
|
||||
json!({"database": {"host": "localhost", "port": 5432}}),
|
||||
);
|
||||
|
||||
let result = engine.render(
|
||||
"postgres://{{ parameters.config.database.host }}:{{ parameters.config.database.port }}",
|
||||
&context,
|
||||
);
|
||||
assert_eq!(result.unwrap(), "postgres://localhost:5432");
|
||||
}
|
||||
|
||||
// Note: Custom filter tests are disabled since we're using Tera::one_off
|
||||
// which doesn't support custom filters. In production, we would need to
|
||||
// use a pre-configured Tera instance with templates registered.
|
||||
|
||||
#[test]
|
||||
fn test_json_operations() {
|
||||
let engine = TemplateEngine::new();
|
||||
let mut context = VariableContext::new();
|
||||
context.set(
|
||||
VariableScope::Parameters,
|
||||
"data".to_string(),
|
||||
json!({"key": "value"}),
|
||||
);
|
||||
|
||||
// Test accessing JSON properties
|
||||
let result = engine.render("{{ parameters.data.key }}", &context);
|
||||
assert_eq!(result.unwrap(), "value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conditional_rendering() {
|
||||
let engine = TemplateEngine::new();
|
||||
let mut context = VariableContext::new();
|
||||
context.set(
|
||||
VariableScope::Parameters,
|
||||
"env".to_string(),
|
||||
json!("production"),
|
||||
);
|
||||
|
||||
let result = engine.render(
|
||||
"{% if parameters.env == 'production' %}prod{% else %}dev{% endif %}",
|
||||
&context,
|
||||
);
|
||||
assert_eq!(result.unwrap(), "prod");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loop_rendering() {
|
||||
let engine = TemplateEngine::new();
|
||||
let mut context = VariableContext::new();
|
||||
context.set(
|
||||
VariableScope::Parameters,
|
||||
"items".to_string(),
|
||||
json!(["a", "b", "c"]),
|
||||
);
|
||||
|
||||
let result = engine.render(
|
||||
"{% for item in parameters.items %}{{ item }}{% endfor %}",
|
||||
&context,
|
||||
);
|
||||
assert_eq!(result.unwrap(), "abc");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_merge() {
|
||||
let mut ctx1 = VariableContext::new();
|
||||
ctx1.set(VariableScope::Vars, "a".to_string(), json!(1));
|
||||
ctx1.set(VariableScope::Vars, "b".to_string(), json!(2));
|
||||
|
||||
let mut ctx2 = VariableContext::new();
|
||||
ctx2.set(VariableScope::Vars, "b".to_string(), json!(3));
|
||||
ctx2.set(VariableScope::Vars, "c".to_string(), json!(4));
|
||||
|
||||
ctx1.merge(&ctx2);
|
||||
|
||||
assert_eq!(ctx1.get("a"), Some(&json!(1)));
|
||||
assert_eq!(ctx1.get("b"), Some(&json!(3))); // ctx2 overwrites
|
||||
assert_eq!(ctx1.get("c"), Some(&json!(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_scopes() {
|
||||
let engine = TemplateEngine::new();
|
||||
let context = VariableContext::new()
|
||||
.with_system(HashMap::from([("sys_var".to_string(), json!("system"))]))
|
||||
.with_kv(HashMap::from([("kv_var".to_string(), json!("keyvalue"))]))
|
||||
.with_pack_config(HashMap::from([("setting".to_string(), json!("config"))]))
|
||||
.with_parameters(HashMap::from([("param".to_string(), json!("parameter"))]))
|
||||
.with_vars(HashMap::from([("var".to_string(), json!("variable"))]))
|
||||
.with_task(HashMap::from([(
|
||||
"result".to_string(),
|
||||
json!("task_result"),
|
||||
)]));
|
||||
|
||||
let template = "{{ system.sys_var }}-{{ kv.kv_var }}-{{ pack.config.setting }}-{{ parameters.param }}-{{ vars.var }}-{{ task.result }}";
|
||||
let result = engine.render(template, &context);
|
||||
assert_eq!(
|
||||
result.unwrap(),
|
||||
"system-keyvalue-config-parameter-variable-task_result"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -33,5 +33,6 @@ aes-gcm = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
//! so the `ProcessRuntime` uses version-specific interpreter binaries,
|
||||
//! environment commands, etc.
|
||||
|
||||
use attune_common::auth::jwt::{generate_execution_token, JwtConfig};
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::runtime::RuntimeExecutionConfig;
|
||||
use attune_common::models::{runtime::Runtime as RuntimeModel, Action, Execution, ExecutionStatus};
|
||||
@@ -42,6 +43,18 @@ pub struct ActionExecutor {
|
||||
max_stderr_bytes: usize,
|
||||
packs_base_dir: PathBuf,
|
||||
api_url: String,
|
||||
jwt_config: JwtConfig,
|
||||
}
|
||||
|
||||
/// Normalize a server bind address into a connectable URL.
|
||||
///
|
||||
/// When the server binds to `0.0.0.0` (all interfaces) or `::` (IPv6 any),
|
||||
/// we substitute `127.0.0.1` so that actions running on the same host can
|
||||
/// reach the API.
|
||||
fn normalize_api_url(raw_url: &str) -> String {
|
||||
raw_url
|
||||
.replace("://0.0.0.0", "://127.0.0.1")
|
||||
.replace("://[::]", "://127.0.0.1")
|
||||
}
|
||||
|
||||
impl ActionExecutor {
|
||||
@@ -55,7 +68,9 @@ impl ActionExecutor {
|
||||
max_stderr_bytes: usize,
|
||||
packs_base_dir: PathBuf,
|
||||
api_url: String,
|
||||
jwt_config: JwtConfig,
|
||||
) -> Self {
|
||||
let api_url = normalize_api_url(&api_url);
|
||||
Self {
|
||||
pool,
|
||||
runtime_registry,
|
||||
@@ -65,6 +80,7 @@ impl ActionExecutor {
|
||||
max_stderr_bytes,
|
||||
packs_base_dir,
|
||||
api_url,
|
||||
jwt_config,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,9 +292,34 @@ impl ActionExecutor {
|
||||
env.insert("ATTUNE_ACTION".to_string(), execution.action_ref.clone());
|
||||
env.insert("ATTUNE_API_URL".to_string(), self.api_url.clone());
|
||||
|
||||
// TODO: Generate execution-scoped API token
|
||||
// For now, set placeholder to maintain interface compatibility
|
||||
env.insert("ATTUNE_API_TOKEN".to_string(), "".to_string());
|
||||
// Generate execution-scoped API token.
|
||||
// The identity that triggered the execution is derived from the `sub` claim
|
||||
// of the original token; for rule-triggered executions we use identity 1
|
||||
// (the system identity) as a reasonable default.
|
||||
let identity_id: i64 = 1; // System identity fallback
|
||||
// Default timeout is 300s; add 60s grace period for cleanup.
|
||||
// The actual `timeout` variable is computed later in this function,
|
||||
// but the token TTL just needs a reasonable upper bound.
|
||||
let token_ttl = Some(360_i64);
|
||||
match generate_execution_token(
|
||||
identity_id,
|
||||
execution.id,
|
||||
&execution.action_ref,
|
||||
&self.jwt_config,
|
||||
token_ttl,
|
||||
) {
|
||||
Ok(token) => {
|
||||
env.insert("ATTUNE_API_TOKEN".to_string(), token);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to generate execution token for execution {}: {}. \
|
||||
Actions that call back to the API will not authenticate.",
|
||||
execution.id, e
|
||||
);
|
||||
env.insert("ATTUNE_API_TOKEN".to_string(), String::new());
|
||||
}
|
||||
}
|
||||
|
||||
// Add rule and trigger context if execution was triggered by enforcement
|
||||
if let Some(enforcement_id) = execution.enforcement {
|
||||
|
||||
@@ -285,6 +285,17 @@ impl WorkerService {
|
||||
let api_url = std::env::var("ATTUNE_API_URL")
|
||||
.unwrap_or_else(|_| format!("http://{}:{}", config.server.host, config.server.port));
|
||||
|
||||
// Build JWT config for generating execution-scoped tokens
|
||||
let jwt_config = attune_common::auth::jwt::JwtConfig {
|
||||
secret: config
|
||||
.security
|
||||
.jwt_secret
|
||||
.clone()
|
||||
.unwrap_or_else(|| "insecure_default_secret_change_in_production".to_string()),
|
||||
access_token_expiration: config.security.jwt_access_expiration as i64,
|
||||
refresh_token_expiration: config.security.jwt_refresh_expiration as i64,
|
||||
};
|
||||
|
||||
let executor = Arc::new(ActionExecutor::new(
|
||||
pool.clone(),
|
||||
runtime_registry,
|
||||
@@ -294,6 +305,7 @@ impl WorkerService {
|
||||
max_stderr_bytes,
|
||||
packs_base_dir.clone(),
|
||||
api_url,
|
||||
jwt_config,
|
||||
));
|
||||
|
||||
// Initialize heartbeat manager
|
||||
|
||||
Reference in New Issue
Block a user