proper sql filtering

This commit is contained in:
2026-03-01 20:43:48 -06:00
parent 6b9d7d6cf2
commit bbe94d75f8
54 changed files with 6692 additions and 928 deletions

View File

@@ -319,6 +319,10 @@ pub struct EnforcementQueryParams {
#[param(example = "core.webhook")]
pub trigger_ref: Option<String>,
/// Filter by rule reference
#[param(example = "core.on_webhook")]
pub rule_ref: Option<String>,
/// Page number (1-indexed)
#[serde(default = "default_page")]
#[param(example = 1, minimum = 1)]

View File

@@ -7,6 +7,7 @@ use utoipa::{IntoParams, ToSchema};
use attune_common::models::enums::ExecutionStatus;
use attune_common::models::execution::WorkflowTaskMetadata;
use attune_common::repositories::execution::ExecutionWithRefs;
/// Request DTO for creating a manual execution
#[derive(Debug, Clone, Deserialize, ToSchema)]
@@ -63,6 +64,12 @@ pub struct ExecutionResponse {
#[schema(value_type = Object, example = json!({"message_id": "1234567890.123456"}))]
pub result: Option<JsonValue>,
/// When the execution actually started running (worker picked it up).
/// Null if the execution hasn't started running yet.
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "2024-01-13T10:31:00Z", nullable = true)]
pub started_at: Option<DateTime<Utc>>,
/// Workflow task metadata (only populated for workflow task executions)
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(value_type = Option<Object>, nullable = true)]
@@ -108,6 +115,12 @@ pub struct ExecutionSummary {
#[schema(example = "core.timer")]
pub trigger_ref: Option<String>,
/// When the execution actually started running (worker picked it up).
/// Null if the execution hasn't started running yet.
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "2024-01-13T10:31:00Z", nullable = true)]
pub started_at: Option<DateTime<Utc>>,
/// Workflow task metadata (only populated for workflow task executions)
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(value_type = Option<Object>, nullable = true)]
@@ -207,6 +220,7 @@ impl From<attune_common::models::execution::Execution> for ExecutionResponse {
result: execution
.result
.map(|r| serde_json::to_value(r).unwrap_or(JsonValue::Null)),
started_at: execution.started_at,
workflow_task: execution.workflow_task,
created: execution.created,
updated: execution.updated,
@@ -225,6 +239,7 @@ impl From<attune_common::models::execution::Execution> for ExecutionSummary {
enforcement: execution.enforcement,
rule_ref: None, // Populated separately via enforcement lookup
trigger_ref: None, // Populated separately via enforcement lookup
started_at: execution.started_at,
workflow_task: execution.workflow_task,
created: execution.created,
updated: execution.updated,
@@ -232,6 +247,26 @@ impl From<attune_common::models::execution::Execution> for ExecutionSummary {
}
}
/// Convert from the joined query result (execution + enforcement refs).
/// `rule_ref` and `trigger_ref` are already populated from the SQL JOIN.
impl From<ExecutionWithRefs> for ExecutionSummary {
fn from(row: ExecutionWithRefs) -> Self {
Self {
id: row.id,
action_ref: row.action_ref,
status: row.status,
parent: row.parent,
enforcement: row.enforcement,
rule_ref: row.rule_ref,
trigger_ref: row.trigger_ref,
started_at: row.started_at,
workflow_task: row.workflow_task,
created: row.created,
updated: row.updated,
}
}
}
fn default_page() -> u32 {
1
}

View File

@@ -11,10 +11,10 @@ use std::sync::Arc;
use validator::Validate;
use attune_common::repositories::{
action::{ActionRepository, CreateActionInput, UpdateActionInput},
action::{ActionRepository, ActionSearchFilters, CreateActionInput, UpdateActionInput},
pack::PackRepository,
queue_stats::QueueStatsRepository,
Create, Delete, FindByRef, List, Update,
Create, Delete, FindByRef, Update,
};
use crate::{
@@ -47,21 +47,20 @@ pub async fn list_actions(
RequireAuth(_user): RequireAuth,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get all actions (we'll implement pagination in repository later)
let actions = ActionRepository::list(&state.db).await?;
// All filtering and pagination happen in a single SQL query.
let filters = ActionSearchFilters {
pack: None,
query: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = actions.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(actions.len());
let result = ActionRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_actions: Vec<ActionSummary> = actions[start..end]
.iter()
.map(|a| ActionSummary::from(a.clone()))
.collect();
let paginated_actions: Vec<ActionSummary> =
result.rows.into_iter().map(ActionSummary::from).collect();
let response = PaginatedResponse::new(paginated_actions, &pagination, total);
let response = PaginatedResponse::new(paginated_actions, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -92,21 +91,20 @@ pub async fn list_actions_by_pack(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
// Get actions for this pack
let actions = ActionRepository::find_by_pack(&state.db, pack.id).await?;
// All filtering and pagination happen in a single SQL query.
let filters = ActionSearchFilters {
pack: Some(pack.id),
query: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = actions.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(actions.len());
let result = ActionRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_actions: Vec<ActionSummary> = actions[start..end]
.iter()
.map(|a| ActionSummary::from(a.clone()))
.collect();
let paginated_actions: Vec<ActionSummary> =
result.rows.into_iter().map(ActionSummary::from).collect();
let response = PaginatedResponse::new(paginated_actions, &pagination, total);
let response = PaginatedResponse::new(paginated_actions, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}

View File

@@ -16,9 +16,12 @@ use validator::Validate;
use attune_common::{
mq::{EventCreatedPayload, MessageEnvelope, MessageType},
repositories::{
event::{CreateEventInput, EnforcementRepository, EventRepository},
event::{
CreateEventInput, EnforcementRepository, EnforcementSearchFilters, EventRepository,
EventSearchFilters,
},
trigger::TriggerRepository,
Create, FindById, FindByRef, List,
Create, FindById, FindByRef,
},
};
@@ -220,53 +223,27 @@ pub async fn list_events(
State(state): State<Arc<AppState>>,
Query(query): Query<EventQueryParams>,
) -> ApiResult<impl IntoResponse> {
// Get events based on filters
let events = if let Some(trigger_id) = query.trigger {
// Filter by trigger ID
EventRepository::find_by_trigger(&state.db, trigger_id).await?
} else if let Some(trigger_ref) = &query.trigger_ref {
// Filter by trigger reference
EventRepository::find_by_trigger_ref(&state.db, trigger_ref).await?
} else {
// Get all events
EventRepository::list(&state.db).await?
// All filtering and pagination happen in a single SQL query.
let filters = EventSearchFilters {
trigger: query.trigger,
trigger_ref: query.trigger_ref.clone(),
source: query.source,
rule_ref: query.rule_ref.clone(),
limit: query.limit(),
offset: query.offset(),
};
// Apply additional filters in memory
let mut filtered_events = events;
let result = EventRepository::search(&state.db, &filters).await?;
if let Some(source_id) = query.source {
filtered_events.retain(|e| e.source == Some(source_id));
}
let paginated_events: Vec<EventSummary> =
result.rows.into_iter().map(EventSummary::from).collect();
if let Some(rule_ref) = &query.rule_ref {
let rule_ref_lower = rule_ref.to_lowercase();
filtered_events.retain(|e| {
e.rule_ref
.as_ref()
.map(|r| r.to_lowercase().contains(&rule_ref_lower))
.unwrap_or(false)
});
}
// Calculate pagination
let total = filtered_events.len() as u64;
let start = query.offset() as usize;
let end = (start + query.limit() as usize).min(filtered_events.len());
// Get paginated slice
let paginated_events: Vec<EventSummary> = filtered_events[start..end]
.iter()
.map(|event| EventSummary::from(event.clone()))
.collect();
// Convert query params to pagination params for response
let pagination_params = PaginationParams {
page: query.page,
page_size: query.per_page,
};
let response = PaginatedResponse::new(paginated_events, &pagination_params, total);
let response = PaginatedResponse::new(paginated_events, &pagination_params, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -319,46 +296,32 @@ pub async fn list_enforcements(
State(state): State<Arc<AppState>>,
Query(query): Query<EnforcementQueryParams>,
) -> ApiResult<impl IntoResponse> {
// Get enforcements based on filters
let enforcements = if let Some(status) = query.status {
// Filter by status
EnforcementRepository::find_by_status(&state.db, status).await?
} else if let Some(rule_id) = query.rule {
// Filter by rule ID
EnforcementRepository::find_by_rule(&state.db, rule_id).await?
} else if let Some(event_id) = query.event {
// Filter by event ID
EnforcementRepository::find_by_event(&state.db, event_id).await?
} else {
// Get all enforcements
EnforcementRepository::list(&state.db).await?
// All filtering and pagination happen in a single SQL query.
// Filters are combinable (AND), not mutually exclusive.
let filters = EnforcementSearchFilters {
status: query.status,
rule: query.rule,
event: query.event,
trigger_ref: query.trigger_ref.clone(),
rule_ref: query.rule_ref.clone(),
limit: query.limit(),
offset: query.offset(),
};
// Apply additional filters in memory
let mut filtered_enforcements = enforcements;
let result = EnforcementRepository::search(&state.db, &filters).await?;
if let Some(trigger_ref) = &query.trigger_ref {
filtered_enforcements.retain(|e| e.trigger_ref == *trigger_ref);
}
// Calculate pagination
let total = filtered_enforcements.len() as u64;
let start = query.offset() as usize;
let end = (start + query.limit() as usize).min(filtered_enforcements.len());
// Get paginated slice
let paginated_enforcements: Vec<EnforcementSummary> = filtered_enforcements[start..end]
.iter()
.map(|enforcement| EnforcementSummary::from(enforcement.clone()))
let paginated_enforcements: Vec<EnforcementSummary> = result
.rows
.into_iter()
.map(EnforcementSummary::from)
.collect();
// Convert query params to pagination params for response
let pagination_params = PaginationParams {
page: query.page,
page_size: query.per_page,
};
let response = PaginatedResponse::new(paginated_enforcements, &pagination_params, total);
let response = PaginatedResponse::new(paginated_enforcements, &pagination_params, result.total);
Ok((StatusCode::OK, Json(response)))
}

View File

@@ -18,9 +18,10 @@ use attune_common::models::enums::ExecutionStatus;
use attune_common::mq::{ExecutionRequestedPayload, MessageEnvelope, MessageType};
use attune_common::repositories::{
action::ActionRepository,
execution::{CreateExecutionInput, ExecutionRepository},
Create, EnforcementRepository, FindById, FindByRef, List,
execution::{CreateExecutionInput, ExecutionRepository, ExecutionSearchFilters},
Create, FindById, FindByRef,
};
use sqlx::Row;
use crate::{
auth::middleware::RequireAuth,
@@ -125,117 +126,37 @@ pub async fn list_executions(
RequireAuth(_user): RequireAuth,
Query(query): Query<ExecutionQueryParams>,
) -> ApiResult<impl IntoResponse> {
// Get executions based on filters
let executions = if let Some(status) = query.status {
// Filter by status
ExecutionRepository::find_by_status(&state.db, status).await?
} else if let Some(enforcement_id) = query.enforcement {
// Filter by enforcement
ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?
} else {
// Get all executions
ExecutionRepository::list(&state.db).await?
// All filtering, pagination, and the enforcement JOIN happen in a single
// SQL query — no in-memory filtering or post-fetch lookups.
let filters = ExecutionSearchFilters {
status: query.status,
action_ref: query.action_ref.clone(),
pack_name: query.pack_name.clone(),
rule_ref: query.rule_ref.clone(),
trigger_ref: query.trigger_ref.clone(),
executor: query.executor,
result_contains: query.result_contains.clone(),
enforcement: query.enforcement,
parent: query.parent,
top_level_only: query.top_level_only == Some(true),
limit: query.limit(),
offset: query.offset(),
};
// Apply additional filters in memory (could be optimized with database queries)
let mut filtered_executions = executions;
let result = ExecutionRepository::search(&state.db, &filters).await?;
if let Some(action_ref) = &query.action_ref {
filtered_executions.retain(|e| e.action_ref == *action_ref);
}
if let Some(pack_name) = &query.pack_name {
filtered_executions.retain(|e| {
// action_ref format is "pack.action"
e.action_ref.starts_with(&format!("{}.", pack_name))
});
}
if let Some(result_search) = &query.result_contains {
let search_lower = result_search.to_lowercase();
filtered_executions.retain(|e| {
if let Some(result) = &e.result {
// Convert result to JSON string and search case-insensitively
let result_str = serde_json::to_string(result).unwrap_or_default();
result_str.to_lowercase().contains(&search_lower)
} else {
false
}
});
}
if let Some(parent_id) = query.parent {
filtered_executions.retain(|e| e.parent == Some(parent_id));
}
if query.top_level_only == Some(true) {
filtered_executions.retain(|e| e.parent.is_none());
}
if let Some(executor_id) = query.executor {
filtered_executions.retain(|e| e.executor == Some(executor_id));
}
// Fetch enforcements for all executions to populate rule_ref and trigger_ref
let enforcement_ids: Vec<i64> = filtered_executions
.iter()
.filter_map(|e| e.enforcement)
let paginated_executions: Vec<ExecutionSummary> = result
.rows
.into_iter()
.map(ExecutionSummary::from)
.collect();
let enforcement_map: std::collections::HashMap<i64, _> = if !enforcement_ids.is_empty() {
let enforcements = EnforcementRepository::list(&state.db).await?;
enforcements.into_iter().map(|enf| (enf.id, enf)).collect()
} else {
std::collections::HashMap::new()
};
// Filter by rule_ref if specified
if let Some(rule_ref) = &query.rule_ref {
filtered_executions.retain(|e| {
e.enforcement
.and_then(|enf_id| enforcement_map.get(&enf_id))
.map(|enf| enf.rule_ref == *rule_ref)
.unwrap_or(false)
});
}
// Filter by trigger_ref if specified
if let Some(trigger_ref) = &query.trigger_ref {
filtered_executions.retain(|e| {
e.enforcement
.and_then(|enf_id| enforcement_map.get(&enf_id))
.map(|enf| enf.trigger_ref == *trigger_ref)
.unwrap_or(false)
});
}
// Calculate pagination
let total = filtered_executions.len() as u64;
let start = query.offset() as usize;
let end = (start + query.limit() as usize).min(filtered_executions.len());
// Get paginated slice and populate rule_ref/trigger_ref from enforcements
let paginated_executions: Vec<ExecutionSummary> = filtered_executions[start..end]
.iter()
.map(|e| {
let mut summary = ExecutionSummary::from(e.clone());
if let Some(enf_id) = e.enforcement {
if let Some(enforcement) = enforcement_map.get(&enf_id) {
summary.rule_ref = Some(enforcement.rule_ref.clone());
summary.trigger_ref = Some(enforcement.trigger_ref.clone());
}
}
summary
})
.collect();
// Convert query params to pagination params for response
let pagination_params = PaginationParams {
page: query.page,
page_size: query.per_page,
};
let response = PaginatedResponse::new(paginated_executions, &pagination_params, total);
let response = PaginatedResponse::new(paginated_executions, &pagination_params, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -310,21 +231,23 @@ pub async fn list_executions_by_status(
}
};
// Get executions by status
let executions = ExecutionRepository::find_by_status(&state.db, status).await?;
// Use the search method for SQL-side filtering + pagination.
let filters = ExecutionSearchFilters {
status: Some(status),
limit: pagination.limit(),
offset: pagination.offset(),
..Default::default()
};
// Calculate pagination
let total = executions.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(executions.len());
let result = ExecutionRepository::search(&state.db, &filters).await?;
// Get paginated slice
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
.iter()
.map(|e| ExecutionSummary::from(e.clone()))
let paginated_executions: Vec<ExecutionSummary> = result
.rows
.into_iter()
.map(ExecutionSummary::from)
.collect();
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -350,21 +273,23 @@ pub async fn list_executions_by_enforcement(
Path(enforcement_id): Path<i64>,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get executions by enforcement
let executions = ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?;
// Use the search method for SQL-side filtering + pagination.
let filters = ExecutionSearchFilters {
enforcement: Some(enforcement_id),
limit: pagination.limit(),
offset: pagination.offset(),
..Default::default()
};
// Calculate pagination
let total = executions.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(executions.len());
let result = ExecutionRepository::search(&state.db, &filters).await?;
// Get paginated slice
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
.iter()
.map(|e| ExecutionSummary::from(e.clone()))
let paginated_executions: Vec<ExecutionSummary> = result
.rows
.into_iter()
.map(ExecutionSummary::from)
.collect();
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -384,34 +309,37 @@ pub async fn get_execution_stats(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
) -> ApiResult<impl IntoResponse> {
// Get all executions (limited by repository to 1000)
let executions = ExecutionRepository::list(&state.db).await?;
// Use a single SQL query with COUNT + GROUP BY instead of fetching all rows.
let rows = sqlx::query(
"SELECT status::text AS status, COUNT(*) AS cnt FROM execution GROUP BY status",
)
.fetch_all(&state.db)
.await?;
// Calculate statistics
let total = executions.len();
let completed = executions
.iter()
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Completed)
.count();
let failed = executions
.iter()
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Failed)
.count();
let running = executions
.iter()
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Running)
.count();
let pending = executions
.iter()
.filter(|e| {
matches!(
e.status,
attune_common::models::enums::ExecutionStatus::Requested
| attune_common::models::enums::ExecutionStatus::Scheduling
| attune_common::models::enums::ExecutionStatus::Scheduled
)
})
.count();
let mut completed: i64 = 0;
let mut failed: i64 = 0;
let mut running: i64 = 0;
let mut pending: i64 = 0;
let mut cancelled: i64 = 0;
let mut timeout: i64 = 0;
let mut abandoned: i64 = 0;
let mut total: i64 = 0;
for row in &rows {
let status: &str = row.get("status");
let cnt: i64 = row.get("cnt");
total += cnt;
match status {
"completed" => completed = cnt,
"failed" => failed = cnt,
"running" => running = cnt,
"requested" | "scheduling" | "scheduled" => pending += cnt,
"cancelled" | "canceling" => cancelled += cnt,
"timeout" => timeout = cnt,
"abandoned" => abandoned = cnt,
_ => {}
}
}
let stats = serde_json::json!({
"total": total,
@@ -419,9 +347,9 @@ pub async fn get_execution_stats(
"failed": failed,
"running": running,
"pending": pending,
"cancelled": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Cancelled).count(),
"timeout": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Timeout).count(),
"abandoned": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Abandoned).count(),
"cancelled": cancelled,
"timeout": timeout,
"abandoned": abandoned,
});
let response = ApiResponse::new(stats);

View File

@@ -14,8 +14,10 @@ use attune_common::{
mq::{InquiryRespondedPayload, MessageEnvelope, MessageType},
repositories::{
execution::ExecutionRepository,
inquiry::{CreateInquiryInput, InquiryRepository, UpdateInquiryInput},
Create, Delete, FindById, List, Update,
inquiry::{
CreateInquiryInput, InquiryRepository, InquirySearchFilters, UpdateInquiryInput,
},
Create, Delete, FindById, Update,
},
};
@@ -51,45 +53,30 @@ pub async fn list_inquiries(
State(state): State<Arc<AppState>>,
Query(query): Query<InquiryQueryParams>,
) -> ApiResult<impl IntoResponse> {
// Get inquiries based on filters
let inquiries = if let Some(status) = query.status {
// Filter by status
InquiryRepository::find_by_status(&state.db, status).await?
} else if let Some(execution_id) = query.execution {
// Filter by execution
InquiryRepository::find_by_execution(&state.db, execution_id).await?
} else {
// Get all inquiries
InquiryRepository::list(&state.db).await?
// All filtering and pagination happen in a single SQL query.
// Filters are combinable (AND), not mutually exclusive.
let limit = query.limit.unwrap_or(50).min(500) as u32;
let offset = query.offset.unwrap_or(0) as u32;
let filters = InquirySearchFilters {
status: query.status,
execution: query.execution,
assigned_to: query.assigned_to,
limit,
offset,
};
// Apply additional filters in memory
let mut filtered_inquiries = inquiries;
let result = InquiryRepository::search(&state.db, &filters).await?;
if let Some(assigned_to) = query.assigned_to {
filtered_inquiries.retain(|i| i.assigned_to == Some(assigned_to));
}
let paginated_inquiries: Vec<InquirySummary> =
result.rows.into_iter().map(InquirySummary::from).collect();
// Calculate pagination
let total = filtered_inquiries.len() as u64;
let offset = query.offset.unwrap_or(0);
let limit = query.limit.unwrap_or(50).min(500);
let start = offset;
let end = (start + limit).min(filtered_inquiries.len());
// Get paginated slice
let paginated_inquiries: Vec<InquirySummary> = filtered_inquiries[start..end]
.iter()
.map(|inquiry| InquirySummary::from(inquiry.clone()))
.collect();
// Convert to pagination params for response
let pagination_params = PaginationParams {
page: (offset / limit.max(1)) as u32 + 1,
page_size: limit as u32,
page: (offset / limit.max(1)) + 1,
page_size: limit,
};
let response = PaginatedResponse::new(paginated_inquiries, &pagination_params, total);
let response = PaginatedResponse::new(paginated_inquiries, &pagination_params, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -161,20 +148,21 @@ pub async fn list_inquiries_by_status(
}
};
let inquiries = InquiryRepository::find_by_status(&state.db, status).await?;
// Use the search method for SQL-side filtering + pagination.
let filters = InquirySearchFilters {
status: Some(status),
execution: None,
assigned_to: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = inquiries.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(inquiries.len());
let result = InquiryRepository::search(&state.db, &filters).await?;
// Get paginated slice
let paginated_inquiries: Vec<InquirySummary> = inquiries[start..end]
.iter()
.map(|inquiry| InquirySummary::from(inquiry.clone()))
.collect();
let paginated_inquiries: Vec<InquirySummary> =
result.rows.into_iter().map(InquirySummary::from).collect();
let response = PaginatedResponse::new(paginated_inquiries, &pagination, total);
let response = PaginatedResponse::new(paginated_inquiries, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -209,20 +197,21 @@ pub async fn list_inquiries_by_execution(
ApiError::NotFound(format!("Execution with ID {} not found", execution_id))
})?;
let inquiries = InquiryRepository::find_by_execution(&state.db, execution_id).await?;
// Use the search method for SQL-side filtering + pagination.
let filters = InquirySearchFilters {
status: None,
execution: Some(execution_id),
assigned_to: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = inquiries.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(inquiries.len());
let result = InquiryRepository::search(&state.db, &filters).await?;
// Get paginated slice
let paginated_inquiries: Vec<InquirySummary> = inquiries[start..end]
.iter()
.map(|inquiry| InquirySummary::from(inquiry.clone()))
.collect();
let paginated_inquiries: Vec<InquirySummary> =
result.rows.into_iter().map(InquirySummary::from).collect();
let response = PaginatedResponse::new(paginated_inquiries, &pagination, total);
let response = PaginatedResponse::new(paginated_inquiries, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}

View File

@@ -13,10 +13,10 @@ use validator::Validate;
use attune_common::models::OwnerType;
use attune_common::repositories::{
action::ActionRepository,
key::{CreateKeyInput, KeyRepository, UpdateKeyInput},
key::{CreateKeyInput, KeyRepository, KeySearchFilters, UpdateKeyInput},
pack::PackRepository,
trigger::SensorRepository,
Create, Delete, FindByRef, List, Update,
Create, Delete, FindByRef, Update,
};
use crate::auth::RequireAuth;
@@ -46,40 +46,24 @@ pub async fn list_keys(
State(state): State<Arc<AppState>>,
Query(query): Query<KeyQueryParams>,
) -> ApiResult<impl IntoResponse> {
// Get keys based on filters
let keys = if let Some(owner_type) = query.owner_type {
// Filter by owner type
KeyRepository::find_by_owner_type(&state.db, owner_type).await?
} else {
// Get all keys
KeyRepository::list(&state.db).await?
// All filtering and pagination happen in a single SQL query.
let filters = KeySearchFilters {
owner_type: query.owner_type,
owner: query.owner.clone(),
limit: query.limit(),
offset: query.offset(),
};
// Apply additional filters in memory
let mut filtered_keys = keys;
let result = KeyRepository::search(&state.db, &filters).await?;
if let Some(owner) = &query.owner {
filtered_keys.retain(|k| k.owner.as_ref() == Some(owner));
}
let paginated_keys: Vec<KeySummary> = result.rows.into_iter().map(KeySummary::from).collect();
// Calculate pagination
let total = filtered_keys.len() as u64;
let start = query.offset() as usize;
let end = (start + query.limit() as usize).min(filtered_keys.len());
// Get paginated slice (values redacted in summary)
let paginated_keys: Vec<KeySummary> = filtered_keys[start..end]
.iter()
.map(|key| KeySummary::from(key.clone()))
.collect();
// Convert query params to pagination params for response
let pagination_params = PaginationParams {
page: query.page,
page_size: query.per_page,
};
let response = PaginatedResponse::new(paginated_keys, &pagination_params, total);
let response = PaginatedResponse::new(paginated_keys, &pagination_params, result.total);
Ok((StatusCode::OK, Json(response)))
}

View File

@@ -17,9 +17,9 @@ use attune_common::mq::{
use attune_common::repositories::{
action::ActionRepository,
pack::PackRepository,
rule::{CreateRuleInput, RuleRepository, UpdateRuleInput},
rule::{CreateRuleInput, RuleRepository, RuleSearchFilters, UpdateRuleInput},
trigger::TriggerRepository,
Create, Delete, FindByRef, List, Update,
Create, Delete, FindByRef, Update,
};
use crate::{
@@ -50,21 +50,21 @@ pub async fn list_rules(
RequireAuth(_user): RequireAuth,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get all rules
let rules = RuleRepository::list(&state.db).await?;
let filters = RuleSearchFilters {
pack: None,
action: None,
trigger: None,
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = rules.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(rules.len());
let result = RuleRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_rules: Vec<RuleSummary> = rules[start..end]
.iter()
.map(|r| RuleSummary::from(r.clone()))
.collect();
let paginated_rules: Vec<RuleSummary> =
result.rows.into_iter().map(RuleSummary::from).collect();
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -85,21 +85,21 @@ pub async fn list_enabled_rules(
RequireAuth(_user): RequireAuth,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get enabled rules
let rules = RuleRepository::find_enabled(&state.db).await?;
let filters = RuleSearchFilters {
pack: None,
action: None,
trigger: None,
enabled: Some(true),
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = rules.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(rules.len());
let result = RuleRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_rules: Vec<RuleSummary> = rules[start..end]
.iter()
.map(|r| RuleSummary::from(r.clone()))
.collect();
let paginated_rules: Vec<RuleSummary> =
result.rows.into_iter().map(RuleSummary::from).collect();
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -130,21 +130,21 @@ pub async fn list_rules_by_pack(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
// Get rules for this pack
let rules = RuleRepository::find_by_pack(&state.db, pack.id).await?;
let filters = RuleSearchFilters {
pack: Some(pack.id),
action: None,
trigger: None,
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = rules.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(rules.len());
let result = RuleRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_rules: Vec<RuleSummary> = rules[start..end]
.iter()
.map(|r| RuleSummary::from(r.clone()))
.collect();
let paginated_rules: Vec<RuleSummary> =
result.rows.into_iter().map(RuleSummary::from).collect();
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -175,21 +175,21 @@ pub async fn list_rules_by_action(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
// Get rules for this action
let rules = RuleRepository::find_by_action(&state.db, action.id).await?;
let filters = RuleSearchFilters {
pack: None,
action: Some(action.id),
trigger: None,
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = rules.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(rules.len());
let result = RuleRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_rules: Vec<RuleSummary> = rules[start..end]
.iter()
.map(|r| RuleSummary::from(r.clone()))
.collect();
let paginated_rules: Vec<RuleSummary> =
result.rows.into_iter().map(RuleSummary::from).collect();
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -220,21 +220,21 @@ pub async fn list_rules_by_trigger(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
// Get rules for this trigger
let rules = RuleRepository::find_by_trigger(&state.db, trigger.id).await?;
let filters = RuleSearchFilters {
pack: None,
action: None,
trigger: Some(trigger.id),
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = rules.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(rules.len());
let result = RuleRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_rules: Vec<RuleSummary> = rules[start..end]
.iter()
.map(|r| RuleSummary::from(r.clone()))
.collect();
let paginated_rules: Vec<RuleSummary> =
result.rows.into_iter().map(RuleSummary::from).collect();
let response = PaginatedResponse::new(paginated_rules, &pagination, total);
let response = PaginatedResponse::new(paginated_rules, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}

View File

@@ -14,10 +14,10 @@ use attune_common::repositories::{
pack::PackRepository,
runtime::RuntimeRepository,
trigger::{
CreateSensorInput, CreateTriggerInput, SensorRepository, TriggerRepository,
UpdateSensorInput, UpdateTriggerInput,
CreateSensorInput, CreateTriggerInput, SensorRepository, SensorSearchFilters,
TriggerRepository, TriggerSearchFilters, UpdateSensorInput, UpdateTriggerInput,
},
Create, Delete, FindByRef, List, Update,
Create, Delete, FindByRef, Update,
};
use crate::{
@@ -54,21 +54,19 @@ pub async fn list_triggers(
RequireAuth(_user): RequireAuth,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get all triggers
let triggers = TriggerRepository::list(&state.db).await?;
let filters = TriggerSearchFilters {
pack: None,
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = triggers.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(triggers.len());
let result = TriggerRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
.iter()
.map(|t| TriggerSummary::from(t.clone()))
.collect();
let paginated_triggers: Vec<TriggerSummary> =
result.rows.into_iter().map(TriggerSummary::from).collect();
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -89,21 +87,19 @@ pub async fn list_enabled_triggers(
RequireAuth(_user): RequireAuth,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get enabled triggers
let triggers = TriggerRepository::find_enabled(&state.db).await?;
let filters = TriggerSearchFilters {
pack: None,
enabled: Some(true),
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = triggers.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(triggers.len());
let result = TriggerRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
.iter()
.map(|t| TriggerSummary::from(t.clone()))
.collect();
let paginated_triggers: Vec<TriggerSummary> =
result.rows.into_iter().map(TriggerSummary::from).collect();
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -134,21 +130,19 @@ pub async fn list_triggers_by_pack(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
// Get triggers for this pack
let triggers = TriggerRepository::find_by_pack(&state.db, pack.id).await?;
let filters = TriggerSearchFilters {
pack: Some(pack.id),
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = triggers.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(triggers.len());
let result = TriggerRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_triggers: Vec<TriggerSummary> = triggers[start..end]
.iter()
.map(|t| TriggerSummary::from(t.clone()))
.collect();
let paginated_triggers: Vec<TriggerSummary> =
result.rows.into_iter().map(TriggerSummary::from).collect();
let response = PaginatedResponse::new(paginated_triggers, &pagination, total);
let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -438,21 +432,20 @@ pub async fn list_sensors(
RequireAuth(_user): RequireAuth,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get all sensors
let sensors = SensorRepository::list(&state.db).await?;
let filters = SensorSearchFilters {
pack: None,
trigger: None,
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = sensors.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(sensors.len());
let result = SensorRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
.iter()
.map(|s| SensorSummary::from(s.clone()))
.collect();
let paginated_sensors: Vec<SensorSummary> =
result.rows.into_iter().map(SensorSummary::from).collect();
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -473,21 +466,20 @@ pub async fn list_enabled_sensors(
RequireAuth(_user): RequireAuth,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get enabled sensors
let sensors = SensorRepository::find_enabled(&state.db).await?;
let filters = SensorSearchFilters {
pack: None,
trigger: None,
enabled: Some(true),
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = sensors.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(sensors.len());
let result = SensorRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
.iter()
.map(|s| SensorSummary::from(s.clone()))
.collect();
let paginated_sensors: Vec<SensorSummary> =
result.rows.into_iter().map(SensorSummary::from).collect();
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -518,21 +510,20 @@ pub async fn list_sensors_by_pack(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
// Get sensors for this pack
let sensors = SensorRepository::find_by_pack(&state.db, pack.id).await?;
let filters = SensorSearchFilters {
pack: Some(pack.id),
trigger: None,
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = sensors.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(sensors.len());
let result = SensorRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
.iter()
.map(|s| SensorSummary::from(s.clone()))
.collect();
let paginated_sensors: Vec<SensorSummary> =
result.rows.into_iter().map(SensorSummary::from).collect();
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -563,21 +554,20 @@ pub async fn list_sensors_by_trigger(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
// Get sensors for this trigger
let sensors = SensorRepository::find_by_trigger(&state.db, trigger.id).await?;
let filters = SensorSearchFilters {
pack: None,
trigger: Some(trigger.id),
enabled: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = sensors.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(sensors.len());
let result = SensorRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_sensors: Vec<SensorSummary> = sensors[start..end]
.iter()
.map(|s| SensorSummary::from(s.clone()))
.collect();
let paginated_sensors: Vec<SensorSummary> =
result.rows.into_iter().map(SensorSummary::from).collect();
let response = PaginatedResponse::new(paginated_sensors, &pagination, total);
let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}

View File

@@ -16,8 +16,9 @@ use attune_common::repositories::{
pack::PackRepository,
workflow::{
CreateWorkflowDefinitionInput, UpdateWorkflowDefinitionInput, WorkflowDefinitionRepository,
WorkflowSearchFilters,
},
Create, Delete, FindByRef, List, Update,
Create, Delete, FindByRef, Update,
};
use crate::{
@@ -54,64 +55,30 @@ pub async fn list_workflows(
// Validate search params
search_params.validate()?;
// Get workflows based on filters
let mut workflows = if let Some(tags_str) = &search_params.tags {
// Filter by tags
let tags: Vec<&str> = tags_str.split(',').map(|s| s.trim()).collect();
let mut results = Vec::new();
for tag in tags {
let mut tag_results = WorkflowDefinitionRepository::find_by_tag(&state.db, tag).await?;
results.append(&mut tag_results);
}
// Remove duplicates by ID
results.sort_by_key(|w| w.id);
results.dedup_by_key(|w| w.id);
results
} else if search_params.enabled == Some(true) {
// Filter by enabled status (only return enabled workflows)
WorkflowDefinitionRepository::find_enabled(&state.db).await?
} else {
// Get all workflows
WorkflowDefinitionRepository::list(&state.db).await?
// Parse comma-separated tags into a Vec if provided
let tags = search_params.tags.as_ref().map(|t| {
t.split(',')
.map(|s| s.trim().to_string())
.collect::<Vec<_>>()
});
// All filtering and pagination happen in a single SQL query.
let filters = WorkflowSearchFilters {
pack: None,
pack_ref: search_params.pack_ref.clone(),
enabled: search_params.enabled,
tags,
search: search_params.search.clone(),
limit: pagination.limit(),
offset: pagination.offset(),
};
// Apply enabled filter if specified and not already filtered by it
if let Some(enabled) = search_params.enabled {
if search_params.tags.is_some() {
// If we filtered by tags, also apply enabled filter
workflows.retain(|w| w.enabled == enabled);
}
}
let result = WorkflowDefinitionRepository::list_search(&state.db, &filters).await?;
// Apply search filter if provided
if let Some(search_term) = &search_params.search {
let search_lower = search_term.to_lowercase();
workflows.retain(|w| {
w.label.to_lowercase().contains(&search_lower)
|| w.description
.as_ref()
.map(|d| d.to_lowercase().contains(&search_lower))
.unwrap_or(false)
});
}
let paginated_workflows: Vec<WorkflowSummary> =
result.rows.into_iter().map(WorkflowSummary::from).collect();
// Apply pack_ref filter if provided
if let Some(pack_ref) = &search_params.pack_ref {
workflows.retain(|w| w.pack_ref == *pack_ref);
}
// Calculate pagination
let total = workflows.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(workflows.len());
// Get paginated slice
let paginated_workflows: Vec<WorkflowSummary> = workflows[start..end]
.iter()
.map(|w| WorkflowSummary::from(w.clone()))
.collect();
let response = PaginatedResponse::new(paginated_workflows, &pagination, total);
let response = PaginatedResponse::new(paginated_workflows, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -138,25 +105,27 @@ pub async fn list_workflows_by_pack(
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Verify pack exists
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
let _pack = PackRepository::find_by_ref(&state.db, &pack_ref)
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
// Get workflows for this pack
let workflows = WorkflowDefinitionRepository::find_by_pack(&state.db, pack.id).await?;
// All filtering and pagination happen in a single SQL query.
let filters = WorkflowSearchFilters {
pack: None,
pack_ref: Some(pack_ref),
enabled: None,
tags: None,
search: None,
limit: pagination.limit(),
offset: pagination.offset(),
};
// Calculate pagination
let total = workflows.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(workflows.len());
let result = WorkflowDefinitionRepository::list_search(&state.db, &filters).await?;
// Get paginated slice
let paginated_workflows: Vec<WorkflowSummary> = workflows[start..end]
.iter()
.map(|w| WorkflowSummary::from(w.clone()))
.collect();
let paginated_workflows: Vec<WorkflowSummary> =
result.rows.into_iter().map(WorkflowSummary::from).collect();
let response = PaginatedResponse::new(paginated_workflows, &pagination, total);
let response = PaginatedResponse::new(paginated_workflows, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}

View File

@@ -1104,6 +1104,11 @@ pub mod execution {
pub status: ExecutionStatus,
pub result: Option<JsonDict>,
/// When the execution actually started running (worker picked it up).
/// Set when status transitions to `Running`. Used to compute accurate
/// duration that excludes queue/scheduling wait time.
pub started_at: Option<DateTime<Utc>>,
/// Workflow task metadata (only populated for workflow task executions)
///
/// Provides direct access to workflow orchestration state without JOINs.

View File

@@ -8,6 +8,26 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
/// Filters for [`ActionRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct ActionSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Text search across ref, label, description (case-insensitive)
pub query: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`ActionRepository::list_search`].
#[derive(Debug)]
pub struct ActionSearchResult {
pub rows: Vec<Action>,
pub total: u64,
}
/// Repository for Action operations
pub struct ActionRepository;
@@ -287,6 +307,92 @@ impl Delete for ActionRepository {
}
impl ActionRepository {
/// Search actions with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(
db: E,
filters: &ActionSearchFilters,
) -> Result<ActionSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_version_constraint, param_schema, out_schema, workflow_def, is_adhoc, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM action"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM action");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(ref query) = filters.query {
let pattern = format!("%{}%", query.to_lowercase());
// Search needs an OR across multiple columns, wrapped in parens
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push("(LOWER(ref) LIKE ");
qb.push_bind(pattern.clone());
qb.push(" OR LOWER(label) LIKE ");
qb.push_bind(pattern.clone());
qb.push(" OR LOWER(description) LIKE ");
qb.push_bind(pattern.clone());
qb.push(")");
count_qb.push("(LOWER(ref) LIKE ");
count_qb.push_bind(pattern.clone());
count_qb.push(" OR LOWER(label) LIKE ");
count_qb.push_bind(pattern.clone());
count_qb.push(" OR LOWER(description) LIKE ");
count_qb.push_bind(pattern);
count_qb.push(")");
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Action> = qb.build_query_as().fetch_all(db).await?;
Ok(ActionSearchResult { rows, total })
}
/// Find actions by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Action>>
where

View File

@@ -15,6 +15,56 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
// ============================================================================
// Event Search
// ============================================================================
/// Filters for [`EventRepository::search`].
///
/// All fields are optional. When set, the corresponding WHERE clause is added.
/// Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct EventSearchFilters {
pub trigger: Option<Id>,
pub trigger_ref: Option<String>,
pub source: Option<Id>,
pub rule_ref: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`EventRepository::search`].
#[derive(Debug)]
pub struct EventSearchResult {
pub rows: Vec<Event>,
pub total: u64,
}
// ============================================================================
// Enforcement Search
// ============================================================================
/// Filters for [`EnforcementRepository::search`].
///
/// All fields are optional and combinable. Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct EnforcementSearchFilters {
pub rule: Option<Id>,
pub event: Option<Id>,
pub status: Option<EnforcementStatus>,
pub trigger_ref: Option<String>,
pub rule_ref: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`EnforcementRepository::search`].
#[derive(Debug)]
pub struct EnforcementSearchResult {
pub rows: Vec<Enforcement>,
pub total: u64,
}
/// Repository for Event operations
pub struct EventRepository;
@@ -173,6 +223,75 @@ impl EventRepository {
Ok(events)
}
/// Search events with all filters pushed into SQL.
///
/// Builds a dynamic query so that every filter, pagination, and the total
/// count are handled in the database — no in-memory filtering or slicing.
pub async fn search<'e, E>(db: E, filters: &EventSearchFilters) -> Result<EventSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, trigger, trigger_ref, config, payload, source, source_ref, rule, rule_ref, created";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM event"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM event");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(trigger_id) = filters.trigger {
push_condition!("trigger = ", trigger_id);
}
if let Some(ref trigger_ref) = filters.trigger_ref {
push_condition!("trigger_ref = ", trigger_ref.clone());
}
if let Some(source_id) = filters.source {
push_condition!("source = ", source_id);
}
if let Some(ref rule_ref) = filters.rule_ref {
push_condition!(
"LOWER(rule_ref) LIKE ",
format!("%{}%", rule_ref.to_lowercase())
);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Event> = qb.build_query_as().fetch_all(db).await?;
Ok(EventSearchResult { rows, total })
}
}
// ============================================================================
@@ -425,4 +544,75 @@ impl EnforcementRepository {
Ok(enforcements)
}
/// Search enforcements with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn search<'e, E>(
db: E,
filters: &EnforcementSearchFilters,
) -> Result<EnforcementSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, rule, rule_ref, trigger_ref, config, event, status, payload, condition, conditions, created, resolved_at";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM enforcement"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM enforcement");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(status) = &filters.status {
push_condition!("status = ", status.clone());
}
if let Some(rule_id) = filters.rule {
push_condition!("rule = ", rule_id);
}
if let Some(event_id) = filters.event {
push_condition!("event = ", event_id);
}
if let Some(ref trigger_ref) = filters.trigger_ref {
push_condition!("trigger_ref = ", trigger_ref.clone());
}
if let Some(ref rule_ref) = filters.rule_ref {
push_condition!("rule_ref = ", rule_ref.clone());
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Enforcement> = qb.build_query_as().fetch_all(db).await?;
Ok(EnforcementSearchResult { rows, total })
}
}

View File

@@ -1,11 +1,71 @@
//! Execution repository for database operations
use chrono::{DateTime, Utc};
use crate::models::{enums::ExecutionStatus, execution::*, Id, JsonDict};
use crate::Result;
use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
/// Filters for the [`ExecutionRepository::search`] query-builder method.
///
/// Every field is optional. When set, the corresponding `WHERE` clause is
/// appended to the query. Pagination (`limit`/`offset`) is always applied.
///
/// Filters that involve the `enforcement` table (`rule_ref`, `trigger_ref`)
/// cause a `LEFT JOIN enforcement` to be added automatically.
#[derive(Debug, Clone, Default)]
pub struct ExecutionSearchFilters {
pub status: Option<ExecutionStatus>,
pub action_ref: Option<String>,
pub pack_name: Option<String>,
pub rule_ref: Option<String>,
pub trigger_ref: Option<String>,
pub executor: Option<Id>,
pub result_contains: Option<String>,
pub enforcement: Option<Id>,
pub parent: Option<Id>,
pub top_level_only: bool,
pub limit: u32,
pub offset: u32,
}
/// Result of [`ExecutionRepository::search`].
///
/// Includes the matching rows *and* the total count (before LIMIT/OFFSET)
/// so the caller can build pagination metadata without a second round-trip.
#[derive(Debug)]
pub struct ExecutionSearchResult {
pub rows: Vec<ExecutionWithRefs>,
pub total: u64,
}
/// An execution row with optional `rule_ref` / `trigger_ref` populated from
/// the joined `enforcement` table. This avoids a separate in-memory lookup.
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct ExecutionWithRefs {
// — execution columns (same order as SELECT_COLUMNS) —
pub id: Id,
pub action: Option<Id>,
pub action_ref: String,
pub config: Option<JsonDict>,
pub env_vars: Option<JsonDict>,
pub parent: Option<Id>,
pub enforcement: Option<Id>,
pub executor: Option<Id>,
pub status: ExecutionStatus,
pub result: Option<JsonDict>,
pub started_at: Option<DateTime<Utc>>,
#[sqlx(json, default)]
pub workflow_task: Option<WorkflowTaskMetadata>,
pub created: DateTime<Utc>,
pub updated: DateTime<Utc>,
// — joined from enforcement —
pub rule_ref: Option<String>,
pub trigger_ref: Option<String>,
}
/// Column list for SELECT queries on the execution table.
///
/// Defined once to avoid drift between queries and the `Execution` model.
@@ -13,7 +73,7 @@ use super::{Create, Delete, FindById, List, Repository, Update};
/// are NOT in the Rust struct, so `SELECT *` must never be used.
pub const SELECT_COLUMNS: &str = "\
id, action, action_ref, config, env_vars, parent, enforcement, \
executor, status, result, workflow_task, created, updated";
executor, status, result, started_at, workflow_task, created, updated";
pub struct ExecutionRepository;
@@ -43,6 +103,7 @@ pub struct UpdateExecutionInput {
pub status: Option<ExecutionStatus>,
pub result: Option<JsonDict>,
pub executor: Option<Id>,
pub started_at: Option<DateTime<Utc>>,
pub workflow_task: Option<WorkflowTaskMetadata>,
}
@@ -52,6 +113,7 @@ impl From<Execution> for UpdateExecutionInput {
status: Some(execution.status),
result: execution.result,
executor: execution.executor,
started_at: execution.started_at,
workflow_task: execution.workflow_task,
}
}
@@ -146,6 +208,13 @@ impl Update for ExecutionRepository {
query.push("executor = ").push_bind(executor_id);
has_updates = true;
}
if let Some(started_at) = input.started_at {
if has_updates {
query.push(", ");
}
query.push("started_at = ").push_bind(started_at);
has_updates = true;
}
if let Some(workflow_task) = &input.workflow_task {
if has_updates {
query.push(", ");
@@ -239,4 +308,141 @@ impl ExecutionRepository {
.await
.map_err(Into::into)
}
/// Search executions with all filters pushed into SQL.
///
/// Builds a dynamic query with only the WHERE clauses that apply,
/// a LEFT JOIN on `enforcement` when `rule_ref` or `trigger_ref` filters
/// are present (or always, to populate those columns on the result),
/// and proper LIMIT/OFFSET so pagination is server-side.
///
/// Returns both the matching page of rows and the total count.
pub async fn search<'e, E>(
db: E,
filters: &ExecutionSearchFilters,
) -> Result<ExecutionSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
// We always LEFT JOIN enforcement so we can return rule_ref/trigger_ref
// on every row without a second round-trip.
let prefixed_select = SELECT_COLUMNS
.split(", ")
.map(|col| format!("e.{col}"))
.collect::<Vec<_>>()
.join(", ");
let select_clause = format!(
"{prefixed_select}, enf.rule_ref AS rule_ref, enf.trigger_ref AS trigger_ref"
);
let from_clause = "FROM execution e LEFT JOIN enforcement enf ON e.enforcement = enf.id";
// ── Build WHERE clauses ──────────────────────────────────────────
let mut conditions: Vec<String> = Vec::new();
// We'll collect bind values to push into the QueryBuilder afterwards.
// Because QueryBuilder doesn't let us interleave raw SQL and binds in
// arbitrary order easily, we build the SQL string with numbered $N
// placeholders and then bind in order.
// Track the next placeholder index ($1, $2, …).
// We can't use QueryBuilder's push_bind because we need the COUNT(*)
// query to share the same WHERE clause text. Instead we build the
// clause once and execute both queries with manual binds.
// ── Use QueryBuilder for the data query ──────────────────────────
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_clause} {from_clause}"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT COUNT(*) AS total {from_clause}"));
// Helper: append the same condition to both builders.
// We need a tiny state machine since push_bind moves the value.
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
let needs_where = conditions.is_empty();
conditions.push(String::new()); // just to track count
if needs_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
macro_rules! push_raw_condition {
($cond:expr) => {{
let needs_where = conditions.is_empty();
conditions.push(String::new());
if needs_where {
qb.push(concat!(" WHERE ", $cond));
count_qb.push(concat!(" WHERE ", $cond));
} else {
qb.push(concat!(" AND ", $cond));
count_qb.push(concat!(" AND ", $cond));
}
}};
}
if let Some(status) = &filters.status {
push_condition!("e.status = ", status.clone());
}
if let Some(action_ref) = &filters.action_ref {
push_condition!("e.action_ref = ", action_ref.clone());
}
if let Some(pack_name) = &filters.pack_name {
let pattern = format!("{pack_name}.%");
push_condition!("e.action_ref LIKE ", pattern);
}
if let Some(enforcement_id) = filters.enforcement {
push_condition!("e.enforcement = ", enforcement_id);
}
if let Some(parent_id) = filters.parent {
push_condition!("e.parent = ", parent_id);
}
if filters.top_level_only {
push_raw_condition!("e.parent IS NULL");
}
if let Some(executor_id) = filters.executor {
push_condition!("e.executor = ", executor_id);
}
if let Some(rule_ref) = &filters.rule_ref {
push_condition!("enf.rule_ref = ", rule_ref.clone());
}
if let Some(trigger_ref) = &filters.trigger_ref {
push_condition!("enf.trigger_ref = ", trigger_ref.clone());
}
if let Some(search) = &filters.result_contains {
let pattern = format!("%{}%", search.to_lowercase());
push_condition!("LOWER(e.result::text) LIKE ", pattern);
}
// ── COUNT query ──────────────────────────────────────────────────
let total: i64 = count_qb
.build_query_scalar()
.fetch_one(db)
.await?;
let total = total.max(0) as u64;
// ── Data query with ORDER BY + pagination ────────────────────────
qb.push(" ORDER BY e.created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<ExecutionWithRefs> = qb
.build_query_as()
.fetch_all(db)
.await?;
Ok(ExecutionSearchResult { rows, total })
}
}

View File

@@ -7,6 +7,25 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
/// Filters for [`InquiryRepository::search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct InquirySearchFilters {
pub status: Option<InquiryStatus>,
pub execution: Option<Id>,
pub assigned_to: Option<Id>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`InquiryRepository::search`].
#[derive(Debug)]
pub struct InquirySearchResult {
pub rows: Vec<Inquiry>,
pub total: u64,
}
pub struct InquiryRepository;
impl Repository for InquiryRepository {
@@ -157,4 +176,66 @@ impl InquiryRepository {
"SELECT id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated FROM inquiry WHERE execution = $1 ORDER BY created DESC"
).bind(execution_id).fetch_all(executor).await.map_err(Into::into)
}
/// Search inquiries with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn search<'e, E>(db: E, filters: &InquirySearchFilters) -> Result<InquirySearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM inquiry"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM inquiry");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(status) = &filters.status {
push_condition!("status = ", status.clone());
}
if let Some(execution_id) = filters.execution {
push_condition!("execution = ", execution_id);
}
if let Some(assigned_to) = filters.assigned_to {
push_condition!("assigned_to = ", assigned_to);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Inquiry> = qb.build_query_as().fetch_all(db).await?;
Ok(InquirySearchResult { rows, total })
}
}

View File

@@ -6,6 +6,24 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
/// Filters for [`KeyRepository::search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct KeySearchFilters {
pub owner_type: Option<OwnerType>,
pub owner: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`KeyRepository::search`].
#[derive(Debug)]
pub struct KeySearchResult {
pub rows: Vec<Key>,
pub total: u64,
}
pub struct KeyRepository;
impl Repository for KeyRepository {
@@ -165,4 +183,63 @@ impl KeyRepository {
"SELECT id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated FROM key WHERE owner_type = $1 ORDER BY ref ASC"
).bind(owner_type).fetch_all(executor).await.map_err(Into::into)
}
/// Search keys with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn search<'e, E>(db: E, filters: &KeySearchFilters) -> Result<KeySearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM key"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM key");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(ref owner_type) = filters.owner_type {
push_condition!("owner_type = ", owner_type.clone());
}
if let Some(ref owner) = filters.owner {
push_condition!("owner = ", owner.clone());
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Key> = qb.build_query_as().fetch_all(db).await?;
Ok(KeySearchResult { rows, total })
}
}

View File

@@ -8,6 +8,30 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
/// Filters for [`RuleRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct RuleSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by action ID
pub action: Option<Id>,
/// Filter by trigger ID
pub trigger: Option<Id>,
/// Filter by enabled status
pub enabled: Option<bool>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`RuleRepository::list_search`].
#[derive(Debug)]
pub struct RuleSearchResult {
pub rows: Vec<Rule>,
pub total: u64,
}
/// Input for restoring an ad-hoc rule during pack reinstallation.
/// Unlike `CreateRuleInput`, action and trigger IDs are optional because
/// the referenced entities may not exist yet or may have been removed.
@@ -275,6 +299,71 @@ impl Delete for RuleRepository {
}
impl RuleRepository {
/// Search rules with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(db: E, filters: &RuleSearchFilters) -> Result<RuleSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, action, action_ref, trigger, trigger_ref, conditions, action_params, trigger_params, enabled, is_adhoc, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM rule"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM rule");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(action_id) = filters.action {
push_condition!("action = ", action_id);
}
if let Some(trigger_id) = filters.trigger {
push_condition!("trigger = ", trigger_id);
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Rule> = qb.build_query_as().fetch_all(db).await?;
Ok(RuleSearchResult { rows, total })
}
/// Find rules by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Rule>>
where

View File

@@ -9,6 +9,56 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
// ============================================================================
// Trigger Search
// ============================================================================
/// Filters for [`TriggerRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct TriggerSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by enabled status
pub enabled: Option<bool>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`TriggerRepository::list_search`].
#[derive(Debug)]
pub struct TriggerSearchResult {
pub rows: Vec<Trigger>,
pub total: u64,
}
// ============================================================================
// Sensor Search
// ============================================================================
/// Filters for [`SensorRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct SensorSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by trigger ID
pub trigger: Option<Id>,
/// Filter by enabled status
pub enabled: Option<bool>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`SensorRepository::list_search`].
#[derive(Debug)]
pub struct SensorSearchResult {
pub rows: Vec<Sensor>,
pub total: u64,
}
/// Repository for Trigger operations
pub struct TriggerRepository;
@@ -251,6 +301,68 @@ impl Delete for TriggerRepository {
}
impl TriggerRepository {
/// Search triggers with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(
db: E,
filters: &TriggerSearchFilters,
) -> Result<TriggerSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, enabled, param_schema, out_schema, webhook_enabled, webhook_key, webhook_config, is_adhoc, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM trigger"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM trigger");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Trigger> = qb.build_query_as().fetch_all(db).await?;
Ok(TriggerSearchResult { rows, total })
}
/// Find triggers by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Trigger>>
where
@@ -795,6 +907,71 @@ impl Delete for SensorRepository {
}
impl SensorRepository {
/// Search sensors with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(
db: E,
filters: &SensorSearchFilters,
) -> Result<SensorSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_ref, runtime_version_constraint, trigger, trigger_ref, enabled, param_schema, config, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM sensor"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM sensor");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(trigger_id) = filters.trigger {
push_condition!("trigger = ", trigger_id);
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Sensor> = qb.build_query_as().fetch_all(db).await?;
Ok(SensorSearchResult { rows, total })
}
/// Find sensors by trigger ID
pub async fn find_by_trigger<'e, E>(executor: E, trigger_id: Id) -> Result<Vec<Sensor>>
where

View File

@@ -6,6 +6,37 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
// ============================================================================
// Workflow Definition Search
// ============================================================================
/// Filters for [`WorkflowDefinitionRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
/// Tag filtering uses `ANY(tags)` for each tag (OR across tags, AND with other filters).
#[derive(Debug, Clone, Default)]
pub struct WorkflowSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by pack reference
pub pack_ref: Option<String>,
/// Filter by enabled status
pub enabled: Option<bool>,
/// Filter by tags (OR across tags — matches if any tag is present)
pub tags: Option<Vec<String>>,
/// Text search across label and description (case-insensitive substring)
pub search: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`WorkflowDefinitionRepository::list_search`].
#[derive(Debug)]
pub struct WorkflowSearchResult {
pub rows: Vec<WorkflowDefinition>,
pub total: u64,
}
// ============================================================================
// WORKFLOW DEFINITION REPOSITORY
// ============================================================================
@@ -226,6 +257,102 @@ impl Delete for WorkflowDefinitionRepository {
}
impl WorkflowDefinitionRepository {
/// Search workflow definitions with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
/// Tags use an OR match — a workflow matches if it contains ANY of the
/// requested tags (via `tags && ARRAY[...]`).
pub async fn list_search<'e, E>(
db: E,
filters: &WorkflowSearchFilters,
) -> Result<WorkflowSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, version, param_schema, out_schema, definition, tags, enabled, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM workflow_definition"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM workflow_definition");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(ref pack_ref) = filters.pack_ref {
push_condition!("pack_ref = ", pack_ref.clone());
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
if let Some(ref tags) = filters.tags {
if !tags.is_empty() {
// Use PostgreSQL array overlap operator: tags && ARRAY[...]
push_condition!("tags && ", tags.clone());
}
}
if let Some(ref search) = filters.search {
let pattern = format!("%{}%", search.to_lowercase());
// Search needs an OR across multiple columns, wrapped in parens
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push("(LOWER(label) LIKE ");
qb.push_bind(pattern.clone());
qb.push(" OR LOWER(COALESCE(description, '')) LIKE ");
qb.push_bind(pattern.clone());
qb.push(")");
count_qb.push("(LOWER(label) LIKE ");
count_qb.push_bind(pattern.clone());
count_qb.push(" OR LOWER(COALESCE(description, '')) LIKE ");
count_qb.push_bind(pattern);
count_qb.push(")");
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY label ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<WorkflowDefinition> = qb.build_query_as().fetch_all(db).await?;
Ok(WorkflowSearchResult { rows, total })
}
/// Find all workflows for a specific pack by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<WorkflowDefinition>>
where

View File

@@ -0,0 +1,112 @@
//! # Expression AST
//!
//! Defines the abstract syntax tree nodes produced by the parser and consumed
//! by the evaluator.
use std::fmt;
/// A binary operator connecting two sub-expressions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOp {
// Arithmetic
Add,
Sub,
Mul,
Div,
Mod,
// Comparison
Eq,
Ne,
Lt,
Gt,
Le,
Ge,
// Logical
And,
Or,
// Membership
In,
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BinaryOp::Add => write!(f, "+"),
BinaryOp::Sub => write!(f, "-"),
BinaryOp::Mul => write!(f, "*"),
BinaryOp::Div => write!(f, "/"),
BinaryOp::Mod => write!(f, "%"),
BinaryOp::Eq => write!(f, "=="),
BinaryOp::Ne => write!(f, "!="),
BinaryOp::Lt => write!(f, "<"),
BinaryOp::Gt => write!(f, ">"),
BinaryOp::Le => write!(f, "<="),
BinaryOp::Ge => write!(f, ">="),
BinaryOp::And => write!(f, "and"),
BinaryOp::Or => write!(f, "or"),
BinaryOp::In => write!(f, "in"),
}
}
}
/// A unary operator applied to a single sub-expression.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOp {
/// Arithmetic negation: `-x`
Neg,
/// Logical negation: `not x`
Not,
}
impl fmt::Display for UnaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnaryOp::Neg => write!(f, "-"),
UnaryOp::Not => write!(f, "not"),
}
}
}
/// An expression AST node.
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
/// A literal JSON value: number, string, bool, or null.
Literal(serde_json::Value),
/// An array literal: `[expr, expr, ...]`
Array(Vec<Expr>),
/// A variable reference by name (e.g., `x`, `parameters`, `item`).
Ident(String),
/// Binary operation: `left op right`
BinaryOp {
op: BinaryOp,
left: Box<Expr>,
right: Box<Expr>,
},
/// Unary operation: `op operand`
UnaryOp {
op: UnaryOp,
operand: Box<Expr>,
},
/// Property access: `expr.field`
DotAccess {
object: Box<Expr>,
field: String,
},
/// Index/bracket access: `expr[index_expr]`
IndexAccess {
object: Box<Expr>,
index: Box<Expr>,
},
/// Function call: `name(arg1, arg2, ...)`
FunctionCall {
name: String,
args: Vec<Expr>,
},
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,545 @@
//! # Workflow Expression Engine
//!
//! A complete expression evaluator for workflow templates, supporting arithmetic,
//! comparison, boolean logic, member access, and built-in functions over JSON values.
//!
//! ## Architecture
//!
//! The engine is structured as a classic three-phase interpreter:
//!
//! 1. **Lexer** (`tokenizer.rs`) — converts expression strings into a stream of tokens
//! 2. **Parser** (`parser.rs`) — builds an AST from tokens using recursive descent
//! 3. **Evaluator** (`evaluator.rs`) — walks the AST and produces a `JsonValue` result
//!
//! ## Supported Operators
//!
//! ### Arithmetic
//! - `+` (addition for numbers, concatenation for strings)
//! - `-` (subtraction, unary negation)
//! - `*`, `/`, `%` (multiplication, division, modulo)
//!
//! ### Comparison
//! - `==`, `!=` (equality — works on all types, recursive for objects/arrays)
//! - `>`, `<`, `>=`, `<=` (ordering — numbers and strings only)
//! - Float/int comparisons allowed: `3 == 3.0` → true
//!
//! ### Boolean / Logical
//! - `and`, `or`, `not`
//!
//! ### Membership & Access
//! - `.` — object property access
//! - `[n]` — array index / object bracket access
//! - `in` — membership test (item in list, key in object, substring in string)
//!
//! ## Built-in Functions
//!
//! ### Type conversion
//! - `string(v)`, `number(v)`, `int(v)`, `bool(v)`
//!
//! ### Introspection
//! - `type_of(v)`, `length(v)`, `keys(obj)`, `values(obj)`
//!
//! ### Math
//! - `abs(n)`, `floor(n)`, `ceil(n)`, `round(n)`, `min(a,b)`, `max(a,b)`, `sum(arr)`
//!
//! ### String
//! - `lower(s)`, `upper(s)`, `trim(s)`, `split(s, sep)`, `join(arr, sep)`
//! - `replace(s, old, new)`, `starts_with(s, prefix)`, `ends_with(s, suffix)`
//! - `match(pattern, s)` — regex match
//!
//! ### Collection
//! - `contains(haystack, needle)`, `reversed(v)`, `sort(arr)`, `unique(arr)`
//! - `flat(arr)`, `zip(a, b)`, `range(n)` / `range(start, end)`
//!
//! ### Workflow-specific
//! - `result()`, `succeeded()`, `failed()`, `timed_out()`
mod ast;
mod evaluator;
mod parser;
mod tokenizer;
pub use ast::{BinaryOp, Expr, UnaryOp};
pub use evaluator::{is_truthy, EvalContext, EvalError, EvalResult};
pub use parser::{ParseError, Parser};
pub use tokenizer::{Token, TokenKind, Tokenizer};
use serde_json::Value as JsonValue;
/// Parse and evaluate an expression string against the given context.
///
/// This is the main entry point for the expression engine. It tokenizes the
/// input, parses it into an AST, and evaluates it to produce a `JsonValue`.
pub fn eval_expression(input: &str, ctx: &dyn EvalContext) -> EvalResult<JsonValue> {
let tokens = Tokenizer::new(input).tokenize().map_err(|e| {
EvalError::ParseError(format!("{}", e))
})?;
let ast = Parser::new(&tokens).parse().map_err(|e| {
EvalError::ParseError(format!("{}", e))
})?;
evaluator::eval(&ast, ctx)
}
/// Parse an expression string into an AST without evaluating it.
///
/// Useful for validation or inspection.
pub fn parse_expression(input: &str) -> Result<Expr, ParseError> {
let tokens = Tokenizer::new(input).tokenize().map_err(|e| {
ParseError::TokenError(format!("{}", e))
})?;
Parser::new(&tokens).parse()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::collections::HashMap;
/// A minimal eval context for integration tests.
struct TestContext {
variables: HashMap<String, JsonValue>,
}
impl TestContext {
fn new() -> Self {
Self {
variables: HashMap::new(),
}
}
fn with_var(mut self, name: &str, value: JsonValue) -> Self {
self.variables.insert(name.to_string(), value);
self
}
}
impl EvalContext for TestContext {
fn resolve_variable(&self, name: &str) -> EvalResult<JsonValue> {
self.variables
.get(name)
.cloned()
.ok_or_else(|| EvalError::VariableNotFound(name.to_string()))
}
fn call_workflow_function(
&self,
_name: &str,
_args: &[JsonValue],
) -> EvalResult<Option<JsonValue>> {
Ok(None)
}
}
// ---------------------------------------------------------------
// Arithmetic
// ---------------------------------------------------------------
#[test]
fn test_integer_arithmetic() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2 + 3", &ctx).unwrap(), json!(5));
assert_eq!(eval_expression("10 - 4", &ctx).unwrap(), json!(6));
assert_eq!(eval_expression("3 * 7", &ctx).unwrap(), json!(21));
assert_eq!(eval_expression("15 / 5", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("17 % 5", &ctx).unwrap(), json!(2));
}
#[test]
fn test_float_arithmetic() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2.5 + 1.5", &ctx).unwrap(), json!(4.0));
assert_eq!(eval_expression("10.0 / 3.0", &ctx).unwrap(), json!(10.0 / 3.0));
}
#[test]
fn test_mixed_int_float() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2 + 1.5", &ctx).unwrap(), json!(3.5));
// Integer division yields float when not evenly divisible
assert_eq!(eval_expression("10 / 4", &ctx).unwrap(), json!(2.5));
assert_eq!(eval_expression("10 / 4.0", &ctx).unwrap(), json!(2.5));
// Evenly divisible integer division stays integer
assert_eq!(eval_expression("10 / 5", &ctx).unwrap(), json!(2));
}
#[test]
fn test_operator_precedence() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2 + 3 * 4", &ctx).unwrap(), json!(14));
assert_eq!(eval_expression("(2 + 3) * 4", &ctx).unwrap(), json!(20));
assert_eq!(eval_expression("10 - 2 * 3 + 1", &ctx).unwrap(), json!(5));
}
#[test]
fn test_unary_negation() {
let ctx = TestContext::new();
assert_eq!(eval_expression("-5", &ctx).unwrap(), json!(-5));
assert_eq!(eval_expression("-2 + 3", &ctx).unwrap(), json!(1));
assert_eq!(eval_expression("-(2 + 3)", &ctx).unwrap(), json!(-5));
}
#[test]
fn test_string_concatenation() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("\"hello\" + \" \" + \"world\"", &ctx).unwrap(),
json!("hello world")
);
}
// ---------------------------------------------------------------
// Comparison
// ---------------------------------------------------------------
#[test]
fn test_number_comparison() {
let ctx = TestContext::new();
assert_eq!(eval_expression("3 == 3", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 != 4", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 > 2", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 < 2", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("3 >= 3", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 <= 4", &ctx).unwrap(), json!(true));
}
#[test]
fn test_int_float_equality() {
let ctx = TestContext::new();
assert_eq!(eval_expression("3 == 3.0", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3.0 == 3", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 != 3.1", &ctx).unwrap(), json!(true));
}
#[test]
fn test_string_comparison() {
let ctx = TestContext::new();
assert_eq!(eval_expression("\"abc\" == \"abc\"", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("\"abc\" < \"abd\"", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("\"abc\" > \"abb\"", &ctx).unwrap(), json!(true));
}
#[test]
fn test_null_equality() {
let ctx = TestContext::new();
assert_eq!(eval_expression("null == null", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("null != null", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("null == 0", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("null == false", &ctx).unwrap(), json!(false));
}
#[test]
fn test_array_equality() {
let ctx = TestContext::new()
.with_var("a", json!([1, 2, 3]))
.with_var("b", json!([1, 2, 3]))
.with_var("c", json!([1, 2, 4]));
assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true));
}
#[test]
fn test_object_equality() {
let ctx = TestContext::new()
.with_var("a", json!({"x": 1, "y": 2}))
.with_var("b", json!({"y": 2, "x": 1}))
.with_var("c", json!({"x": 1, "y": 3}));
assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true));
}
// ---------------------------------------------------------------
// Boolean / Logical
// ---------------------------------------------------------------
#[test]
fn test_boolean_operators() {
let ctx = TestContext::new();
assert_eq!(eval_expression("true and true", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("true and false", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("false or true", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("false or false", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("not true", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("not false", &ctx).unwrap(), json!(true));
}
#[test]
fn test_boolean_precedence() {
let ctx = TestContext::new();
// `and` binds tighter than `or`
assert_eq!(
eval_expression("true or false and false", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("(true or false) and false", &ctx).unwrap(),
json!(false)
);
}
// ---------------------------------------------------------------
// Membership & access
// ---------------------------------------------------------------
#[test]
fn test_dot_access() {
let ctx = TestContext::new()
.with_var("obj", json!({"a": {"b": 42}}));
assert_eq!(eval_expression("obj.a.b", &ctx).unwrap(), json!(42));
}
#[test]
fn test_bracket_access() {
let ctx = TestContext::new()
.with_var("arr", json!([10, 20, 30]))
.with_var("obj", json!({"key": "value"}));
assert_eq!(eval_expression("arr[1]", &ctx).unwrap(), json!(20));
assert_eq!(eval_expression("obj[\"key\"]", &ctx).unwrap(), json!("value"));
}
#[test]
fn test_in_operator() {
let ctx = TestContext::new()
.with_var("arr", json!([1, 2, 3]))
.with_var("obj", json!({"key": "val"}));
assert_eq!(eval_expression("2 in arr", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("5 in arr", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("\"key\" in obj", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("\"nope\" in obj", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("\"ell\" in \"hello\"", &ctx).unwrap(), json!(true));
}
// ---------------------------------------------------------------
// Built-in functions
// ---------------------------------------------------------------
#[test]
fn test_length() {
let ctx = TestContext::new()
.with_var("arr", json!([1, 2, 3]))
.with_var("obj", json!({"a": 1, "b": 2}));
assert_eq!(eval_expression("length(arr)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("length(\"hello\")", &ctx).unwrap(), json!(5));
assert_eq!(eval_expression("length(obj)", &ctx).unwrap(), json!(2));
}
#[test]
fn test_type_conversions() {
let ctx = TestContext::new();
assert_eq!(eval_expression("string(42)", &ctx).unwrap(), json!("42"));
assert_eq!(eval_expression("number(\"3.14\")", &ctx).unwrap(), json!(3.14));
assert_eq!(eval_expression("int(3.9)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("int(\"42\")", &ctx).unwrap(), json!(42));
assert_eq!(eval_expression("bool(1)", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("bool(0)", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("bool(\"\")", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("bool(\"x\")", &ctx).unwrap(), json!(true));
}
#[test]
fn test_type_of() {
let ctx = TestContext::new()
.with_var("arr", json!([1]))
.with_var("obj", json!({}));
assert_eq!(eval_expression("type_of(42)", &ctx).unwrap(), json!("number"));
assert_eq!(eval_expression("type_of(\"hi\")", &ctx).unwrap(), json!("string"));
assert_eq!(eval_expression("type_of(true)", &ctx).unwrap(), json!("bool"));
assert_eq!(eval_expression("type_of(null)", &ctx).unwrap(), json!("null"));
assert_eq!(eval_expression("type_of(arr)", &ctx).unwrap(), json!("array"));
assert_eq!(eval_expression("type_of(obj)", &ctx).unwrap(), json!("object"));
}
#[test]
fn test_keys_values() {
let ctx = TestContext::new()
.with_var("obj", json!({"b": 2, "a": 1}));
let keys = eval_expression("sort(keys(obj))", &ctx).unwrap();
assert_eq!(keys, json!(["a", "b"]));
let values = eval_expression("sort(values(obj))", &ctx).unwrap();
assert_eq!(values, json!([1, 2]));
}
#[test]
fn test_math_functions() {
let ctx = TestContext::new();
assert_eq!(eval_expression("abs(-5)", &ctx).unwrap(), json!(5));
assert_eq!(eval_expression("floor(3.7)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("ceil(3.2)", &ctx).unwrap(), json!(4));
assert_eq!(eval_expression("round(3.5)", &ctx).unwrap(), json!(4));
assert_eq!(eval_expression("min(3, 7)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("max(3, 7)", &ctx).unwrap(), json!(7));
assert_eq!(eval_expression("sum([1, 2, 3, 4])", &ctx).unwrap(), json!(10));
}
#[test]
fn test_string_functions() {
let ctx = TestContext::new();
assert_eq!(eval_expression("lower(\"HELLO\")", &ctx).unwrap(), json!("hello"));
assert_eq!(eval_expression("upper(\"hello\")", &ctx).unwrap(), json!("HELLO"));
assert_eq!(eval_expression("trim(\" hi \")", &ctx).unwrap(), json!("hi"));
assert_eq!(
eval_expression("replace(\"hello world\", \"world\", \"rust\")", &ctx).unwrap(),
json!("hello rust")
);
assert_eq!(
eval_expression("starts_with(\"hello\", \"hel\")", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("ends_with(\"hello\", \"llo\")", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("split(\"a,b,c\", \",\")", &ctx).unwrap(),
json!(["a", "b", "c"])
);
assert_eq!(
eval_expression("join([\"a\", \"b\", \"c\"], \",\")", &ctx).unwrap(),
json!("a,b,c")
);
}
#[test]
fn test_regex_match() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("match(\"^hello\", \"hello world\")", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("match(\"^world\", \"hello world\")", &ctx).unwrap(),
json!(false)
);
}
#[test]
fn test_collection_functions() {
let ctx = TestContext::new()
.with_var("arr", json!([3, 1, 2]));
assert_eq!(eval_expression("sort(arr)", &ctx).unwrap(), json!([1, 2, 3]));
assert_eq!(eval_expression("reversed(arr)", &ctx).unwrap(), json!([2, 1, 3]));
assert_eq!(
eval_expression("unique([1, 2, 2, 3, 1])", &ctx).unwrap(),
json!([1, 2, 3])
);
assert_eq!(
eval_expression("flat([[1, 2], [3, 4]])", &ctx).unwrap(),
json!([1, 2, 3, 4])
);
assert_eq!(
eval_expression("zip([1, 2], [\"a\", \"b\"])", &ctx).unwrap(),
json!([[1, "a"], [2, "b"]])
);
}
#[test]
fn test_range() {
let ctx = TestContext::new();
assert_eq!(eval_expression("range(5)", &ctx).unwrap(), json!([0, 1, 2, 3, 4]));
assert_eq!(eval_expression("range(2, 5)", &ctx).unwrap(), json!([2, 3, 4]));
}
#[test]
fn test_reversed_string() {
let ctx = TestContext::new();
assert_eq!(eval_expression("reversed(\"abc\")", &ctx).unwrap(), json!("cba"));
}
#[test]
fn test_contains_function() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("contains([1, 2, 3], 2)", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("contains(\"hello\", \"ell\")", &ctx).unwrap(),
json!(true)
);
}
// ---------------------------------------------------------------
// Complex expressions
// ---------------------------------------------------------------
#[test]
fn test_complex_expression() {
let ctx = TestContext::new()
.with_var("items", json!([1, 2, 3, 4, 5]));
assert_eq!(
eval_expression("length(items) > 3 and 5 in items", &ctx).unwrap(),
json!(true)
);
}
#[test]
fn test_chained_access() {
let ctx = TestContext::new()
.with_var("data", json!({"users": [{"name": "Alice"}, {"name": "Bob"}]}));
assert_eq!(
eval_expression("data.users[1].name", &ctx).unwrap(),
json!("Bob")
);
}
#[test]
fn test_ternary_via_boolean() {
let ctx = TestContext::new()
.with_var("x", json!(10));
// No ternary operator, but boolean expressions work for conditions
assert_eq!(
eval_expression("x > 5 and x < 20", &ctx).unwrap(),
json!(true)
);
}
#[test]
fn test_no_implicit_type_coercion() {
let ctx = TestContext::new();
// String + number should error, not silently coerce
assert!(eval_expression("\"hello\" + 5", &ctx).is_err());
// Comparing different types (other than int/float) should return false for ==
assert_eq!(eval_expression("\"3\" == 3", &ctx).unwrap(), json!(false));
}
#[test]
fn test_division_by_zero() {
let ctx = TestContext::new();
assert!(eval_expression("5 / 0", &ctx).is_err());
assert!(eval_expression("5 % 0", &ctx).is_err());
}
#[test]
fn test_array_literal() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("[1, 2, 3]", &ctx).unwrap(),
json!([1, 2, 3])
);
assert_eq!(
eval_expression("[\"a\", \"b\"]", &ctx).unwrap(),
json!(["a", "b"])
);
}
#[test]
fn test_nested_function_calls() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("length(split(\"a,b,c\", \",\"))", &ctx).unwrap(),
json!(3)
);
assert_eq!(
eval_expression("join(sort([\"c\", \"a\", \"b\"]), \"-\")", &ctx).unwrap(),
json!("a-b-c")
);
}
#[test]
fn test_boolean_literals() {
let ctx = TestContext::new();
assert_eq!(eval_expression("true", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("false", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("null", &ctx).unwrap(), json!(null));
}
}

View File

@@ -0,0 +1,520 @@
//! # Expression Parser
//!
//! Recursive-descent parser that transforms a token stream into an AST.
//!
//! ## Operator Precedence (lowest to highest)
//!
//! 1. `or`
//! 2. `and`
//! 3. `not` (unary)
//! 4. `==`, `!=`, `<`, `>`, `<=`, `>=`, `in`
//! 5. `+`, `-` (addition / subtraction)
//! 6. `*`, `/`, `%`
//! 7. Unary `-`
//! 8. Postfix: `.field`, `[index]`, `(args)`
use super::ast::{BinaryOp, Expr, UnaryOp};
use super::tokenizer::{Token, TokenKind};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ParseError {
#[error("Unexpected token {0} at position {1}")]
UnexpectedToken(String, usize),
#[error("Expected {0}, found {1} at position {2}")]
Expected(String, String, usize),
#[error("Unexpected end of expression")]
UnexpectedEof,
#[error("Token error: {0}")]
TokenError(String),
}
/// The parser state.
pub struct Parser<'a> {
tokens: &'a [Token],
pos: usize,
}
impl<'a> Parser<'a> {
pub fn new(tokens: &'a [Token]) -> Self {
Self { tokens, pos: 0 }
}
/// Parse the token stream into a single expression AST.
pub fn parse(&mut self) -> Result<Expr, ParseError> {
let expr = self.parse_or()?;
// We should be at EOF now
if !self.at_end() {
let tok = self.peek();
return Err(ParseError::UnexpectedToken(
format!("{}", tok.kind),
tok.span.0,
));
}
Ok(expr)
}
// ----- Helpers -----
fn peek(&self) -> &Token {
&self.tokens[self.pos.min(self.tokens.len() - 1)]
}
fn at_end(&self) -> bool {
self.peek().kind == TokenKind::Eof
}
fn advance(&mut self) -> &Token {
let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)];
if self.pos < self.tokens.len() {
self.pos += 1;
}
tok
}
fn expect(&mut self, expected: &TokenKind) -> Result<&Token, ParseError> {
let tok = self.peek();
if std::mem::discriminant(&tok.kind) == std::mem::discriminant(expected) {
Ok(self.advance())
} else {
Err(ParseError::Expected(
format!("{}", expected),
format!("{}", tok.kind),
tok.span.0,
))
}
}
fn check(&self, kind: &TokenKind) -> bool {
std::mem::discriminant(&self.peek().kind) == std::mem::discriminant(kind)
}
// ----- Grammar rules -----
// or_expr = and_expr ( "or" and_expr )*
fn parse_or(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_and()?;
while self.peek().kind == TokenKind::Or {
self.advance();
let right = self.parse_and()?;
left = Expr::BinaryOp {
op: BinaryOp::Or,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// and_expr = not_expr ( "and" not_expr )*
fn parse_and(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_not()?;
while self.peek().kind == TokenKind::And {
self.advance();
let right = self.parse_not()?;
left = Expr::BinaryOp {
op: BinaryOp::And,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// not_expr = "not" not_expr | comparison
fn parse_not(&mut self) -> Result<Expr, ParseError> {
if self.peek().kind == TokenKind::Not {
self.advance();
let operand = self.parse_not()?;
return Ok(Expr::UnaryOp {
op: UnaryOp::Not,
operand: Box::new(operand),
});
}
self.parse_comparison()
}
// comparison = addition ( ("==" | "!=" | "<" | ">" | "<=" | ">=" | "in") addition )*
fn parse_comparison(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_addition()?;
loop {
let op = match self.peek().kind {
TokenKind::EqEq => BinaryOp::Eq,
TokenKind::BangEq => BinaryOp::Ne,
TokenKind::Lt => BinaryOp::Lt,
TokenKind::Gt => BinaryOp::Gt,
TokenKind::LtEq => BinaryOp::Le,
TokenKind::GtEq => BinaryOp::Ge,
TokenKind::In => BinaryOp::In,
_ => break,
};
self.advance();
let right = self.parse_addition()?;
left = Expr::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// addition = multiplication ( ("+" | "-") multiplication )*
fn parse_addition(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_multiplication()?;
loop {
let op = match self.peek().kind {
TokenKind::Plus => BinaryOp::Add,
TokenKind::Minus => BinaryOp::Sub,
_ => break,
};
self.advance();
let right = self.parse_multiplication()?;
left = Expr::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// multiplication = unary ( ("*" | "/" | "%") unary )*
fn parse_multiplication(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_unary()?;
loop {
let op = match self.peek().kind {
TokenKind::Star => BinaryOp::Mul,
TokenKind::Slash => BinaryOp::Div,
TokenKind::Percent => BinaryOp::Mod,
_ => break,
};
self.advance();
let right = self.parse_unary()?;
left = Expr::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// unary = "-" unary | postfix
fn parse_unary(&mut self) -> Result<Expr, ParseError> {
if self.peek().kind == TokenKind::Minus {
self.advance();
let operand = self.parse_unary()?;
return Ok(Expr::UnaryOp {
op: UnaryOp::Neg,
operand: Box::new(operand),
});
}
self.parse_postfix()
}
// postfix = primary ( "." IDENT | "[" expr "]" | "(" args ")" )*
fn parse_postfix(&mut self) -> Result<Expr, ParseError> {
let mut expr = self.parse_primary()?;
loop {
match self.peek().kind {
TokenKind::Dot => {
self.advance();
// The field after dot
let tok = self.advance().clone();
let field = match &tok.kind {
TokenKind::Ident(name) => name.clone(),
// Allow keywords as field names (e.g., obj.in, obj.and)
TokenKind::And => "and".to_string(),
TokenKind::Or => "or".to_string(),
TokenKind::Not => "not".to_string(),
TokenKind::In => "in".to_string(),
TokenKind::True => "true".to_string(),
TokenKind::False => "false".to_string(),
TokenKind::Null => "null".to_string(),
_ => {
return Err(ParseError::Expected(
"identifier".to_string(),
format!("{}", tok.kind),
tok.span.0,
));
}
};
expr = Expr::DotAccess {
object: Box::new(expr),
field,
};
}
TokenKind::LBracket => {
self.advance();
let index = self.parse_or()?;
self.expect(&TokenKind::RBracket)?;
expr = Expr::IndexAccess {
object: Box::new(expr),
index: Box::new(index),
};
}
TokenKind::LParen => {
// Only if the expression so far is an identifier (function name)
// or a dot-access chain (method-like call).
// For now we handle Ident -> FunctionCall transformation.
if let Expr::Ident(name) = expr {
self.advance();
let args = self.parse_args()?;
self.expect(&TokenKind::RParen)?;
expr = Expr::FunctionCall { name, args };
} else {
break;
}
}
_ => break,
}
}
Ok(expr)
}
// args = ( expr ( "," expr )* )?
fn parse_args(&mut self) -> Result<Vec<Expr>, ParseError> {
let mut args = Vec::new();
if self.check(&TokenKind::RParen) {
return Ok(args);
}
args.push(self.parse_or()?);
while self.peek().kind == TokenKind::Comma {
self.advance();
args.push(self.parse_or()?);
}
Ok(args)
}
// primary = INTEGER | FLOAT | STRING | "true" | "false" | "null"
// | IDENT | "(" expr ")" | "[" elements "]"
fn parse_primary(&mut self) -> Result<Expr, ParseError> {
let tok = self.peek().clone();
match &tok.kind {
TokenKind::Integer(n) => {
let n = *n;
self.advance();
Ok(Expr::Literal(serde_json::json!(n)))
}
TokenKind::Float(f) => {
let f = *f;
self.advance();
Ok(Expr::Literal(serde_json::json!(f)))
}
TokenKind::StringLit(s) => {
let s = s.clone();
self.advance();
Ok(Expr::Literal(serde_json::Value::String(s)))
}
TokenKind::True => {
self.advance();
Ok(Expr::Literal(serde_json::json!(true)))
}
TokenKind::False => {
self.advance();
Ok(Expr::Literal(serde_json::json!(false)))
}
TokenKind::Null => {
self.advance();
Ok(Expr::Literal(serde_json::json!(null)))
}
TokenKind::Ident(name) => {
let name = name.clone();
self.advance();
Ok(Expr::Ident(name))
}
TokenKind::LParen => {
self.advance();
let expr = self.parse_or()?;
self.expect(&TokenKind::RParen)?;
Ok(expr)
}
TokenKind::LBracket => {
self.advance();
let mut elements = Vec::new();
if !self.check(&TokenKind::RBracket) {
elements.push(self.parse_or()?);
while self.peek().kind == TokenKind::Comma {
self.advance();
// Allow trailing comma
if self.check(&TokenKind::RBracket) {
break;
}
elements.push(self.parse_or()?);
}
}
self.expect(&TokenKind::RBracket)?;
Ok(Expr::Array(elements))
}
_ => Err(ParseError::UnexpectedToken(
format!("{}", tok.kind),
tok.span.0,
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::tokenizer::Tokenizer;
fn parse(input: &str) -> Expr {
let tokens = Tokenizer::new(input).tokenize().unwrap();
Parser::new(&tokens).parse().unwrap()
}
#[test]
fn test_simple_add() {
let ast = parse("2 + 3");
assert_eq!(
ast,
Expr::BinaryOp {
op: BinaryOp::Add,
left: Box::new(Expr::Literal(serde_json::json!(2))),
right: Box::new(Expr::Literal(serde_json::json!(3))),
}
);
}
#[test]
fn test_precedence() {
// 2 + 3 * 4 should parse as 2 + (3 * 4)
let ast = parse("2 + 3 * 4");
match ast {
Expr::BinaryOp {
op: BinaryOp::Add,
right,
..
} => {
assert!(matches!(
*right,
Expr::BinaryOp {
op: BinaryOp::Mul,
..
}
));
}
_ => panic!("Expected Add at top level"),
}
}
#[test]
fn test_function_call() {
let ast = parse("length(arr)");
assert_eq!(
ast,
Expr::FunctionCall {
name: "length".to_string(),
args: vec![Expr::Ident("arr".to_string())],
}
);
}
#[test]
fn test_dot_access() {
let ast = parse("obj.field.sub");
assert_eq!(
ast,
Expr::DotAccess {
object: Box::new(Expr::DotAccess {
object: Box::new(Expr::Ident("obj".to_string())),
field: "field".to_string(),
}),
field: "sub".to_string(),
}
);
}
#[test]
fn test_array_literal() {
let ast = parse("[1, 2, 3]");
assert_eq!(
ast,
Expr::Array(vec![
Expr::Literal(serde_json::json!(1)),
Expr::Literal(serde_json::json!(2)),
Expr::Literal(serde_json::json!(3)),
])
);
}
#[test]
fn test_bracket_access() {
let ast = parse("arr[0]");
assert_eq!(
ast,
Expr::IndexAccess {
object: Box::new(Expr::Ident("arr".to_string())),
index: Box::new(Expr::Literal(serde_json::json!(0))),
}
);
}
#[test]
fn test_not_operator() {
let ast = parse("not true");
assert_eq!(
ast,
Expr::UnaryOp {
op: UnaryOp::Not,
operand: Box::new(Expr::Literal(serde_json::json!(true))),
}
);
}
#[test]
fn test_in_operator() {
let ast = parse("x in arr");
assert_eq!(
ast,
Expr::BinaryOp {
op: BinaryOp::In,
left: Box::new(Expr::Ident("x".to_string())),
right: Box::new(Expr::Ident("arr".to_string())),
}
);
}
#[test]
fn test_complex_expression() {
// Should parse without error
let _ast = parse("length(items) > 3 and 5 in items");
}
#[test]
fn test_chained_access() {
// data.users[1].name
let _ast = parse("data.users[1].name");
}
#[test]
fn test_nested_function() {
let _ast = parse("length(split(\"a,b,c\", \",\"))");
}
#[test]
fn test_trailing_comma_in_array() {
let ast = parse("[1, 2, 3,]");
assert_eq!(
ast,
Expr::Array(vec![
Expr::Literal(serde_json::json!(1)),
Expr::Literal(serde_json::json!(2)),
Expr::Literal(serde_json::json!(3)),
])
);
}
}

View File

@@ -0,0 +1,512 @@
//! # Expression Tokenizer (Lexer)
//!
//! Converts an expression string into a sequence of tokens.
use std::fmt;
use thiserror::Error;
/// A token produced by the lexer.
#[derive(Debug, Clone, PartialEq)]
pub struct Token {
pub kind: TokenKind,
pub span: (usize, usize),
}
impl Token {
pub fn new(kind: TokenKind, start: usize, end: usize) -> Self {
Self {
kind,
span: (start, end),
}
}
}
/// The kind of a token.
#[derive(Debug, Clone, PartialEq)]
pub enum TokenKind {
// Literals
Integer(i64),
Float(f64),
StringLit(String),
True,
False,
Null,
// Identifier
Ident(String),
// Keywords (also parsed as identifiers initially, then classified)
And,
Or,
Not,
In,
// Operators
Plus,
Minus,
Star,
Slash,
Percent,
EqEq,
BangEq,
Lt,
Gt,
LtEq,
GtEq,
// Delimiters
LParen,
RParen,
LBracket,
RBracket,
Comma,
Dot,
// End of input
Eof,
}
impl fmt::Display for TokenKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TokenKind::Integer(n) => write!(f, "{}", n),
TokenKind::Float(n) => write!(f, "{}", n),
TokenKind::StringLit(s) => write!(f, "\"{}\"", s),
TokenKind::True => write!(f, "true"),
TokenKind::False => write!(f, "false"),
TokenKind::Null => write!(f, "null"),
TokenKind::Ident(s) => write!(f, "{}", s),
TokenKind::And => write!(f, "and"),
TokenKind::Or => write!(f, "or"),
TokenKind::Not => write!(f, "not"),
TokenKind::In => write!(f, "in"),
TokenKind::Plus => write!(f, "+"),
TokenKind::Minus => write!(f, "-"),
TokenKind::Star => write!(f, "*"),
TokenKind::Slash => write!(f, "/"),
TokenKind::Percent => write!(f, "%"),
TokenKind::EqEq => write!(f, "=="),
TokenKind::BangEq => write!(f, "!="),
TokenKind::Lt => write!(f, "<"),
TokenKind::Gt => write!(f, ">"),
TokenKind::LtEq => write!(f, "<="),
TokenKind::GtEq => write!(f, ">="),
TokenKind::LParen => write!(f, "("),
TokenKind::RParen => write!(f, ")"),
TokenKind::LBracket => write!(f, "["),
TokenKind::RBracket => write!(f, "]"),
TokenKind::Comma => write!(f, ","),
TokenKind::Dot => write!(f, "."),
TokenKind::Eof => write!(f, "EOF"),
}
}
}
#[derive(Debug, Error)]
pub enum TokenError {
#[error("Unexpected character '{0}' at position {1}")]
UnexpectedChar(char, usize),
#[error("Unterminated string literal starting at position {0}")]
UnterminatedString(usize),
#[error("Invalid number literal at position {0}: {1}")]
InvalidNumber(usize, String),
}
/// The tokenizer / lexer.
pub struct Tokenizer {
chars: Vec<char>,
pos: usize,
}
impl Tokenizer {
pub fn new(input: &str) -> Self {
Self {
chars: input.chars().collect(),
pos: 0,
}
}
/// Tokenize the entire input and return a vector of tokens.
pub fn tokenize(&mut self) -> Result<Vec<Token>, TokenError> {
let mut tokens = Vec::new();
loop {
let tok = self.next_token()?;
if tok.kind == TokenKind::Eof {
tokens.push(tok);
break;
}
tokens.push(tok);
}
Ok(tokens)
}
fn peek(&self) -> Option<char> {
self.chars.get(self.pos).copied()
}
fn advance(&mut self) -> Option<char> {
let ch = self.chars.get(self.pos).copied();
if ch.is_some() {
self.pos += 1;
}
ch
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.peek() {
if ch.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn next_token(&mut self) -> Result<Token, TokenError> {
self.skip_whitespace();
let start = self.pos;
let ch = match self.peek() {
Some(ch) => ch,
None => return Ok(Token::new(TokenKind::Eof, start, start)),
};
// Single-char and multi-char operators/delimiters
match ch {
'+' => {
self.advance();
Ok(Token::new(TokenKind::Plus, start, self.pos))
}
'-' => {
self.advance();
Ok(Token::new(TokenKind::Minus, start, self.pos))
}
'*' => {
self.advance();
Ok(Token::new(TokenKind::Star, start, self.pos))
}
'/' => {
self.advance();
Ok(Token::new(TokenKind::Slash, start, self.pos))
}
'%' => {
self.advance();
Ok(Token::new(TokenKind::Percent, start, self.pos))
}
'(' => {
self.advance();
Ok(Token::new(TokenKind::LParen, start, self.pos))
}
')' => {
self.advance();
Ok(Token::new(TokenKind::RParen, start, self.pos))
}
'[' => {
self.advance();
Ok(Token::new(TokenKind::LBracket, start, self.pos))
}
']' => {
self.advance();
Ok(Token::new(TokenKind::RBracket, start, self.pos))
}
',' => {
self.advance();
Ok(Token::new(TokenKind::Comma, start, self.pos))
}
'.' => {
self.advance();
Ok(Token::new(TokenKind::Dot, start, self.pos))
}
'=' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::EqEq, start, self.pos))
} else {
Err(TokenError::UnexpectedChar('=', start))
}
}
'!' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::BangEq, start, self.pos))
} else {
Err(TokenError::UnexpectedChar('!', start))
}
}
'<' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::LtEq, start, self.pos))
} else {
Ok(Token::new(TokenKind::Lt, start, self.pos))
}
}
'>' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::GtEq, start, self.pos))
} else {
Ok(Token::new(TokenKind::Gt, start, self.pos))
}
}
'"' | '\'' => self.read_string(ch),
c if c.is_ascii_digit() => self.read_number(),
c if c.is_ascii_alphabetic() || c == '_' => self.read_ident(),
other => Err(TokenError::UnexpectedChar(other, start)),
}
}
fn read_string(&mut self, quote: char) -> Result<Token, TokenError> {
let start = self.pos;
self.advance(); // consume opening quote
let mut s = String::new();
loop {
match self.advance() {
Some('\\') => {
// Escape sequence
match self.advance() {
Some('n') => s.push('\n'),
Some('t') => s.push('\t'),
Some('r') => s.push('\r'),
Some('\\') => s.push('\\'),
Some(c) if c == quote => s.push(c),
Some(c) => {
s.push('\\');
s.push(c);
}
None => return Err(TokenError::UnterminatedString(start)),
}
}
Some(c) if c == quote => {
return Ok(Token::new(TokenKind::StringLit(s), start, self.pos));
}
Some(c) => s.push(c),
None => return Err(TokenError::UnterminatedString(start)),
}
}
}
fn read_number(&mut self) -> Result<Token, TokenError> {
let start = self.pos;
let mut num_str = String::new();
let mut is_float = false;
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
num_str.push(ch);
self.advance();
} else if ch == '.' && !is_float {
// Check if this is a decimal point or a method call dot
// Look ahead to see if next char is a digit
let next_pos = self.pos + 1;
if next_pos < self.chars.len() && self.chars[next_pos].is_ascii_digit() {
is_float = true;
num_str.push(ch);
self.advance();
} else {
// It's a dot access, stop number parsing here
break;
}
} else {
break;
}
}
if is_float {
let val: f64 = num_str.parse().map_err(|_| {
TokenError::InvalidNumber(start, num_str.clone())
})?;
Ok(Token::new(TokenKind::Float(val), start, self.pos))
} else {
let val: i64 = num_str.parse().map_err(|_| {
TokenError::InvalidNumber(start, num_str.clone())
})?;
Ok(Token::new(TokenKind::Integer(val), start, self.pos))
}
}
fn read_ident(&mut self) -> Result<Token, TokenError> {
let start = self.pos;
let mut ident = String::new();
while let Some(ch) = self.peek() {
if ch.is_ascii_alphanumeric() || ch == '_' {
ident.push(ch);
self.advance();
} else {
break;
}
}
let kind = match ident.as_str() {
"true" => TokenKind::True,
"false" => TokenKind::False,
"null" => TokenKind::Null,
"and" => TokenKind::And,
"or" => TokenKind::Or,
"not" => TokenKind::Not,
"in" => TokenKind::In,
_ => TokenKind::Ident(ident),
};
Ok(Token::new(kind, start, self.pos))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tokenize(input: &str) -> Vec<TokenKind> {
let mut t = Tokenizer::new(input);
t.tokenize()
.unwrap()
.into_iter()
.map(|t| t.kind)
.collect()
}
#[test]
fn test_simple_expression() {
let kinds = tokenize("2 + 3");
assert_eq!(
kinds,
vec![
TokenKind::Integer(2),
TokenKind::Plus,
TokenKind::Integer(3),
TokenKind::Eof,
]
);
}
#[test]
fn test_comparison() {
let kinds = tokenize("x >= 10");
assert_eq!(
kinds,
vec![
TokenKind::Ident("x".to_string()),
TokenKind::GtEq,
TokenKind::Integer(10),
TokenKind::Eof,
]
);
}
#[test]
fn test_keywords() {
let kinds = tokenize("true and not false or null in x");
assert_eq!(
kinds,
vec![
TokenKind::True,
TokenKind::And,
TokenKind::Not,
TokenKind::False,
TokenKind::Or,
TokenKind::Null,
TokenKind::In,
TokenKind::Ident("x".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_string_literals() {
let kinds = tokenize("\"hello\" + 'world'");
assert_eq!(
kinds,
vec![
TokenKind::StringLit("hello".to_string()),
TokenKind::Plus,
TokenKind::StringLit("world".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_float() {
let kinds = tokenize("3.14");
assert_eq!(kinds, vec![TokenKind::Float(3.14), TokenKind::Eof]);
}
#[test]
fn test_dot_access() {
let kinds = tokenize("obj.field");
assert_eq!(
kinds,
vec![
TokenKind::Ident("obj".to_string()),
TokenKind::Dot,
TokenKind::Ident("field".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_function_call() {
let kinds = tokenize("length(arr)");
assert_eq!(
kinds,
vec![
TokenKind::Ident("length".to_string()),
TokenKind::LParen,
TokenKind::Ident("arr".to_string()),
TokenKind::RParen,
TokenKind::Eof,
]
);
}
#[test]
fn test_bracket_access() {
let kinds = tokenize("arr[0]");
assert_eq!(
kinds,
vec![
TokenKind::Ident("arr".to_string()),
TokenKind::LBracket,
TokenKind::Integer(0),
TokenKind::RBracket,
TokenKind::Eof,
]
);
}
#[test]
fn test_escape_sequences() {
let kinds = tokenize(r#""hello\nworld""#);
assert_eq!(
kinds,
vec![
TokenKind::StringLit("hello\nworld".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_integer_followed_by_dot() {
// `42.field` - the 42 is an integer and `.field` is separate
let kinds = tokenize("42.field");
assert_eq!(
kinds,
vec![
TokenKind::Integer(42),
TokenKind::Dot,
TokenKind::Ident("field".to_string()),
TokenKind::Eof,
]
);
}
}

View File

@@ -0,0 +1,674 @@
//! # Workflow Expression Validator
//!
//! Static validation of `{{ }}` template expressions in workflow definitions.
//! Catches syntax errors and unresolved variable references **before** the
//! workflow is saved, so users get immediate feedback instead of opaque
//! runtime failures during execution.
//!
//! ## What is validated
//!
//! 1. **Syntax** — every `{{ expr }}` block must parse successfully.
//! 2. **Variable references** — top-level identifiers that are not a known
//! namespace (`parameters`, `workflow`, `task`, `config`, `keystore`,
//! `item`, `index`, `system`, …) must exist in either the workflow's
//! `vars` map or its `param_schema` keys (bare-name fallback targets).
//!
//! ## What is NOT validated
//!
//! - **Type correctness** — e.g. whether `range(parameters.n)` actually
//! receives an integer. That requires runtime values.
//! - **Deep property paths** — e.g. `task.fetch.result.data`. We validate
//! that `task` is a known namespace but not that `fetch` is a real task
//! name (it might not exist yet at save time if tasks are re-ordered).
//! - **Function arity** — built-in functions are not checked for argument
//! count here; the evaluator already reports those errors at runtime.
use std::collections::HashSet;
use serde_json::Value as JsonValue;
use super::expression::{parse_expression, Expr, ParseError};
use super::parser::{PublishDirective, WorkflowDefinition};
// ───────────────────────────────────────────────────────────────────────────
// Public API
// ───────────────────────────────────────────────────────────────────────────
/// A single validation diagnostic.
#[derive(Debug, Clone)]
pub struct ExpressionWarning {
/// Human-readable location (e.g. `task 'sleep_1' with_items`).
pub location: String,
/// The raw template string that was checked.
pub expression: String,
/// What went wrong.
pub message: String,
}
impl std::fmt::Display for ExpressionWarning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {} — `{}`", self.location, self.message, self.expression)
}
}
/// Validate all template expressions in a workflow definition.
///
/// Returns an empty vec on success, or one [`ExpressionWarning`] per problem
/// found. The caller decides whether warnings are fatal (block save) or
/// advisory.
///
/// `param_schema` is the *flat-format* schema (`{ "url": { "type": "string" }, … }`)
/// passed alongside the definition in the save request. Its top-level keys
/// are the declared parameter names.
pub fn validate_workflow_expressions(
workflow: &WorkflowDefinition,
param_schema: Option<&JsonValue>,
) -> Vec<ExpressionWarning> {
let known_names = build_known_names(workflow, param_schema);
let mut warnings = Vec::new();
for task in &workflow.tasks {
let task_loc = format!("task '{}'", task.name);
// ── with_items expression ────────────────────────────────────
if let Some(ref expr) = task.with_items {
validate_template(
expr,
&format!("{task_loc} with_items"),
&known_names,
&mut warnings,
);
}
// ── task-level when condition ────────────────────────────────
if let Some(ref expr) = task.when {
validate_template(
expr,
&format!("{task_loc} when"),
&known_names,
&mut warnings,
);
}
// ── input templates ──────────────────────────────────────────
for (key, value) in &task.input {
collect_json_templates(
value,
&format!("{task_loc} input.{key}"),
&known_names,
&mut warnings,
);
}
// ── next transitions ─────────────────────────────────────────
for (ti, transition) in task.next.iter().enumerate() {
if let Some(ref when_expr) = transition.when {
validate_template(
when_expr,
&format!("{task_loc} next[{ti}].when"),
&known_names,
&mut warnings,
);
}
for directive in &transition.publish {
match directive {
PublishDirective::Simple(map) => {
for (pk, pv) in map {
validate_template(
pv,
&format!("{task_loc} next[{ti}].publish.{pk}"),
&known_names,
&mut warnings,
);
}
}
PublishDirective::Key(_) => { /* nothing to validate */ }
}
}
}
// ── legacy task-level publish ────────────────────────────────
for directive in &task.publish {
if let PublishDirective::Simple(map) = directive {
for (pk, pv) in map {
validate_template(
pv,
&format!("{task_loc} publish.{pk}"),
&known_names,
&mut warnings,
);
}
}
}
}
warnings
}
// ───────────────────────────────────────────────────────────────────────────
// Internals
// ───────────────────────────────────────────────────────────────────────────
/// Canonical namespace identifiers that are always valid as top-level names
/// inside `{{ }}` expressions. These are resolved by `WorkflowContext` at
/// runtime and never need to exist in `vars` or `param_schema`.
const CANONICAL_NAMESPACES: &[&str] = &[
"parameters",
"workflow",
"vars",
"variables",
"task",
"tasks",
"config",
"keystore",
"item",
"index",
"system",
];
/// Built-in constants that are valid bare identifiers.
const BUILTIN_LITERALS: &[&str] = &["true", "false", "null"];
/// Build the set of bare names that are valid in expressions:
/// canonical namespaces + workflow var names + param_schema keys.
fn build_known_names(
workflow: &WorkflowDefinition,
param_schema: Option<&JsonValue>,
) -> HashSet<String> {
let mut names: HashSet<String> = CANONICAL_NAMESPACES
.iter()
.map(|s| (*s).to_string())
.collect();
for lit in BUILTIN_LITERALS {
names.insert((*lit).to_string());
}
// Workflow vars
for key in workflow.vars.keys() {
names.insert(key.clone());
}
// Parameter schema keys (flat format: top-level keys are param names)
if let Some(JsonValue::Object(map)) = param_schema {
for key in map.keys() {
names.insert(key.clone());
}
}
// Also accept the workflow-level `parameters` schema if present on the
// definition itself (some loaders put it there).
if let Some(JsonValue::Object(ref map)) = workflow.parameters {
for key in map.keys() {
names.insert(key.clone());
}
}
names
}
/// Extract `{{ … }}` blocks from a template string and validate each one.
fn validate_template(
template: &str,
location: &str,
known_names: &HashSet<String>,
warnings: &mut Vec<ExpressionWarning>,
) {
for raw_expr in extract_expressions(template) {
let trimmed = raw_expr.trim();
if trimmed.is_empty() {
continue;
}
// Phase 1: parse
match parse_expression(trimmed) {
Err(e) => {
warnings.push(ExpressionWarning {
location: location.to_string(),
expression: raw_expr.to_string(),
message: format!("syntax error: {e}"),
});
}
Ok(ast) => {
// Phase 2: check bare-name references
let mut bare_idents = Vec::new();
collect_bare_idents(&ast, &mut bare_idents);
for ident in bare_idents {
if !known_names.contains(&ident) {
warnings.push(ExpressionWarning {
location: location.to_string(),
expression: raw_expr.to_string(),
message: format!(
"unknown variable '{}'. Use 'parameters.{}' for input \
parameters, or define it in workflow vars",
ident, ident,
),
});
}
}
}
}
}
}
/// Recursively walk a JSON value looking for string leaves that contain
/// `{{ }}` templates.
fn collect_json_templates(
value: &JsonValue,
location: &str,
known_names: &HashSet<String>,
warnings: &mut Vec<ExpressionWarning>,
) {
match value {
JsonValue::String(s) => {
validate_template(s, location, known_names, warnings);
}
JsonValue::Array(arr) => {
for (i, item) in arr.iter().enumerate() {
collect_json_templates(
item,
&format!("{location}[{i}]"),
known_names,
warnings,
);
}
}
JsonValue::Object(map) => {
for (key, val) in map {
collect_json_templates(
val,
&format!("{location}.{key}"),
known_names,
warnings,
);
}
}
_ => { /* numbers, bools, null — nothing to validate */ }
}
}
/// Extract the inner expression strings from all `{{ … }}` blocks in a
/// template. Handles nested braces conservatively (takes everything between
/// the outermost `{{` and `}}`).
fn extract_expressions(template: &str) -> Vec<&str> {
let mut results = Vec::new();
let mut rest = template;
while let Some(start) = rest.find("{{") {
let after_open = start + 2;
if let Some(end) = rest[after_open..].find("}}") {
results.push(&rest[after_open..after_open + end]);
rest = &rest[after_open + end + 2..];
} else {
// Unclosed `{{` — skip
break;
}
}
results
}
/// Collect bare `Ident` nodes that appear at the *top level* of an
/// expression — i.e. identifiers that are not the right-hand side of a
/// `.field` access (those are field names, not variable references).
///
/// For `DotAccess { object: Ident("parameters"), field: "n" }` we collect
/// `"parameters"` but NOT `"n"`.
///
/// For `FunctionCall { name: "range", args: [Ident("n")] }` we collect
/// `"n"` (it's a bare variable reference used as a function argument).
fn collect_bare_idents(expr: &Expr, out: &mut Vec<String>) {
match expr {
Expr::Ident(name) => {
out.push(name.clone());
}
Expr::Literal(_) => {}
Expr::Array(items) => {
for item in items {
collect_bare_idents(item, out);
}
}
Expr::BinaryOp { left, right, .. } => {
collect_bare_idents(left, out);
collect_bare_idents(right, out);
}
Expr::UnaryOp { operand, .. } => {
collect_bare_idents(operand, out);
}
Expr::DotAccess { object, .. } => {
// Only recurse into the object side — the field name is not a
// variable reference.
collect_bare_idents(object, out);
}
Expr::IndexAccess { object, index } => {
collect_bare_idents(object, out);
collect_bare_idents(index, out);
}
Expr::FunctionCall { args, .. } => {
// Function name itself is not a variable reference.
for arg in args {
collect_bare_idents(arg, out);
}
}
}
}
// ───────────────────────────────────────────────────────────────────────────
// Tests
// ───────────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn minimal_workflow(tasks: Vec<super::super::parser::Task>) -> WorkflowDefinition {
WorkflowDefinition {
r#ref: "test.wf".to_string(),
label: "Test".to_string(),
description: None,
version: "1.0.0".to_string(),
parameters: None,
output: None,
vars: HashMap::new(),
tasks,
output_map: None,
tags: vec![],
}
}
fn action_task(name: &str) -> super::super::parser::Task {
super::super::parser::Task {
name: name.to_string(),
r#type: super::super::parser::TaskType::Action,
action: Some("core.echo".to_string()),
input: HashMap::new(),
when: None,
with_items: None,
batch_size: None,
concurrency: None,
retry: None,
timeout: None,
next: vec![],
on_success: None,
on_failure: None,
on_complete: None,
on_timeout: None,
decision: vec![],
publish: vec![],
join: None,
tasks: None,
chart_meta: None,
}
}
// ── extract_expressions ──────────────────────────────────────────
#[test]
fn test_extract_single() {
let exprs = extract_expressions("{{ parameters.n }}");
assert_eq!(exprs, vec![" parameters.n "]);
}
#[test]
fn test_extract_multiple() {
let exprs = extract_expressions("Hello {{ name }}, you have {{ count }} items");
assert_eq!(exprs.len(), 2);
assert_eq!(exprs[0].trim(), "name");
assert_eq!(exprs[1].trim(), "count");
}
#[test]
fn test_extract_no_expressions() {
let exprs = extract_expressions("plain text");
assert!(exprs.is_empty());
}
#[test]
fn test_extract_unclosed() {
let exprs = extract_expressions("{{ oops");
assert!(exprs.is_empty());
}
// ── collect_bare_idents ──────────────────────────────────────────
#[test]
fn test_bare_ident() {
let ast = parse_expression("n").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["n"]);
}
#[test]
fn test_dot_access_does_not_collect_field() {
let ast = parse_expression("parameters.n").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["parameters"]);
}
#[test]
fn test_function_arg_collected() {
let ast = parse_expression("range(n)").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["n"]);
}
#[test]
fn test_nested_dot_access() {
let ast = parse_expression("task.fetch.result.data").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["task"]);
}
#[test]
fn test_binary_op() {
let ast = parse_expression("parameters.x + workflow.y").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["parameters", "workflow"]);
}
// ── validate_workflow_expressions ─────────────────────────────────
#[test]
fn test_valid_workflow_no_warnings() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(parameters.n) }}".to_string());
task.input.insert(
"message".to_string(),
serde_json::json!("Hello {{ item }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_bare_name_from_vars_ok() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(n) }}".to_string());
let mut wf = minimal_workflow(vec![task]);
wf.vars.insert("n".to_string(), serde_json::json!(5));
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_bare_name_from_param_schema_ok() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(n) }}".to_string());
let wf = minimal_workflow(vec![task]);
let schema = serde_json::json!({
"n": { "type": "integer", "required": true }
});
let warnings = validate_workflow_expressions(&wf, Some(&schema));
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_unknown_bare_name_warning() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(n) }}".to_string());
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("unknown variable 'n'"));
assert!(warnings[0].message.contains("parameters.n"));
assert!(warnings[0].location.contains("with_items"));
}
#[test]
fn test_syntax_error_warning() {
let mut task = action_task("greet");
task.input.insert(
"msg".to_string(),
serde_json::json!("{{ +++ }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("syntax error"));
}
#[test]
fn test_transition_when_validated() {
let mut task = action_task("step1");
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ bad_var > 3 }}".to_string()),
publish: vec![],
r#do: Some(vec!["step2".to_string()]),
chart_meta: None,
}];
let wf = minimal_workflow(vec![task, action_task("step2")]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("unknown variable 'bad_var'"));
assert!(warnings[0].location.contains("next[0].when"));
}
#[test]
fn test_transition_publish_validated() {
let mut task = action_task("step1");
let mut publish_map = HashMap::new();
publish_map.insert("out".to_string(), "{{ unknown_thing }}".to_string());
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ succeeded() }}".to_string()),
publish: vec![PublishDirective::Simple(publish_map)],
r#do: Some(vec!["step2".to_string()]),
chart_meta: None,
}];
let wf = minimal_workflow(vec![task, action_task("step2")]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("unknown variable 'unknown_thing'"));
assert!(warnings[0].location.contains("publish.out"));
}
#[test]
fn test_workflow_functions_no_warning() {
// succeeded(), failed(), result() etc. are function calls,
// not variable references — should not produce warnings.
let mut task = action_task("step1");
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ succeeded() and result().code == 200 }}".to_string()),
publish: vec![],
r#do: Some(vec!["step2".to_string()]),
chart_meta: None,
}];
let wf = minimal_workflow(vec![task, action_task("step2")]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_plain_text_no_warning() {
let mut task = action_task("step1");
task.input.insert(
"msg".to_string(),
serde_json::json!("just plain text"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty());
}
#[test]
fn test_multiple_errors_collected() {
let mut task = action_task("step1");
task.with_items = Some("{{ range(a) }}".to_string());
task.input.insert(
"x".to_string(),
serde_json::json!("{{ b + c }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
// a, b, c are all unknown
assert_eq!(warnings.len(), 3);
let names: HashSet<_> = warnings
.iter()
.flat_map(|w| {
// extract the variable name from "unknown variable 'X'"
w.message
.strip_prefix("unknown variable '")
.and_then(|s| s.split('\'').next())
.map(|s| s.to_string())
})
.collect();
assert!(names.contains("a"));
assert!(names.contains("b"));
assert!(names.contains("c"));
}
#[test]
fn test_index_access_validated() {
let mut task = action_task("step1");
task.input.insert(
"val".to_string(),
serde_json::json!("{{ items[idx] }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
// Both `items` and `idx` are bare unknowns
assert_eq!(warnings.len(), 2);
}
#[test]
fn test_builtin_literals_ok() {
let mut task = action_task("step1");
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ true and not false }}".to_string()),
publish: vec![],
r#do: None,
chart_meta: None,
}];
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
}

View File

@@ -3,6 +3,7 @@
//! This module provides utilities for loading, parsing, validating, and registering
//! workflow definitions from YAML files.
pub mod expression;
pub mod loader;
pub mod pack_service;
pub mod parser;

View File

@@ -356,7 +356,7 @@ async fn test_update_execution_status() {
status: Some(ExecutionStatus::Running),
result: None,
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -401,7 +401,7 @@ async fn test_update_execution_result() {
status: Some(ExecutionStatus::Completed),
result: Some(result_data.clone()),
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -445,7 +445,7 @@ async fn test_update_execution_executor() {
status: Some(ExecutionStatus::Scheduled),
result: None,
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -492,7 +492,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Scheduling),
result: None,
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -507,7 +507,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Scheduled),
result: None,
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -522,7 +522,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Running),
result: None,
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -537,7 +537,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Completed),
result: Some(json!({"success": true})),
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -578,7 +578,7 @@ async fn test_update_execution_failed_status() {
status: Some(ExecutionStatus::Failed),
result: Some(json!({"error": "Connection timeout"})),
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -984,7 +984,7 @@ async fn test_execution_timestamps() {
status: Some(ExecutionStatus::Running),
result: None,
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -1095,7 +1095,7 @@ async fn test_execution_result_json() {
status: Some(ExecutionStatus::Completed),
result: Some(complex_result.clone()),
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)

View File

@@ -244,8 +244,7 @@ impl InquiryHandler {
let update_input = UpdateExecutionInput {
status: None, // Keep current status, let worker handle completion
result: Some(updated_result),
executor: None,
workflow_task: None, // Not updating workflow metadata
..Default::default()
};
ExecutionRepository::update(pool, execution.id, update_input).await?;

View File

@@ -381,10 +381,7 @@ impl RetryManager {
&self.pool,
execution_id,
UpdateExecutionInput {
status: None,
result: None,
executor: None,
workflow_task: None,
..Default::default()
},
)
.await?;

View File

@@ -66,6 +66,53 @@ fn extract_workflow_params(config: &Option<JsonValue>) -> JsonValue {
}
}
/// Apply default values from a workflow's `param_schema` to the provided
/// parameters.
///
/// The param_schema uses the flat format where each key maps to an object
/// that may contain a `"default"` field:
///
/// ```json
/// { "n": { "type": "integer", "default": 10 } }
/// ```
///
/// Any parameter that has a default in the schema but is missing (or `null`)
/// in the supplied `params` will be filled in. Parameters already provided
/// by the caller are never overwritten.
fn apply_param_defaults(params: JsonValue, param_schema: &Option<JsonValue>) -> JsonValue {
let schema = match param_schema {
Some(s) if s.is_object() => s,
_ => return params,
};
let mut obj = match params {
JsonValue::Object(m) => m,
_ => return params,
};
if let Some(schema_obj) = schema.as_object() {
for (key, prop) in schema_obj {
// Only fill in missing / null parameters
let needs_default = match obj.get(key) {
None => true,
Some(JsonValue::Null) => true,
_ => false,
};
if needs_default {
if let Some(default_val) = prop.get("default") {
debug!(
"Applying default for parameter '{}': {}",
key, default_val
);
obj.insert(key.clone(), default_val.clone());
}
}
}
}
JsonValue::Object(obj)
}
/// Payload for execution scheduled messages
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ExecutionScheduledPayload {
@@ -316,7 +363,10 @@ impl ExecutionScheduler {
// Build initial workflow context from execution parameters and
// workflow-level vars so that entry-point task inputs are rendered.
// Apply defaults from the workflow's param_schema for any parameters
// that were not supplied by the caller.
let workflow_params = extract_workflow_params(&execution.config);
let workflow_params = apply_param_defaults(workflow_params, &workflow_def.param_schema);
let wf_ctx = WorkflowContext::new(
workflow_params,
definition
@@ -563,7 +613,7 @@ impl ExecutionScheduler {
};
let total = items.len();
let concurrency_limit = task_node.concurrency.unwrap_or(total);
let concurrency_limit = task_node.concurrency.unwrap_or(1);
let dispatch_count = total.min(concurrency_limit);
info!(
@@ -842,6 +892,18 @@ impl ExecutionScheduler {
return Ok(());
}
// Load the workflow definition so we can apply param_schema defaults
let workflow_def =
WorkflowDefinitionRepository::find_by_id(pool, workflow_execution.workflow_def)
.await?
.ok_or_else(|| {
anyhow::anyhow!(
"Workflow definition {} not found for workflow_execution {}",
workflow_execution.workflow_def,
workflow_execution_id
)
})?;
// Rebuild the task graph from the stored JSON
let graph: TaskGraph = serde_json::from_value(workflow_execution.task_graph.clone())
.map_err(|e| {
@@ -897,7 +959,7 @@ impl ExecutionScheduler {
let concurrency_limit = graph
.get_task(task_name)
.and_then(|n| n.concurrency)
.unwrap_or(usize::MAX);
.unwrap_or(1);
let free_slots =
concurrency_limit.saturating_sub(in_flight_count.0 as usize);
@@ -995,6 +1057,7 @@ impl ExecutionScheduler {
// results so that successor task inputs can be rendered.
// -----------------------------------------------------------------
let workflow_params = extract_workflow_params(&parent_execution.config);
let workflow_params = apply_param_defaults(workflow_params, &workflow_def.param_schema);
// Collect results from all completed children of this workflow
let child_executions =
@@ -1619,11 +1682,11 @@ mod tests {
let dispatch_count = total.min(concurrency_limit);
assert_eq!(dispatch_count, 3);
// No concurrency limit → dispatch all
// No concurrency limit → default to serial (1 at a time)
let concurrency: Option<usize> = None;
let concurrency_limit = concurrency.unwrap_or(total);
let concurrency_limit = concurrency.unwrap_or(1);
let dispatch_count = total.min(concurrency_limit);
assert_eq!(dispatch_count, 20);
assert_eq!(dispatch_count, 1);
// Concurrency exceeds total → dispatch all
let concurrency: Option<usize> = Some(50);

View File

@@ -3,6 +3,33 @@
//! This module manages workflow execution context, including variables,
//! template rendering, and data flow between tasks.
//!
//! ## Canonical Namespaces
//!
//! All data accessible inside `{{ }}` template expressions is organised into
//! well-defined, non-overlapping namespaces:
//!
//! | Namespace | Example | Description |
//! |-----------|---------|-------------|
//! | `parameters` | `{{ parameters.url }}` | Immutable workflow input parameters |
//! | `workflow` | `{{ workflow.counter }}` | Mutable workflow-scoped variables (set via `publish`) |
//! | `task` | `{{ task.fetch.result.data }}` | Completed task results keyed by task name |
//! | `config` | `{{ config.api_token }}` | Pack configuration values (read-only) |
//! | `keystore` | `{{ keystore.secret_key }}` | Encrypted secrets from the key store (read-only) |
//! | `item` | `{{ item }}` or `{{ item.name }}` | Current element in a `with_items` loop |
//! | `index` | `{{ index }}` | Zero-based iteration index in a `with_items` loop |
//! | `system` | `{{ system.workflow_start }}` | System-provided variables |
//!
//! ### Backward-compatible aliases
//!
//! The following aliases resolve to the same data as their canonical form and
//! are kept for backward compatibility with existing workflow definitions:
//!
//! - `vars` / `variables` → same as `workflow`
//! - `tasks` → same as `task`
//!
//! Bare variable names (e.g. `{{ my_var }}`) also resolve against the
//! `workflow` variable store as a last-resort fallback.
//!
//! ## Function-call expressions
//!
//! Templates support Orquesta-style function calls:
@@ -19,6 +46,9 @@
//! expression instead of stringifying it. This means `"{{ item }}"` resolving
//! to integer `5` stays as `5`, not the string `"5"`.
use attune_common::workflow::expression::{
self, is_truthy, EvalContext, EvalError, EvalResult as ExprResult,
};
use dashmap::DashMap;
use serde_json::{json, Value as JsonValue};
use std::collections::HashMap;
@@ -63,18 +93,25 @@ pub enum TaskOutcome {
/// not the underlying data, making it O(1) instead of O(context_size).
#[derive(Debug, Clone)]
pub struct WorkflowContext {
/// Workflow-level variables (shared via Arc)
/// Mutable workflow-scoped variables. Canonical namespace: `workflow`.
/// Also accessible as `vars`, `variables`, or bare names (fallback).
variables: Arc<DashMap<String, JsonValue>>,
/// Workflow input parameters (shared via Arc)
/// Immutable workflow input parameters. Canonical namespace: `parameters`.
parameters: Arc<JsonValue>,
/// Task results (shared via Arc, keyed by task name)
/// Completed task results keyed by task name. Canonical namespace: `task`.
task_results: Arc<DashMap<String, JsonValue>>,
/// System variables (shared via Arc)
/// System-provided variables. Canonical namespace: `system`.
system: Arc<DashMap<String, JsonValue>>,
/// Pack configuration values (read-only). Canonical namespace: `config`.
pack_config: Arc<JsonValue>,
/// Encrypted keystore values (read-only). Canonical namespace: `keystore`.
keystore: Arc<JsonValue>,
/// Current item (for with-items iteration) - per-item data
current_item: Option<JsonValue>,
@@ -89,7 +126,11 @@ pub struct WorkflowContext {
}
impl WorkflowContext {
/// Create a new workflow context
/// Create a new workflow context.
///
/// `parameters` — the immutable input parameters for this workflow run.
/// `initial_vars` — initial workflow-scoped variables (from the workflow
/// definition's `vars` section).
pub fn new(parameters: JsonValue, initial_vars: HashMap<String, JsonValue>) -> Self {
let system = DashMap::new();
system.insert("workflow_start".to_string(), json!(chrono::Utc::now()));
@@ -104,6 +145,8 @@ impl WorkflowContext {
parameters: Arc::new(parameters),
task_results: Arc::new(DashMap::new()),
system: Arc::new(system),
pack_config: Arc::new(JsonValue::Null),
keystore: Arc::new(JsonValue::Null),
current_item: None,
current_index: None,
last_task_result: None,
@@ -142,6 +185,8 @@ impl WorkflowContext {
parameters: Arc::new(parameters),
task_results: Arc::new(results),
system: Arc::new(system),
pack_config: Arc::new(JsonValue::Null),
keystore: Arc::new(JsonValue::Null),
current_item: None,
current_index: None,
last_task_result: None,
@@ -149,28 +194,38 @@ impl WorkflowContext {
}
}
/// Set a variable
/// Set a workflow-scoped variable (accessible as `workflow.<name>`).
pub fn set_var(&mut self, name: &str, value: JsonValue) {
self.variables.insert(name.to_string(), value);
}
/// Get a variable
/// Get a workflow-scoped variable by name.
pub fn get_var(&self, name: &str) -> Option<JsonValue> {
self.variables.get(name).map(|entry| entry.value().clone())
}
/// Store a task result
/// Store a completed task's result (accessible as `task.<name>.*`).
pub fn set_task_result(&mut self, task_name: &str, result: JsonValue) {
self.task_results.insert(task_name.to_string(), result);
}
/// Get a task result
/// Get a task result by task name.
pub fn get_task_result(&self, task_name: &str) -> Option<JsonValue> {
self.task_results
.get(task_name)
.map(|entry| entry.value().clone())
}
/// Set the pack configuration (accessible as `config.<key>`).
pub fn set_pack_config(&mut self, config: JsonValue) {
self.pack_config = Arc::new(config);
}
/// Set the keystore secrets (accessible as `keystore.<key>`).
pub fn set_keystore(&mut self, secrets: JsonValue) {
self.keystore = Arc::new(secrets);
}
/// Set current item for iteration
pub fn set_current_item(&mut self, item: JsonValue, index: usize) {
self.current_item = Some(item);
@@ -299,220 +354,55 @@ impl WorkflowContext {
}
}
/// Evaluate a template expression
fn evaluate_expression(&self, expr: &str) -> ContextResult<JsonValue> {
// ---------------------------------------------------------------
// Function-call expressions: result(), succeeded(), failed(), timed_out()
// ---------------------------------------------------------------
// We handle these *before* splitting on `.` because the function
// name contains parentheses which would confuse the dot-split.
//
// Supported patterns:
// result() → last task result
// result().foo.bar → nested access into result
// result().data.items → nested access into result
// succeeded() → boolean
// failed() → boolean
// timed_out() → boolean
// ---------------------------------------------------------------
if let Some(result_val) = self.try_evaluate_function_call(expr)? {
return Ok(result_val);
}
// ---------------------------------------------------------------
// Dot-path expressions
// ---------------------------------------------------------------
let parts: Vec<&str> = expr.split('.').collect();
if parts.is_empty() {
return Err(ContextError::InvalidExpression(expr.to_string()));
}
match parts[0] {
"parameters" => self.get_nested_value(&self.parameters, &parts[1..]),
"vars" | "variables" => {
if parts.len() < 2 {
return Err(ContextError::InvalidExpression(expr.to_string()));
}
let var_name = parts[1];
if let Some(entry) = self.variables.get(var_name) {
let value = entry.value().clone();
drop(entry);
if parts.len() > 2 {
self.get_nested_value(&value, &parts[2..])
} else {
Ok(value)
}
} else {
Err(ContextError::VariableNotFound(var_name.to_string()))
}
}
"task" | "tasks" => {
if parts.len() < 2 {
return Err(ContextError::InvalidExpression(expr.to_string()));
}
let task_name = parts[1];
if let Some(entry) = self.task_results.get(task_name) {
let result = entry.value().clone();
drop(entry);
if parts.len() > 2 {
self.get_nested_value(&result, &parts[2..])
} else {
Ok(result)
}
} else {
Err(ContextError::VariableNotFound(format!(
"task.{}",
task_name
)))
}
}
"item" => {
if let Some(ref item) = self.current_item {
if parts.len() > 1 {
self.get_nested_value(item, &parts[1..])
} else {
Ok(item.clone())
}
} else {
Err(ContextError::VariableNotFound("item".to_string()))
}
}
"index" => {
if let Some(index) = self.current_index {
Ok(json!(index))
} else {
Err(ContextError::VariableNotFound("index".to_string()))
}
}
"system" => {
if parts.len() < 2 {
return Err(ContextError::InvalidExpression(expr.to_string()));
}
let key = parts[1];
if let Some(entry) = self.system.get(key) {
Ok(entry.value().clone())
} else {
Err(ContextError::VariableNotFound(format!("system.{}", key)))
}
}
// Direct variable reference (e.g., `number_list` published by a
// previous task's transition)
var_name => {
if let Some(entry) = self.variables.get(var_name) {
let value = entry.value().clone();
drop(entry);
if parts.len() > 1 {
self.get_nested_value(&value, &parts[1..])
} else {
Ok(value)
}
} else {
Err(ContextError::VariableNotFound(var_name.to_string()))
}
}
}
}
/// Try to evaluate `expr` as a function-call expression.
/// Evaluate a template expression using the expression engine.
///
/// Returns `Ok(Some(value))` if the expression starts with a recognised
/// function call, `Ok(None)` if it does not match, or `Err` on failure.
fn try_evaluate_function_call(&self, expr: &str) -> ContextResult<Option<JsonValue>> {
// succeeded()
if expr == "succeeded()" {
let val = self
.last_task_outcome
.map(|o| o == TaskOutcome::Succeeded)
.unwrap_or(false);
return Ok(Some(json!(val)));
}
// failed()
if expr == "failed()" {
let val = self
.last_task_outcome
.map(|o| o == TaskOutcome::Failed)
.unwrap_or(false);
return Ok(Some(json!(val)));
}
// timed_out()
if expr == "timed_out()" {
let val = self
.last_task_outcome
.map(|o| o == TaskOutcome::TimedOut)
.unwrap_or(false);
return Ok(Some(json!(val)));
}
// result() or result().path.to.field
if expr == "result()" || expr.starts_with("result().") {
let base = self.last_task_result.clone().unwrap_or(JsonValue::Null);
if expr == "result()" {
return Ok(Some(base));
}
// Strip "result()." prefix and navigate the remaining path
let rest = &expr["result().".len()..];
let path_parts: Vec<&str> = rest.split('.').collect();
let val = self.get_nested_value(&base, &path_parts)?;
return Ok(Some(val));
}
Ok(None)
/// Supports the full expression language including arithmetic, comparison,
/// boolean logic, member access, and built-in functions. Falls back to
/// legacy dot-path resolution for simple variable references when the
/// expression engine cannot parse the input.
fn evaluate_expression(&self, expr: &str) -> ContextResult<JsonValue> {
// Use the expression engine for all expressions. It handles:
// - Dot-path access: parameters.config.port
// - Bracket access: arr[0], obj["key"]
// - Arithmetic: 2 + 3, length(items) * 2
// - Comparison: x > 5, status == "ok"
// - Boolean logic: x > 0 and x < 10
// - Function calls: length(arr), result(), succeeded()
// - Membership: "key" in obj, 5 in arr
expression::eval_expression(expr, self).map_err(|e| match e {
EvalError::VariableNotFound(name) => ContextError::VariableNotFound(name),
EvalError::TypeError(msg) => ContextError::TypeConversion(msg),
EvalError::ParseError(msg) => ContextError::InvalidExpression(msg),
other => ContextError::InvalidExpression(format!("{}", other)),
})
}
/// Get nested value from JSON
fn get_nested_value(&self, value: &JsonValue, path: &[&str]) -> ContextResult<JsonValue> {
let mut current = value;
for key in path {
match current {
JsonValue::Object(obj) => {
current = obj
.get(*key)
.ok_or_else(|| ContextError::VariableNotFound(key.to_string()))?;
}
JsonValue::Array(arr) => {
let index: usize = key.parse().map_err(|_| {
ContextError::InvalidExpression(format!("Invalid array index: {}", key))
})?;
current = arr.get(index).ok_or_else(|| {
ContextError::InvalidExpression(format!(
"Array index out of bounds: {}",
index
))
})?;
}
_ => {
return Err(ContextError::InvalidExpression(format!(
"Cannot access property '{}' on non-object/array value",
key
)));
}
}
}
Ok(current.clone())
}
/// Evaluate a conditional expression (for 'when' clauses)
/// Evaluate a conditional expression (for 'when' clauses).
///
/// Uses the full expression engine so conditions can contain comparisons,
/// boolean operators, function calls, and arithmetic. For example:
///
/// ```text
/// succeeded()
/// result().status == "ok"
/// length(items) > 3 and "admin" in roles
/// not failed()
/// ```
pub fn evaluate_condition(&self, condition: &str) -> ContextResult<bool> {
// For now, simple boolean evaluation
// TODO: Support more complex expressions (comparisons, logical operators)
let rendered = self.render_template(condition)?;
// Try to parse as boolean
match rendered.trim().to_lowercase().as_str() {
"true" | "1" | "yes" => Ok(true),
"false" | "0" | "no" | "" => Ok(false),
other => {
// Try to evaluate as truthy/falsy
Ok(!other.is_empty())
// Try the expression engine first — it handles complex conditions
// like `result().code == 200 and succeeded()`.
match expression::eval_expression(condition, self) {
Ok(val) => Ok(is_truthy(&val)),
Err(_) => {
// Fall back to template rendering for backward compat with
// simple template conditions like `{{ succeeded() }}` (though
// bare expressions are preferred going forward).
let rendered = self.render_template(condition)?;
match rendered.trim().to_lowercase().as_str() {
"true" | "1" | "yes" => Ok(true),
"false" | "0" | "no" | "" => Ok(false),
_ => Ok(!rendered.trim().is_empty()),
}
}
}
}
@@ -574,6 +464,8 @@ impl WorkflowContext {
"parameters": self.parameters.as_ref(),
"task_results": task_results,
"system": system,
"pack_config": self.pack_config.as_ref(),
"keystore": self.keystore.as_ref(),
})
}
@@ -602,11 +494,16 @@ impl WorkflowContext {
}
}
let pack_config = data["pack_config"].clone();
let keystore = data["keystore"].clone();
Ok(Self {
variables: Arc::new(variables),
parameters: Arc::new(parameters),
task_results: Arc::new(task_results),
system: Arc::new(system),
pack_config: Arc::new(pack_config),
keystore: Arc::new(keystore),
current_item: None,
current_index: None,
last_task_result: None,
@@ -626,10 +523,122 @@ fn value_to_string(value: &JsonValue) -> String {
}
}
// ---------------------------------------------------------------
// EvalContext implementation — bridges the expression engine into
// the WorkflowContext's variable resolution and workflow functions.
// ---------------------------------------------------------------
impl EvalContext for WorkflowContext {
fn resolve_variable(&self, name: &str) -> ExprResult<JsonValue> {
match name {
// ── Canonical namespaces ──────────────────────────────
"parameters" => Ok(self.parameters.as_ref().clone()),
// `workflow` is the canonical name for mutable vars.
// `vars` and `variables` are backward-compatible aliases.
"workflow" | "vars" | "variables" => {
let map: serde_json::Map<String, JsonValue> = self
.variables
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
Ok(JsonValue::Object(map))
}
// `task` (alias: `tasks`) — completed task results.
"task" | "tasks" => {
let map: serde_json::Map<String, JsonValue> = self
.task_results
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
Ok(JsonValue::Object(map))
}
// `config` — pack configuration (read-only).
"config" => Ok(self.pack_config.as_ref().clone()),
// `keystore` — encrypted secrets (read-only).
"keystore" => Ok(self.keystore.as_ref().clone()),
// ── Iteration context ────────────────────────────────
"item" => self
.current_item
.clone()
.ok_or_else(|| EvalError::VariableNotFound("item".to_string())),
"index" => self
.current_index
.map(|i| json!(i))
.ok_or_else(|| EvalError::VariableNotFound("index".to_string())),
// ── System variables ──────────────────────────────────
"system" => {
let map: serde_json::Map<String, JsonValue> = self
.system
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
Ok(JsonValue::Object(map))
}
// ── Bare-name fallback ───────────────────────────────
// Resolve against workflow variables last so that
// `{{ my_var }}` still works as shorthand for
// `{{ workflow.my_var }}`.
_ => {
if let Some(entry) = self.variables.get(name) {
Ok(entry.value().clone())
} else {
Err(EvalError::VariableNotFound(name.to_string()))
}
}
}
}
fn call_workflow_function(
&self,
name: &str,
_args: &[JsonValue],
) -> ExprResult<Option<JsonValue>> {
match name {
"succeeded" => {
let val = self
.last_task_outcome
.map(|o| o == TaskOutcome::Succeeded)
.unwrap_or(false);
Ok(Some(json!(val)))
}
"failed" => {
let val = self
.last_task_outcome
.map(|o| o == TaskOutcome::Failed)
.unwrap_or(false);
Ok(Some(json!(val)))
}
"timed_out" => {
let val = self
.last_task_outcome
.map(|o| o == TaskOutcome::TimedOut)
.unwrap_or(false);
Ok(Some(json!(val)))
}
"result" => {
let base = self.last_task_result.clone().unwrap_or(JsonValue::Null);
Ok(Some(base))
}
_ => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
// ---------------------------------------------------------------
// parameters namespace
// ---------------------------------------------------------------
#[test]
fn test_basic_template_rendering() {
let params = json!({
@@ -641,28 +650,6 @@ mod tests {
assert_eq!(result, "Hello World!");
}
#[test]
fn test_variable_access() {
let mut vars = HashMap::new();
vars.insert("greeting".to_string(), json!("Hello"));
let ctx = WorkflowContext::new(json!({}), vars);
let result = ctx.render_template("{{ greeting }} World").unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn test_task_result_access() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_task_result("task1", json!({"status": "success"}));
let result = ctx
.render_template("Status: {{ task.task1.status }}")
.unwrap();
assert_eq!(result, "Status: success");
}
#[test]
fn test_nested_value_access() {
let params = json!({
@@ -680,6 +667,143 @@ mod tests {
assert_eq!(result, "Port: 8080");
}
// ---------------------------------------------------------------
// workflow namespace (canonical) + vars/variables aliases
// ---------------------------------------------------------------
#[test]
fn test_workflow_namespace_canonical() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_var("greeting", json!("Hello"));
// Canonical: workflow.<name>
let result = ctx.render_template("{{ workflow.greeting }} World").unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn test_workflow_namespace_vars_alias() {
let mut vars = HashMap::new();
vars.insert("greeting".to_string(), json!("Hello"));
let ctx = WorkflowContext::new(json!({}), vars);
// Backward-compat alias: vars.<name>
let result = ctx.render_template("{{ vars.greeting }} World").unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn test_workflow_namespace_variables_alias() {
let mut vars = HashMap::new();
vars.insert("greeting".to_string(), json!("Hello"));
let ctx = WorkflowContext::new(json!({}), vars);
// Backward-compat alias: variables.<name>
let result = ctx.render_template("{{ variables.greeting }} World").unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn test_variable_access_bare_name_fallback() {
let mut vars = HashMap::new();
vars.insert("greeting".to_string(), json!("Hello"));
let ctx = WorkflowContext::new(json!({}), vars);
// Bare name falls back to workflow variables
let result = ctx.render_template("{{ greeting }} World").unwrap();
assert_eq!(result, "Hello World");
}
// ---------------------------------------------------------------
// task namespace
// ---------------------------------------------------------------
#[test]
fn test_task_result_access() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_task_result("task1", json!({"status": "success"}));
let result = ctx
.render_template("Status: {{ task.task1.status }}")
.unwrap();
assert_eq!(result, "Status: success");
}
#[test]
fn test_task_result_deep_access() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_task_result("fetch", json!({"result": {"data": {"id": 42}}}));
let val = ctx.evaluate_expression("task.fetch.result.data.id").unwrap();
assert_eq!(val, json!(42));
}
#[test]
fn test_task_result_stdout() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_task_result("run_cmd", json!({"result": {"stdout": "hello world"}}));
let val = ctx.evaluate_expression("task.run_cmd.result.stdout").unwrap();
assert_eq!(val, json!("hello world"));
}
// ---------------------------------------------------------------
// config namespace (pack configuration)
// ---------------------------------------------------------------
#[test]
fn test_config_namespace() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_pack_config(json!({"api_token": "tok_abc123", "base_url": "https://api.example.com"}));
let val = ctx.evaluate_expression("config.api_token").unwrap();
assert_eq!(val, json!("tok_abc123"));
let result = ctx
.render_template("URL: {{ config.base_url }}")
.unwrap();
assert_eq!(result, "URL: https://api.example.com");
}
#[test]
fn test_config_namespace_nested() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_pack_config(json!({"slack": {"webhook_url": "https://hooks.slack.com/xxx"}}));
let val = ctx.evaluate_expression("config.slack.webhook_url").unwrap();
assert_eq!(val, json!("https://hooks.slack.com/xxx"));
}
// ---------------------------------------------------------------
// keystore namespace (encrypted secrets)
// ---------------------------------------------------------------
#[test]
fn test_keystore_namespace() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_keystore(json!({"secret_key": "s3cr3t", "db_password": "hunter2"}));
let val = ctx.evaluate_expression("keystore.secret_key").unwrap();
assert_eq!(val, json!("s3cr3t"));
let val = ctx.evaluate_expression("keystore.db_password").unwrap();
assert_eq!(val, json!("hunter2"));
}
#[test]
fn test_keystore_bracket_access() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_keystore(json!({"My Secret Key": "value123"}));
let val = ctx.evaluate_expression("keystore[\"My Secret Key\"]").unwrap();
assert_eq!(val, json!("value123"));
}
// ---------------------------------------------------------------
// item / index (with_items iteration)
// ---------------------------------------------------------------
#[test]
fn test_item_context() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
@@ -691,6 +815,10 @@ mod tests {
assert_eq!(result, "Item: item1, Index: 0");
}
// ---------------------------------------------------------------
// Condition evaluation
// ---------------------------------------------------------------
#[test]
fn test_condition_evaluation() {
let params = json!({"enabled": true});
@@ -700,6 +828,133 @@ mod tests {
assert!(!ctx.evaluate_condition("false").unwrap());
}
#[test]
fn test_condition_with_comparison() {
let ctx = WorkflowContext::new(json!({"count": 10}), HashMap::new());
assert!(ctx.evaluate_condition("parameters.count > 5").unwrap());
assert!(!ctx.evaluate_condition("parameters.count < 5").unwrap());
assert!(ctx.evaluate_condition("parameters.count == 10").unwrap());
assert!(ctx.evaluate_condition("parameters.count >= 10").unwrap());
assert!(ctx.evaluate_condition("parameters.count != 99").unwrap());
}
#[test]
fn test_condition_with_boolean_operators() {
let ctx = WorkflowContext::new(json!({"x": 10, "y": 20}), HashMap::new());
assert!(ctx
.evaluate_condition("parameters.x > 5 and parameters.y > 15")
.unwrap());
assert!(!ctx
.evaluate_condition("parameters.x > 5 and parameters.y > 25")
.unwrap());
assert!(ctx
.evaluate_condition("parameters.x > 50 or parameters.y > 15")
.unwrap());
assert!(ctx
.evaluate_condition("not parameters.x > 50")
.unwrap());
}
#[test]
fn test_condition_with_in_operator() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_var("roles", json!(["admin", "user"]));
// Via bare-name fallback
assert!(ctx.evaluate_condition("\"admin\" in roles").unwrap());
assert!(!ctx.evaluate_condition("\"root\" in roles").unwrap());
// Via canonical workflow namespace
assert!(ctx.evaluate_condition("\"admin\" in workflow.roles").unwrap());
}
#[test]
fn test_condition_with_function_calls() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_last_task_outcome(
json!({"status": "ok", "code": 200}),
TaskOutcome::Succeeded,
);
assert!(ctx.evaluate_condition("succeeded()").unwrap());
assert!(!ctx.evaluate_condition("failed()").unwrap());
assert!(ctx
.evaluate_condition("succeeded() and result().code == 200")
.unwrap());
assert!(!ctx
.evaluate_condition("succeeded() and result().code == 404")
.unwrap());
}
#[test]
fn test_condition_with_length() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_var("items", json!([1, 2, 3, 4, 5]));
assert!(ctx.evaluate_condition("length(items) > 3").unwrap());
assert!(!ctx.evaluate_condition("length(items) > 10").unwrap());
assert!(ctx
.evaluate_condition("length(items) == 5")
.unwrap());
}
#[test]
fn test_condition_with_config() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_pack_config(json!({"retries": 3}));
assert!(ctx.evaluate_condition("config.retries > 0").unwrap());
assert!(ctx.evaluate_condition("config.retries == 3").unwrap());
}
// ---------------------------------------------------------------
// Expression engine in templates
// ---------------------------------------------------------------
#[test]
fn test_expression_arithmetic() {
let ctx = WorkflowContext::new(json!({"x": 10}), HashMap::new());
let input = json!({"result": "{{ parameters.x + 5 }}"});
let result = ctx.render_json(&input).unwrap();
assert_eq!(result["result"], json!(15));
}
#[test]
fn test_expression_string_concat() {
let ctx = WorkflowContext::new(
json!({"first": "Hello", "second": "World"}),
HashMap::new(),
);
let input = json!({"msg": "{{ parameters.first + \" \" + parameters.second }}"});
let result = ctx.render_json(&input).unwrap();
assert_eq!(result["msg"], json!("Hello World"));
}
#[test]
fn test_expression_nested_functions() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_var("data", json!("a,b,c"));
let input = json!({"count": "{{ length(split(data, \",\")) }}"});
let result = ctx.render_json(&input).unwrap();
assert_eq!(result["count"], json!(3));
}
#[test]
fn test_expression_bracket_access() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_var("arr", json!([10, 20, 30]));
let input = json!({"second": "{{ arr[1] }}"});
let result = ctx.render_json(&input).unwrap();
assert_eq!(result["second"], json!(20));
}
#[test]
fn test_expression_type_conversion() {
let ctx = WorkflowContext::new(json!({}), HashMap::new());
let input = json!({"val": "{{ int(3.9) }}"});
let result = ctx.render_json(&input).unwrap();
assert_eq!(result["val"], json!(3));
}
// ---------------------------------------------------------------
// render_json type-preserving behaviour
// ---------------------------------------------------------------
#[test]
fn test_render_json() {
let params = json!({"name": "test"});
@@ -769,6 +1024,10 @@ mod tests {
assert!(result["ok"].is_boolean());
}
// ---------------------------------------------------------------
// result() / succeeded() / failed() / timed_out()
// ---------------------------------------------------------------
#[test]
fn test_result_function() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
@@ -813,6 +1072,10 @@ mod tests {
assert_eq!(ctx.evaluate_expression("timed_out()").unwrap(), json!(true));
}
// ---------------------------------------------------------------
// Publish
// ---------------------------------------------------------------
#[test]
fn test_publish_with_result_function() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
@@ -846,6 +1109,28 @@ mod tests {
assert_eq!(ctx.get_var("my_var").unwrap(), result);
}
#[test]
fn test_published_var_accessible_via_workflow_namespace() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_var("counter", json!(42));
// Via canonical namespace
let val = ctx.evaluate_expression("workflow.counter").unwrap();
assert_eq!(val, json!(42));
// Via backward-compat alias
let val = ctx.evaluate_expression("vars.counter").unwrap();
assert_eq!(val, json!(42));
// Via bare-name fallback
let val = ctx.evaluate_expression("counter").unwrap();
assert_eq!(val, json!(42));
}
// ---------------------------------------------------------------
// Rebuild / Export / Import round-trip
// ---------------------------------------------------------------
#[test]
fn test_rebuild_context() {
let stored_vars = json!({"number_list": [0, 1, 2]});
@@ -868,17 +1153,31 @@ mod tests {
let mut ctx = WorkflowContext::new(json!({"key": "value"}), HashMap::new());
ctx.set_var("test", json!("data"));
ctx.set_task_result("task1", json!({"result": "ok"}));
ctx.set_pack_config(json!({"setting": "val"}));
ctx.set_keystore(json!({"secret": "hidden"}));
let exported = ctx.export();
let _imported = WorkflowContext::import(exported).unwrap();
let imported = WorkflowContext::import(exported).unwrap();
assert_eq!(ctx.get_var("test").unwrap(), json!("data"));
assert_eq!(imported.get_var("test").unwrap(), json!("data"));
assert_eq!(
ctx.get_task_result("task1").unwrap(),
imported.get_task_result("task1").unwrap(),
json!({"result": "ok"})
);
assert_eq!(
imported.evaluate_expression("config.setting").unwrap(),
json!("val")
);
assert_eq!(
imported.evaluate_expression("keystore.secret").unwrap(),
json!("hidden")
);
}
// ---------------------------------------------------------------
// with_items type preservation
// ---------------------------------------------------------------
#[test]
fn test_with_items_integer_type_preservation() {
// Simulates the sleep_2 task from the hello_workflow:
@@ -902,4 +1201,40 @@ mod tests {
assert_eq!(rendered["message"], json!("Sleeping for 3 seconds "));
assert!(rendered["message"].is_string());
}
// ---------------------------------------------------------------
// Cross-namespace expressions
// ---------------------------------------------------------------
#[test]
fn test_cross_namespace_expression() {
let mut ctx = WorkflowContext::new(json!({"limit": 5}), HashMap::new());
ctx.set_var("items", json!([1, 2, 3]));
ctx.set_pack_config(json!({"multiplier": 2}));
assert!(ctx
.evaluate_condition("length(workflow.items) < parameters.limit")
.unwrap());
let val = ctx
.evaluate_expression("parameters.limit * config.multiplier")
.unwrap();
assert_eq!(val, json!(10));
}
#[test]
fn test_keystore_in_template() {
let mut ctx = WorkflowContext::new(json!({}), HashMap::new());
ctx.set_keystore(json!({"api_key": "abc-123"}));
let input = json!({"auth": "Bearer {{ keystore.api_key }}"});
let result = ctx.render_json(&input).unwrap();
assert_eq!(result["auth"], json!("Bearer abc-123"));
}
#[test]
fn test_config_null_when_not_set() {
let ctx = WorkflowContext::new(json!({}), HashMap::new());
let val = ctx.evaluate_expression("config").unwrap();
assert_eq!(val, json!(null));
}
}

View File

@@ -654,8 +654,7 @@ impl ActionExecutor {
let input = UpdateExecutionInput {
status: Some(ExecutionStatus::Completed),
result: Some(result_data),
executor: None,
workflow_task: None, // Not updating workflow metadata
..Default::default()
};
ExecutionRepository::update(&self.pool, execution_id, input).await?;
@@ -755,8 +754,7 @@ impl ActionExecutor {
let input = UpdateExecutionInput {
status: Some(ExecutionStatus::Failed),
result: Some(result_data),
executor: None,
workflow_task: None, // Not updating workflow metadata
..Default::default()
};
ExecutionRepository::update(&self.pool, execution_id, input).await?;
@@ -775,11 +773,16 @@ impl ActionExecutor {
execution_id, status
);
let started_at = if status == ExecutionStatus::Running {
Some(chrono::Utc::now())
} else {
None
};
let input = UpdateExecutionInput {
status: Some(status),
result: None,
executor: None,
workflow_task: None, // Not updating workflow metadata
started_at,
..Default::default()
};
ExecutionRepository::update(&self.pool, execution_id, input).await?;