proper sql filtering
This commit is contained in:
@@ -319,6 +319,10 @@ pub struct EnforcementQueryParams {
|
||||
#[param(example = "core.webhook")]
|
||||
pub trigger_ref: Option<String>,
|
||||
|
||||
/// Filter by rule reference
|
||||
#[param(example = "core.on_webhook")]
|
||||
pub rule_ref: Option<String>,
|
||||
|
||||
/// Page number (1-indexed)
|
||||
#[serde(default = "default_page")]
|
||||
#[param(example = 1, minimum = 1)]
|
||||
|
||||
@@ -7,6 +7,7 @@ use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
use attune_common::models::enums::ExecutionStatus;
|
||||
use attune_common::models::execution::WorkflowTaskMetadata;
|
||||
use attune_common::repositories::execution::ExecutionWithRefs;
|
||||
|
||||
/// Request DTO for creating a manual execution
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
@@ -63,6 +64,12 @@ pub struct ExecutionResponse {
|
||||
#[schema(value_type = Object, example = json!({"message_id": "1234567890.123456"}))]
|
||||
pub result: Option<JsonValue>,
|
||||
|
||||
/// When the execution actually started running (worker picked it up).
|
||||
/// Null if the execution hasn't started running yet.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "2024-01-13T10:31:00Z", nullable = true)]
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Workflow task metadata (only populated for workflow task executions)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Option<Object>, nullable = true)]
|
||||
@@ -108,6 +115,12 @@ pub struct ExecutionSummary {
|
||||
#[schema(example = "core.timer")]
|
||||
pub trigger_ref: Option<String>,
|
||||
|
||||
/// When the execution actually started running (worker picked it up).
|
||||
/// Null if the execution hasn't started running yet.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "2024-01-13T10:31:00Z", nullable = true)]
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Workflow task metadata (only populated for workflow task executions)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Option<Object>, nullable = true)]
|
||||
@@ -207,6 +220,7 @@ impl From<attune_common::models::execution::Execution> for ExecutionResponse {
|
||||
result: execution
|
||||
.result
|
||||
.map(|r| serde_json::to_value(r).unwrap_or(JsonValue::Null)),
|
||||
started_at: execution.started_at,
|
||||
workflow_task: execution.workflow_task,
|
||||
created: execution.created,
|
||||
updated: execution.updated,
|
||||
@@ -225,6 +239,7 @@ impl From<attune_common::models::execution::Execution> for ExecutionSummary {
|
||||
enforcement: execution.enforcement,
|
||||
rule_ref: None, // Populated separately via enforcement lookup
|
||||
trigger_ref: None, // Populated separately via enforcement lookup
|
||||
started_at: execution.started_at,
|
||||
workflow_task: execution.workflow_task,
|
||||
created: execution.created,
|
||||
updated: execution.updated,
|
||||
@@ -232,6 +247,26 @@ impl From<attune_common::models::execution::Execution> for ExecutionSummary {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from the joined query result (execution + enforcement refs).
|
||||
/// `rule_ref` and `trigger_ref` are already populated from the SQL JOIN.
|
||||
impl From<ExecutionWithRefs> for ExecutionSummary {
|
||||
fn from(row: ExecutionWithRefs) -> Self {
|
||||
Self {
|
||||
id: row.id,
|
||||
action_ref: row.action_ref,
|
||||
status: row.status,
|
||||
parent: row.parent,
|
||||
enforcement: row.enforcement,
|
||||
rule_ref: row.rule_ref,
|
||||
trigger_ref: row.trigger_ref,
|
||||
started_at: row.started_at,
|
||||
workflow_task: row.workflow_task,
|
||||
created: row.created,
|
||||
updated: row.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_page() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
@@ -11,10 +11,10 @@ use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::repositories::{
|
||||
action::{ActionRepository, CreateActionInput, UpdateActionInput},
|
||||
action::{ActionRepository, ActionSearchFilters, CreateActionInput, UpdateActionInput},
|
||||
pack::PackRepository,
|
||||
queue_stats::QueueStatsRepository,
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
Create, Delete, FindByRef, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -47,21 +47,20 @@ pub async fn list_actions(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all actions (we'll implement pagination in repository later)
|
||||
let actions = ActionRepository::list(&state.db).await?;
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
let filters = ActionSearchFilters {
|
||||
pack: None,
|
||||
query: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = actions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(actions.len());
|
||||
let result = ActionRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_actions: Vec<ActionSummary> = actions[start..end]
|
||||
.iter()
|
||||
.map(|a| ActionSummary::from(a.clone()))
|
||||
.collect();
|
||||
let paginated_actions: Vec<ActionSummary> =
|
||||
result.rows.into_iter().map(ActionSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_actions, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_actions, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -92,21 +91,20 @@ pub async fn list_actions_by_pack(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get actions for this pack
|
||||
let actions = ActionRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
let filters = ActionSearchFilters {
|
||||
pack: Some(pack.id),
|
||||
query: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = actions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(actions.len());
|
||||
let result = ActionRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_actions: Vec<ActionSummary> = actions[start..end]
|
||||
.iter()
|
||||
.map(|a| ActionSummary::from(a.clone()))
|
||||
.collect();
|
||||
let paginated_actions: Vec<ActionSummary> =
|
||||
result.rows.into_iter().map(ActionSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_actions, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_actions, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
@@ -16,9 +16,12 @@ use validator::Validate;
|
||||
use attune_common::{
|
||||
mq::{EventCreatedPayload, MessageEnvelope, MessageType},
|
||||
repositories::{
|
||||
event::{CreateEventInput, EnforcementRepository, EventRepository},
|
||||
event::{
|
||||
CreateEventInput, EnforcementRepository, EnforcementSearchFilters, EventRepository,
|
||||
EventSearchFilters,
|
||||
},
|
||||
trigger::TriggerRepository,
|
||||
Create, FindById, FindByRef, List,
|
||||
Create, FindById, FindByRef,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -220,53 +223,27 @@ pub async fn list_events(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<EventQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get events based on filters
|
||||
let events = if let Some(trigger_id) = query.trigger {
|
||||
// Filter by trigger ID
|
||||
EventRepository::find_by_trigger(&state.db, trigger_id).await?
|
||||
} else if let Some(trigger_ref) = &query.trigger_ref {
|
||||
// Filter by trigger reference
|
||||
EventRepository::find_by_trigger_ref(&state.db, trigger_ref).await?
|
||||
} else {
|
||||
// Get all events
|
||||
EventRepository::list(&state.db).await?
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
let filters = EventSearchFilters {
|
||||
trigger: query.trigger,
|
||||
trigger_ref: query.trigger_ref.clone(),
|
||||
source: query.source,
|
||||
rule_ref: query.rule_ref.clone(),
|
||||
limit: query.limit(),
|
||||
offset: query.offset(),
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_events = events;
|
||||
let result = EventRepository::search(&state.db, &filters).await?;
|
||||
|
||||
if let Some(source_id) = query.source {
|
||||
filtered_events.retain(|e| e.source == Some(source_id));
|
||||
}
|
||||
let paginated_events: Vec<EventSummary> =
|
||||
result.rows.into_iter().map(EventSummary::from).collect();
|
||||
|
||||
if let Some(rule_ref) = &query.rule_ref {
|
||||
let rule_ref_lower = rule_ref.to_lowercase();
|
||||
filtered_events.retain(|e| {
|
||||
e.rule_ref
|
||||
.as_ref()
|
||||
.map(|r| r.to_lowercase().contains(&rule_ref_lower))
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_events.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_events.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_events: Vec<EventSummary> = filtered_events[start..end]
|
||||
.iter()
|
||||
.map(|event| EventSummary::from(event.clone()))
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_events, &pagination_params, total);
|
||||
let response = PaginatedResponse::new(paginated_events, &pagination_params, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -319,46 +296,32 @@ pub async fn list_enforcements(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<EnforcementQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enforcements based on filters
|
||||
let enforcements = if let Some(status) = query.status {
|
||||
// Filter by status
|
||||
EnforcementRepository::find_by_status(&state.db, status).await?
|
||||
} else if let Some(rule_id) = query.rule {
|
||||
// Filter by rule ID
|
||||
EnforcementRepository::find_by_rule(&state.db, rule_id).await?
|
||||
} else if let Some(event_id) = query.event {
|
||||
// Filter by event ID
|
||||
EnforcementRepository::find_by_event(&state.db, event_id).await?
|
||||
} else {
|
||||
// Get all enforcements
|
||||
EnforcementRepository::list(&state.db).await?
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
// Filters are combinable (AND), not mutually exclusive.
|
||||
let filters = EnforcementSearchFilters {
|
||||
status: query.status,
|
||||
rule: query.rule,
|
||||
event: query.event,
|
||||
trigger_ref: query.trigger_ref.clone(),
|
||||
rule_ref: query.rule_ref.clone(),
|
||||
limit: query.limit(),
|
||||
offset: query.offset(),
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_enforcements = enforcements;
|
||||
let result = EnforcementRepository::search(&state.db, &filters).await?;
|
||||
|
||||
if let Some(trigger_ref) = &query.trigger_ref {
|
||||
filtered_enforcements.retain(|e| e.trigger_ref == *trigger_ref);
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_enforcements.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_enforcements.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_enforcements: Vec<EnforcementSummary> = filtered_enforcements[start..end]
|
||||
.iter()
|
||||
.map(|enforcement| EnforcementSummary::from(enforcement.clone()))
|
||||
let paginated_enforcements: Vec<EnforcementSummary> = result
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(EnforcementSummary::from)
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_enforcements, &pagination_params, total);
|
||||
let response = PaginatedResponse::new(paginated_enforcements, &pagination_params, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
@@ -18,9 +18,10 @@ use attune_common::models::enums::ExecutionStatus;
|
||||
use attune_common::mq::{ExecutionRequestedPayload, MessageEnvelope, MessageType};
|
||||
use attune_common::repositories::{
|
||||
action::ActionRepository,
|
||||
execution::{CreateExecutionInput, ExecutionRepository},
|
||||
Create, EnforcementRepository, FindById, FindByRef, List,
|
||||
execution::{CreateExecutionInput, ExecutionRepository, ExecutionSearchFilters},
|
||||
Create, FindById, FindByRef,
|
||||
};
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
@@ -125,117 +126,37 @@ pub async fn list_executions(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(query): Query<ExecutionQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get executions based on filters
|
||||
let executions = if let Some(status) = query.status {
|
||||
// Filter by status
|
||||
ExecutionRepository::find_by_status(&state.db, status).await?
|
||||
} else if let Some(enforcement_id) = query.enforcement {
|
||||
// Filter by enforcement
|
||||
ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?
|
||||
} else {
|
||||
// Get all executions
|
||||
ExecutionRepository::list(&state.db).await?
|
||||
// All filtering, pagination, and the enforcement JOIN happen in a single
|
||||
// SQL query — no in-memory filtering or post-fetch lookups.
|
||||
let filters = ExecutionSearchFilters {
|
||||
status: query.status,
|
||||
action_ref: query.action_ref.clone(),
|
||||
pack_name: query.pack_name.clone(),
|
||||
rule_ref: query.rule_ref.clone(),
|
||||
trigger_ref: query.trigger_ref.clone(),
|
||||
executor: query.executor,
|
||||
result_contains: query.result_contains.clone(),
|
||||
enforcement: query.enforcement,
|
||||
parent: query.parent,
|
||||
top_level_only: query.top_level_only == Some(true),
|
||||
limit: query.limit(),
|
||||
offset: query.offset(),
|
||||
};
|
||||
|
||||
// Apply additional filters in memory (could be optimized with database queries)
|
||||
let mut filtered_executions = executions;
|
||||
let result = ExecutionRepository::search(&state.db, &filters).await?;
|
||||
|
||||
if let Some(action_ref) = &query.action_ref {
|
||||
filtered_executions.retain(|e| e.action_ref == *action_ref);
|
||||
}
|
||||
|
||||
if let Some(pack_name) = &query.pack_name {
|
||||
filtered_executions.retain(|e| {
|
||||
// action_ref format is "pack.action"
|
||||
e.action_ref.starts_with(&format!("{}.", pack_name))
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(result_search) = &query.result_contains {
|
||||
let search_lower = result_search.to_lowercase();
|
||||
filtered_executions.retain(|e| {
|
||||
if let Some(result) = &e.result {
|
||||
// Convert result to JSON string and search case-insensitively
|
||||
let result_str = serde_json::to_string(result).unwrap_or_default();
|
||||
result_str.to_lowercase().contains(&search_lower)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(parent_id) = query.parent {
|
||||
filtered_executions.retain(|e| e.parent == Some(parent_id));
|
||||
}
|
||||
|
||||
if query.top_level_only == Some(true) {
|
||||
filtered_executions.retain(|e| e.parent.is_none());
|
||||
}
|
||||
|
||||
if let Some(executor_id) = query.executor {
|
||||
filtered_executions.retain(|e| e.executor == Some(executor_id));
|
||||
}
|
||||
|
||||
// Fetch enforcements for all executions to populate rule_ref and trigger_ref
|
||||
let enforcement_ids: Vec<i64> = filtered_executions
|
||||
.iter()
|
||||
.filter_map(|e| e.enforcement)
|
||||
let paginated_executions: Vec<ExecutionSummary> = result
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(ExecutionSummary::from)
|
||||
.collect();
|
||||
|
||||
let enforcement_map: std::collections::HashMap<i64, _> = if !enforcement_ids.is_empty() {
|
||||
let enforcements = EnforcementRepository::list(&state.db).await?;
|
||||
enforcements.into_iter().map(|enf| (enf.id, enf)).collect()
|
||||
} else {
|
||||
std::collections::HashMap::new()
|
||||
};
|
||||
|
||||
// Filter by rule_ref if specified
|
||||
if let Some(rule_ref) = &query.rule_ref {
|
||||
filtered_executions.retain(|e| {
|
||||
e.enforcement
|
||||
.and_then(|enf_id| enforcement_map.get(&enf_id))
|
||||
.map(|enf| enf.rule_ref == *rule_ref)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Filter by trigger_ref if specified
|
||||
if let Some(trigger_ref) = &query.trigger_ref {
|
||||
filtered_executions.retain(|e| {
|
||||
e.enforcement
|
||||
.and_then(|enf_id| enforcement_map.get(&enf_id))
|
||||
.map(|enf| enf.trigger_ref == *trigger_ref)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_executions.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_executions.len());
|
||||
|
||||
// Get paginated slice and populate rule_ref/trigger_ref from enforcements
|
||||
let paginated_executions: Vec<ExecutionSummary> = filtered_executions[start..end]
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let mut summary = ExecutionSummary::from(e.clone());
|
||||
if let Some(enf_id) = e.enforcement {
|
||||
if let Some(enforcement) = enforcement_map.get(&enf_id) {
|
||||
summary.rule_ref = Some(enforcement.rule_ref.clone());
|
||||
summary.trigger_ref = Some(enforcement.trigger_ref.clone());
|
||||
}
|
||||
}
|
||||
summary
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination_params, total);
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination_params, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -310,21 +231,23 @@ pub async fn list_executions_by_status(
|
||||
}
|
||||
};
|
||||
|
||||
// Get executions by status
|
||||
let executions = ExecutionRepository::find_by_status(&state.db, status).await?;
|
||||
// Use the search method for SQL-side filtering + pagination.
|
||||
let filters = ExecutionSearchFilters {
|
||||
status: Some(status),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = executions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(executions.len());
|
||||
let result = ExecutionRepository::search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
|
||||
.iter()
|
||||
.map(|e| ExecutionSummary::from(e.clone()))
|
||||
let paginated_executions: Vec<ExecutionSummary> = result
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(ExecutionSummary::from)
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -350,21 +273,23 @@ pub async fn list_executions_by_enforcement(
|
||||
Path(enforcement_id): Path<i64>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get executions by enforcement
|
||||
let executions = ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?;
|
||||
// Use the search method for SQL-side filtering + pagination.
|
||||
let filters = ExecutionSearchFilters {
|
||||
enforcement: Some(enforcement_id),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = executions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(executions.len());
|
||||
let result = ExecutionRepository::search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
|
||||
.iter()
|
||||
.map(|e| ExecutionSummary::from(e.clone()))
|
||||
let paginated_executions: Vec<ExecutionSummary> = result
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(ExecutionSummary::from)
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -384,34 +309,37 @@ pub async fn get_execution_stats(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all executions (limited by repository to 1000)
|
||||
let executions = ExecutionRepository::list(&state.db).await?;
|
||||
// Use a single SQL query with COUNT + GROUP BY instead of fetching all rows.
|
||||
let rows = sqlx::query(
|
||||
"SELECT status::text AS status, COUNT(*) AS cnt FROM execution GROUP BY status",
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
// Calculate statistics
|
||||
let total = executions.len();
|
||||
let completed = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Completed)
|
||||
.count();
|
||||
let failed = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Failed)
|
||||
.count();
|
||||
let running = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Running)
|
||||
.count();
|
||||
let pending = executions
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
matches!(
|
||||
e.status,
|
||||
attune_common::models::enums::ExecutionStatus::Requested
|
||||
| attune_common::models::enums::ExecutionStatus::Scheduling
|
||||
| attune_common::models::enums::ExecutionStatus::Scheduled
|
||||
)
|
||||
})
|
||||
.count();
|
||||
let mut completed: i64 = 0;
|
||||
let mut failed: i64 = 0;
|
||||
let mut running: i64 = 0;
|
||||
let mut pending: i64 = 0;
|
||||
let mut cancelled: i64 = 0;
|
||||
let mut timeout: i64 = 0;
|
||||
let mut abandoned: i64 = 0;
|
||||
let mut total: i64 = 0;
|
||||
|
||||
for row in &rows {
|
||||
let status: &str = row.get("status");
|
||||
let cnt: i64 = row.get("cnt");
|
||||
total += cnt;
|
||||
match status {
|
||||
"completed" => completed = cnt,
|
||||
"failed" => failed = cnt,
|
||||
"running" => running = cnt,
|
||||
"requested" | "scheduling" | "scheduled" => pending += cnt,
|
||||
"cancelled" | "canceling" => cancelled += cnt,
|
||||
"timeout" => timeout = cnt,
|
||||
"abandoned" => abandoned = cnt,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let stats = serde_json::json!({
|
||||
"total": total,
|
||||
@@ -419,9 +347,9 @@ pub async fn get_execution_stats(
|
||||
"failed": failed,
|
||||
"running": running,
|
||||
"pending": pending,
|
||||
"cancelled": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Cancelled).count(),
|
||||
"timeout": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Timeout).count(),
|
||||
"abandoned": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Abandoned).count(),
|
||||
"cancelled": cancelled,
|
||||
"timeout": timeout,
|
||||
"abandoned": abandoned,
|
||||
});
|
||||
|
||||
let response = ApiResponse::new(stats);
|
||||
|
||||
@@ -14,8 +14,10 @@ use attune_common::{
|
||||
mq::{InquiryRespondedPayload, MessageEnvelope, MessageType},
|
||||
repositories::{
|
||||
execution::ExecutionRepository,
|
||||
inquiry::{CreateInquiryInput, InquiryRepository, UpdateInquiryInput},
|
||||
Create, Delete, FindById, List, Update,
|
||||
inquiry::{
|
||||
CreateInquiryInput, InquiryRepository, InquirySearchFilters, UpdateInquiryInput,
|
||||
},
|
||||
Create, Delete, FindById, Update,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -51,45 +53,30 @@ pub async fn list_inquiries(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<InquiryQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get inquiries based on filters
|
||||
let inquiries = if let Some(status) = query.status {
|
||||
// Filter by status
|
||||
InquiryRepository::find_by_status(&state.db, status).await?
|
||||
} else if let Some(execution_id) = query.execution {
|
||||
// Filter by execution
|
||||
InquiryRepository::find_by_execution(&state.db, execution_id).await?
|
||||
} else {
|
||||
// Get all inquiries
|
||||
InquiryRepository::list(&state.db).await?
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
// Filters are combinable (AND), not mutually exclusive.
|
||||
let limit = query.limit.unwrap_or(50).min(500) as u32;
|
||||
let offset = query.offset.unwrap_or(0) as u32;
|
||||
|
||||
let filters = InquirySearchFilters {
|
||||
status: query.status,
|
||||
execution: query.execution,
|
||||
assigned_to: query.assigned_to,
|
||||
limit,
|
||||
offset,
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_inquiries = inquiries;
|
||||
let result = InquiryRepository::search(&state.db, &filters).await?;
|
||||
|
||||
if let Some(assigned_to) = query.assigned_to {
|
||||
filtered_inquiries.retain(|i| i.assigned_to == Some(assigned_to));
|
||||
}
|
||||
let paginated_inquiries: Vec<InquirySummary> =
|
||||
result.rows.into_iter().map(InquirySummary::from).collect();
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_inquiries.len() as u64;
|
||||
let offset = query.offset.unwrap_or(0);
|
||||
let limit = query.limit.unwrap_or(50).min(500);
|
||||
let start = offset;
|
||||
let end = (start + limit).min(filtered_inquiries.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_inquiries: Vec<InquirySummary> = filtered_inquiries[start..end]
|
||||
.iter()
|
||||
.map(|inquiry| InquirySummary::from(inquiry.clone()))
|
||||
.collect();
|
||||
|
||||
// Convert to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: (offset / limit.max(1)) as u32 + 1,
|
||||
page_size: limit as u32,
|
||||
page: (offset / limit.max(1)) + 1,
|
||||
page_size: limit,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination_params, total);
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination_params, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -161,20 +148,21 @@ pub async fn list_inquiries_by_status(
|
||||
}
|
||||
};
|
||||
|
||||
let inquiries = InquiryRepository::find_by_status(&state.db, status).await?;
|
||||
// Use the search method for SQL-side filtering + pagination.
|
||||
let filters = InquirySearchFilters {
|
||||
status: Some(status),
|
||||
execution: None,
|
||||
assigned_to: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = inquiries.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(inquiries.len());
|
||||
let result = InquiryRepository::search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_inquiries: Vec<InquirySummary> = inquiries[start..end]
|
||||
.iter()
|
||||
.map(|inquiry| InquirySummary::from(inquiry.clone()))
|
||||
.collect();
|
||||
let paginated_inquiries: Vec<InquirySummary> =
|
||||
result.rows.into_iter().map(InquirySummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -209,20 +197,21 @@ pub async fn list_inquiries_by_execution(
|
||||
ApiError::NotFound(format!("Execution with ID {} not found", execution_id))
|
||||
})?;
|
||||
|
||||
let inquiries = InquiryRepository::find_by_execution(&state.db, execution_id).await?;
|
||||
// Use the search method for SQL-side filtering + pagination.
|
||||
let filters = InquirySearchFilters {
|
||||
status: None,
|
||||
execution: Some(execution_id),
|
||||
assigned_to: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = inquiries.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(inquiries.len());
|
||||
let result = InquiryRepository::search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_inquiries: Vec<InquirySummary> = inquiries[start..end]
|
||||
.iter()
|
||||
.map(|inquiry| InquirySummary::from(inquiry.clone()))
|
||||
.collect();
|
||||
let paginated_inquiries: Vec<InquirySummary> =
|
||||
result.rows.into_iter().map(InquirySummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_inquiries, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
@@ -13,10 +13,10 @@ use validator::Validate;
|
||||
use attune_common::models::OwnerType;
|
||||
use attune_common::repositories::{
|
||||
action::ActionRepository,
|
||||
key::{CreateKeyInput, KeyRepository, UpdateKeyInput},
|
||||
key::{CreateKeyInput, KeyRepository, KeySearchFilters, UpdateKeyInput},
|
||||
pack::PackRepository,
|
||||
trigger::SensorRepository,
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
Create, Delete, FindByRef, Update,
|
||||
};
|
||||
|
||||
use crate::auth::RequireAuth;
|
||||
@@ -46,40 +46,24 @@ pub async fn list_keys(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<KeyQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get keys based on filters
|
||||
let keys = if let Some(owner_type) = query.owner_type {
|
||||
// Filter by owner type
|
||||
KeyRepository::find_by_owner_type(&state.db, owner_type).await?
|
||||
} else {
|
||||
// Get all keys
|
||||
KeyRepository::list(&state.db).await?
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
let filters = KeySearchFilters {
|
||||
owner_type: query.owner_type,
|
||||
owner: query.owner.clone(),
|
||||
limit: query.limit(),
|
||||
offset: query.offset(),
|
||||
};
|
||||
|
||||
// Apply additional filters in memory
|
||||
let mut filtered_keys = keys;
|
||||
let result = KeyRepository::search(&state.db, &filters).await?;
|
||||
|
||||
if let Some(owner) = &query.owner {
|
||||
filtered_keys.retain(|k| k.owner.as_ref() == Some(owner));
|
||||
}
|
||||
let paginated_keys: Vec<KeySummary> = result.rows.into_iter().map(KeySummary::from).collect();
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_keys.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_keys.len());
|
||||
|
||||
// Get paginated slice (values redacted in summary)
|
||||
let paginated_keys: Vec<KeySummary> = filtered_keys[start..end]
|
||||
.iter()
|
||||
.map(|key| KeySummary::from(key.clone()))
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_keys, &pagination_params, total);
|
||||
let response = PaginatedResponse::new(paginated_keys, &pagination_params, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
@@ -17,9 +17,9 @@ use attune_common::mq::{
|
||||
use attune_common::repositories::{
|
||||
action::ActionRepository,
|
||||
pack::PackRepository,
|
||||
rule::{CreateRuleInput, RuleRepository, UpdateRuleInput},
|
||||
rule::{CreateRuleInput, RuleRepository, RuleSearchFilters, UpdateRuleInput},
|
||||
trigger::TriggerRepository,
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
Create, Delete, FindByRef, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -50,21 +50,21 @@ pub async fn list_rules(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all rules
|
||||
let rules = RuleRepository::list(&state.db).await?;
|
||||
let filters = RuleSearchFilters {
|
||||
pack: None,
|
||||
action: None,
|
||||
trigger: None,
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
let result = RuleRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
let paginated_rules: Vec<RuleSummary> =
|
||||
result.rows.into_iter().map(RuleSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -85,21 +85,21 @@ pub async fn list_enabled_rules(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enabled rules
|
||||
let rules = RuleRepository::find_enabled(&state.db).await?;
|
||||
let filters = RuleSearchFilters {
|
||||
pack: None,
|
||||
action: None,
|
||||
trigger: None,
|
||||
enabled: Some(true),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
let result = RuleRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
let paginated_rules: Vec<RuleSummary> =
|
||||
result.rows.into_iter().map(RuleSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -130,21 +130,21 @@ pub async fn list_rules_by_pack(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get rules for this pack
|
||||
let rules = RuleRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
let filters = RuleSearchFilters {
|
||||
pack: Some(pack.id),
|
||||
action: None,
|
||||
trigger: None,
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
let result = RuleRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
let paginated_rules: Vec<RuleSummary> =
|
||||
result.rows.into_iter().map(RuleSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -175,21 +175,21 @@ pub async fn list_rules_by_action(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
|
||||
|
||||
// Get rules for this action
|
||||
let rules = RuleRepository::find_by_action(&state.db, action.id).await?;
|
||||
let filters = RuleSearchFilters {
|
||||
pack: None,
|
||||
action: Some(action.id),
|
||||
trigger: None,
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
let result = RuleRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
let paginated_rules: Vec<RuleSummary> =
|
||||
result.rows.into_iter().map(RuleSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -220,21 +220,21 @@ pub async fn list_rules_by_trigger(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Get rules for this trigger
|
||||
let rules = RuleRepository::find_by_trigger(&state.db, trigger.id).await?;
|
||||
let filters = RuleSearchFilters {
|
||||
pack: None,
|
||||
action: None,
|
||||
trigger: Some(trigger.id),
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = rules.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(rules.len());
|
||||
let result = RuleRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_rules: Vec<RuleSummary> = rules[start..end]
|
||||
.iter()
|
||||
.map(|r| RuleSummary::from(r.clone()))
|
||||
.collect();
|
||||
let paginated_rules: Vec<RuleSummary> =
|
||||
result.rows.into_iter().map(RuleSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ use attune_common::repositories::{
|
||||
pack::PackRepository,
|
||||
runtime::RuntimeRepository,
|
||||
trigger::{
|
||||
CreateSensorInput, CreateTriggerInput, SensorRepository, TriggerRepository,
|
||||
UpdateSensorInput, UpdateTriggerInput,
|
||||
CreateSensorInput, CreateTriggerInput, SensorRepository, SensorSearchFilters,
|
||||
TriggerRepository, TriggerSearchFilters, UpdateSensorInput, UpdateTriggerInput,
|
||||
},
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
Create, Delete, FindByRef, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -54,21 +54,19 @@ pub async fn list_triggers(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all triggers
|
||||
let triggers = TriggerRepository::list(&state.db).await?;
|
||||
let filters = TriggerSearchFilters {
|
||||
pack: None,
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = triggers.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(triggers.len());
|
||||
let result = TriggerRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
|
||||
.iter()
|
||||
.map(|t| TriggerSummary::from(t.clone()))
|
||||
.collect();
|
||||
let paginated_triggers: Vec<TriggerSummary> =
|
||||
result.rows.into_iter().map(TriggerSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -89,21 +87,19 @@ pub async fn list_enabled_triggers(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enabled triggers
|
||||
let triggers = TriggerRepository::find_enabled(&state.db).await?;
|
||||
let filters = TriggerSearchFilters {
|
||||
pack: None,
|
||||
enabled: Some(true),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = triggers.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(triggers.len());
|
||||
let result = TriggerRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
|
||||
.iter()
|
||||
.map(|t| TriggerSummary::from(t.clone()))
|
||||
.collect();
|
||||
let paginated_triggers: Vec<TriggerSummary> =
|
||||
result.rows.into_iter().map(TriggerSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -134,21 +130,19 @@ pub async fn list_triggers_by_pack(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get triggers for this pack
|
||||
let triggers = TriggerRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
let filters = TriggerSearchFilters {
|
||||
pack: Some(pack.id),
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = triggers.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(triggers.len());
|
||||
let result = TriggerRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
|
||||
.iter()
|
||||
.map(|t| TriggerSummary::from(t.clone()))
|
||||
.collect();
|
||||
let paginated_triggers: Vec<TriggerSummary> =
|
||||
result.rows.into_iter().map(TriggerSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -438,21 +432,20 @@ pub async fn list_sensors(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all sensors
|
||||
let sensors = SensorRepository::list(&state.db).await?;
|
||||
let filters = SensorSearchFilters {
|
||||
pack: None,
|
||||
trigger: None,
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
let result = SensorRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
let paginated_sensors: Vec<SensorSummary> =
|
||||
result.rows.into_iter().map(SensorSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -473,21 +466,20 @@ pub async fn list_enabled_sensors(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get enabled sensors
|
||||
let sensors = SensorRepository::find_enabled(&state.db).await?;
|
||||
let filters = SensorSearchFilters {
|
||||
pack: None,
|
||||
trigger: None,
|
||||
enabled: Some(true),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
let result = SensorRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
let paginated_sensors: Vec<SensorSummary> =
|
||||
result.rows.into_iter().map(SensorSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -518,21 +510,20 @@ pub async fn list_sensors_by_pack(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get sensors for this pack
|
||||
let sensors = SensorRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
let filters = SensorSearchFilters {
|
||||
pack: Some(pack.id),
|
||||
trigger: None,
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
let result = SensorRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
let paginated_sensors: Vec<SensorSummary> =
|
||||
result.rows.into_iter().map(SensorSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -563,21 +554,20 @@ pub async fn list_sensors_by_trigger(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
// Get sensors for this trigger
|
||||
let sensors = SensorRepository::find_by_trigger(&state.db, trigger.id).await?;
|
||||
let filters = SensorSearchFilters {
|
||||
pack: None,
|
||||
trigger: Some(trigger.id),
|
||||
enabled: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = sensors.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(sensors.len());
|
||||
let result = SensorRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
|
||||
.iter()
|
||||
.map(|s| SensorSummary::from(s.clone()))
|
||||
.collect();
|
||||
let paginated_sensors: Vec<SensorSummary> =
|
||||
result.rows.into_iter().map(SensorSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
@@ -16,8 +16,9 @@ use attune_common::repositories::{
|
||||
pack::PackRepository,
|
||||
workflow::{
|
||||
CreateWorkflowDefinitionInput, UpdateWorkflowDefinitionInput, WorkflowDefinitionRepository,
|
||||
WorkflowSearchFilters,
|
||||
},
|
||||
Create, Delete, FindByRef, List, Update,
|
||||
Create, Delete, FindByRef, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -54,64 +55,30 @@ pub async fn list_workflows(
|
||||
// Validate search params
|
||||
search_params.validate()?;
|
||||
|
||||
// Get workflows based on filters
|
||||
let mut workflows = if let Some(tags_str) = &search_params.tags {
|
||||
// Filter by tags
|
||||
let tags: Vec<&str> = tags_str.split(',').map(|s| s.trim()).collect();
|
||||
let mut results = Vec::new();
|
||||
for tag in tags {
|
||||
let mut tag_results = WorkflowDefinitionRepository::find_by_tag(&state.db, tag).await?;
|
||||
results.append(&mut tag_results);
|
||||
}
|
||||
// Remove duplicates by ID
|
||||
results.sort_by_key(|w| w.id);
|
||||
results.dedup_by_key(|w| w.id);
|
||||
results
|
||||
} else if search_params.enabled == Some(true) {
|
||||
// Filter by enabled status (only return enabled workflows)
|
||||
WorkflowDefinitionRepository::find_enabled(&state.db).await?
|
||||
} else {
|
||||
// Get all workflows
|
||||
WorkflowDefinitionRepository::list(&state.db).await?
|
||||
// Parse comma-separated tags into a Vec if provided
|
||||
let tags = search_params.tags.as_ref().map(|t| {
|
||||
t.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
let filters = WorkflowSearchFilters {
|
||||
pack: None,
|
||||
pack_ref: search_params.pack_ref.clone(),
|
||||
enabled: search_params.enabled,
|
||||
tags,
|
||||
search: search_params.search.clone(),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Apply enabled filter if specified and not already filtered by it
|
||||
if let Some(enabled) = search_params.enabled {
|
||||
if search_params.tags.is_some() {
|
||||
// If we filtered by tags, also apply enabled filter
|
||||
workflows.retain(|w| w.enabled == enabled);
|
||||
}
|
||||
}
|
||||
let result = WorkflowDefinitionRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Apply search filter if provided
|
||||
if let Some(search_term) = &search_params.search {
|
||||
let search_lower = search_term.to_lowercase();
|
||||
workflows.retain(|w| {
|
||||
w.label.to_lowercase().contains(&search_lower)
|
||||
|| w.description
|
||||
.as_ref()
|
||||
.map(|d| d.to_lowercase().contains(&search_lower))
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
let paginated_workflows: Vec<WorkflowSummary> =
|
||||
result.rows.into_iter().map(WorkflowSummary::from).collect();
|
||||
|
||||
// Apply pack_ref filter if provided
|
||||
if let Some(pack_ref) = &search_params.pack_ref {
|
||||
workflows.retain(|w| w.pack_ref == *pack_ref);
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = workflows.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(workflows.len());
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_workflows: Vec<WorkflowSummary> = workflows[start..end]
|
||||
.iter()
|
||||
.map(|w| WorkflowSummary::from(w.clone()))
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_workflows, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_workflows, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -138,25 +105,27 @@ pub async fn list_workflows_by_pack(
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify pack exists
|
||||
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
let _pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
// Get workflows for this pack
|
||||
let workflows = WorkflowDefinitionRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
// All filtering and pagination happen in a single SQL query.
|
||||
let filters = WorkflowSearchFilters {
|
||||
pack: None,
|
||||
pack_ref: Some(pack_ref),
|
||||
enabled: None,
|
||||
tags: None,
|
||||
search: None,
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = workflows.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(workflows.len());
|
||||
let result = WorkflowDefinitionRepository::list_search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_workflows: Vec<WorkflowSummary> = workflows[start..end]
|
||||
.iter()
|
||||
.map(|w| WorkflowSummary::from(w.clone()))
|
||||
.collect();
|
||||
let paginated_workflows: Vec<WorkflowSummary> =
|
||||
result.rows.into_iter().map(WorkflowSummary::from).collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_workflows, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_workflows, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
@@ -1104,6 +1104,11 @@ pub mod execution {
|
||||
pub status: ExecutionStatus,
|
||||
pub result: Option<JsonDict>,
|
||||
|
||||
/// When the execution actually started running (worker picked it up).
|
||||
/// Set when status transitions to `Running`. Used to compute accurate
|
||||
/// duration that excludes queue/scheduling wait time.
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Workflow task metadata (only populated for workflow task executions)
|
||||
///
|
||||
/// Provides direct access to workflow orchestration state without JOINs.
|
||||
|
||||
@@ -8,6 +8,26 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
|
||||
|
||||
/// Filters for [`ActionRepository::list_search`].
|
||||
///
|
||||
/// All fields are optional and combinable (AND). Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ActionSearchFilters {
|
||||
/// Filter by pack ID
|
||||
pub pack: Option<Id>,
|
||||
/// Text search across ref, label, description (case-insensitive)
|
||||
pub query: Option<String>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`ActionRepository::list_search`].
|
||||
#[derive(Debug)]
|
||||
pub struct ActionSearchResult {
|
||||
pub rows: Vec<Action>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
/// Repository for Action operations
|
||||
pub struct ActionRepository;
|
||||
|
||||
@@ -287,6 +307,92 @@ impl Delete for ActionRepository {
|
||||
}
|
||||
|
||||
impl ActionRepository {
|
||||
/// Search actions with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
pub async fn list_search<'e, E>(
|
||||
db: E,
|
||||
filters: &ActionSearchFilters,
|
||||
) -> Result<ActionSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_version_constraint, param_schema, out_schema, workflow_def, is_adhoc, created, updated";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM action"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM action");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(pack_id) = filters.pack {
|
||||
push_condition!("pack = ", pack_id);
|
||||
}
|
||||
if let Some(ref query) = filters.query {
|
||||
let pattern = format!("%{}%", query.to_lowercase());
|
||||
// Search needs an OR across multiple columns, wrapped in parens
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push("(LOWER(ref) LIKE ");
|
||||
qb.push_bind(pattern.clone());
|
||||
qb.push(" OR LOWER(label) LIKE ");
|
||||
qb.push_bind(pattern.clone());
|
||||
qb.push(" OR LOWER(description) LIKE ");
|
||||
qb.push_bind(pattern.clone());
|
||||
qb.push(")");
|
||||
|
||||
count_qb.push("(LOWER(ref) LIKE ");
|
||||
count_qb.push_bind(pattern.clone());
|
||||
count_qb.push(" OR LOWER(label) LIKE ");
|
||||
count_qb.push_bind(pattern.clone());
|
||||
count_qb.push(" OR LOWER(description) LIKE ");
|
||||
count_qb.push_bind(pattern);
|
||||
count_qb.push(")");
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY ref ASC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Action> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(ActionSearchResult { rows, total })
|
||||
}
|
||||
|
||||
/// Find actions by pack ID
|
||||
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Action>>
|
||||
where
|
||||
|
||||
@@ -15,6 +15,56 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, List, Repository, Update};
|
||||
|
||||
// ============================================================================
|
||||
// Event Search
|
||||
// ============================================================================
|
||||
|
||||
/// Filters for [`EventRepository::search`].
|
||||
///
|
||||
/// All fields are optional. When set, the corresponding WHERE clause is added.
|
||||
/// Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EventSearchFilters {
|
||||
pub trigger: Option<Id>,
|
||||
pub trigger_ref: Option<String>,
|
||||
pub source: Option<Id>,
|
||||
pub rule_ref: Option<String>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`EventRepository::search`].
|
||||
#[derive(Debug)]
|
||||
pub struct EventSearchResult {
|
||||
pub rows: Vec<Event>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Enforcement Search
|
||||
// ============================================================================
|
||||
|
||||
/// Filters for [`EnforcementRepository::search`].
|
||||
///
|
||||
/// All fields are optional and combinable. Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EnforcementSearchFilters {
|
||||
pub rule: Option<Id>,
|
||||
pub event: Option<Id>,
|
||||
pub status: Option<EnforcementStatus>,
|
||||
pub trigger_ref: Option<String>,
|
||||
pub rule_ref: Option<String>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`EnforcementRepository::search`].
|
||||
#[derive(Debug)]
|
||||
pub struct EnforcementSearchResult {
|
||||
pub rows: Vec<Enforcement>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
/// Repository for Event operations
|
||||
pub struct EventRepository;
|
||||
|
||||
@@ -173,6 +223,75 @@ impl EventRepository {
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
/// Search events with all filters pushed into SQL.
|
||||
///
|
||||
/// Builds a dynamic query so that every filter, pagination, and the total
|
||||
/// count are handled in the database — no in-memory filtering or slicing.
|
||||
pub async fn search<'e, E>(db: E, filters: &EventSearchFilters) -> Result<EventSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, trigger, trigger_ref, config, payload, source, source_ref, rule, rule_ref, created";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM event"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM event");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(trigger_id) = filters.trigger {
|
||||
push_condition!("trigger = ", trigger_id);
|
||||
}
|
||||
if let Some(ref trigger_ref) = filters.trigger_ref {
|
||||
push_condition!("trigger_ref = ", trigger_ref.clone());
|
||||
}
|
||||
if let Some(source_id) = filters.source {
|
||||
push_condition!("source = ", source_id);
|
||||
}
|
||||
if let Some(ref rule_ref) = filters.rule_ref {
|
||||
push_condition!(
|
||||
"LOWER(rule_ref) LIKE ",
|
||||
format!("%{}%", rule_ref.to_lowercase())
|
||||
);
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY created DESC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Event> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(EventSearchResult { rows, total })
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -425,4 +544,75 @@ impl EnforcementRepository {
|
||||
|
||||
Ok(enforcements)
|
||||
}
|
||||
|
||||
/// Search enforcements with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
pub async fn search<'e, E>(
|
||||
db: E,
|
||||
filters: &EnforcementSearchFilters,
|
||||
) -> Result<EnforcementSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, rule, rule_ref, trigger_ref, config, event, status, payload, condition, conditions, created, resolved_at";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM enforcement"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM enforcement");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(status) = &filters.status {
|
||||
push_condition!("status = ", status.clone());
|
||||
}
|
||||
if let Some(rule_id) = filters.rule {
|
||||
push_condition!("rule = ", rule_id);
|
||||
}
|
||||
if let Some(event_id) = filters.event {
|
||||
push_condition!("event = ", event_id);
|
||||
}
|
||||
if let Some(ref trigger_ref) = filters.trigger_ref {
|
||||
push_condition!("trigger_ref = ", trigger_ref.clone());
|
||||
}
|
||||
if let Some(ref rule_ref) = filters.rule_ref {
|
||||
push_condition!("rule_ref = ", rule_ref.clone());
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY created DESC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Enforcement> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(EnforcementSearchResult { rows, total })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,71 @@
|
||||
//! Execution repository for database operations
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
use crate::models::{enums::ExecutionStatus, execution::*, Id, JsonDict};
|
||||
use crate::Result;
|
||||
use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, List, Repository, Update};
|
||||
|
||||
/// Filters for the [`ExecutionRepository::search`] query-builder method.
|
||||
///
|
||||
/// Every field is optional. When set, the corresponding `WHERE` clause is
|
||||
/// appended to the query. Pagination (`limit`/`offset`) is always applied.
|
||||
///
|
||||
/// Filters that involve the `enforcement` table (`rule_ref`, `trigger_ref`)
|
||||
/// cause a `LEFT JOIN enforcement` to be added automatically.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ExecutionSearchFilters {
|
||||
pub status: Option<ExecutionStatus>,
|
||||
pub action_ref: Option<String>,
|
||||
pub pack_name: Option<String>,
|
||||
pub rule_ref: Option<String>,
|
||||
pub trigger_ref: Option<String>,
|
||||
pub executor: Option<Id>,
|
||||
pub result_contains: Option<String>,
|
||||
pub enforcement: Option<Id>,
|
||||
pub parent: Option<Id>,
|
||||
pub top_level_only: bool,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`ExecutionRepository::search`].
|
||||
///
|
||||
/// Includes the matching rows *and* the total count (before LIMIT/OFFSET)
|
||||
/// so the caller can build pagination metadata without a second round-trip.
|
||||
#[derive(Debug)]
|
||||
pub struct ExecutionSearchResult {
|
||||
pub rows: Vec<ExecutionWithRefs>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
/// An execution row with optional `rule_ref` / `trigger_ref` populated from
|
||||
/// the joined `enforcement` table. This avoids a separate in-memory lookup.
|
||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||
pub struct ExecutionWithRefs {
|
||||
// — execution columns (same order as SELECT_COLUMNS) —
|
||||
pub id: Id,
|
||||
pub action: Option<Id>,
|
||||
pub action_ref: String,
|
||||
pub config: Option<JsonDict>,
|
||||
pub env_vars: Option<JsonDict>,
|
||||
pub parent: Option<Id>,
|
||||
pub enforcement: Option<Id>,
|
||||
pub executor: Option<Id>,
|
||||
pub status: ExecutionStatus,
|
||||
pub result: Option<JsonDict>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
#[sqlx(json, default)]
|
||||
pub workflow_task: Option<WorkflowTaskMetadata>,
|
||||
pub created: DateTime<Utc>,
|
||||
pub updated: DateTime<Utc>,
|
||||
// — joined from enforcement —
|
||||
pub rule_ref: Option<String>,
|
||||
pub trigger_ref: Option<String>,
|
||||
}
|
||||
|
||||
/// Column list for SELECT queries on the execution table.
|
||||
///
|
||||
/// Defined once to avoid drift between queries and the `Execution` model.
|
||||
@@ -13,7 +73,7 @@ use super::{Create, Delete, FindById, List, Repository, Update};
|
||||
/// are NOT in the Rust struct, so `SELECT *` must never be used.
|
||||
pub const SELECT_COLUMNS: &str = "\
|
||||
id, action, action_ref, config, env_vars, parent, enforcement, \
|
||||
executor, status, result, workflow_task, created, updated";
|
||||
executor, status, result, started_at, workflow_task, created, updated";
|
||||
|
||||
pub struct ExecutionRepository;
|
||||
|
||||
@@ -43,6 +103,7 @@ pub struct UpdateExecutionInput {
|
||||
pub status: Option<ExecutionStatus>,
|
||||
pub result: Option<JsonDict>,
|
||||
pub executor: Option<Id>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub workflow_task: Option<WorkflowTaskMetadata>,
|
||||
}
|
||||
|
||||
@@ -52,6 +113,7 @@ impl From<Execution> for UpdateExecutionInput {
|
||||
status: Some(execution.status),
|
||||
result: execution.result,
|
||||
executor: execution.executor,
|
||||
started_at: execution.started_at,
|
||||
workflow_task: execution.workflow_task,
|
||||
}
|
||||
}
|
||||
@@ -146,6 +208,13 @@ impl Update for ExecutionRepository {
|
||||
query.push("executor = ").push_bind(executor_id);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(started_at) = input.started_at {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("started_at = ").push_bind(started_at);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(workflow_task) = &input.workflow_task {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
@@ -239,4 +308,141 @@ impl ExecutionRepository {
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Search executions with all filters pushed into SQL.
|
||||
///
|
||||
/// Builds a dynamic query with only the WHERE clauses that apply,
|
||||
/// a LEFT JOIN on `enforcement` when `rule_ref` or `trigger_ref` filters
|
||||
/// are present (or always, to populate those columns on the result),
|
||||
/// and proper LIMIT/OFFSET so pagination is server-side.
|
||||
///
|
||||
/// Returns both the matching page of rows and the total count.
|
||||
pub async fn search<'e, E>(
|
||||
db: E,
|
||||
filters: &ExecutionSearchFilters,
|
||||
) -> Result<ExecutionSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
// We always LEFT JOIN enforcement so we can return rule_ref/trigger_ref
|
||||
// on every row without a second round-trip.
|
||||
let prefixed_select = SELECT_COLUMNS
|
||||
.split(", ")
|
||||
.map(|col| format!("e.{col}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
let select_clause = format!(
|
||||
"{prefixed_select}, enf.rule_ref AS rule_ref, enf.trigger_ref AS trigger_ref"
|
||||
);
|
||||
|
||||
let from_clause = "FROM execution e LEFT JOIN enforcement enf ON e.enforcement = enf.id";
|
||||
|
||||
// ── Build WHERE clauses ──────────────────────────────────────────
|
||||
let mut conditions: Vec<String> = Vec::new();
|
||||
|
||||
// We'll collect bind values to push into the QueryBuilder afterwards.
|
||||
// Because QueryBuilder doesn't let us interleave raw SQL and binds in
|
||||
// arbitrary order easily, we build the SQL string with numbered $N
|
||||
// placeholders and then bind in order.
|
||||
|
||||
// Track the next placeholder index ($1, $2, …).
|
||||
// We can't use QueryBuilder's push_bind because we need the COUNT(*)
|
||||
// query to share the same WHERE clause text. Instead we build the
|
||||
// clause once and execute both queries with manual binds.
|
||||
|
||||
// ── Use QueryBuilder for the data query ──────────────────────────
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_clause} {from_clause}"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT COUNT(*) AS total {from_clause}"));
|
||||
|
||||
// Helper: append the same condition to both builders.
|
||||
// We need a tiny state machine since push_bind moves the value.
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
let needs_where = conditions.is_empty();
|
||||
conditions.push(String::new()); // just to track count
|
||||
if needs_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! push_raw_condition {
|
||||
($cond:expr) => {{
|
||||
let needs_where = conditions.is_empty();
|
||||
conditions.push(String::new());
|
||||
if needs_where {
|
||||
qb.push(concat!(" WHERE ", $cond));
|
||||
count_qb.push(concat!(" WHERE ", $cond));
|
||||
} else {
|
||||
qb.push(concat!(" AND ", $cond));
|
||||
count_qb.push(concat!(" AND ", $cond));
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(status) = &filters.status {
|
||||
push_condition!("e.status = ", status.clone());
|
||||
}
|
||||
if let Some(action_ref) = &filters.action_ref {
|
||||
push_condition!("e.action_ref = ", action_ref.clone());
|
||||
}
|
||||
if let Some(pack_name) = &filters.pack_name {
|
||||
let pattern = format!("{pack_name}.%");
|
||||
push_condition!("e.action_ref LIKE ", pattern);
|
||||
}
|
||||
if let Some(enforcement_id) = filters.enforcement {
|
||||
push_condition!("e.enforcement = ", enforcement_id);
|
||||
}
|
||||
if let Some(parent_id) = filters.parent {
|
||||
push_condition!("e.parent = ", parent_id);
|
||||
}
|
||||
if filters.top_level_only {
|
||||
push_raw_condition!("e.parent IS NULL");
|
||||
}
|
||||
if let Some(executor_id) = filters.executor {
|
||||
push_condition!("e.executor = ", executor_id);
|
||||
}
|
||||
if let Some(rule_ref) = &filters.rule_ref {
|
||||
push_condition!("enf.rule_ref = ", rule_ref.clone());
|
||||
}
|
||||
if let Some(trigger_ref) = &filters.trigger_ref {
|
||||
push_condition!("enf.trigger_ref = ", trigger_ref.clone());
|
||||
}
|
||||
if let Some(search) = &filters.result_contains {
|
||||
let pattern = format!("%{}%", search.to_lowercase());
|
||||
push_condition!("LOWER(e.result::text) LIKE ", pattern);
|
||||
}
|
||||
|
||||
// ── COUNT query ──────────────────────────────────────────────────
|
||||
let total: i64 = count_qb
|
||||
.build_query_scalar()
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// ── Data query with ORDER BY + pagination ────────────────────────
|
||||
qb.push(" ORDER BY e.created DESC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<ExecutionWithRefs> = qb
|
||||
.build_query_as()
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
Ok(ExecutionSearchResult { rows, total })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,25 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, List, Repository, Update};
|
||||
|
||||
/// Filters for [`InquiryRepository::search`].
|
||||
///
|
||||
/// All fields are optional and combinable (AND). Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct InquirySearchFilters {
|
||||
pub status: Option<InquiryStatus>,
|
||||
pub execution: Option<Id>,
|
||||
pub assigned_to: Option<Id>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`InquiryRepository::search`].
|
||||
#[derive(Debug)]
|
||||
pub struct InquirySearchResult {
|
||||
pub rows: Vec<Inquiry>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
pub struct InquiryRepository;
|
||||
|
||||
impl Repository for InquiryRepository {
|
||||
@@ -157,4 +176,66 @@ impl InquiryRepository {
|
||||
"SELECT id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated FROM inquiry WHERE execution = $1 ORDER BY created DESC"
|
||||
).bind(execution_id).fetch_all(executor).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Search inquiries with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
pub async fn search<'e, E>(db: E, filters: &InquirySearchFilters) -> Result<InquirySearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM inquiry"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM inquiry");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(status) = &filters.status {
|
||||
push_condition!("status = ", status.clone());
|
||||
}
|
||||
if let Some(execution_id) = filters.execution {
|
||||
push_condition!("execution = ", execution_id);
|
||||
}
|
||||
if let Some(assigned_to) = filters.assigned_to {
|
||||
push_condition!("assigned_to = ", assigned_to);
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY created DESC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Inquiry> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(InquirySearchResult { rows, total })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,24 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, List, Repository, Update};
|
||||
|
||||
/// Filters for [`KeyRepository::search`].
|
||||
///
|
||||
/// All fields are optional and combinable (AND). Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct KeySearchFilters {
|
||||
pub owner_type: Option<OwnerType>,
|
||||
pub owner: Option<String>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`KeyRepository::search`].
|
||||
#[derive(Debug)]
|
||||
pub struct KeySearchResult {
|
||||
pub rows: Vec<Key>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
pub struct KeyRepository;
|
||||
|
||||
impl Repository for KeyRepository {
|
||||
@@ -165,4 +183,63 @@ impl KeyRepository {
|
||||
"SELECT id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated FROM key WHERE owner_type = $1 ORDER BY ref ASC"
|
||||
).bind(owner_type).fetch_all(executor).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Search keys with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
pub async fn search<'e, E>(db: E, filters: &KeySearchFilters) -> Result<KeySearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM key"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM key");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(ref owner_type) = filters.owner_type {
|
||||
push_condition!("owner_type = ", owner_type.clone());
|
||||
}
|
||||
if let Some(ref owner) = filters.owner {
|
||||
push_condition!("owner = ", owner.clone());
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY ref ASC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Key> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(KeySearchResult { rows, total })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,30 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
|
||||
|
||||
/// Filters for [`RuleRepository::list_search`].
|
||||
///
|
||||
/// All fields are optional and combinable (AND). Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RuleSearchFilters {
|
||||
/// Filter by pack ID
|
||||
pub pack: Option<Id>,
|
||||
/// Filter by action ID
|
||||
pub action: Option<Id>,
|
||||
/// Filter by trigger ID
|
||||
pub trigger: Option<Id>,
|
||||
/// Filter by enabled status
|
||||
pub enabled: Option<bool>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`RuleRepository::list_search`].
|
||||
#[derive(Debug)]
|
||||
pub struct RuleSearchResult {
|
||||
pub rows: Vec<Rule>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
/// Input for restoring an ad-hoc rule during pack reinstallation.
|
||||
/// Unlike `CreateRuleInput`, action and trigger IDs are optional because
|
||||
/// the referenced entities may not exist yet or may have been removed.
|
||||
@@ -275,6 +299,71 @@ impl Delete for RuleRepository {
|
||||
}
|
||||
|
||||
impl RuleRepository {
|
||||
/// Search rules with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
pub async fn list_search<'e, E>(db: E, filters: &RuleSearchFilters) -> Result<RuleSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, ref, pack, pack_ref, label, description, action, action_ref, trigger, trigger_ref, conditions, action_params, trigger_params, enabled, is_adhoc, created, updated";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM rule"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM rule");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(pack_id) = filters.pack {
|
||||
push_condition!("pack = ", pack_id);
|
||||
}
|
||||
if let Some(action_id) = filters.action {
|
||||
push_condition!("action = ", action_id);
|
||||
}
|
||||
if let Some(trigger_id) = filters.trigger {
|
||||
push_condition!("trigger = ", trigger_id);
|
||||
}
|
||||
if let Some(enabled) = filters.enabled {
|
||||
push_condition!("enabled = ", enabled);
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY ref ASC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Rule> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(RuleSearchResult { rows, total })
|
||||
}
|
||||
|
||||
/// Find rules by pack ID
|
||||
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Rule>>
|
||||
where
|
||||
|
||||
@@ -9,6 +9,56 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
|
||||
|
||||
// ============================================================================
|
||||
// Trigger Search
|
||||
// ============================================================================
|
||||
|
||||
/// Filters for [`TriggerRepository::list_search`].
|
||||
///
|
||||
/// All fields are optional and combinable (AND). Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TriggerSearchFilters {
|
||||
/// Filter by pack ID
|
||||
pub pack: Option<Id>,
|
||||
/// Filter by enabled status
|
||||
pub enabled: Option<bool>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`TriggerRepository::list_search`].
|
||||
#[derive(Debug)]
|
||||
pub struct TriggerSearchResult {
|
||||
pub rows: Vec<Trigger>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Sensor Search
|
||||
// ============================================================================
|
||||
|
||||
/// Filters for [`SensorRepository::list_search`].
|
||||
///
|
||||
/// All fields are optional and combinable (AND). Pagination is always applied.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SensorSearchFilters {
|
||||
/// Filter by pack ID
|
||||
pub pack: Option<Id>,
|
||||
/// Filter by trigger ID
|
||||
pub trigger: Option<Id>,
|
||||
/// Filter by enabled status
|
||||
pub enabled: Option<bool>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`SensorRepository::list_search`].
|
||||
#[derive(Debug)]
|
||||
pub struct SensorSearchResult {
|
||||
pub rows: Vec<Sensor>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
/// Repository for Trigger operations
|
||||
pub struct TriggerRepository;
|
||||
|
||||
@@ -251,6 +301,68 @@ impl Delete for TriggerRepository {
|
||||
}
|
||||
|
||||
impl TriggerRepository {
|
||||
/// Search triggers with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
pub async fn list_search<'e, E>(
|
||||
db: E,
|
||||
filters: &TriggerSearchFilters,
|
||||
) -> Result<TriggerSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, ref, pack, pack_ref, label, description, enabled, param_schema, out_schema, webhook_enabled, webhook_key, webhook_config, is_adhoc, created, updated";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM trigger"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM trigger");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(pack_id) = filters.pack {
|
||||
push_condition!("pack = ", pack_id);
|
||||
}
|
||||
if let Some(enabled) = filters.enabled {
|
||||
push_condition!("enabled = ", enabled);
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY ref ASC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Trigger> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(TriggerSearchResult { rows, total })
|
||||
}
|
||||
|
||||
/// Find triggers by pack ID
|
||||
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Trigger>>
|
||||
where
|
||||
@@ -795,6 +907,71 @@ impl Delete for SensorRepository {
|
||||
}
|
||||
|
||||
impl SensorRepository {
|
||||
/// Search sensors with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
pub async fn list_search<'e, E>(
|
||||
db: E,
|
||||
filters: &SensorSearchFilters,
|
||||
) -> Result<SensorSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_ref, runtime_version_constraint, trigger, trigger_ref, enabled, param_schema, config, created, updated";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM sensor"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM sensor");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(pack_id) = filters.pack {
|
||||
push_condition!("pack = ", pack_id);
|
||||
}
|
||||
if let Some(trigger_id) = filters.trigger {
|
||||
push_condition!("trigger = ", trigger_id);
|
||||
}
|
||||
if let Some(enabled) = filters.enabled {
|
||||
push_condition!("enabled = ", enabled);
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY ref ASC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<Sensor> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(SensorSearchResult { rows, total })
|
||||
}
|
||||
|
||||
/// Find sensors by trigger ID
|
||||
pub async fn find_by_trigger<'e, E>(executor: E, trigger_id: Id) -> Result<Vec<Sensor>>
|
||||
where
|
||||
|
||||
@@ -6,6 +6,37 @@ use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
|
||||
|
||||
// ============================================================================
|
||||
// Workflow Definition Search
|
||||
// ============================================================================
|
||||
|
||||
/// Filters for [`WorkflowDefinitionRepository::list_search`].
|
||||
///
|
||||
/// All fields are optional and combinable (AND). Pagination is always applied.
|
||||
/// Tag filtering uses `ANY(tags)` for each tag (OR across tags, AND with other filters).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct WorkflowSearchFilters {
|
||||
/// Filter by pack ID
|
||||
pub pack: Option<Id>,
|
||||
/// Filter by pack reference
|
||||
pub pack_ref: Option<String>,
|
||||
/// Filter by enabled status
|
||||
pub enabled: Option<bool>,
|
||||
/// Filter by tags (OR across tags — matches if any tag is present)
|
||||
pub tags: Option<Vec<String>>,
|
||||
/// Text search across label and description (case-insensitive substring)
|
||||
pub search: Option<String>,
|
||||
pub limit: u32,
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Result of [`WorkflowDefinitionRepository::list_search`].
|
||||
#[derive(Debug)]
|
||||
pub struct WorkflowSearchResult {
|
||||
pub rows: Vec<WorkflowDefinition>,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WORKFLOW DEFINITION REPOSITORY
|
||||
// ============================================================================
|
||||
@@ -226,6 +257,102 @@ impl Delete for WorkflowDefinitionRepository {
|
||||
}
|
||||
|
||||
impl WorkflowDefinitionRepository {
|
||||
/// Search workflow definitions with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
/// Tags use an OR match — a workflow matches if it contains ANY of the
|
||||
/// requested tags (via `tags && ARRAY[...]`).
|
||||
pub async fn list_search<'e, E>(
|
||||
db: E,
|
||||
filters: &WorkflowSearchFilters,
|
||||
) -> Result<WorkflowSearchResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let select_cols = "id, ref, pack, pack_ref, label, description, version, param_schema, out_schema, definition, tags, enabled, created, updated";
|
||||
|
||||
let mut qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new(format!("SELECT {select_cols} FROM workflow_definition"));
|
||||
let mut count_qb: QueryBuilder<'_, Postgres> =
|
||||
QueryBuilder::new("SELECT COUNT(*) FROM workflow_definition");
|
||||
|
||||
let mut has_where = false;
|
||||
|
||||
macro_rules! push_condition {
|
||||
($cond_prefix:expr, $value:expr) => {{
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push($cond_prefix);
|
||||
qb.push_bind($value.clone());
|
||||
count_qb.push($cond_prefix);
|
||||
count_qb.push_bind($value);
|
||||
}};
|
||||
}
|
||||
|
||||
if let Some(pack_id) = filters.pack {
|
||||
push_condition!("pack = ", pack_id);
|
||||
}
|
||||
if let Some(ref pack_ref) = filters.pack_ref {
|
||||
push_condition!("pack_ref = ", pack_ref.clone());
|
||||
}
|
||||
if let Some(enabled) = filters.enabled {
|
||||
push_condition!("enabled = ", enabled);
|
||||
}
|
||||
if let Some(ref tags) = filters.tags {
|
||||
if !tags.is_empty() {
|
||||
// Use PostgreSQL array overlap operator: tags && ARRAY[...]
|
||||
push_condition!("tags && ", tags.clone());
|
||||
}
|
||||
}
|
||||
if let Some(ref search) = filters.search {
|
||||
let pattern = format!("%{}%", search.to_lowercase());
|
||||
// Search needs an OR across multiple columns, wrapped in parens
|
||||
if !has_where {
|
||||
qb.push(" WHERE ");
|
||||
count_qb.push(" WHERE ");
|
||||
has_where = true;
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
qb.push("(LOWER(label) LIKE ");
|
||||
qb.push_bind(pattern.clone());
|
||||
qb.push(" OR LOWER(COALESCE(description, '')) LIKE ");
|
||||
qb.push_bind(pattern.clone());
|
||||
qb.push(")");
|
||||
|
||||
count_qb.push("(LOWER(label) LIKE ");
|
||||
count_qb.push_bind(pattern.clone());
|
||||
count_qb.push(" OR LOWER(COALESCE(description, '')) LIKE ");
|
||||
count_qb.push_bind(pattern);
|
||||
count_qb.push(")");
|
||||
}
|
||||
|
||||
// Suppress unused-assignment warning from the macro's last expansion.
|
||||
let _ = has_where;
|
||||
|
||||
// Count
|
||||
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
|
||||
let total = total.max(0) as u64;
|
||||
|
||||
// Data query
|
||||
qb.push(" ORDER BY label ASC");
|
||||
qb.push(" LIMIT ");
|
||||
qb.push_bind(filters.limit as i64);
|
||||
qb.push(" OFFSET ");
|
||||
qb.push_bind(filters.offset as i64);
|
||||
|
||||
let rows: Vec<WorkflowDefinition> = qb.build_query_as().fetch_all(db).await?;
|
||||
|
||||
Ok(WorkflowSearchResult { rows, total })
|
||||
}
|
||||
|
||||
/// Find all workflows for a specific pack by pack ID
|
||||
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<WorkflowDefinition>>
|
||||
where
|
||||
|
||||
112
crates/common/src/workflow/expression/ast.rs
Normal file
112
crates/common/src/workflow/expression/ast.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
//! # Expression AST
|
||||
//!
|
||||
//! Defines the abstract syntax tree nodes produced by the parser and consumed
|
||||
//! by the evaluator.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// A binary operator connecting two sub-expressions.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BinaryOp {
|
||||
// Arithmetic
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
Mod,
|
||||
// Comparison
|
||||
Eq,
|
||||
Ne,
|
||||
Lt,
|
||||
Gt,
|
||||
Le,
|
||||
Ge,
|
||||
// Logical
|
||||
And,
|
||||
Or,
|
||||
// Membership
|
||||
In,
|
||||
}
|
||||
|
||||
impl fmt::Display for BinaryOp {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
BinaryOp::Add => write!(f, "+"),
|
||||
BinaryOp::Sub => write!(f, "-"),
|
||||
BinaryOp::Mul => write!(f, "*"),
|
||||
BinaryOp::Div => write!(f, "/"),
|
||||
BinaryOp::Mod => write!(f, "%"),
|
||||
BinaryOp::Eq => write!(f, "=="),
|
||||
BinaryOp::Ne => write!(f, "!="),
|
||||
BinaryOp::Lt => write!(f, "<"),
|
||||
BinaryOp::Gt => write!(f, ">"),
|
||||
BinaryOp::Le => write!(f, "<="),
|
||||
BinaryOp::Ge => write!(f, ">="),
|
||||
BinaryOp::And => write!(f, "and"),
|
||||
BinaryOp::Or => write!(f, "or"),
|
||||
BinaryOp::In => write!(f, "in"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A unary operator applied to a single sub-expression.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum UnaryOp {
|
||||
/// Arithmetic negation: `-x`
|
||||
Neg,
|
||||
/// Logical negation: `not x`
|
||||
Not,
|
||||
}
|
||||
|
||||
impl fmt::Display for UnaryOp {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
UnaryOp::Neg => write!(f, "-"),
|
||||
UnaryOp::Not => write!(f, "not"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An expression AST node.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Expr {
|
||||
/// A literal JSON value: number, string, bool, or null.
|
||||
Literal(serde_json::Value),
|
||||
|
||||
/// An array literal: `[expr, expr, ...]`
|
||||
Array(Vec<Expr>),
|
||||
|
||||
/// A variable reference by name (e.g., `x`, `parameters`, `item`).
|
||||
Ident(String),
|
||||
|
||||
/// Binary operation: `left op right`
|
||||
BinaryOp {
|
||||
op: BinaryOp,
|
||||
left: Box<Expr>,
|
||||
right: Box<Expr>,
|
||||
},
|
||||
|
||||
/// Unary operation: `op operand`
|
||||
UnaryOp {
|
||||
op: UnaryOp,
|
||||
operand: Box<Expr>,
|
||||
},
|
||||
|
||||
/// Property access: `expr.field`
|
||||
DotAccess {
|
||||
object: Box<Expr>,
|
||||
field: String,
|
||||
},
|
||||
|
||||
/// Index/bracket access: `expr[index_expr]`
|
||||
IndexAccess {
|
||||
object: Box<Expr>,
|
||||
index: Box<Expr>,
|
||||
},
|
||||
|
||||
/// Function call: `name(arg1, arg2, ...)`
|
||||
FunctionCall {
|
||||
name: String,
|
||||
args: Vec<Expr>,
|
||||
},
|
||||
}
|
||||
1316
crates/common/src/workflow/expression/evaluator.rs
Normal file
1316
crates/common/src/workflow/expression/evaluator.rs
Normal file
File diff suppressed because it is too large
Load Diff
545
crates/common/src/workflow/expression/mod.rs
Normal file
545
crates/common/src/workflow/expression/mod.rs
Normal file
@@ -0,0 +1,545 @@
|
||||
//! # Workflow Expression Engine
|
||||
//!
|
||||
//! A complete expression evaluator for workflow templates, supporting arithmetic,
|
||||
//! comparison, boolean logic, member access, and built-in functions over JSON values.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The engine is structured as a classic three-phase interpreter:
|
||||
//!
|
||||
//! 1. **Lexer** (`tokenizer.rs`) — converts expression strings into a stream of tokens
|
||||
//! 2. **Parser** (`parser.rs`) — builds an AST from tokens using recursive descent
|
||||
//! 3. **Evaluator** (`evaluator.rs`) — walks the AST and produces a `JsonValue` result
|
||||
//!
|
||||
//! ## Supported Operators
|
||||
//!
|
||||
//! ### Arithmetic
|
||||
//! - `+` (addition for numbers, concatenation for strings)
|
||||
//! - `-` (subtraction, unary negation)
|
||||
//! - `*`, `/`, `%` (multiplication, division, modulo)
|
||||
//!
|
||||
//! ### Comparison
|
||||
//! - `==`, `!=` (equality — works on all types, recursive for objects/arrays)
|
||||
//! - `>`, `<`, `>=`, `<=` (ordering — numbers and strings only)
|
||||
//! - Float/int comparisons allowed: `3 == 3.0` → true
|
||||
//!
|
||||
//! ### Boolean / Logical
|
||||
//! - `and`, `or`, `not`
|
||||
//!
|
||||
//! ### Membership & Access
|
||||
//! - `.` — object property access
|
||||
//! - `[n]` — array index / object bracket access
|
||||
//! - `in` — membership test (item in list, key in object, substring in string)
|
||||
//!
|
||||
//! ## Built-in Functions
|
||||
//!
|
||||
//! ### Type conversion
|
||||
//! - `string(v)`, `number(v)`, `int(v)`, `bool(v)`
|
||||
//!
|
||||
//! ### Introspection
|
||||
//! - `type_of(v)`, `length(v)`, `keys(obj)`, `values(obj)`
|
||||
//!
|
||||
//! ### Math
|
||||
//! - `abs(n)`, `floor(n)`, `ceil(n)`, `round(n)`, `min(a,b)`, `max(a,b)`, `sum(arr)`
|
||||
//!
|
||||
//! ### String
|
||||
//! - `lower(s)`, `upper(s)`, `trim(s)`, `split(s, sep)`, `join(arr, sep)`
|
||||
//! - `replace(s, old, new)`, `starts_with(s, prefix)`, `ends_with(s, suffix)`
|
||||
//! - `match(pattern, s)` — regex match
|
||||
//!
|
||||
//! ### Collection
|
||||
//! - `contains(haystack, needle)`, `reversed(v)`, `sort(arr)`, `unique(arr)`
|
||||
//! - `flat(arr)`, `zip(a, b)`, `range(n)` / `range(start, end)`
|
||||
//!
|
||||
//! ### Workflow-specific
|
||||
//! - `result()`, `succeeded()`, `failed()`, `timed_out()`
|
||||
|
||||
mod ast;
|
||||
mod evaluator;
|
||||
mod parser;
|
||||
mod tokenizer;
|
||||
|
||||
pub use ast::{BinaryOp, Expr, UnaryOp};
|
||||
pub use evaluator::{is_truthy, EvalContext, EvalError, EvalResult};
|
||||
pub use parser::{ParseError, Parser};
|
||||
pub use tokenizer::{Token, TokenKind, Tokenizer};
|
||||
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
/// Parse and evaluate an expression string against the given context.
|
||||
///
|
||||
/// This is the main entry point for the expression engine. It tokenizes the
|
||||
/// input, parses it into an AST, and evaluates it to produce a `JsonValue`.
|
||||
pub fn eval_expression(input: &str, ctx: &dyn EvalContext) -> EvalResult<JsonValue> {
|
||||
let tokens = Tokenizer::new(input).tokenize().map_err(|e| {
|
||||
EvalError::ParseError(format!("{}", e))
|
||||
})?;
|
||||
let ast = Parser::new(&tokens).parse().map_err(|e| {
|
||||
EvalError::ParseError(format!("{}", e))
|
||||
})?;
|
||||
evaluator::eval(&ast, ctx)
|
||||
}
|
||||
|
||||
/// Parse an expression string into an AST without evaluating it.
|
||||
///
|
||||
/// Useful for validation or inspection.
|
||||
pub fn parse_expression(input: &str) -> Result<Expr, ParseError> {
|
||||
let tokens = Tokenizer::new(input).tokenize().map_err(|e| {
|
||||
ParseError::TokenError(format!("{}", e))
|
||||
})?;
|
||||
Parser::new(&tokens).parse()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A minimal eval context for integration tests.
|
||||
struct TestContext {
|
||||
variables: HashMap<String, JsonValue>,
|
||||
}
|
||||
|
||||
impl TestContext {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
variables: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_var(mut self, name: &str, value: JsonValue) -> Self {
|
||||
self.variables.insert(name.to_string(), value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl EvalContext for TestContext {
|
||||
fn resolve_variable(&self, name: &str) -> EvalResult<JsonValue> {
|
||||
self.variables
|
||||
.get(name)
|
||||
.cloned()
|
||||
.ok_or_else(|| EvalError::VariableNotFound(name.to_string()))
|
||||
}
|
||||
|
||||
fn call_workflow_function(
|
||||
&self,
|
||||
_name: &str,
|
||||
_args: &[JsonValue],
|
||||
) -> EvalResult<Option<JsonValue>> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Arithmetic
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_integer_arithmetic() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("2 + 3", &ctx).unwrap(), json!(5));
|
||||
assert_eq!(eval_expression("10 - 4", &ctx).unwrap(), json!(6));
|
||||
assert_eq!(eval_expression("3 * 7", &ctx).unwrap(), json!(21));
|
||||
assert_eq!(eval_expression("15 / 5", &ctx).unwrap(), json!(3));
|
||||
assert_eq!(eval_expression("17 % 5", &ctx).unwrap(), json!(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_float_arithmetic() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("2.5 + 1.5", &ctx).unwrap(), json!(4.0));
|
||||
assert_eq!(eval_expression("10.0 / 3.0", &ctx).unwrap(), json!(10.0 / 3.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_int_float() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("2 + 1.5", &ctx).unwrap(), json!(3.5));
|
||||
// Integer division yields float when not evenly divisible
|
||||
assert_eq!(eval_expression("10 / 4", &ctx).unwrap(), json!(2.5));
|
||||
assert_eq!(eval_expression("10 / 4.0", &ctx).unwrap(), json!(2.5));
|
||||
// Evenly divisible integer division stays integer
|
||||
assert_eq!(eval_expression("10 / 5", &ctx).unwrap(), json!(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operator_precedence() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("2 + 3 * 4", &ctx).unwrap(), json!(14));
|
||||
assert_eq!(eval_expression("(2 + 3) * 4", &ctx).unwrap(), json!(20));
|
||||
assert_eq!(eval_expression("10 - 2 * 3 + 1", &ctx).unwrap(), json!(5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_negation() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("-5", &ctx).unwrap(), json!(-5));
|
||||
assert_eq!(eval_expression("-2 + 3", &ctx).unwrap(), json!(1));
|
||||
assert_eq!(eval_expression("-(2 + 3)", &ctx).unwrap(), json!(-5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_concatenation() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(
|
||||
eval_expression("\"hello\" + \" \" + \"world\"", &ctx).unwrap(),
|
||||
json!("hello world")
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Comparison
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_number_comparison() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("3 == 3", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("3 != 4", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("3 > 2", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("3 < 2", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("3 >= 3", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("3 <= 4", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int_float_equality() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("3 == 3.0", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("3.0 == 3", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("3 != 3.1", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_comparison() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("\"abc\" == \"abc\"", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("\"abc\" < \"abd\"", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("\"abc\" > \"abb\"", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_equality() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("null == null", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("null != null", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("null == 0", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("null == false", &ctx).unwrap(), json!(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_array_equality() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("a", json!([1, 2, 3]))
|
||||
.with_var("b", json!([1, 2, 3]))
|
||||
.with_var("c", json!([1, 2, 4]));
|
||||
assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_object_equality() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("a", json!({"x": 1, "y": 2}))
|
||||
.with_var("b", json!({"y": 2, "x": 1}))
|
||||
.with_var("c", json!({"x": 1, "y": 3}));
|
||||
assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Boolean / Logical
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_boolean_operators() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("true and true", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("true and false", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("false or true", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("false or false", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("not true", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("not false", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boolean_precedence() {
|
||||
let ctx = TestContext::new();
|
||||
// `and` binds tighter than `or`
|
||||
assert_eq!(
|
||||
eval_expression("true or false and false", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("(true or false) and false", &ctx).unwrap(),
|
||||
json!(false)
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Membership & access
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_dot_access() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("obj", json!({"a": {"b": 42}}));
|
||||
assert_eq!(eval_expression("obj.a.b", &ctx).unwrap(), json!(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bracket_access() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("arr", json!([10, 20, 30]))
|
||||
.with_var("obj", json!({"key": "value"}));
|
||||
assert_eq!(eval_expression("arr[1]", &ctx).unwrap(), json!(20));
|
||||
assert_eq!(eval_expression("obj[\"key\"]", &ctx).unwrap(), json!("value"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_in_operator() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("arr", json!([1, 2, 3]))
|
||||
.with_var("obj", json!({"key": "val"}));
|
||||
assert_eq!(eval_expression("2 in arr", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("5 in arr", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("\"key\" in obj", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("\"nope\" in obj", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("\"ell\" in \"hello\"", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Built-in functions
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_length() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("arr", json!([1, 2, 3]))
|
||||
.with_var("obj", json!({"a": 1, "b": 2}));
|
||||
assert_eq!(eval_expression("length(arr)", &ctx).unwrap(), json!(3));
|
||||
assert_eq!(eval_expression("length(\"hello\")", &ctx).unwrap(), json!(5));
|
||||
assert_eq!(eval_expression("length(obj)", &ctx).unwrap(), json!(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_conversions() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("string(42)", &ctx).unwrap(), json!("42"));
|
||||
assert_eq!(eval_expression("number(\"3.14\")", &ctx).unwrap(), json!(3.14));
|
||||
assert_eq!(eval_expression("int(3.9)", &ctx).unwrap(), json!(3));
|
||||
assert_eq!(eval_expression("int(\"42\")", &ctx).unwrap(), json!(42));
|
||||
assert_eq!(eval_expression("bool(1)", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("bool(0)", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("bool(\"\")", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("bool(\"x\")", &ctx).unwrap(), json!(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_of() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("arr", json!([1]))
|
||||
.with_var("obj", json!({}));
|
||||
assert_eq!(eval_expression("type_of(42)", &ctx).unwrap(), json!("number"));
|
||||
assert_eq!(eval_expression("type_of(\"hi\")", &ctx).unwrap(), json!("string"));
|
||||
assert_eq!(eval_expression("type_of(true)", &ctx).unwrap(), json!("bool"));
|
||||
assert_eq!(eval_expression("type_of(null)", &ctx).unwrap(), json!("null"));
|
||||
assert_eq!(eval_expression("type_of(arr)", &ctx).unwrap(), json!("array"));
|
||||
assert_eq!(eval_expression("type_of(obj)", &ctx).unwrap(), json!("object"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keys_values() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("obj", json!({"b": 2, "a": 1}));
|
||||
let keys = eval_expression("sort(keys(obj))", &ctx).unwrap();
|
||||
assert_eq!(keys, json!(["a", "b"]));
|
||||
let values = eval_expression("sort(values(obj))", &ctx).unwrap();
|
||||
assert_eq!(values, json!([1, 2]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_math_functions() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("abs(-5)", &ctx).unwrap(), json!(5));
|
||||
assert_eq!(eval_expression("floor(3.7)", &ctx).unwrap(), json!(3));
|
||||
assert_eq!(eval_expression("ceil(3.2)", &ctx).unwrap(), json!(4));
|
||||
assert_eq!(eval_expression("round(3.5)", &ctx).unwrap(), json!(4));
|
||||
assert_eq!(eval_expression("min(3, 7)", &ctx).unwrap(), json!(3));
|
||||
assert_eq!(eval_expression("max(3, 7)", &ctx).unwrap(), json!(7));
|
||||
assert_eq!(eval_expression("sum([1, 2, 3, 4])", &ctx).unwrap(), json!(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_functions() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("lower(\"HELLO\")", &ctx).unwrap(), json!("hello"));
|
||||
assert_eq!(eval_expression("upper(\"hello\")", &ctx).unwrap(), json!("HELLO"));
|
||||
assert_eq!(eval_expression("trim(\" hi \")", &ctx).unwrap(), json!("hi"));
|
||||
assert_eq!(
|
||||
eval_expression("replace(\"hello world\", \"world\", \"rust\")", &ctx).unwrap(),
|
||||
json!("hello rust")
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("starts_with(\"hello\", \"hel\")", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("ends_with(\"hello\", \"llo\")", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("split(\"a,b,c\", \",\")", &ctx).unwrap(),
|
||||
json!(["a", "b", "c"])
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("join([\"a\", \"b\", \"c\"], \",\")", &ctx).unwrap(),
|
||||
json!("a,b,c")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regex_match() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(
|
||||
eval_expression("match(\"^hello\", \"hello world\")", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("match(\"^world\", \"hello world\")", &ctx).unwrap(),
|
||||
json!(false)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collection_functions() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("arr", json!([3, 1, 2]));
|
||||
assert_eq!(eval_expression("sort(arr)", &ctx).unwrap(), json!([1, 2, 3]));
|
||||
assert_eq!(eval_expression("reversed(arr)", &ctx).unwrap(), json!([2, 1, 3]));
|
||||
assert_eq!(
|
||||
eval_expression("unique([1, 2, 2, 3, 1])", &ctx).unwrap(),
|
||||
json!([1, 2, 3])
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("flat([[1, 2], [3, 4]])", &ctx).unwrap(),
|
||||
json!([1, 2, 3, 4])
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("zip([1, 2], [\"a\", \"b\"])", &ctx).unwrap(),
|
||||
json!([[1, "a"], [2, "b"]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("range(5)", &ctx).unwrap(), json!([0, 1, 2, 3, 4]));
|
||||
assert_eq!(eval_expression("range(2, 5)", &ctx).unwrap(), json!([2, 3, 4]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reversed_string() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("reversed(\"abc\")", &ctx).unwrap(), json!("cba"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contains_function() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(
|
||||
eval_expression("contains([1, 2, 3], 2)", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("contains(\"hello\", \"ell\")", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Complex expressions
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_complex_expression() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("items", json!([1, 2, 3, 4, 5]));
|
||||
assert_eq!(
|
||||
eval_expression("length(items) > 3 and 5 in items", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chained_access() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("data", json!({"users": [{"name": "Alice"}, {"name": "Bob"}]}));
|
||||
assert_eq!(
|
||||
eval_expression("data.users[1].name", &ctx).unwrap(),
|
||||
json!("Bob")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ternary_via_boolean() {
|
||||
let ctx = TestContext::new()
|
||||
.with_var("x", json!(10));
|
||||
// No ternary operator, but boolean expressions work for conditions
|
||||
assert_eq!(
|
||||
eval_expression("x > 5 and x < 20", &ctx).unwrap(),
|
||||
json!(true)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_implicit_type_coercion() {
|
||||
let ctx = TestContext::new();
|
||||
// String + number should error, not silently coerce
|
||||
assert!(eval_expression("\"hello\" + 5", &ctx).is_err());
|
||||
// Comparing different types (other than int/float) should return false for ==
|
||||
assert_eq!(eval_expression("\"3\" == 3", &ctx).unwrap(), json!(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_division_by_zero() {
|
||||
let ctx = TestContext::new();
|
||||
assert!(eval_expression("5 / 0", &ctx).is_err());
|
||||
assert!(eval_expression("5 % 0", &ctx).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_array_literal() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(
|
||||
eval_expression("[1, 2, 3]", &ctx).unwrap(),
|
||||
json!([1, 2, 3])
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("[\"a\", \"b\"]", &ctx).unwrap(),
|
||||
json!(["a", "b"])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_function_calls() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(
|
||||
eval_expression("length(split(\"a,b,c\", \",\"))", &ctx).unwrap(),
|
||||
json!(3)
|
||||
);
|
||||
assert_eq!(
|
||||
eval_expression("join(sort([\"c\", \"a\", \"b\"]), \"-\")", &ctx).unwrap(),
|
||||
json!("a-b-c")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boolean_literals() {
|
||||
let ctx = TestContext::new();
|
||||
assert_eq!(eval_expression("true", &ctx).unwrap(), json!(true));
|
||||
assert_eq!(eval_expression("false", &ctx).unwrap(), json!(false));
|
||||
assert_eq!(eval_expression("null", &ctx).unwrap(), json!(null));
|
||||
}
|
||||
}
|
||||
520
crates/common/src/workflow/expression/parser.rs
Normal file
520
crates/common/src/workflow/expression/parser.rs
Normal file
@@ -0,0 +1,520 @@
|
||||
//! # Expression Parser
|
||||
//!
|
||||
//! Recursive-descent parser that transforms a token stream into an AST.
|
||||
//!
|
||||
//! ## Operator Precedence (lowest to highest)
|
||||
//!
|
||||
//! 1. `or`
|
||||
//! 2. `and`
|
||||
//! 3. `not` (unary)
|
||||
//! 4. `==`, `!=`, `<`, `>`, `<=`, `>=`, `in`
|
||||
//! 5. `+`, `-` (addition / subtraction)
|
||||
//! 6. `*`, `/`, `%`
|
||||
//! 7. Unary `-`
|
||||
//! 8. Postfix: `.field`, `[index]`, `(args)`
|
||||
|
||||
use super::ast::{BinaryOp, Expr, UnaryOp};
|
||||
use super::tokenizer::{Token, TokenKind};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ParseError {
|
||||
#[error("Unexpected token {0} at position {1}")]
|
||||
UnexpectedToken(String, usize),
|
||||
|
||||
#[error("Expected {0}, found {1} at position {2}")]
|
||||
Expected(String, String, usize),
|
||||
|
||||
#[error("Unexpected end of expression")]
|
||||
UnexpectedEof,
|
||||
|
||||
#[error("Token error: {0}")]
|
||||
TokenError(String),
|
||||
}
|
||||
|
||||
/// The parser state.
|
||||
pub struct Parser<'a> {
|
||||
tokens: &'a [Token],
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
pub fn new(tokens: &'a [Token]) -> Self {
|
||||
Self { tokens, pos: 0 }
|
||||
}
|
||||
|
||||
/// Parse the token stream into a single expression AST.
|
||||
pub fn parse(&mut self) -> Result<Expr, ParseError> {
|
||||
let expr = self.parse_or()?;
|
||||
// We should be at EOF now
|
||||
if !self.at_end() {
|
||||
let tok = self.peek();
|
||||
return Err(ParseError::UnexpectedToken(
|
||||
format!("{}", tok.kind),
|
||||
tok.span.0,
|
||||
));
|
||||
}
|
||||
Ok(expr)
|
||||
}
|
||||
|
||||
// ----- Helpers -----
|
||||
|
||||
fn peek(&self) -> &Token {
|
||||
&self.tokens[self.pos.min(self.tokens.len() - 1)]
|
||||
}
|
||||
|
||||
fn at_end(&self) -> bool {
|
||||
self.peek().kind == TokenKind::Eof
|
||||
}
|
||||
|
||||
fn advance(&mut self) -> &Token {
|
||||
let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)];
|
||||
if self.pos < self.tokens.len() {
|
||||
self.pos += 1;
|
||||
}
|
||||
tok
|
||||
}
|
||||
|
||||
fn expect(&mut self, expected: &TokenKind) -> Result<&Token, ParseError> {
|
||||
let tok = self.peek();
|
||||
if std::mem::discriminant(&tok.kind) == std::mem::discriminant(expected) {
|
||||
Ok(self.advance())
|
||||
} else {
|
||||
Err(ParseError::Expected(
|
||||
format!("{}", expected),
|
||||
format!("{}", tok.kind),
|
||||
tok.span.0,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&self, kind: &TokenKind) -> bool {
|
||||
std::mem::discriminant(&self.peek().kind) == std::mem::discriminant(kind)
|
||||
}
|
||||
|
||||
// ----- Grammar rules -----
|
||||
|
||||
// or_expr = and_expr ( "or" and_expr )*
|
||||
fn parse_or(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut left = self.parse_and()?;
|
||||
while self.peek().kind == TokenKind::Or {
|
||||
self.advance();
|
||||
let right = self.parse_and()?;
|
||||
left = Expr::BinaryOp {
|
||||
op: BinaryOp::Or,
|
||||
left: Box::new(left),
|
||||
right: Box::new(right),
|
||||
};
|
||||
}
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
// and_expr = not_expr ( "and" not_expr )*
|
||||
fn parse_and(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut left = self.parse_not()?;
|
||||
while self.peek().kind == TokenKind::And {
|
||||
self.advance();
|
||||
let right = self.parse_not()?;
|
||||
left = Expr::BinaryOp {
|
||||
op: BinaryOp::And,
|
||||
left: Box::new(left),
|
||||
right: Box::new(right),
|
||||
};
|
||||
}
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
// not_expr = "not" not_expr | comparison
|
||||
fn parse_not(&mut self) -> Result<Expr, ParseError> {
|
||||
if self.peek().kind == TokenKind::Not {
|
||||
self.advance();
|
||||
let operand = self.parse_not()?;
|
||||
return Ok(Expr::UnaryOp {
|
||||
op: UnaryOp::Not,
|
||||
operand: Box::new(operand),
|
||||
});
|
||||
}
|
||||
self.parse_comparison()
|
||||
}
|
||||
|
||||
// comparison = addition ( ("==" | "!=" | "<" | ">" | "<=" | ">=" | "in") addition )*
|
||||
fn parse_comparison(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut left = self.parse_addition()?;
|
||||
|
||||
loop {
|
||||
let op = match self.peek().kind {
|
||||
TokenKind::EqEq => BinaryOp::Eq,
|
||||
TokenKind::BangEq => BinaryOp::Ne,
|
||||
TokenKind::Lt => BinaryOp::Lt,
|
||||
TokenKind::Gt => BinaryOp::Gt,
|
||||
TokenKind::LtEq => BinaryOp::Le,
|
||||
TokenKind::GtEq => BinaryOp::Ge,
|
||||
TokenKind::In => BinaryOp::In,
|
||||
_ => break,
|
||||
};
|
||||
self.advance();
|
||||
let right = self.parse_addition()?;
|
||||
left = Expr::BinaryOp {
|
||||
op,
|
||||
left: Box::new(left),
|
||||
right: Box::new(right),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
// addition = multiplication ( ("+" | "-") multiplication )*
|
||||
fn parse_addition(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut left = self.parse_multiplication()?;
|
||||
|
||||
loop {
|
||||
let op = match self.peek().kind {
|
||||
TokenKind::Plus => BinaryOp::Add,
|
||||
TokenKind::Minus => BinaryOp::Sub,
|
||||
_ => break,
|
||||
};
|
||||
self.advance();
|
||||
let right = self.parse_multiplication()?;
|
||||
left = Expr::BinaryOp {
|
||||
op,
|
||||
left: Box::new(left),
|
||||
right: Box::new(right),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
// multiplication = unary ( ("*" | "/" | "%") unary )*
|
||||
fn parse_multiplication(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut left = self.parse_unary()?;
|
||||
|
||||
loop {
|
||||
let op = match self.peek().kind {
|
||||
TokenKind::Star => BinaryOp::Mul,
|
||||
TokenKind::Slash => BinaryOp::Div,
|
||||
TokenKind::Percent => BinaryOp::Mod,
|
||||
_ => break,
|
||||
};
|
||||
self.advance();
|
||||
let right = self.parse_unary()?;
|
||||
left = Expr::BinaryOp {
|
||||
op,
|
||||
left: Box::new(left),
|
||||
right: Box::new(right),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
// unary = "-" unary | postfix
|
||||
fn parse_unary(&mut self) -> Result<Expr, ParseError> {
|
||||
if self.peek().kind == TokenKind::Minus {
|
||||
self.advance();
|
||||
let operand = self.parse_unary()?;
|
||||
return Ok(Expr::UnaryOp {
|
||||
op: UnaryOp::Neg,
|
||||
operand: Box::new(operand),
|
||||
});
|
||||
}
|
||||
self.parse_postfix()
|
||||
}
|
||||
|
||||
// postfix = primary ( "." IDENT | "[" expr "]" | "(" args ")" )*
|
||||
fn parse_postfix(&mut self) -> Result<Expr, ParseError> {
|
||||
let mut expr = self.parse_primary()?;
|
||||
|
||||
loop {
|
||||
match self.peek().kind {
|
||||
TokenKind::Dot => {
|
||||
self.advance();
|
||||
// The field after dot
|
||||
let tok = self.advance().clone();
|
||||
let field = match &tok.kind {
|
||||
TokenKind::Ident(name) => name.clone(),
|
||||
// Allow keywords as field names (e.g., obj.in, obj.and)
|
||||
TokenKind::And => "and".to_string(),
|
||||
TokenKind::Or => "or".to_string(),
|
||||
TokenKind::Not => "not".to_string(),
|
||||
TokenKind::In => "in".to_string(),
|
||||
TokenKind::True => "true".to_string(),
|
||||
TokenKind::False => "false".to_string(),
|
||||
TokenKind::Null => "null".to_string(),
|
||||
_ => {
|
||||
return Err(ParseError::Expected(
|
||||
"identifier".to_string(),
|
||||
format!("{}", tok.kind),
|
||||
tok.span.0,
|
||||
));
|
||||
}
|
||||
};
|
||||
expr = Expr::DotAccess {
|
||||
object: Box::new(expr),
|
||||
field,
|
||||
};
|
||||
}
|
||||
TokenKind::LBracket => {
|
||||
self.advance();
|
||||
let index = self.parse_or()?;
|
||||
self.expect(&TokenKind::RBracket)?;
|
||||
expr = Expr::IndexAccess {
|
||||
object: Box::new(expr),
|
||||
index: Box::new(index),
|
||||
};
|
||||
}
|
||||
TokenKind::LParen => {
|
||||
// Only if the expression so far is an identifier (function name)
|
||||
// or a dot-access chain (method-like call).
|
||||
// For now we handle Ident -> FunctionCall transformation.
|
||||
if let Expr::Ident(name) = expr {
|
||||
self.advance();
|
||||
let args = self.parse_args()?;
|
||||
self.expect(&TokenKind::RParen)?;
|
||||
expr = Expr::FunctionCall { name, args };
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(expr)
|
||||
}
|
||||
|
||||
// args = ( expr ( "," expr )* )?
|
||||
fn parse_args(&mut self) -> Result<Vec<Expr>, ParseError> {
|
||||
let mut args = Vec::new();
|
||||
if self.check(&TokenKind::RParen) {
|
||||
return Ok(args);
|
||||
}
|
||||
args.push(self.parse_or()?);
|
||||
while self.peek().kind == TokenKind::Comma {
|
||||
self.advance();
|
||||
args.push(self.parse_or()?);
|
||||
}
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
// primary = INTEGER | FLOAT | STRING | "true" | "false" | "null"
|
||||
// | IDENT | "(" expr ")" | "[" elements "]"
|
||||
fn parse_primary(&mut self) -> Result<Expr, ParseError> {
|
||||
let tok = self.peek().clone();
|
||||
match &tok.kind {
|
||||
TokenKind::Integer(n) => {
|
||||
let n = *n;
|
||||
self.advance();
|
||||
Ok(Expr::Literal(serde_json::json!(n)))
|
||||
}
|
||||
TokenKind::Float(f) => {
|
||||
let f = *f;
|
||||
self.advance();
|
||||
Ok(Expr::Literal(serde_json::json!(f)))
|
||||
}
|
||||
TokenKind::StringLit(s) => {
|
||||
let s = s.clone();
|
||||
self.advance();
|
||||
Ok(Expr::Literal(serde_json::Value::String(s)))
|
||||
}
|
||||
TokenKind::True => {
|
||||
self.advance();
|
||||
Ok(Expr::Literal(serde_json::json!(true)))
|
||||
}
|
||||
TokenKind::False => {
|
||||
self.advance();
|
||||
Ok(Expr::Literal(serde_json::json!(false)))
|
||||
}
|
||||
TokenKind::Null => {
|
||||
self.advance();
|
||||
Ok(Expr::Literal(serde_json::json!(null)))
|
||||
}
|
||||
TokenKind::Ident(name) => {
|
||||
let name = name.clone();
|
||||
self.advance();
|
||||
Ok(Expr::Ident(name))
|
||||
}
|
||||
TokenKind::LParen => {
|
||||
self.advance();
|
||||
let expr = self.parse_or()?;
|
||||
self.expect(&TokenKind::RParen)?;
|
||||
Ok(expr)
|
||||
}
|
||||
TokenKind::LBracket => {
|
||||
self.advance();
|
||||
let mut elements = Vec::new();
|
||||
if !self.check(&TokenKind::RBracket) {
|
||||
elements.push(self.parse_or()?);
|
||||
while self.peek().kind == TokenKind::Comma {
|
||||
self.advance();
|
||||
// Allow trailing comma
|
||||
if self.check(&TokenKind::RBracket) {
|
||||
break;
|
||||
}
|
||||
elements.push(self.parse_or()?);
|
||||
}
|
||||
}
|
||||
self.expect(&TokenKind::RBracket)?;
|
||||
Ok(Expr::Array(elements))
|
||||
}
|
||||
_ => Err(ParseError::UnexpectedToken(
|
||||
format!("{}", tok.kind),
|
||||
tok.span.0,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::super::tokenizer::Tokenizer;
|
||||
|
||||
fn parse(input: &str) -> Expr {
|
||||
let tokens = Tokenizer::new(input).tokenize().unwrap();
|
||||
Parser::new(&tokens).parse().unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_add() {
|
||||
let ast = parse("2 + 3");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::BinaryOp {
|
||||
op: BinaryOp::Add,
|
||||
left: Box::new(Expr::Literal(serde_json::json!(2))),
|
||||
right: Box::new(Expr::Literal(serde_json::json!(3))),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_precedence() {
|
||||
// 2 + 3 * 4 should parse as 2 + (3 * 4)
|
||||
let ast = parse("2 + 3 * 4");
|
||||
match ast {
|
||||
Expr::BinaryOp {
|
||||
op: BinaryOp::Add,
|
||||
right,
|
||||
..
|
||||
} => {
|
||||
assert!(matches!(
|
||||
*right,
|
||||
Expr::BinaryOp {
|
||||
op: BinaryOp::Mul,
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
_ => panic!("Expected Add at top level"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call() {
|
||||
let ast = parse("length(arr)");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::FunctionCall {
|
||||
name: "length".to_string(),
|
||||
args: vec![Expr::Ident("arr".to_string())],
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_access() {
|
||||
let ast = parse("obj.field.sub");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::DotAccess {
|
||||
object: Box::new(Expr::DotAccess {
|
||||
object: Box::new(Expr::Ident("obj".to_string())),
|
||||
field: "field".to_string(),
|
||||
}),
|
||||
field: "sub".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_array_literal() {
|
||||
let ast = parse("[1, 2, 3]");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::Array(vec![
|
||||
Expr::Literal(serde_json::json!(1)),
|
||||
Expr::Literal(serde_json::json!(2)),
|
||||
Expr::Literal(serde_json::json!(3)),
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bracket_access() {
|
||||
let ast = parse("arr[0]");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::IndexAccess {
|
||||
object: Box::new(Expr::Ident("arr".to_string())),
|
||||
index: Box::new(Expr::Literal(serde_json::json!(0))),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_operator() {
|
||||
let ast = parse("not true");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::UnaryOp {
|
||||
op: UnaryOp::Not,
|
||||
operand: Box::new(Expr::Literal(serde_json::json!(true))),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_in_operator() {
|
||||
let ast = parse("x in arr");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::BinaryOp {
|
||||
op: BinaryOp::In,
|
||||
left: Box::new(Expr::Ident("x".to_string())),
|
||||
right: Box::new(Expr::Ident("arr".to_string())),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_expression() {
|
||||
// Should parse without error
|
||||
let _ast = parse("length(items) > 3 and 5 in items");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chained_access() {
|
||||
// data.users[1].name
|
||||
let _ast = parse("data.users[1].name");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_function() {
|
||||
let _ast = parse("length(split(\"a,b,c\", \",\"))");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trailing_comma_in_array() {
|
||||
let ast = parse("[1, 2, 3,]");
|
||||
assert_eq!(
|
||||
ast,
|
||||
Expr::Array(vec![
|
||||
Expr::Literal(serde_json::json!(1)),
|
||||
Expr::Literal(serde_json::json!(2)),
|
||||
Expr::Literal(serde_json::json!(3)),
|
||||
])
|
||||
);
|
||||
}
|
||||
}
|
||||
512
crates/common/src/workflow/expression/tokenizer.rs
Normal file
512
crates/common/src/workflow/expression/tokenizer.rs
Normal file
@@ -0,0 +1,512 @@
|
||||
//! # Expression Tokenizer (Lexer)
|
||||
//!
|
||||
//! Converts an expression string into a sequence of tokens.
|
||||
|
||||
use std::fmt;
|
||||
use thiserror::Error;
|
||||
|
||||
/// A token produced by the lexer.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Token {
|
||||
pub kind: TokenKind,
|
||||
pub span: (usize, usize),
|
||||
}
|
||||
|
||||
impl Token {
|
||||
pub fn new(kind: TokenKind, start: usize, end: usize) -> Self {
|
||||
Self {
|
||||
kind,
|
||||
span: (start, end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The kind of a token.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum TokenKind {
|
||||
// Literals
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
StringLit(String),
|
||||
True,
|
||||
False,
|
||||
Null,
|
||||
|
||||
// Identifier
|
||||
Ident(String),
|
||||
|
||||
// Keywords (also parsed as identifiers initially, then classified)
|
||||
And,
|
||||
Or,
|
||||
Not,
|
||||
In,
|
||||
|
||||
// Operators
|
||||
Plus,
|
||||
Minus,
|
||||
Star,
|
||||
Slash,
|
||||
Percent,
|
||||
EqEq,
|
||||
BangEq,
|
||||
Lt,
|
||||
Gt,
|
||||
LtEq,
|
||||
GtEq,
|
||||
|
||||
// Delimiters
|
||||
LParen,
|
||||
RParen,
|
||||
LBracket,
|
||||
RBracket,
|
||||
Comma,
|
||||
Dot,
|
||||
|
||||
// End of input
|
||||
Eof,
|
||||
}
|
||||
|
||||
impl fmt::Display for TokenKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
TokenKind::Integer(n) => write!(f, "{}", n),
|
||||
TokenKind::Float(n) => write!(f, "{}", n),
|
||||
TokenKind::StringLit(s) => write!(f, "\"{}\"", s),
|
||||
TokenKind::True => write!(f, "true"),
|
||||
TokenKind::False => write!(f, "false"),
|
||||
TokenKind::Null => write!(f, "null"),
|
||||
TokenKind::Ident(s) => write!(f, "{}", s),
|
||||
TokenKind::And => write!(f, "and"),
|
||||
TokenKind::Or => write!(f, "or"),
|
||||
TokenKind::Not => write!(f, "not"),
|
||||
TokenKind::In => write!(f, "in"),
|
||||
TokenKind::Plus => write!(f, "+"),
|
||||
TokenKind::Minus => write!(f, "-"),
|
||||
TokenKind::Star => write!(f, "*"),
|
||||
TokenKind::Slash => write!(f, "/"),
|
||||
TokenKind::Percent => write!(f, "%"),
|
||||
TokenKind::EqEq => write!(f, "=="),
|
||||
TokenKind::BangEq => write!(f, "!="),
|
||||
TokenKind::Lt => write!(f, "<"),
|
||||
TokenKind::Gt => write!(f, ">"),
|
||||
TokenKind::LtEq => write!(f, "<="),
|
||||
TokenKind::GtEq => write!(f, ">="),
|
||||
TokenKind::LParen => write!(f, "("),
|
||||
TokenKind::RParen => write!(f, ")"),
|
||||
TokenKind::LBracket => write!(f, "["),
|
||||
TokenKind::RBracket => write!(f, "]"),
|
||||
TokenKind::Comma => write!(f, ","),
|
||||
TokenKind::Dot => write!(f, "."),
|
||||
TokenKind::Eof => write!(f, "EOF"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TokenError {
|
||||
#[error("Unexpected character '{0}' at position {1}")]
|
||||
UnexpectedChar(char, usize),
|
||||
|
||||
#[error("Unterminated string literal starting at position {0}")]
|
||||
UnterminatedString(usize),
|
||||
|
||||
#[error("Invalid number literal at position {0}: {1}")]
|
||||
InvalidNumber(usize, String),
|
||||
}
|
||||
|
||||
/// The tokenizer / lexer.
|
||||
pub struct Tokenizer {
|
||||
chars: Vec<char>,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn new(input: &str) -> Self {
|
||||
Self {
|
||||
chars: input.chars().collect(),
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokenize the entire input and return a vector of tokens.
|
||||
pub fn tokenize(&mut self) -> Result<Vec<Token>, TokenError> {
|
||||
let mut tokens = Vec::new();
|
||||
loop {
|
||||
let tok = self.next_token()?;
|
||||
if tok.kind == TokenKind::Eof {
|
||||
tokens.push(tok);
|
||||
break;
|
||||
}
|
||||
tokens.push(tok);
|
||||
}
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
fn peek(&self) -> Option<char> {
|
||||
self.chars.get(self.pos).copied()
|
||||
}
|
||||
|
||||
fn advance(&mut self) -> Option<char> {
|
||||
let ch = self.chars.get(self.pos).copied();
|
||||
if ch.is_some() {
|
||||
self.pos += 1;
|
||||
}
|
||||
ch
|
||||
}
|
||||
|
||||
fn skip_whitespace(&mut self) {
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_whitespace() {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn next_token(&mut self) -> Result<Token, TokenError> {
|
||||
self.skip_whitespace();
|
||||
|
||||
let start = self.pos;
|
||||
|
||||
let ch = match self.peek() {
|
||||
Some(ch) => ch,
|
||||
None => return Ok(Token::new(TokenKind::Eof, start, start)),
|
||||
};
|
||||
|
||||
// Single-char and multi-char operators/delimiters
|
||||
match ch {
|
||||
'+' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::Plus, start, self.pos))
|
||||
}
|
||||
'-' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::Minus, start, self.pos))
|
||||
}
|
||||
'*' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::Star, start, self.pos))
|
||||
}
|
||||
'/' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::Slash, start, self.pos))
|
||||
}
|
||||
'%' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::Percent, start, self.pos))
|
||||
}
|
||||
'(' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::LParen, start, self.pos))
|
||||
}
|
||||
')' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::RParen, start, self.pos))
|
||||
}
|
||||
'[' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::LBracket, start, self.pos))
|
||||
}
|
||||
']' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::RBracket, start, self.pos))
|
||||
}
|
||||
',' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::Comma, start, self.pos))
|
||||
}
|
||||
'.' => {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::Dot, start, self.pos))
|
||||
}
|
||||
'=' => {
|
||||
self.advance();
|
||||
if self.peek() == Some('=') {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::EqEq, start, self.pos))
|
||||
} else {
|
||||
Err(TokenError::UnexpectedChar('=', start))
|
||||
}
|
||||
}
|
||||
'!' => {
|
||||
self.advance();
|
||||
if self.peek() == Some('=') {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::BangEq, start, self.pos))
|
||||
} else {
|
||||
Err(TokenError::UnexpectedChar('!', start))
|
||||
}
|
||||
}
|
||||
'<' => {
|
||||
self.advance();
|
||||
if self.peek() == Some('=') {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::LtEq, start, self.pos))
|
||||
} else {
|
||||
Ok(Token::new(TokenKind::Lt, start, self.pos))
|
||||
}
|
||||
}
|
||||
'>' => {
|
||||
self.advance();
|
||||
if self.peek() == Some('=') {
|
||||
self.advance();
|
||||
Ok(Token::new(TokenKind::GtEq, start, self.pos))
|
||||
} else {
|
||||
Ok(Token::new(TokenKind::Gt, start, self.pos))
|
||||
}
|
||||
}
|
||||
'"' | '\'' => self.read_string(ch),
|
||||
c if c.is_ascii_digit() => self.read_number(),
|
||||
c if c.is_ascii_alphabetic() || c == '_' => self.read_ident(),
|
||||
other => Err(TokenError::UnexpectedChar(other, start)),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_string(&mut self, quote: char) -> Result<Token, TokenError> {
|
||||
let start = self.pos;
|
||||
self.advance(); // consume opening quote
|
||||
let mut s = String::new();
|
||||
loop {
|
||||
match self.advance() {
|
||||
Some('\\') => {
|
||||
// Escape sequence
|
||||
match self.advance() {
|
||||
Some('n') => s.push('\n'),
|
||||
Some('t') => s.push('\t'),
|
||||
Some('r') => s.push('\r'),
|
||||
Some('\\') => s.push('\\'),
|
||||
Some(c) if c == quote => s.push(c),
|
||||
Some(c) => {
|
||||
s.push('\\');
|
||||
s.push(c);
|
||||
}
|
||||
None => return Err(TokenError::UnterminatedString(start)),
|
||||
}
|
||||
}
|
||||
Some(c) if c == quote => {
|
||||
return Ok(Token::new(TokenKind::StringLit(s), start, self.pos));
|
||||
}
|
||||
Some(c) => s.push(c),
|
||||
None => return Err(TokenError::UnterminatedString(start)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_number(&mut self) -> Result<Token, TokenError> {
|
||||
let start = self.pos;
|
||||
let mut num_str = String::new();
|
||||
let mut is_float = false;
|
||||
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_digit() {
|
||||
num_str.push(ch);
|
||||
self.advance();
|
||||
} else if ch == '.' && !is_float {
|
||||
// Check if this is a decimal point or a method call dot
|
||||
// Look ahead to see if next char is a digit
|
||||
let next_pos = self.pos + 1;
|
||||
if next_pos < self.chars.len() && self.chars[next_pos].is_ascii_digit() {
|
||||
is_float = true;
|
||||
num_str.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
// It's a dot access, stop number parsing here
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if is_float {
|
||||
let val: f64 = num_str.parse().map_err(|_| {
|
||||
TokenError::InvalidNumber(start, num_str.clone())
|
||||
})?;
|
||||
Ok(Token::new(TokenKind::Float(val), start, self.pos))
|
||||
} else {
|
||||
let val: i64 = num_str.parse().map_err(|_| {
|
||||
TokenError::InvalidNumber(start, num_str.clone())
|
||||
})?;
|
||||
Ok(Token::new(TokenKind::Integer(val), start, self.pos))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_ident(&mut self) -> Result<Token, TokenError> {
|
||||
let start = self.pos;
|
||||
let mut ident = String::new();
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_alphanumeric() || ch == '_' {
|
||||
ident.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let kind = match ident.as_str() {
|
||||
"true" => TokenKind::True,
|
||||
"false" => TokenKind::False,
|
||||
"null" => TokenKind::Null,
|
||||
"and" => TokenKind::And,
|
||||
"or" => TokenKind::Or,
|
||||
"not" => TokenKind::Not,
|
||||
"in" => TokenKind::In,
|
||||
_ => TokenKind::Ident(ident),
|
||||
};
|
||||
|
||||
Ok(Token::new(kind, start, self.pos))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tokenize(input: &str) -> Vec<TokenKind> {
|
||||
let mut t = Tokenizer::new(input);
|
||||
t.tokenize()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|t| t.kind)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_expression() {
|
||||
let kinds = tokenize("2 + 3");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::Integer(2),
|
||||
TokenKind::Plus,
|
||||
TokenKind::Integer(3),
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_comparison() {
|
||||
let kinds = tokenize("x >= 10");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::Ident("x".to_string()),
|
||||
TokenKind::GtEq,
|
||||
TokenKind::Integer(10),
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keywords() {
|
||||
let kinds = tokenize("true and not false or null in x");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::True,
|
||||
TokenKind::And,
|
||||
TokenKind::Not,
|
||||
TokenKind::False,
|
||||
TokenKind::Or,
|
||||
TokenKind::Null,
|
||||
TokenKind::In,
|
||||
TokenKind::Ident("x".to_string()),
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_literals() {
|
||||
let kinds = tokenize("\"hello\" + 'world'");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::StringLit("hello".to_string()),
|
||||
TokenKind::Plus,
|
||||
TokenKind::StringLit("world".to_string()),
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_float() {
|
||||
let kinds = tokenize("3.14");
|
||||
assert_eq!(kinds, vec![TokenKind::Float(3.14), TokenKind::Eof]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_access() {
|
||||
let kinds = tokenize("obj.field");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::Ident("obj".to_string()),
|
||||
TokenKind::Dot,
|
||||
TokenKind::Ident("field".to_string()),
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call() {
|
||||
let kinds = tokenize("length(arr)");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::Ident("length".to_string()),
|
||||
TokenKind::LParen,
|
||||
TokenKind::Ident("arr".to_string()),
|
||||
TokenKind::RParen,
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bracket_access() {
|
||||
let kinds = tokenize("arr[0]");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::Ident("arr".to_string()),
|
||||
TokenKind::LBracket,
|
||||
TokenKind::Integer(0),
|
||||
TokenKind::RBracket,
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escape_sequences() {
|
||||
let kinds = tokenize(r#""hello\nworld""#);
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::StringLit("hello\nworld".to_string()),
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_integer_followed_by_dot() {
|
||||
// `42.field` - the 42 is an integer and `.field` is separate
|
||||
let kinds = tokenize("42.field");
|
||||
assert_eq!(
|
||||
kinds,
|
||||
vec![
|
||||
TokenKind::Integer(42),
|
||||
TokenKind::Dot,
|
||||
TokenKind::Ident("field".to_string()),
|
||||
TokenKind::Eof,
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
674
crates/common/src/workflow/expression_validator.rs
Normal file
674
crates/common/src/workflow/expression_validator.rs
Normal file
@@ -0,0 +1,674 @@
|
||||
//! # Workflow Expression Validator
|
||||
//!
|
||||
//! Static validation of `{{ }}` template expressions in workflow definitions.
|
||||
//! Catches syntax errors and unresolved variable references **before** the
|
||||
//! workflow is saved, so users get immediate feedback instead of opaque
|
||||
//! runtime failures during execution.
|
||||
//!
|
||||
//! ## What is validated
|
||||
//!
|
||||
//! 1. **Syntax** — every `{{ expr }}` block must parse successfully.
|
||||
//! 2. **Variable references** — top-level identifiers that are not a known
|
||||
//! namespace (`parameters`, `workflow`, `task`, `config`, `keystore`,
|
||||
//! `item`, `index`, `system`, …) must exist in either the workflow's
|
||||
//! `vars` map or its `param_schema` keys (bare-name fallback targets).
|
||||
//!
|
||||
//! ## What is NOT validated
|
||||
//!
|
||||
//! - **Type correctness** — e.g. whether `range(parameters.n)` actually
|
||||
//! receives an integer. That requires runtime values.
|
||||
//! - **Deep property paths** — e.g. `task.fetch.result.data`. We validate
|
||||
//! that `task` is a known namespace but not that `fetch` is a real task
|
||||
//! name (it might not exist yet at save time if tasks are re-ordered).
|
||||
//! - **Function arity** — built-in functions are not checked for argument
|
||||
//! count here; the evaluator already reports those errors at runtime.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use super::expression::{parse_expression, Expr, ParseError};
|
||||
use super::parser::{PublishDirective, WorkflowDefinition};
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────────────
|
||||
// Public API
|
||||
// ───────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A single validation diagnostic.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExpressionWarning {
|
||||
/// Human-readable location (e.g. `task 'sleep_1' with_items`).
|
||||
pub location: String,
|
||||
/// The raw template string that was checked.
|
||||
pub expression: String,
|
||||
/// What went wrong.
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ExpressionWarning {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}: {} — `{}`", self.location, self.message, self.expression)
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate all template expressions in a workflow definition.
|
||||
///
|
||||
/// Returns an empty vec on success, or one [`ExpressionWarning`] per problem
|
||||
/// found. The caller decides whether warnings are fatal (block save) or
|
||||
/// advisory.
|
||||
///
|
||||
/// `param_schema` is the *flat-format* schema (`{ "url": { "type": "string" }, … }`)
|
||||
/// passed alongside the definition in the save request. Its top-level keys
|
||||
/// are the declared parameter names.
|
||||
pub fn validate_workflow_expressions(
|
||||
workflow: &WorkflowDefinition,
|
||||
param_schema: Option<&JsonValue>,
|
||||
) -> Vec<ExpressionWarning> {
|
||||
let known_names = build_known_names(workflow, param_schema);
|
||||
let mut warnings = Vec::new();
|
||||
|
||||
for task in &workflow.tasks {
|
||||
let task_loc = format!("task '{}'", task.name);
|
||||
|
||||
// ── with_items expression ────────────────────────────────────
|
||||
if let Some(ref expr) = task.with_items {
|
||||
validate_template(
|
||||
expr,
|
||||
&format!("{task_loc} with_items"),
|
||||
&known_names,
|
||||
&mut warnings,
|
||||
);
|
||||
}
|
||||
|
||||
// ── task-level when condition ────────────────────────────────
|
||||
if let Some(ref expr) = task.when {
|
||||
validate_template(
|
||||
expr,
|
||||
&format!("{task_loc} when"),
|
||||
&known_names,
|
||||
&mut warnings,
|
||||
);
|
||||
}
|
||||
|
||||
// ── input templates ──────────────────────────────────────────
|
||||
for (key, value) in &task.input {
|
||||
collect_json_templates(
|
||||
value,
|
||||
&format!("{task_loc} input.{key}"),
|
||||
&known_names,
|
||||
&mut warnings,
|
||||
);
|
||||
}
|
||||
|
||||
// ── next transitions ─────────────────────────────────────────
|
||||
for (ti, transition) in task.next.iter().enumerate() {
|
||||
if let Some(ref when_expr) = transition.when {
|
||||
validate_template(
|
||||
when_expr,
|
||||
&format!("{task_loc} next[{ti}].when"),
|
||||
&known_names,
|
||||
&mut warnings,
|
||||
);
|
||||
}
|
||||
|
||||
for directive in &transition.publish {
|
||||
match directive {
|
||||
PublishDirective::Simple(map) => {
|
||||
for (pk, pv) in map {
|
||||
validate_template(
|
||||
pv,
|
||||
&format!("{task_loc} next[{ti}].publish.{pk}"),
|
||||
&known_names,
|
||||
&mut warnings,
|
||||
);
|
||||
}
|
||||
}
|
||||
PublishDirective::Key(_) => { /* nothing to validate */ }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── legacy task-level publish ────────────────────────────────
|
||||
for directive in &task.publish {
|
||||
if let PublishDirective::Simple(map) = directive {
|
||||
for (pk, pv) in map {
|
||||
validate_template(
|
||||
pv,
|
||||
&format!("{task_loc} publish.{pk}"),
|
||||
&known_names,
|
||||
&mut warnings,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
warnings
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────────────
|
||||
// Internals
|
||||
// ───────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Canonical namespace identifiers that are always valid as top-level names
|
||||
/// inside `{{ }}` expressions. These are resolved by `WorkflowContext` at
|
||||
/// runtime and never need to exist in `vars` or `param_schema`.
|
||||
const CANONICAL_NAMESPACES: &[&str] = &[
|
||||
"parameters",
|
||||
"workflow",
|
||||
"vars",
|
||||
"variables",
|
||||
"task",
|
||||
"tasks",
|
||||
"config",
|
||||
"keystore",
|
||||
"item",
|
||||
"index",
|
||||
"system",
|
||||
];
|
||||
|
||||
/// Built-in constants that are valid bare identifiers.
|
||||
const BUILTIN_LITERALS: &[&str] = &["true", "false", "null"];
|
||||
|
||||
/// Build the set of bare names that are valid in expressions:
|
||||
/// canonical namespaces + workflow var names + param_schema keys.
|
||||
fn build_known_names(
|
||||
workflow: &WorkflowDefinition,
|
||||
param_schema: Option<&JsonValue>,
|
||||
) -> HashSet<String> {
|
||||
let mut names: HashSet<String> = CANONICAL_NAMESPACES
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect();
|
||||
|
||||
for lit in BUILTIN_LITERALS {
|
||||
names.insert((*lit).to_string());
|
||||
}
|
||||
|
||||
// Workflow vars
|
||||
for key in workflow.vars.keys() {
|
||||
names.insert(key.clone());
|
||||
}
|
||||
|
||||
// Parameter schema keys (flat format: top-level keys are param names)
|
||||
if let Some(JsonValue::Object(map)) = param_schema {
|
||||
for key in map.keys() {
|
||||
names.insert(key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Also accept the workflow-level `parameters` schema if present on the
|
||||
// definition itself (some loaders put it there).
|
||||
if let Some(JsonValue::Object(ref map)) = workflow.parameters {
|
||||
for key in map.keys() {
|
||||
names.insert(key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
names
|
||||
}
|
||||
|
||||
/// Extract `{{ … }}` blocks from a template string and validate each one.
|
||||
fn validate_template(
|
||||
template: &str,
|
||||
location: &str,
|
||||
known_names: &HashSet<String>,
|
||||
warnings: &mut Vec<ExpressionWarning>,
|
||||
) {
|
||||
for raw_expr in extract_expressions(template) {
|
||||
let trimmed = raw_expr.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Phase 1: parse
|
||||
match parse_expression(trimmed) {
|
||||
Err(e) => {
|
||||
warnings.push(ExpressionWarning {
|
||||
location: location.to_string(),
|
||||
expression: raw_expr.to_string(),
|
||||
message: format!("syntax error: {e}"),
|
||||
});
|
||||
}
|
||||
Ok(ast) => {
|
||||
// Phase 2: check bare-name references
|
||||
let mut bare_idents = Vec::new();
|
||||
collect_bare_idents(&ast, &mut bare_idents);
|
||||
|
||||
for ident in bare_idents {
|
||||
if !known_names.contains(&ident) {
|
||||
warnings.push(ExpressionWarning {
|
||||
location: location.to_string(),
|
||||
expression: raw_expr.to_string(),
|
||||
message: format!(
|
||||
"unknown variable '{}'. Use 'parameters.{}' for input \
|
||||
parameters, or define it in workflow vars",
|
||||
ident, ident,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively walk a JSON value looking for string leaves that contain
|
||||
/// `{{ }}` templates.
|
||||
fn collect_json_templates(
|
||||
value: &JsonValue,
|
||||
location: &str,
|
||||
known_names: &HashSet<String>,
|
||||
warnings: &mut Vec<ExpressionWarning>,
|
||||
) {
|
||||
match value {
|
||||
JsonValue::String(s) => {
|
||||
validate_template(s, location, known_names, warnings);
|
||||
}
|
||||
JsonValue::Array(arr) => {
|
||||
for (i, item) in arr.iter().enumerate() {
|
||||
collect_json_templates(
|
||||
item,
|
||||
&format!("{location}[{i}]"),
|
||||
known_names,
|
||||
warnings,
|
||||
);
|
||||
}
|
||||
}
|
||||
JsonValue::Object(map) => {
|
||||
for (key, val) in map {
|
||||
collect_json_templates(
|
||||
val,
|
||||
&format!("{location}.{key}"),
|
||||
known_names,
|
||||
warnings,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => { /* numbers, bools, null — nothing to validate */ }
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the inner expression strings from all `{{ … }}` blocks in a
|
||||
/// template. Handles nested braces conservatively (takes everything between
|
||||
/// the outermost `{{` and `}}`).
|
||||
fn extract_expressions(template: &str) -> Vec<&str> {
|
||||
let mut results = Vec::new();
|
||||
let mut rest = template;
|
||||
|
||||
while let Some(start) = rest.find("{{") {
|
||||
let after_open = start + 2;
|
||||
if let Some(end) = rest[after_open..].find("}}") {
|
||||
results.push(&rest[after_open..after_open + end]);
|
||||
rest = &rest[after_open + end + 2..];
|
||||
} else {
|
||||
// Unclosed `{{` — skip
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Collect bare `Ident` nodes that appear at the *top level* of an
|
||||
/// expression — i.e. identifiers that are not the right-hand side of a
|
||||
/// `.field` access (those are field names, not variable references).
|
||||
///
|
||||
/// For `DotAccess { object: Ident("parameters"), field: "n" }` we collect
|
||||
/// `"parameters"` but NOT `"n"`.
|
||||
///
|
||||
/// For `FunctionCall { name: "range", args: [Ident("n")] }` we collect
|
||||
/// `"n"` (it's a bare variable reference used as a function argument).
|
||||
fn collect_bare_idents(expr: &Expr, out: &mut Vec<String>) {
|
||||
match expr {
|
||||
Expr::Ident(name) => {
|
||||
out.push(name.clone());
|
||||
}
|
||||
Expr::Literal(_) => {}
|
||||
Expr::Array(items) => {
|
||||
for item in items {
|
||||
collect_bare_idents(item, out);
|
||||
}
|
||||
}
|
||||
Expr::BinaryOp { left, right, .. } => {
|
||||
collect_bare_idents(left, out);
|
||||
collect_bare_idents(right, out);
|
||||
}
|
||||
Expr::UnaryOp { operand, .. } => {
|
||||
collect_bare_idents(operand, out);
|
||||
}
|
||||
Expr::DotAccess { object, .. } => {
|
||||
// Only recurse into the object side — the field name is not a
|
||||
// variable reference.
|
||||
collect_bare_idents(object, out);
|
||||
}
|
||||
Expr::IndexAccess { object, index } => {
|
||||
collect_bare_idents(object, out);
|
||||
collect_bare_idents(index, out);
|
||||
}
|
||||
Expr::FunctionCall { args, .. } => {
|
||||
// Function name itself is not a variable reference.
|
||||
for arg in args {
|
||||
collect_bare_idents(arg, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────────────
|
||||
// Tests
|
||||
// ───────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn minimal_workflow(tasks: Vec<super::super::parser::Task>) -> WorkflowDefinition {
|
||||
WorkflowDefinition {
|
||||
r#ref: "test.wf".to_string(),
|
||||
label: "Test".to_string(),
|
||||
description: None,
|
||||
version: "1.0.0".to_string(),
|
||||
parameters: None,
|
||||
output: None,
|
||||
vars: HashMap::new(),
|
||||
tasks,
|
||||
output_map: None,
|
||||
tags: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn action_task(name: &str) -> super::super::parser::Task {
|
||||
super::super::parser::Task {
|
||||
name: name.to_string(),
|
||||
r#type: super::super::parser::TaskType::Action,
|
||||
action: Some("core.echo".to_string()),
|
||||
input: HashMap::new(),
|
||||
when: None,
|
||||
with_items: None,
|
||||
batch_size: None,
|
||||
concurrency: None,
|
||||
retry: None,
|
||||
timeout: None,
|
||||
next: vec![],
|
||||
on_success: None,
|
||||
on_failure: None,
|
||||
on_complete: None,
|
||||
on_timeout: None,
|
||||
decision: vec![],
|
||||
publish: vec![],
|
||||
join: None,
|
||||
tasks: None,
|
||||
chart_meta: None,
|
||||
}
|
||||
}
|
||||
|
||||
// ── extract_expressions ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_extract_single() {
|
||||
let exprs = extract_expressions("{{ parameters.n }}");
|
||||
assert_eq!(exprs, vec![" parameters.n "]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_multiple() {
|
||||
let exprs = extract_expressions("Hello {{ name }}, you have {{ count }} items");
|
||||
assert_eq!(exprs.len(), 2);
|
||||
assert_eq!(exprs[0].trim(), "name");
|
||||
assert_eq!(exprs[1].trim(), "count");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_no_expressions() {
|
||||
let exprs = extract_expressions("plain text");
|
||||
assert!(exprs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_unclosed() {
|
||||
let exprs = extract_expressions("{{ oops");
|
||||
assert!(exprs.is_empty());
|
||||
}
|
||||
|
||||
// ── collect_bare_idents ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_bare_ident() {
|
||||
let ast = parse_expression("n").unwrap();
|
||||
let mut idents = Vec::new();
|
||||
collect_bare_idents(&ast, &mut idents);
|
||||
assert_eq!(idents, vec!["n"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_access_does_not_collect_field() {
|
||||
let ast = parse_expression("parameters.n").unwrap();
|
||||
let mut idents = Vec::new();
|
||||
collect_bare_idents(&ast, &mut idents);
|
||||
assert_eq!(idents, vec!["parameters"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_arg_collected() {
|
||||
let ast = parse_expression("range(n)").unwrap();
|
||||
let mut idents = Vec::new();
|
||||
collect_bare_idents(&ast, &mut idents);
|
||||
assert_eq!(idents, vec!["n"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_dot_access() {
|
||||
let ast = parse_expression("task.fetch.result.data").unwrap();
|
||||
let mut idents = Vec::new();
|
||||
collect_bare_idents(&ast, &mut idents);
|
||||
assert_eq!(idents, vec!["task"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_op() {
|
||||
let ast = parse_expression("parameters.x + workflow.y").unwrap();
|
||||
let mut idents = Vec::new();
|
||||
collect_bare_idents(&ast, &mut idents);
|
||||
assert_eq!(idents, vec!["parameters", "workflow"]);
|
||||
}
|
||||
|
||||
// ── validate_workflow_expressions ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_valid_workflow_no_warnings() {
|
||||
let mut task = action_task("greet");
|
||||
task.with_items = Some("{{ range(parameters.n) }}".to_string());
|
||||
task.input.insert(
|
||||
"message".to_string(),
|
||||
serde_json::json!("Hello {{ item }}"),
|
||||
);
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bare_name_from_vars_ok() {
|
||||
let mut task = action_task("greet");
|
||||
task.with_items = Some("{{ range(n) }}".to_string());
|
||||
|
||||
let mut wf = minimal_workflow(vec![task]);
|
||||
wf.vars.insert("n".to_string(), serde_json::json!(5));
|
||||
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bare_name_from_param_schema_ok() {
|
||||
let mut task = action_task("greet");
|
||||
task.with_items = Some("{{ range(n) }}".to_string());
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let schema = serde_json::json!({
|
||||
"n": { "type": "integer", "required": true }
|
||||
});
|
||||
|
||||
let warnings = validate_workflow_expressions(&wf, Some(&schema));
|
||||
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_bare_name_warning() {
|
||||
let mut task = action_task("greet");
|
||||
task.with_items = Some("{{ range(n) }}".to_string());
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
|
||||
assert_eq!(warnings.len(), 1);
|
||||
assert!(warnings[0].message.contains("unknown variable 'n'"));
|
||||
assert!(warnings[0].message.contains("parameters.n"));
|
||||
assert!(warnings[0].location.contains("with_items"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_syntax_error_warning() {
|
||||
let mut task = action_task("greet");
|
||||
task.input.insert(
|
||||
"msg".to_string(),
|
||||
serde_json::json!("{{ +++ }}"),
|
||||
);
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
|
||||
assert_eq!(warnings.len(), 1);
|
||||
assert!(warnings[0].message.contains("syntax error"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transition_when_validated() {
|
||||
let mut task = action_task("step1");
|
||||
task.next = vec![super::super::parser::TaskTransition {
|
||||
when: Some("{{ bad_var > 3 }}".to_string()),
|
||||
publish: vec![],
|
||||
r#do: Some(vec!["step2".to_string()]),
|
||||
chart_meta: None,
|
||||
}];
|
||||
|
||||
let wf = minimal_workflow(vec![task, action_task("step2")]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
|
||||
assert_eq!(warnings.len(), 1);
|
||||
assert!(warnings[0].message.contains("unknown variable 'bad_var'"));
|
||||
assert!(warnings[0].location.contains("next[0].when"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transition_publish_validated() {
|
||||
let mut task = action_task("step1");
|
||||
let mut publish_map = HashMap::new();
|
||||
publish_map.insert("out".to_string(), "{{ unknown_thing }}".to_string());
|
||||
task.next = vec![super::super::parser::TaskTransition {
|
||||
when: Some("{{ succeeded() }}".to_string()),
|
||||
publish: vec![PublishDirective::Simple(publish_map)],
|
||||
r#do: Some(vec!["step2".to_string()]),
|
||||
chart_meta: None,
|
||||
}];
|
||||
|
||||
let wf = minimal_workflow(vec![task, action_task("step2")]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
|
||||
assert_eq!(warnings.len(), 1);
|
||||
assert!(warnings[0].message.contains("unknown variable 'unknown_thing'"));
|
||||
assert!(warnings[0].location.contains("publish.out"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workflow_functions_no_warning() {
|
||||
// succeeded(), failed(), result() etc. are function calls,
|
||||
// not variable references — should not produce warnings.
|
||||
let mut task = action_task("step1");
|
||||
task.next = vec![super::super::parser::TaskTransition {
|
||||
when: Some("{{ succeeded() and result().code == 200 }}".to_string()),
|
||||
publish: vec![],
|
||||
r#do: Some(vec!["step2".to_string()]),
|
||||
chart_meta: None,
|
||||
}];
|
||||
|
||||
let wf = minimal_workflow(vec![task, action_task("step2")]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plain_text_no_warning() {
|
||||
let mut task = action_task("step1");
|
||||
task.input.insert(
|
||||
"msg".to_string(),
|
||||
serde_json::json!("just plain text"),
|
||||
);
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
assert!(warnings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_errors_collected() {
|
||||
let mut task = action_task("step1");
|
||||
task.with_items = Some("{{ range(a) }}".to_string());
|
||||
task.input.insert(
|
||||
"x".to_string(),
|
||||
serde_json::json!("{{ b + c }}"),
|
||||
);
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
|
||||
// a, b, c are all unknown
|
||||
assert_eq!(warnings.len(), 3);
|
||||
let names: HashSet<_> = warnings
|
||||
.iter()
|
||||
.flat_map(|w| {
|
||||
// extract the variable name from "unknown variable 'X'"
|
||||
w.message
|
||||
.strip_prefix("unknown variable '")
|
||||
.and_then(|s| s.split('\'').next())
|
||||
.map(|s| s.to_string())
|
||||
})
|
||||
.collect();
|
||||
assert!(names.contains("a"));
|
||||
assert!(names.contains("b"));
|
||||
assert!(names.contains("c"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_access_validated() {
|
||||
let mut task = action_task("step1");
|
||||
task.input.insert(
|
||||
"val".to_string(),
|
||||
serde_json::json!("{{ items[idx] }}"),
|
||||
);
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
|
||||
// Both `items` and `idx` are bare unknowns
|
||||
assert_eq!(warnings.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builtin_literals_ok() {
|
||||
let mut task = action_task("step1");
|
||||
task.next = vec![super::super::parser::TaskTransition {
|
||||
when: Some("{{ true and not false }}".to_string()),
|
||||
publish: vec![],
|
||||
r#do: None,
|
||||
chart_meta: None,
|
||||
}];
|
||||
|
||||
let wf = minimal_workflow(vec![task]);
|
||||
let warnings = validate_workflow_expressions(&wf, None);
|
||||
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
//! This module provides utilities for loading, parsing, validating, and registering
|
||||
//! workflow definitions from YAML files.
|
||||
|
||||
pub mod expression;
|
||||
pub mod loader;
|
||||
pub mod pack_service;
|
||||
pub mod parser;
|
||||
|
||||
@@ -356,7 +356,7 @@ async fn test_update_execution_status() {
|
||||
status: Some(ExecutionStatus::Running),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let updated = ExecutionRepository::update(&pool, created.id, update)
|
||||
@@ -401,7 +401,7 @@ async fn test_update_execution_result() {
|
||||
status: Some(ExecutionStatus::Completed),
|
||||
result: Some(result_data.clone()),
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let updated = ExecutionRepository::update(&pool, created.id, update)
|
||||
@@ -445,7 +445,7 @@ async fn test_update_execution_executor() {
|
||||
status: Some(ExecutionStatus::Scheduled),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let updated = ExecutionRepository::update(&pool, created.id, update)
|
||||
@@ -492,7 +492,7 @@ async fn test_update_execution_status_transitions() {
|
||||
status: Some(ExecutionStatus::Scheduling),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -507,7 +507,7 @@ async fn test_update_execution_status_transitions() {
|
||||
status: Some(ExecutionStatus::Scheduled),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -522,7 +522,7 @@ async fn test_update_execution_status_transitions() {
|
||||
status: Some(ExecutionStatus::Running),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -537,7 +537,7 @@ async fn test_update_execution_status_transitions() {
|
||||
status: Some(ExecutionStatus::Completed),
|
||||
result: Some(json!({"success": true})),
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -578,7 +578,7 @@ async fn test_update_execution_failed_status() {
|
||||
status: Some(ExecutionStatus::Failed),
|
||||
result: Some(json!({"error": "Connection timeout"})),
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let updated = ExecutionRepository::update(&pool, created.id, update)
|
||||
@@ -984,7 +984,7 @@ async fn test_execution_timestamps() {
|
||||
status: Some(ExecutionStatus::Running),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let updated = ExecutionRepository::update(&pool, created.id, update)
|
||||
@@ -1095,7 +1095,7 @@ async fn test_execution_result_json() {
|
||||
status: Some(ExecutionStatus::Completed),
|
||||
result: Some(complex_result.clone()),
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let updated = ExecutionRepository::update(&pool, created.id, update)
|
||||
|
||||
@@ -244,8 +244,7 @@ impl InquiryHandler {
|
||||
let update_input = UpdateExecutionInput {
|
||||
status: None, // Keep current status, let worker handle completion
|
||||
result: Some(updated_result),
|
||||
executor: None,
|
||||
workflow_task: None, // Not updating workflow metadata
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
ExecutionRepository::update(pool, execution.id, update_input).await?;
|
||||
|
||||
@@ -381,10 +381,7 @@ impl RetryManager {
|
||||
&self.pool,
|
||||
execution_id,
|
||||
UpdateExecutionInput {
|
||||
status: None,
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -66,6 +66,53 @@ fn extract_workflow_params(config: &Option<JsonValue>) -> JsonValue {
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply default values from a workflow's `param_schema` to the provided
|
||||
/// parameters.
|
||||
///
|
||||
/// The param_schema uses the flat format where each key maps to an object
|
||||
/// that may contain a `"default"` field:
|
||||
///
|
||||
/// ```json
|
||||
/// { "n": { "type": "integer", "default": 10 } }
|
||||
/// ```
|
||||
///
|
||||
/// Any parameter that has a default in the schema but is missing (or `null`)
|
||||
/// in the supplied `params` will be filled in. Parameters already provided
|
||||
/// by the caller are never overwritten.
|
||||
fn apply_param_defaults(params: JsonValue, param_schema: &Option<JsonValue>) -> JsonValue {
|
||||
let schema = match param_schema {
|
||||
Some(s) if s.is_object() => s,
|
||||
_ => return params,
|
||||
};
|
||||
|
||||
let mut obj = match params {
|
||||
JsonValue::Object(m) => m,
|
||||
_ => return params,
|
||||
};
|
||||
|
||||
if let Some(schema_obj) = schema.as_object() {
|
||||
for (key, prop) in schema_obj {
|
||||
// Only fill in missing / null parameters
|
||||
let needs_default = match obj.get(key) {
|
||||
None => true,
|
||||
Some(JsonValue::Null) => true,
|
||||
_ => false,
|
||||
};
|
||||
if needs_default {
|
||||
if let Some(default_val) = prop.get("default") {
|
||||
debug!(
|
||||
"Applying default for parameter '{}': {}",
|
||||
key, default_val
|
||||
);
|
||||
obj.insert(key.clone(), default_val.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
JsonValue::Object(obj)
|
||||
}
|
||||
|
||||
/// Payload for execution scheduled messages
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ExecutionScheduledPayload {
|
||||
@@ -316,7 +363,10 @@ impl ExecutionScheduler {
|
||||
|
||||
// Build initial workflow context from execution parameters and
|
||||
// workflow-level vars so that entry-point task inputs are rendered.
|
||||
// Apply defaults from the workflow's param_schema for any parameters
|
||||
// that were not supplied by the caller.
|
||||
let workflow_params = extract_workflow_params(&execution.config);
|
||||
let workflow_params = apply_param_defaults(workflow_params, &workflow_def.param_schema);
|
||||
let wf_ctx = WorkflowContext::new(
|
||||
workflow_params,
|
||||
definition
|
||||
@@ -563,7 +613,7 @@ impl ExecutionScheduler {
|
||||
};
|
||||
|
||||
let total = items.len();
|
||||
let concurrency_limit = task_node.concurrency.unwrap_or(total);
|
||||
let concurrency_limit = task_node.concurrency.unwrap_or(1);
|
||||
let dispatch_count = total.min(concurrency_limit);
|
||||
|
||||
info!(
|
||||
@@ -842,6 +892,18 @@ impl ExecutionScheduler {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Load the workflow definition so we can apply param_schema defaults
|
||||
let workflow_def =
|
||||
WorkflowDefinitionRepository::find_by_id(pool, workflow_execution.workflow_def)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Workflow definition {} not found for workflow_execution {}",
|
||||
workflow_execution.workflow_def,
|
||||
workflow_execution_id
|
||||
)
|
||||
})?;
|
||||
|
||||
// Rebuild the task graph from the stored JSON
|
||||
let graph: TaskGraph = serde_json::from_value(workflow_execution.task_graph.clone())
|
||||
.map_err(|e| {
|
||||
@@ -897,7 +959,7 @@ impl ExecutionScheduler {
|
||||
let concurrency_limit = graph
|
||||
.get_task(task_name)
|
||||
.and_then(|n| n.concurrency)
|
||||
.unwrap_or(usize::MAX);
|
||||
.unwrap_or(1);
|
||||
|
||||
let free_slots =
|
||||
concurrency_limit.saturating_sub(in_flight_count.0 as usize);
|
||||
@@ -995,6 +1057,7 @@ impl ExecutionScheduler {
|
||||
// results so that successor task inputs can be rendered.
|
||||
// -----------------------------------------------------------------
|
||||
let workflow_params = extract_workflow_params(&parent_execution.config);
|
||||
let workflow_params = apply_param_defaults(workflow_params, &workflow_def.param_schema);
|
||||
|
||||
// Collect results from all completed children of this workflow
|
||||
let child_executions =
|
||||
@@ -1619,11 +1682,11 @@ mod tests {
|
||||
let dispatch_count = total.min(concurrency_limit);
|
||||
assert_eq!(dispatch_count, 3);
|
||||
|
||||
// No concurrency limit → dispatch all
|
||||
// No concurrency limit → default to serial (1 at a time)
|
||||
let concurrency: Option<usize> = None;
|
||||
let concurrency_limit = concurrency.unwrap_or(total);
|
||||
let concurrency_limit = concurrency.unwrap_or(1);
|
||||
let dispatch_count = total.min(concurrency_limit);
|
||||
assert_eq!(dispatch_count, 20);
|
||||
assert_eq!(dispatch_count, 1);
|
||||
|
||||
// Concurrency exceeds total → dispatch all
|
||||
let concurrency: Option<usize> = Some(50);
|
||||
|
||||
@@ -3,6 +3,33 @@
|
||||
//! This module manages workflow execution context, including variables,
|
||||
//! template rendering, and data flow between tasks.
|
||||
//!
|
||||
//! ## Canonical Namespaces
|
||||
//!
|
||||
//! All data accessible inside `{{ }}` template expressions is organised into
|
||||
//! well-defined, non-overlapping namespaces:
|
||||
//!
|
||||
//! | Namespace | Example | Description |
|
||||
//! |-----------|---------|-------------|
|
||||
//! | `parameters` | `{{ parameters.url }}` | Immutable workflow input parameters |
|
||||
//! | `workflow` | `{{ workflow.counter }}` | Mutable workflow-scoped variables (set via `publish`) |
|
||||
//! | `task` | `{{ task.fetch.result.data }}` | Completed task results keyed by task name |
|
||||
//! | `config` | `{{ config.api_token }}` | Pack configuration values (read-only) |
|
||||
//! | `keystore` | `{{ keystore.secret_key }}` | Encrypted secrets from the key store (read-only) |
|
||||
//! | `item` | `{{ item }}` or `{{ item.name }}` | Current element in a `with_items` loop |
|
||||
//! | `index` | `{{ index }}` | Zero-based iteration index in a `with_items` loop |
|
||||
//! | `system` | `{{ system.workflow_start }}` | System-provided variables |
|
||||
//!
|
||||
//! ### Backward-compatible aliases
|
||||
//!
|
||||
//! The following aliases resolve to the same data as their canonical form and
|
||||
//! are kept for backward compatibility with existing workflow definitions:
|
||||
//!
|
||||
//! - `vars` / `variables` → same as `workflow`
|
||||
//! - `tasks` → same as `task`
|
||||
//!
|
||||
//! Bare variable names (e.g. `{{ my_var }}`) also resolve against the
|
||||
//! `workflow` variable store as a last-resort fallback.
|
||||
//!
|
||||
//! ## Function-call expressions
|
||||
//!
|
||||
//! Templates support Orquesta-style function calls:
|
||||
@@ -19,6 +46,9 @@
|
||||
//! expression instead of stringifying it. This means `"{{ item }}"` resolving
|
||||
//! to integer `5` stays as `5`, not the string `"5"`.
|
||||
|
||||
use attune_common::workflow::expression::{
|
||||
self, is_truthy, EvalContext, EvalError, EvalResult as ExprResult,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use std::collections::HashMap;
|
||||
@@ -63,18 +93,25 @@ pub enum TaskOutcome {
|
||||
/// not the underlying data, making it O(1) instead of O(context_size).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkflowContext {
|
||||
/// Workflow-level variables (shared via Arc)
|
||||
/// Mutable workflow-scoped variables. Canonical namespace: `workflow`.
|
||||
/// Also accessible as `vars`, `variables`, or bare names (fallback).
|
||||
variables: Arc<DashMap<String, JsonValue>>,
|
||||
|
||||
/// Workflow input parameters (shared via Arc)
|
||||
/// Immutable workflow input parameters. Canonical namespace: `parameters`.
|
||||
parameters: Arc<JsonValue>,
|
||||
|
||||
/// Task results (shared via Arc, keyed by task name)
|
||||
/// Completed task results keyed by task name. Canonical namespace: `task`.
|
||||
task_results: Arc<DashMap<String, JsonValue>>,
|
||||
|
||||
/// System variables (shared via Arc)
|
||||
/// System-provided variables. Canonical namespace: `system`.
|
||||
system: Arc<DashMap<String, JsonValue>>,
|
||||
|
||||
/// Pack configuration values (read-only). Canonical namespace: `config`.
|
||||
pack_config: Arc<JsonValue>,
|
||||
|
||||
/// Encrypted keystore values (read-only). Canonical namespace: `keystore`.
|
||||
keystore: Arc<JsonValue>,
|
||||
|
||||
/// Current item (for with-items iteration) - per-item data
|
||||
current_item: Option<JsonValue>,
|
||||
|
||||
@@ -89,7 +126,11 @@ pub struct WorkflowContext {
|
||||
}
|
||||
|
||||
impl WorkflowContext {
|
||||
/// Create a new workflow context
|
||||
/// Create a new workflow context.
|
||||
///
|
||||
/// `parameters` — the immutable input parameters for this workflow run.
|
||||
/// `initial_vars` — initial workflow-scoped variables (from the workflow
|
||||
/// definition's `vars` section).
|
||||
pub fn new(parameters: JsonValue, initial_vars: HashMap<String, JsonValue>) -> Self {
|
||||
let system = DashMap::new();
|
||||
system.insert("workflow_start".to_string(), json!(chrono::Utc::now()));
|
||||
@@ -104,6 +145,8 @@ impl WorkflowContext {
|
||||
parameters: Arc::new(parameters),
|
||||
task_results: Arc::new(DashMap::new()),
|
||||
system: Arc::new(system),
|
||||
pack_config: Arc::new(JsonValue::Null),
|
||||
keystore: Arc::new(JsonValue::Null),
|
||||
current_item: None,
|
||||
current_index: None,
|
||||
last_task_result: None,
|
||||
@@ -142,6 +185,8 @@ impl WorkflowContext {
|
||||
parameters: Arc::new(parameters),
|
||||
task_results: Arc::new(results),
|
||||
system: Arc::new(system),
|
||||
pack_config: Arc::new(JsonValue::Null),
|
||||
keystore: Arc::new(JsonValue::Null),
|
||||
current_item: None,
|
||||
current_index: None,
|
||||
last_task_result: None,
|
||||
@@ -149,28 +194,38 @@ impl WorkflowContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a variable
|
||||
/// Set a workflow-scoped variable (accessible as `workflow.<name>`).
|
||||
pub fn set_var(&mut self, name: &str, value: JsonValue) {
|
||||
self.variables.insert(name.to_string(), value);
|
||||
}
|
||||
|
||||
/// Get a variable
|
||||
/// Get a workflow-scoped variable by name.
|
||||
pub fn get_var(&self, name: &str) -> Option<JsonValue> {
|
||||
self.variables.get(name).map(|entry| entry.value().clone())
|
||||
}
|
||||
|
||||
/// Store a task result
|
||||
/// Store a completed task's result (accessible as `task.<name>.*`).
|
||||
pub fn set_task_result(&mut self, task_name: &str, result: JsonValue) {
|
||||
self.task_results.insert(task_name.to_string(), result);
|
||||
}
|
||||
|
||||
/// Get a task result
|
||||
/// Get a task result by task name.
|
||||
pub fn get_task_result(&self, task_name: &str) -> Option<JsonValue> {
|
||||
self.task_results
|
||||
.get(task_name)
|
||||
.map(|entry| entry.value().clone())
|
||||
}
|
||||
|
||||
/// Set the pack configuration (accessible as `config.<key>`).
|
||||
pub fn set_pack_config(&mut self, config: JsonValue) {
|
||||
self.pack_config = Arc::new(config);
|
||||
}
|
||||
|
||||
/// Set the keystore secrets (accessible as `keystore.<key>`).
|
||||
pub fn set_keystore(&mut self, secrets: JsonValue) {
|
||||
self.keystore = Arc::new(secrets);
|
||||
}
|
||||
|
||||
/// Set current item for iteration
|
||||
pub fn set_current_item(&mut self, item: JsonValue, index: usize) {
|
||||
self.current_item = Some(item);
|
||||
@@ -299,220 +354,55 @@ impl WorkflowContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a template expression
|
||||
fn evaluate_expression(&self, expr: &str) -> ContextResult<JsonValue> {
|
||||
// ---------------------------------------------------------------
|
||||
// Function-call expressions: result(), succeeded(), failed(), timed_out()
|
||||
// ---------------------------------------------------------------
|
||||
// We handle these *before* splitting on `.` because the function
|
||||
// name contains parentheses which would confuse the dot-split.
|
||||
//
|
||||
// Supported patterns:
|
||||
// result() → last task result
|
||||
// result().foo.bar → nested access into result
|
||||
// result().data.items → nested access into result
|
||||
// succeeded() → boolean
|
||||
// failed() → boolean
|
||||
// timed_out() → boolean
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
if let Some(result_val) = self.try_evaluate_function_call(expr)? {
|
||||
return Ok(result_val);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Dot-path expressions
|
||||
// ---------------------------------------------------------------
|
||||
let parts: Vec<&str> = expr.split('.').collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
return Err(ContextError::InvalidExpression(expr.to_string()));
|
||||
}
|
||||
|
||||
match parts[0] {
|
||||
"parameters" => self.get_nested_value(&self.parameters, &parts[1..]),
|
||||
"vars" | "variables" => {
|
||||
if parts.len() < 2 {
|
||||
return Err(ContextError::InvalidExpression(expr.to_string()));
|
||||
}
|
||||
let var_name = parts[1];
|
||||
if let Some(entry) = self.variables.get(var_name) {
|
||||
let value = entry.value().clone();
|
||||
drop(entry);
|
||||
if parts.len() > 2 {
|
||||
self.get_nested_value(&value, &parts[2..])
|
||||
} else {
|
||||
Ok(value)
|
||||
}
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound(var_name.to_string()))
|
||||
}
|
||||
}
|
||||
"task" | "tasks" => {
|
||||
if parts.len() < 2 {
|
||||
return Err(ContextError::InvalidExpression(expr.to_string()));
|
||||
}
|
||||
let task_name = parts[1];
|
||||
if let Some(entry) = self.task_results.get(task_name) {
|
||||
let result = entry.value().clone();
|
||||
drop(entry);
|
||||
if parts.len() > 2 {
|
||||
self.get_nested_value(&result, &parts[2..])
|
||||
} else {
|
||||
Ok(result)
|
||||
}
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound(format!(
|
||||
"task.{}",
|
||||
task_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
"item" => {
|
||||
if let Some(ref item) = self.current_item {
|
||||
if parts.len() > 1 {
|
||||
self.get_nested_value(item, &parts[1..])
|
||||
} else {
|
||||
Ok(item.clone())
|
||||
}
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound("item".to_string()))
|
||||
}
|
||||
}
|
||||
"index" => {
|
||||
if let Some(index) = self.current_index {
|
||||
Ok(json!(index))
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound("index".to_string()))
|
||||
}
|
||||
}
|
||||
"system" => {
|
||||
if parts.len() < 2 {
|
||||
return Err(ContextError::InvalidExpression(expr.to_string()));
|
||||
}
|
||||
let key = parts[1];
|
||||
if let Some(entry) = self.system.get(key) {
|
||||
Ok(entry.value().clone())
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound(format!("system.{}", key)))
|
||||
}
|
||||
}
|
||||
// Direct variable reference (e.g., `number_list` published by a
|
||||
// previous task's transition)
|
||||
var_name => {
|
||||
if let Some(entry) = self.variables.get(var_name) {
|
||||
let value = entry.value().clone();
|
||||
drop(entry);
|
||||
if parts.len() > 1 {
|
||||
self.get_nested_value(&value, &parts[1..])
|
||||
} else {
|
||||
Ok(value)
|
||||
}
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound(var_name.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to evaluate `expr` as a function-call expression.
|
||||
/// Evaluate a template expression using the expression engine.
|
||||
///
|
||||
/// Returns `Ok(Some(value))` if the expression starts with a recognised
|
||||
/// function call, `Ok(None)` if it does not match, or `Err` on failure.
|
||||
fn try_evaluate_function_call(&self, expr: &str) -> ContextResult<Option<JsonValue>> {
|
||||
// succeeded()
|
||||
if expr == "succeeded()" {
|
||||
let val = self
|
||||
.last_task_outcome
|
||||
.map(|o| o == TaskOutcome::Succeeded)
|
||||
.unwrap_or(false);
|
||||
return Ok(Some(json!(val)));
|
||||
}
|
||||
|
||||
// failed()
|
||||
if expr == "failed()" {
|
||||
let val = self
|
||||
.last_task_outcome
|
||||
.map(|o| o == TaskOutcome::Failed)
|
||||
.unwrap_or(false);
|
||||
return Ok(Some(json!(val)));
|
||||
}
|
||||
|
||||
// timed_out()
|
||||
if expr == "timed_out()" {
|
||||
let val = self
|
||||
.last_task_outcome
|
||||
.map(|o| o == TaskOutcome::TimedOut)
|
||||
.unwrap_or(false);
|
||||
return Ok(Some(json!(val)));
|
||||
}
|
||||
|
||||
// result() or result().path.to.field
|
||||
if expr == "result()" || expr.starts_with("result().") {
|
||||
let base = self.last_task_result.clone().unwrap_or(JsonValue::Null);
|
||||
|
||||
if expr == "result()" {
|
||||
return Ok(Some(base));
|
||||
}
|
||||
|
||||
// Strip "result()." prefix and navigate the remaining path
|
||||
let rest = &expr["result().".len()..];
|
||||
let path_parts: Vec<&str> = rest.split('.').collect();
|
||||
let val = self.get_nested_value(&base, &path_parts)?;
|
||||
return Ok(Some(val));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
/// Supports the full expression language including arithmetic, comparison,
|
||||
/// boolean logic, member access, and built-in functions. Falls back to
|
||||
/// legacy dot-path resolution for simple variable references when the
|
||||
/// expression engine cannot parse the input.
|
||||
fn evaluate_expression(&self, expr: &str) -> ContextResult<JsonValue> {
|
||||
// Use the expression engine for all expressions. It handles:
|
||||
// - Dot-path access: parameters.config.port
|
||||
// - Bracket access: arr[0], obj["key"]
|
||||
// - Arithmetic: 2 + 3, length(items) * 2
|
||||
// - Comparison: x > 5, status == "ok"
|
||||
// - Boolean logic: x > 0 and x < 10
|
||||
// - Function calls: length(arr), result(), succeeded()
|
||||
// - Membership: "key" in obj, 5 in arr
|
||||
expression::eval_expression(expr, self).map_err(|e| match e {
|
||||
EvalError::VariableNotFound(name) => ContextError::VariableNotFound(name),
|
||||
EvalError::TypeError(msg) => ContextError::TypeConversion(msg),
|
||||
EvalError::ParseError(msg) => ContextError::InvalidExpression(msg),
|
||||
other => ContextError::InvalidExpression(format!("{}", other)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get nested value from JSON
|
||||
fn get_nested_value(&self, value: &JsonValue, path: &[&str]) -> ContextResult<JsonValue> {
|
||||
let mut current = value;
|
||||
|
||||
for key in path {
|
||||
match current {
|
||||
JsonValue::Object(obj) => {
|
||||
current = obj
|
||||
.get(*key)
|
||||
.ok_or_else(|| ContextError::VariableNotFound(key.to_string()))?;
|
||||
}
|
||||
JsonValue::Array(arr) => {
|
||||
let index: usize = key.parse().map_err(|_| {
|
||||
ContextError::InvalidExpression(format!("Invalid array index: {}", key))
|
||||
})?;
|
||||
current = arr.get(index).ok_or_else(|| {
|
||||
ContextError::InvalidExpression(format!(
|
||||
"Array index out of bounds: {}",
|
||||
index
|
||||
))
|
||||
})?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ContextError::InvalidExpression(format!(
|
||||
"Cannot access property '{}' on non-object/array value",
|
||||
key
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(current.clone())
|
||||
}
|
||||
|
||||
/// Evaluate a conditional expression (for 'when' clauses)
|
||||
/// Evaluate a conditional expression (for 'when' clauses).
|
||||
///
|
||||
/// Uses the full expression engine so conditions can contain comparisons,
|
||||
/// boolean operators, function calls, and arithmetic. For example:
|
||||
///
|
||||
/// ```text
|
||||
/// succeeded()
|
||||
/// result().status == "ok"
|
||||
/// length(items) > 3 and "admin" in roles
|
||||
/// not failed()
|
||||
/// ```
|
||||
pub fn evaluate_condition(&self, condition: &str) -> ContextResult<bool> {
|
||||
// For now, simple boolean evaluation
|
||||
// TODO: Support more complex expressions (comparisons, logical operators)
|
||||
|
||||
let rendered = self.render_template(condition)?;
|
||||
|
||||
// Try to parse as boolean
|
||||
match rendered.trim().to_lowercase().as_str() {
|
||||
"true" | "1" | "yes" => Ok(true),
|
||||
"false" | "0" | "no" | "" => Ok(false),
|
||||
other => {
|
||||
// Try to evaluate as truthy/falsy
|
||||
Ok(!other.is_empty())
|
||||
// Try the expression engine first — it handles complex conditions
|
||||
// like `result().code == 200 and succeeded()`.
|
||||
match expression::eval_expression(condition, self) {
|
||||
Ok(val) => Ok(is_truthy(&val)),
|
||||
Err(_) => {
|
||||
// Fall back to template rendering for backward compat with
|
||||
// simple template conditions like `{{ succeeded() }}` (though
|
||||
// bare expressions are preferred going forward).
|
||||
let rendered = self.render_template(condition)?;
|
||||
match rendered.trim().to_lowercase().as_str() {
|
||||
"true" | "1" | "yes" => Ok(true),
|
||||
"false" | "0" | "no" | "" => Ok(false),
|
||||
_ => Ok(!rendered.trim().is_empty()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -574,6 +464,8 @@ impl WorkflowContext {
|
||||
"parameters": self.parameters.as_ref(),
|
||||
"task_results": task_results,
|
||||
"system": system,
|
||||
"pack_config": self.pack_config.as_ref(),
|
||||
"keystore": self.keystore.as_ref(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -602,11 +494,16 @@ impl WorkflowContext {
|
||||
}
|
||||
}
|
||||
|
||||
let pack_config = data["pack_config"].clone();
|
||||
let keystore = data["keystore"].clone();
|
||||
|
||||
Ok(Self {
|
||||
variables: Arc::new(variables),
|
||||
parameters: Arc::new(parameters),
|
||||
task_results: Arc::new(task_results),
|
||||
system: Arc::new(system),
|
||||
pack_config: Arc::new(pack_config),
|
||||
keystore: Arc::new(keystore),
|
||||
current_item: None,
|
||||
current_index: None,
|
||||
last_task_result: None,
|
||||
@@ -626,10 +523,122 @@ fn value_to_string(value: &JsonValue) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// EvalContext implementation — bridges the expression engine into
|
||||
// the WorkflowContext's variable resolution and workflow functions.
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
impl EvalContext for WorkflowContext {
|
||||
fn resolve_variable(&self, name: &str) -> ExprResult<JsonValue> {
|
||||
match name {
|
||||
// ── Canonical namespaces ──────────────────────────────
|
||||
"parameters" => Ok(self.parameters.as_ref().clone()),
|
||||
|
||||
// `workflow` is the canonical name for mutable vars.
|
||||
// `vars` and `variables` are backward-compatible aliases.
|
||||
"workflow" | "vars" | "variables" => {
|
||||
let map: serde_json::Map<String, JsonValue> = self
|
||||
.variables
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect();
|
||||
Ok(JsonValue::Object(map))
|
||||
}
|
||||
|
||||
// `task` (alias: `tasks`) — completed task results.
|
||||
"task" | "tasks" => {
|
||||
let map: serde_json::Map<String, JsonValue> = self
|
||||
.task_results
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect();
|
||||
Ok(JsonValue::Object(map))
|
||||
}
|
||||
|
||||
// `config` — pack configuration (read-only).
|
||||
"config" => Ok(self.pack_config.as_ref().clone()),
|
||||
|
||||
// `keystore` — encrypted secrets (read-only).
|
||||
"keystore" => Ok(self.keystore.as_ref().clone()),
|
||||
|
||||
// ── Iteration context ────────────────────────────────
|
||||
"item" => self
|
||||
.current_item
|
||||
.clone()
|
||||
.ok_or_else(|| EvalError::VariableNotFound("item".to_string())),
|
||||
"index" => self
|
||||
.current_index
|
||||
.map(|i| json!(i))
|
||||
.ok_or_else(|| EvalError::VariableNotFound("index".to_string())),
|
||||
|
||||
// ── System variables ──────────────────────────────────
|
||||
"system" => {
|
||||
let map: serde_json::Map<String, JsonValue> = self
|
||||
.system
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect();
|
||||
Ok(JsonValue::Object(map))
|
||||
}
|
||||
|
||||
// ── Bare-name fallback ───────────────────────────────
|
||||
// Resolve against workflow variables last so that
|
||||
// `{{ my_var }}` still works as shorthand for
|
||||
// `{{ workflow.my_var }}`.
|
||||
_ => {
|
||||
if let Some(entry) = self.variables.get(name) {
|
||||
Ok(entry.value().clone())
|
||||
} else {
|
||||
Err(EvalError::VariableNotFound(name.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn call_workflow_function(
|
||||
&self,
|
||||
name: &str,
|
||||
_args: &[JsonValue],
|
||||
) -> ExprResult<Option<JsonValue>> {
|
||||
match name {
|
||||
"succeeded" => {
|
||||
let val = self
|
||||
.last_task_outcome
|
||||
.map(|o| o == TaskOutcome::Succeeded)
|
||||
.unwrap_or(false);
|
||||
Ok(Some(json!(val)))
|
||||
}
|
||||
"failed" => {
|
||||
let val = self
|
||||
.last_task_outcome
|
||||
.map(|o| o == TaskOutcome::Failed)
|
||||
.unwrap_or(false);
|
||||
Ok(Some(json!(val)))
|
||||
}
|
||||
"timed_out" => {
|
||||
let val = self
|
||||
.last_task_outcome
|
||||
.map(|o| o == TaskOutcome::TimedOut)
|
||||
.unwrap_or(false);
|
||||
Ok(Some(json!(val)))
|
||||
}
|
||||
"result" => {
|
||||
let base = self.last_task_result.clone().unwrap_or(JsonValue::Null);
|
||||
Ok(Some(base))
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// parameters namespace
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_basic_template_rendering() {
|
||||
let params = json!({
|
||||
@@ -641,28 +650,6 @@ mod tests {
|
||||
assert_eq!(result, "Hello World!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_variable_access() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("greeting".to_string(), json!("Hello"));
|
||||
|
||||
let ctx = WorkflowContext::new(json!({}), vars);
|
||||
|
||||
let result = ctx.render_template("{{ greeting }} World").unwrap();
|
||||
assert_eq!(result, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_task_result_access() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_task_result("task1", json!({"status": "success"}));
|
||||
|
||||
let result = ctx
|
||||
.render_template("Status: {{ task.task1.status }}")
|
||||
.unwrap();
|
||||
assert_eq!(result, "Status: success");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_value_access() {
|
||||
let params = json!({
|
||||
@@ -680,6 +667,143 @@ mod tests {
|
||||
assert_eq!(result, "Port: 8080");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// workflow namespace (canonical) + vars/variables aliases
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_workflow_namespace_canonical() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_var("greeting", json!("Hello"));
|
||||
|
||||
// Canonical: workflow.<name>
|
||||
let result = ctx.render_template("{{ workflow.greeting }} World").unwrap();
|
||||
assert_eq!(result, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workflow_namespace_vars_alias() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("greeting".to_string(), json!("Hello"));
|
||||
let ctx = WorkflowContext::new(json!({}), vars);
|
||||
|
||||
// Backward-compat alias: vars.<name>
|
||||
let result = ctx.render_template("{{ vars.greeting }} World").unwrap();
|
||||
assert_eq!(result, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workflow_namespace_variables_alias() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("greeting".to_string(), json!("Hello"));
|
||||
let ctx = WorkflowContext::new(json!({}), vars);
|
||||
|
||||
// Backward-compat alias: variables.<name>
|
||||
let result = ctx.render_template("{{ variables.greeting }} World").unwrap();
|
||||
assert_eq!(result, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_variable_access_bare_name_fallback() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("greeting".to_string(), json!("Hello"));
|
||||
|
||||
let ctx = WorkflowContext::new(json!({}), vars);
|
||||
|
||||
// Bare name falls back to workflow variables
|
||||
let result = ctx.render_template("{{ greeting }} World").unwrap();
|
||||
assert_eq!(result, "Hello World");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// task namespace
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_task_result_access() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_task_result("task1", json!({"status": "success"}));
|
||||
|
||||
let result = ctx
|
||||
.render_template("Status: {{ task.task1.status }}")
|
||||
.unwrap();
|
||||
assert_eq!(result, "Status: success");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_task_result_deep_access() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_task_result("fetch", json!({"result": {"data": {"id": 42}}}));
|
||||
|
||||
let val = ctx.evaluate_expression("task.fetch.result.data.id").unwrap();
|
||||
assert_eq!(val, json!(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_task_result_stdout() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_task_result("run_cmd", json!({"result": {"stdout": "hello world"}}));
|
||||
|
||||
let val = ctx.evaluate_expression("task.run_cmd.result.stdout").unwrap();
|
||||
assert_eq!(val, json!("hello world"));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// config namespace (pack configuration)
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_config_namespace() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_pack_config(json!({"api_token": "tok_abc123", "base_url": "https://api.example.com"}));
|
||||
|
||||
let val = ctx.evaluate_expression("config.api_token").unwrap();
|
||||
assert_eq!(val, json!("tok_abc123"));
|
||||
|
||||
let result = ctx
|
||||
.render_template("URL: {{ config.base_url }}")
|
||||
.unwrap();
|
||||
assert_eq!(result, "URL: https://api.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_namespace_nested() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_pack_config(json!({"slack": {"webhook_url": "https://hooks.slack.com/xxx"}}));
|
||||
|
||||
let val = ctx.evaluate_expression("config.slack.webhook_url").unwrap();
|
||||
assert_eq!(val, json!("https://hooks.slack.com/xxx"));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// keystore namespace (encrypted secrets)
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_keystore_namespace() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_keystore(json!({"secret_key": "s3cr3t", "db_password": "hunter2"}));
|
||||
|
||||
let val = ctx.evaluate_expression("keystore.secret_key").unwrap();
|
||||
assert_eq!(val, json!("s3cr3t"));
|
||||
|
||||
let val = ctx.evaluate_expression("keystore.db_password").unwrap();
|
||||
assert_eq!(val, json!("hunter2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keystore_bracket_access() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_keystore(json!({"My Secret Key": "value123"}));
|
||||
|
||||
let val = ctx.evaluate_expression("keystore[\"My Secret Key\"]").unwrap();
|
||||
assert_eq!(val, json!("value123"));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// item / index (with_items iteration)
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_item_context() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
@@ -691,6 +815,10 @@ mod tests {
|
||||
assert_eq!(result, "Item: item1, Index: 0");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Condition evaluation
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_condition_evaluation() {
|
||||
let params = json!({"enabled": true});
|
||||
@@ -700,6 +828,133 @@ mod tests {
|
||||
assert!(!ctx.evaluate_condition("false").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_condition_with_comparison() {
|
||||
let ctx = WorkflowContext::new(json!({"count": 10}), HashMap::new());
|
||||
assert!(ctx.evaluate_condition("parameters.count > 5").unwrap());
|
||||
assert!(!ctx.evaluate_condition("parameters.count < 5").unwrap());
|
||||
assert!(ctx.evaluate_condition("parameters.count == 10").unwrap());
|
||||
assert!(ctx.evaluate_condition("parameters.count >= 10").unwrap());
|
||||
assert!(ctx.evaluate_condition("parameters.count != 99").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_condition_with_boolean_operators() {
|
||||
let ctx = WorkflowContext::new(json!({"x": 10, "y": 20}), HashMap::new());
|
||||
assert!(ctx
|
||||
.evaluate_condition("parameters.x > 5 and parameters.y > 15")
|
||||
.unwrap());
|
||||
assert!(!ctx
|
||||
.evaluate_condition("parameters.x > 5 and parameters.y > 25")
|
||||
.unwrap());
|
||||
assert!(ctx
|
||||
.evaluate_condition("parameters.x > 50 or parameters.y > 15")
|
||||
.unwrap());
|
||||
assert!(ctx
|
||||
.evaluate_condition("not parameters.x > 50")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_condition_with_in_operator() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_var("roles", json!(["admin", "user"]));
|
||||
// Via bare-name fallback
|
||||
assert!(ctx.evaluate_condition("\"admin\" in roles").unwrap());
|
||||
assert!(!ctx.evaluate_condition("\"root\" in roles").unwrap());
|
||||
// Via canonical workflow namespace
|
||||
assert!(ctx.evaluate_condition("\"admin\" in workflow.roles").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_condition_with_function_calls() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_last_task_outcome(
|
||||
json!({"status": "ok", "code": 200}),
|
||||
TaskOutcome::Succeeded,
|
||||
);
|
||||
assert!(ctx.evaluate_condition("succeeded()").unwrap());
|
||||
assert!(!ctx.evaluate_condition("failed()").unwrap());
|
||||
assert!(ctx
|
||||
.evaluate_condition("succeeded() and result().code == 200")
|
||||
.unwrap());
|
||||
assert!(!ctx
|
||||
.evaluate_condition("succeeded() and result().code == 404")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_condition_with_length() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_var("items", json!([1, 2, 3, 4, 5]));
|
||||
assert!(ctx.evaluate_condition("length(items) > 3").unwrap());
|
||||
assert!(!ctx.evaluate_condition("length(items) > 10").unwrap());
|
||||
assert!(ctx
|
||||
.evaluate_condition("length(items) == 5")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_condition_with_config() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_pack_config(json!({"retries": 3}));
|
||||
assert!(ctx.evaluate_condition("config.retries > 0").unwrap());
|
||||
assert!(ctx.evaluate_condition("config.retries == 3").unwrap());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Expression engine in templates
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_expression_arithmetic() {
|
||||
let ctx = WorkflowContext::new(json!({"x": 10}), HashMap::new());
|
||||
let input = json!({"result": "{{ parameters.x + 5 }}"});
|
||||
let result = ctx.render_json(&input).unwrap();
|
||||
assert_eq!(result["result"], json!(15));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expression_string_concat() {
|
||||
let ctx = WorkflowContext::new(
|
||||
json!({"first": "Hello", "second": "World"}),
|
||||
HashMap::new(),
|
||||
);
|
||||
let input = json!({"msg": "{{ parameters.first + \" \" + parameters.second }}"});
|
||||
let result = ctx.render_json(&input).unwrap();
|
||||
assert_eq!(result["msg"], json!("Hello World"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expression_nested_functions() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_var("data", json!("a,b,c"));
|
||||
let input = json!({"count": "{{ length(split(data, \",\")) }}"});
|
||||
let result = ctx.render_json(&input).unwrap();
|
||||
assert_eq!(result["count"], json!(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expression_bracket_access() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_var("arr", json!([10, 20, 30]));
|
||||
let input = json!({"second": "{{ arr[1] }}"});
|
||||
let result = ctx.render_json(&input).unwrap();
|
||||
assert_eq!(result["second"], json!(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expression_type_conversion() {
|
||||
let ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
let input = json!({"val": "{{ int(3.9) }}"});
|
||||
let result = ctx.render_json(&input).unwrap();
|
||||
assert_eq!(result["val"], json!(3));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// render_json type-preserving behaviour
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_render_json() {
|
||||
let params = json!({"name": "test"});
|
||||
@@ -769,6 +1024,10 @@ mod tests {
|
||||
assert!(result["ok"].is_boolean());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// result() / succeeded() / failed() / timed_out()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_result_function() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
@@ -813,6 +1072,10 @@ mod tests {
|
||||
assert_eq!(ctx.evaluate_expression("timed_out()").unwrap(), json!(true));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Publish
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_publish_with_result_function() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
@@ -846,6 +1109,28 @@ mod tests {
|
||||
assert_eq!(ctx.get_var("my_var").unwrap(), result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_published_var_accessible_via_workflow_namespace() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_var("counter", json!(42));
|
||||
|
||||
// Via canonical namespace
|
||||
let val = ctx.evaluate_expression("workflow.counter").unwrap();
|
||||
assert_eq!(val, json!(42));
|
||||
|
||||
// Via backward-compat alias
|
||||
let val = ctx.evaluate_expression("vars.counter").unwrap();
|
||||
assert_eq!(val, json!(42));
|
||||
|
||||
// Via bare-name fallback
|
||||
let val = ctx.evaluate_expression("counter").unwrap();
|
||||
assert_eq!(val, json!(42));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Rebuild / Export / Import round-trip
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_rebuild_context() {
|
||||
let stored_vars = json!({"number_list": [0, 1, 2]});
|
||||
@@ -868,17 +1153,31 @@ mod tests {
|
||||
let mut ctx = WorkflowContext::new(json!({"key": "value"}), HashMap::new());
|
||||
ctx.set_var("test", json!("data"));
|
||||
ctx.set_task_result("task1", json!({"result": "ok"}));
|
||||
ctx.set_pack_config(json!({"setting": "val"}));
|
||||
ctx.set_keystore(json!({"secret": "hidden"}));
|
||||
|
||||
let exported = ctx.export();
|
||||
let _imported = WorkflowContext::import(exported).unwrap();
|
||||
let imported = WorkflowContext::import(exported).unwrap();
|
||||
|
||||
assert_eq!(ctx.get_var("test").unwrap(), json!("data"));
|
||||
assert_eq!(imported.get_var("test").unwrap(), json!("data"));
|
||||
assert_eq!(
|
||||
ctx.get_task_result("task1").unwrap(),
|
||||
imported.get_task_result("task1").unwrap(),
|
||||
json!({"result": "ok"})
|
||||
);
|
||||
assert_eq!(
|
||||
imported.evaluate_expression("config.setting").unwrap(),
|
||||
json!("val")
|
||||
);
|
||||
assert_eq!(
|
||||
imported.evaluate_expression("keystore.secret").unwrap(),
|
||||
json!("hidden")
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// with_items type preservation
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_with_items_integer_type_preservation() {
|
||||
// Simulates the sleep_2 task from the hello_workflow:
|
||||
@@ -902,4 +1201,40 @@ mod tests {
|
||||
assert_eq!(rendered["message"], json!("Sleeping for 3 seconds "));
|
||||
assert!(rendered["message"].is_string());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Cross-namespace expressions
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_cross_namespace_expression() {
|
||||
let mut ctx = WorkflowContext::new(json!({"limit": 5}), HashMap::new());
|
||||
ctx.set_var("items", json!([1, 2, 3]));
|
||||
ctx.set_pack_config(json!({"multiplier": 2}));
|
||||
|
||||
assert!(ctx
|
||||
.evaluate_condition("length(workflow.items) < parameters.limit")
|
||||
.unwrap());
|
||||
let val = ctx
|
||||
.evaluate_expression("parameters.limit * config.multiplier")
|
||||
.unwrap();
|
||||
assert_eq!(val, json!(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keystore_in_template() {
|
||||
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
ctx.set_keystore(json!({"api_key": "abc-123"}));
|
||||
|
||||
let input = json!({"auth": "Bearer {{ keystore.api_key }}"});
|
||||
let result = ctx.render_json(&input).unwrap();
|
||||
assert_eq!(result["auth"], json!("Bearer abc-123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_null_when_not_set() {
|
||||
let ctx = WorkflowContext::new(json!({}), HashMap::new());
|
||||
let val = ctx.evaluate_expression("config").unwrap();
|
||||
assert_eq!(val, json!(null));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -654,8 +654,7 @@ impl ActionExecutor {
|
||||
let input = UpdateExecutionInput {
|
||||
status: Some(ExecutionStatus::Completed),
|
||||
result: Some(result_data),
|
||||
executor: None,
|
||||
workflow_task: None, // Not updating workflow metadata
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
ExecutionRepository::update(&self.pool, execution_id, input).await?;
|
||||
@@ -755,8 +754,7 @@ impl ActionExecutor {
|
||||
let input = UpdateExecutionInput {
|
||||
status: Some(ExecutionStatus::Failed),
|
||||
result: Some(result_data),
|
||||
executor: None,
|
||||
workflow_task: None, // Not updating workflow metadata
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
ExecutionRepository::update(&self.pool, execution_id, input).await?;
|
||||
@@ -775,11 +773,16 @@ impl ActionExecutor {
|
||||
execution_id, status
|
||||
);
|
||||
|
||||
let started_at = if status == ExecutionStatus::Running {
|
||||
Some(chrono::Utc::now())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let input = UpdateExecutionInput {
|
||||
status: Some(status),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None, // Not updating workflow metadata
|
||||
started_at,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
ExecutionRepository::update(&self.pool, execution_id, input).await?;
|
||||
|
||||
Reference in New Issue
Block a user