proper sql filtering

This commit is contained in:
2026-03-01 20:43:48 -06:00
parent 6b9d7d6cf2
commit bbe94d75f8
54 changed files with 6692 additions and 928 deletions

View File

@@ -1104,6 +1104,11 @@ pub mod execution {
pub status: ExecutionStatus,
pub result: Option<JsonDict>,
/// When the execution actually started running (worker picked it up).
/// Set when status transitions to `Running`. Used to compute accurate
/// duration that excludes queue/scheduling wait time.
pub started_at: Option<DateTime<Utc>>,
/// Workflow task metadata (only populated for workflow task executions)
///
/// Provides direct access to workflow orchestration state without JOINs.

View File

@@ -8,6 +8,26 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
/// Filters for [`ActionRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct ActionSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Text search across ref, label, description (case-insensitive)
pub query: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`ActionRepository::list_search`].
#[derive(Debug)]
pub struct ActionSearchResult {
pub rows: Vec<Action>,
pub total: u64,
}
/// Repository for Action operations
pub struct ActionRepository;
@@ -287,6 +307,92 @@ impl Delete for ActionRepository {
}
impl ActionRepository {
/// Search actions with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(
db: E,
filters: &ActionSearchFilters,
) -> Result<ActionSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_version_constraint, param_schema, out_schema, workflow_def, is_adhoc, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM action"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM action");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(ref query) = filters.query {
let pattern = format!("%{}%", query.to_lowercase());
// Search needs an OR across multiple columns, wrapped in parens
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push("(LOWER(ref) LIKE ");
qb.push_bind(pattern.clone());
qb.push(" OR LOWER(label) LIKE ");
qb.push_bind(pattern.clone());
qb.push(" OR LOWER(description) LIKE ");
qb.push_bind(pattern.clone());
qb.push(")");
count_qb.push("(LOWER(ref) LIKE ");
count_qb.push_bind(pattern.clone());
count_qb.push(" OR LOWER(label) LIKE ");
count_qb.push_bind(pattern.clone());
count_qb.push(" OR LOWER(description) LIKE ");
count_qb.push_bind(pattern);
count_qb.push(")");
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Action> = qb.build_query_as().fetch_all(db).await?;
Ok(ActionSearchResult { rows, total })
}
/// Find actions by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Action>>
where

View File

@@ -15,6 +15,56 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
// ============================================================================
// Event Search
// ============================================================================
/// Filters for [`EventRepository::search`].
///
/// All fields are optional. When set, the corresponding WHERE clause is added.
/// Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct EventSearchFilters {
pub trigger: Option<Id>,
pub trigger_ref: Option<String>,
pub source: Option<Id>,
pub rule_ref: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`EventRepository::search`].
#[derive(Debug)]
pub struct EventSearchResult {
pub rows: Vec<Event>,
pub total: u64,
}
// ============================================================================
// Enforcement Search
// ============================================================================
/// Filters for [`EnforcementRepository::search`].
///
/// All fields are optional and combinable. Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct EnforcementSearchFilters {
pub rule: Option<Id>,
pub event: Option<Id>,
pub status: Option<EnforcementStatus>,
pub trigger_ref: Option<String>,
pub rule_ref: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`EnforcementRepository::search`].
#[derive(Debug)]
pub struct EnforcementSearchResult {
pub rows: Vec<Enforcement>,
pub total: u64,
}
/// Repository for Event operations
pub struct EventRepository;
@@ -173,6 +223,75 @@ impl EventRepository {
Ok(events)
}
/// Search events with all filters pushed into SQL.
///
/// Builds a dynamic query so that every filter, pagination, and the total
/// count are handled in the database — no in-memory filtering or slicing.
pub async fn search<'e, E>(db: E, filters: &EventSearchFilters) -> Result<EventSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, trigger, trigger_ref, config, payload, source, source_ref, rule, rule_ref, created";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM event"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM event");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(trigger_id) = filters.trigger {
push_condition!("trigger = ", trigger_id);
}
if let Some(ref trigger_ref) = filters.trigger_ref {
push_condition!("trigger_ref = ", trigger_ref.clone());
}
if let Some(source_id) = filters.source {
push_condition!("source = ", source_id);
}
if let Some(ref rule_ref) = filters.rule_ref {
push_condition!(
"LOWER(rule_ref) LIKE ",
format!("%{}%", rule_ref.to_lowercase())
);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Event> = qb.build_query_as().fetch_all(db).await?;
Ok(EventSearchResult { rows, total })
}
}
// ============================================================================
@@ -425,4 +544,75 @@ impl EnforcementRepository {
Ok(enforcements)
}
/// Search enforcements with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn search<'e, E>(
db: E,
filters: &EnforcementSearchFilters,
) -> Result<EnforcementSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, rule, rule_ref, trigger_ref, config, event, status, payload, condition, conditions, created, resolved_at";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM enforcement"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM enforcement");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(status) = &filters.status {
push_condition!("status = ", status.clone());
}
if let Some(rule_id) = filters.rule {
push_condition!("rule = ", rule_id);
}
if let Some(event_id) = filters.event {
push_condition!("event = ", event_id);
}
if let Some(ref trigger_ref) = filters.trigger_ref {
push_condition!("trigger_ref = ", trigger_ref.clone());
}
if let Some(ref rule_ref) = filters.rule_ref {
push_condition!("rule_ref = ", rule_ref.clone());
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Enforcement> = qb.build_query_as().fetch_all(db).await?;
Ok(EnforcementSearchResult { rows, total })
}
}

View File

@@ -1,11 +1,71 @@
//! Execution repository for database operations
use chrono::{DateTime, Utc};
use crate::models::{enums::ExecutionStatus, execution::*, Id, JsonDict};
use crate::Result;
use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
/// Filters for the [`ExecutionRepository::search`] query-builder method.
///
/// Every field is optional. When set, the corresponding `WHERE` clause is
/// appended to the query. Pagination (`limit`/`offset`) is always applied.
///
/// Filters that involve the `enforcement` table (`rule_ref`, `trigger_ref`)
/// cause a `LEFT JOIN enforcement` to be added automatically.
#[derive(Debug, Clone, Default)]
pub struct ExecutionSearchFilters {
pub status: Option<ExecutionStatus>,
pub action_ref: Option<String>,
pub pack_name: Option<String>,
pub rule_ref: Option<String>,
pub trigger_ref: Option<String>,
pub executor: Option<Id>,
pub result_contains: Option<String>,
pub enforcement: Option<Id>,
pub parent: Option<Id>,
pub top_level_only: bool,
pub limit: u32,
pub offset: u32,
}
/// Result of [`ExecutionRepository::search`].
///
/// Includes the matching rows *and* the total count (before LIMIT/OFFSET)
/// so the caller can build pagination metadata without a second round-trip.
#[derive(Debug)]
pub struct ExecutionSearchResult {
pub rows: Vec<ExecutionWithRefs>,
pub total: u64,
}
/// An execution row with optional `rule_ref` / `trigger_ref` populated from
/// the joined `enforcement` table. This avoids a separate in-memory lookup.
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct ExecutionWithRefs {
// — execution columns (same order as SELECT_COLUMNS) —
pub id: Id,
pub action: Option<Id>,
pub action_ref: String,
pub config: Option<JsonDict>,
pub env_vars: Option<JsonDict>,
pub parent: Option<Id>,
pub enforcement: Option<Id>,
pub executor: Option<Id>,
pub status: ExecutionStatus,
pub result: Option<JsonDict>,
pub started_at: Option<DateTime<Utc>>,
#[sqlx(json, default)]
pub workflow_task: Option<WorkflowTaskMetadata>,
pub created: DateTime<Utc>,
pub updated: DateTime<Utc>,
// — joined from enforcement —
pub rule_ref: Option<String>,
pub trigger_ref: Option<String>,
}
/// Column list for SELECT queries on the execution table.
///
/// Defined once to avoid drift between queries and the `Execution` model.
@@ -13,7 +73,7 @@ use super::{Create, Delete, FindById, List, Repository, Update};
/// are NOT in the Rust struct, so `SELECT *` must never be used.
pub const SELECT_COLUMNS: &str = "\
id, action, action_ref, config, env_vars, parent, enforcement, \
executor, status, result, workflow_task, created, updated";
executor, status, result, started_at, workflow_task, created, updated";
pub struct ExecutionRepository;
@@ -43,6 +103,7 @@ pub struct UpdateExecutionInput {
pub status: Option<ExecutionStatus>,
pub result: Option<JsonDict>,
pub executor: Option<Id>,
pub started_at: Option<DateTime<Utc>>,
pub workflow_task: Option<WorkflowTaskMetadata>,
}
@@ -52,6 +113,7 @@ impl From<Execution> for UpdateExecutionInput {
status: Some(execution.status),
result: execution.result,
executor: execution.executor,
started_at: execution.started_at,
workflow_task: execution.workflow_task,
}
}
@@ -146,6 +208,13 @@ impl Update for ExecutionRepository {
query.push("executor = ").push_bind(executor_id);
has_updates = true;
}
if let Some(started_at) = input.started_at {
if has_updates {
query.push(", ");
}
query.push("started_at = ").push_bind(started_at);
has_updates = true;
}
if let Some(workflow_task) = &input.workflow_task {
if has_updates {
query.push(", ");
@@ -239,4 +308,141 @@ impl ExecutionRepository {
.await
.map_err(Into::into)
}
/// Search executions with all filters pushed into SQL.
///
/// Builds a dynamic query with only the WHERE clauses that apply,
/// a LEFT JOIN on `enforcement` when `rule_ref` or `trigger_ref` filters
/// are present (or always, to populate those columns on the result),
/// and proper LIMIT/OFFSET so pagination is server-side.
///
/// Returns both the matching page of rows and the total count.
pub async fn search<'e, E>(
db: E,
filters: &ExecutionSearchFilters,
) -> Result<ExecutionSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
// We always LEFT JOIN enforcement so we can return rule_ref/trigger_ref
// on every row without a second round-trip.
let prefixed_select = SELECT_COLUMNS
.split(", ")
.map(|col| format!("e.{col}"))
.collect::<Vec<_>>()
.join(", ");
let select_clause = format!(
"{prefixed_select}, enf.rule_ref AS rule_ref, enf.trigger_ref AS trigger_ref"
);
let from_clause = "FROM execution e LEFT JOIN enforcement enf ON e.enforcement = enf.id";
// ── Build WHERE clauses ──────────────────────────────────────────
let mut conditions: Vec<String> = Vec::new();
// We'll collect bind values to push into the QueryBuilder afterwards.
// Because QueryBuilder doesn't let us interleave raw SQL and binds in
// arbitrary order easily, we build the SQL string with numbered $N
// placeholders and then bind in order.
// Track the next placeholder index ($1, $2, …).
// We can't use QueryBuilder's push_bind because we need the COUNT(*)
// query to share the same WHERE clause text. Instead we build the
// clause once and execute both queries with manual binds.
// ── Use QueryBuilder for the data query ──────────────────────────
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_clause} {from_clause}"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT COUNT(*) AS total {from_clause}"));
// Helper: append the same condition to both builders.
// We need a tiny state machine since push_bind moves the value.
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
let needs_where = conditions.is_empty();
conditions.push(String::new()); // just to track count
if needs_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
macro_rules! push_raw_condition {
($cond:expr) => {{
let needs_where = conditions.is_empty();
conditions.push(String::new());
if needs_where {
qb.push(concat!(" WHERE ", $cond));
count_qb.push(concat!(" WHERE ", $cond));
} else {
qb.push(concat!(" AND ", $cond));
count_qb.push(concat!(" AND ", $cond));
}
}};
}
if let Some(status) = &filters.status {
push_condition!("e.status = ", status.clone());
}
if let Some(action_ref) = &filters.action_ref {
push_condition!("e.action_ref = ", action_ref.clone());
}
if let Some(pack_name) = &filters.pack_name {
let pattern = format!("{pack_name}.%");
push_condition!("e.action_ref LIKE ", pattern);
}
if let Some(enforcement_id) = filters.enforcement {
push_condition!("e.enforcement = ", enforcement_id);
}
if let Some(parent_id) = filters.parent {
push_condition!("e.parent = ", parent_id);
}
if filters.top_level_only {
push_raw_condition!("e.parent IS NULL");
}
if let Some(executor_id) = filters.executor {
push_condition!("e.executor = ", executor_id);
}
if let Some(rule_ref) = &filters.rule_ref {
push_condition!("enf.rule_ref = ", rule_ref.clone());
}
if let Some(trigger_ref) = &filters.trigger_ref {
push_condition!("enf.trigger_ref = ", trigger_ref.clone());
}
if let Some(search) = &filters.result_contains {
let pattern = format!("%{}%", search.to_lowercase());
push_condition!("LOWER(e.result::text) LIKE ", pattern);
}
// ── COUNT query ──────────────────────────────────────────────────
let total: i64 = count_qb
.build_query_scalar()
.fetch_one(db)
.await?;
let total = total.max(0) as u64;
// ── Data query with ORDER BY + pagination ────────────────────────
qb.push(" ORDER BY e.created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<ExecutionWithRefs> = qb
.build_query_as()
.fetch_all(db)
.await?;
Ok(ExecutionSearchResult { rows, total })
}
}

View File

@@ -7,6 +7,25 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
/// Filters for [`InquiryRepository::search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct InquirySearchFilters {
pub status: Option<InquiryStatus>,
pub execution: Option<Id>,
pub assigned_to: Option<Id>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`InquiryRepository::search`].
#[derive(Debug)]
pub struct InquirySearchResult {
pub rows: Vec<Inquiry>,
pub total: u64,
}
pub struct InquiryRepository;
impl Repository for InquiryRepository {
@@ -157,4 +176,66 @@ impl InquiryRepository {
"SELECT id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated FROM inquiry WHERE execution = $1 ORDER BY created DESC"
).bind(execution_id).fetch_all(executor).await.map_err(Into::into)
}
/// Search inquiries with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn search<'e, E>(db: E, filters: &InquirySearchFilters) -> Result<InquirySearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM inquiry"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM inquiry");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(status) = &filters.status {
push_condition!("status = ", status.clone());
}
if let Some(execution_id) = filters.execution {
push_condition!("execution = ", execution_id);
}
if let Some(assigned_to) = filters.assigned_to {
push_condition!("assigned_to = ", assigned_to);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY created DESC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Inquiry> = qb.build_query_as().fetch_all(db).await?;
Ok(InquirySearchResult { rows, total })
}
}

View File

@@ -6,6 +6,24 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
/// Filters for [`KeyRepository::search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct KeySearchFilters {
pub owner_type: Option<OwnerType>,
pub owner: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`KeyRepository::search`].
#[derive(Debug)]
pub struct KeySearchResult {
pub rows: Vec<Key>,
pub total: u64,
}
pub struct KeyRepository;
impl Repository for KeyRepository {
@@ -165,4 +183,63 @@ impl KeyRepository {
"SELECT id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated FROM key WHERE owner_type = $1 ORDER BY ref ASC"
).bind(owner_type).fetch_all(executor).await.map_err(Into::into)
}
/// Search keys with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn search<'e, E>(db: E, filters: &KeySearchFilters) -> Result<KeySearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM key"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM key");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(ref owner_type) = filters.owner_type {
push_condition!("owner_type = ", owner_type.clone());
}
if let Some(ref owner) = filters.owner {
push_condition!("owner = ", owner.clone());
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Key> = qb.build_query_as().fetch_all(db).await?;
Ok(KeySearchResult { rows, total })
}
}

View File

@@ -8,6 +8,30 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
/// Filters for [`RuleRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct RuleSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by action ID
pub action: Option<Id>,
/// Filter by trigger ID
pub trigger: Option<Id>,
/// Filter by enabled status
pub enabled: Option<bool>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`RuleRepository::list_search`].
#[derive(Debug)]
pub struct RuleSearchResult {
pub rows: Vec<Rule>,
pub total: u64,
}
/// Input for restoring an ad-hoc rule during pack reinstallation.
/// Unlike `CreateRuleInput`, action and trigger IDs are optional because
/// the referenced entities may not exist yet or may have been removed.
@@ -275,6 +299,71 @@ impl Delete for RuleRepository {
}
impl RuleRepository {
/// Search rules with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(db: E, filters: &RuleSearchFilters) -> Result<RuleSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, action, action_ref, trigger, trigger_ref, conditions, action_params, trigger_params, enabled, is_adhoc, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM rule"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM rule");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(action_id) = filters.action {
push_condition!("action = ", action_id);
}
if let Some(trigger_id) = filters.trigger {
push_condition!("trigger = ", trigger_id);
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Rule> = qb.build_query_as().fetch_all(db).await?;
Ok(RuleSearchResult { rows, total })
}
/// Find rules by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Rule>>
where

View File

@@ -9,6 +9,56 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
// ============================================================================
// Trigger Search
// ============================================================================
/// Filters for [`TriggerRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct TriggerSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by enabled status
pub enabled: Option<bool>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`TriggerRepository::list_search`].
#[derive(Debug)]
pub struct TriggerSearchResult {
pub rows: Vec<Trigger>,
pub total: u64,
}
// ============================================================================
// Sensor Search
// ============================================================================
/// Filters for [`SensorRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
#[derive(Debug, Clone, Default)]
pub struct SensorSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by trigger ID
pub trigger: Option<Id>,
/// Filter by enabled status
pub enabled: Option<bool>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`SensorRepository::list_search`].
#[derive(Debug)]
pub struct SensorSearchResult {
pub rows: Vec<Sensor>,
pub total: u64,
}
/// Repository for Trigger operations
pub struct TriggerRepository;
@@ -251,6 +301,68 @@ impl Delete for TriggerRepository {
}
impl TriggerRepository {
/// Search triggers with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(
db: E,
filters: &TriggerSearchFilters,
) -> Result<TriggerSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, enabled, param_schema, out_schema, webhook_enabled, webhook_key, webhook_config, is_adhoc, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM trigger"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM trigger");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Trigger> = qb.build_query_as().fetch_all(db).await?;
Ok(TriggerSearchResult { rows, total })
}
/// Find triggers by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<Trigger>>
where
@@ -795,6 +907,71 @@ impl Delete for SensorRepository {
}
impl SensorRepository {
/// Search sensors with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
pub async fn list_search<'e, E>(
db: E,
filters: &SensorSearchFilters,
) -> Result<SensorSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_ref, runtime_version_constraint, trigger, trigger_ref, enabled, param_schema, config, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM sensor"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM sensor");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(trigger_id) = filters.trigger {
push_condition!("trigger = ", trigger_id);
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY ref ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<Sensor> = qb.build_query_as().fetch_all(db).await?;
Ok(SensorSearchResult { rows, total })
}
/// Find sensors by trigger ID
pub async fn find_by_trigger<'e, E>(executor: E, trigger_id: Id) -> Result<Vec<Sensor>>
where

View File

@@ -6,6 +6,37 @@ use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
// ============================================================================
// Workflow Definition Search
// ============================================================================
/// Filters for [`WorkflowDefinitionRepository::list_search`].
///
/// All fields are optional and combinable (AND). Pagination is always applied.
/// Tag filtering uses `ANY(tags)` for each tag (OR across tags, AND with other filters).
#[derive(Debug, Clone, Default)]
pub struct WorkflowSearchFilters {
/// Filter by pack ID
pub pack: Option<Id>,
/// Filter by pack reference
pub pack_ref: Option<String>,
/// Filter by enabled status
pub enabled: Option<bool>,
/// Filter by tags (OR across tags — matches if any tag is present)
pub tags: Option<Vec<String>>,
/// Text search across label and description (case-insensitive substring)
pub search: Option<String>,
pub limit: u32,
pub offset: u32,
}
/// Result of [`WorkflowDefinitionRepository::list_search`].
#[derive(Debug)]
pub struct WorkflowSearchResult {
pub rows: Vec<WorkflowDefinition>,
pub total: u64,
}
// ============================================================================
// WORKFLOW DEFINITION REPOSITORY
// ============================================================================
@@ -226,6 +257,102 @@ impl Delete for WorkflowDefinitionRepository {
}
impl WorkflowDefinitionRepository {
/// Search workflow definitions with all filters pushed into SQL.
///
/// All filter fields are combinable (AND). Pagination is server-side.
/// Tags use an OR match — a workflow matches if it contains ANY of the
/// requested tags (via `tags && ARRAY[...]`).
pub async fn list_search<'e, E>(
db: E,
filters: &WorkflowSearchFilters,
) -> Result<WorkflowSearchResult>
where
E: Executor<'e, Database = Postgres> + Copy + 'e,
{
let select_cols = "id, ref, pack, pack_ref, label, description, version, param_schema, out_schema, definition, tags, enabled, created, updated";
let mut qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new(format!("SELECT {select_cols} FROM workflow_definition"));
let mut count_qb: QueryBuilder<'_, Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM workflow_definition");
let mut has_where = false;
macro_rules! push_condition {
($cond_prefix:expr, $value:expr) => {{
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push($cond_prefix);
qb.push_bind($value.clone());
count_qb.push($cond_prefix);
count_qb.push_bind($value);
}};
}
if let Some(pack_id) = filters.pack {
push_condition!("pack = ", pack_id);
}
if let Some(ref pack_ref) = filters.pack_ref {
push_condition!("pack_ref = ", pack_ref.clone());
}
if let Some(enabled) = filters.enabled {
push_condition!("enabled = ", enabled);
}
if let Some(ref tags) = filters.tags {
if !tags.is_empty() {
// Use PostgreSQL array overlap operator: tags && ARRAY[...]
push_condition!("tags && ", tags.clone());
}
}
if let Some(ref search) = filters.search {
let pattern = format!("%{}%", search.to_lowercase());
// Search needs an OR across multiple columns, wrapped in parens
if !has_where {
qb.push(" WHERE ");
count_qb.push(" WHERE ");
has_where = true;
} else {
qb.push(" AND ");
count_qb.push(" AND ");
}
qb.push("(LOWER(label) LIKE ");
qb.push_bind(pattern.clone());
qb.push(" OR LOWER(COALESCE(description, '')) LIKE ");
qb.push_bind(pattern.clone());
qb.push(")");
count_qb.push("(LOWER(label) LIKE ");
count_qb.push_bind(pattern.clone());
count_qb.push(" OR LOWER(COALESCE(description, '')) LIKE ");
count_qb.push_bind(pattern);
count_qb.push(")");
}
// Suppress unused-assignment warning from the macro's last expansion.
let _ = has_where;
// Count
let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?;
let total = total.max(0) as u64;
// Data query
qb.push(" ORDER BY label ASC");
qb.push(" LIMIT ");
qb.push_bind(filters.limit as i64);
qb.push(" OFFSET ");
qb.push_bind(filters.offset as i64);
let rows: Vec<WorkflowDefinition> = qb.build_query_as().fetch_all(db).await?;
Ok(WorkflowSearchResult { rows, total })
}
/// Find all workflows for a specific pack by pack ID
pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Vec<WorkflowDefinition>>
where

View File

@@ -0,0 +1,112 @@
//! # Expression AST
//!
//! Defines the abstract syntax tree nodes produced by the parser and consumed
//! by the evaluator.
use std::fmt;
/// A binary operator connecting two sub-expressions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOp {
// Arithmetic
Add,
Sub,
Mul,
Div,
Mod,
// Comparison
Eq,
Ne,
Lt,
Gt,
Le,
Ge,
// Logical
And,
Or,
// Membership
In,
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BinaryOp::Add => write!(f, "+"),
BinaryOp::Sub => write!(f, "-"),
BinaryOp::Mul => write!(f, "*"),
BinaryOp::Div => write!(f, "/"),
BinaryOp::Mod => write!(f, "%"),
BinaryOp::Eq => write!(f, "=="),
BinaryOp::Ne => write!(f, "!="),
BinaryOp::Lt => write!(f, "<"),
BinaryOp::Gt => write!(f, ">"),
BinaryOp::Le => write!(f, "<="),
BinaryOp::Ge => write!(f, ">="),
BinaryOp::And => write!(f, "and"),
BinaryOp::Or => write!(f, "or"),
BinaryOp::In => write!(f, "in"),
}
}
}
/// A unary operator applied to a single sub-expression.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOp {
/// Arithmetic negation: `-x`
Neg,
/// Logical negation: `not x`
Not,
}
impl fmt::Display for UnaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnaryOp::Neg => write!(f, "-"),
UnaryOp::Not => write!(f, "not"),
}
}
}
/// An expression AST node.
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
/// A literal JSON value: number, string, bool, or null.
Literal(serde_json::Value),
/// An array literal: `[expr, expr, ...]`
Array(Vec<Expr>),
/// A variable reference by name (e.g., `x`, `parameters`, `item`).
Ident(String),
/// Binary operation: `left op right`
BinaryOp {
op: BinaryOp,
left: Box<Expr>,
right: Box<Expr>,
},
/// Unary operation: `op operand`
UnaryOp {
op: UnaryOp,
operand: Box<Expr>,
},
/// Property access: `expr.field`
DotAccess {
object: Box<Expr>,
field: String,
},
/// Index/bracket access: `expr[index_expr]`
IndexAccess {
object: Box<Expr>,
index: Box<Expr>,
},
/// Function call: `name(arg1, arg2, ...)`
FunctionCall {
name: String,
args: Vec<Expr>,
},
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,545 @@
//! # Workflow Expression Engine
//!
//! A complete expression evaluator for workflow templates, supporting arithmetic,
//! comparison, boolean logic, member access, and built-in functions over JSON values.
//!
//! ## Architecture
//!
//! The engine is structured as a classic three-phase interpreter:
//!
//! 1. **Lexer** (`tokenizer.rs`) — converts expression strings into a stream of tokens
//! 2. **Parser** (`parser.rs`) — builds an AST from tokens using recursive descent
//! 3. **Evaluator** (`evaluator.rs`) — walks the AST and produces a `JsonValue` result
//!
//! ## Supported Operators
//!
//! ### Arithmetic
//! - `+` (addition for numbers, concatenation for strings)
//! - `-` (subtraction, unary negation)
//! - `*`, `/`, `%` (multiplication, division, modulo)
//!
//! ### Comparison
//! - `==`, `!=` (equality — works on all types, recursive for objects/arrays)
//! - `>`, `<`, `>=`, `<=` (ordering — numbers and strings only)
//! - Float/int comparisons allowed: `3 == 3.0` → true
//!
//! ### Boolean / Logical
//! - `and`, `or`, `not`
//!
//! ### Membership & Access
//! - `.` — object property access
//! - `[n]` — array index / object bracket access
//! - `in` — membership test (item in list, key in object, substring in string)
//!
//! ## Built-in Functions
//!
//! ### Type conversion
//! - `string(v)`, `number(v)`, `int(v)`, `bool(v)`
//!
//! ### Introspection
//! - `type_of(v)`, `length(v)`, `keys(obj)`, `values(obj)`
//!
//! ### Math
//! - `abs(n)`, `floor(n)`, `ceil(n)`, `round(n)`, `min(a,b)`, `max(a,b)`, `sum(arr)`
//!
//! ### String
//! - `lower(s)`, `upper(s)`, `trim(s)`, `split(s, sep)`, `join(arr, sep)`
//! - `replace(s, old, new)`, `starts_with(s, prefix)`, `ends_with(s, suffix)`
//! - `match(pattern, s)` — regex match
//!
//! ### Collection
//! - `contains(haystack, needle)`, `reversed(v)`, `sort(arr)`, `unique(arr)`
//! - `flat(arr)`, `zip(a, b)`, `range(n)` / `range(start, end)`
//!
//! ### Workflow-specific
//! - `result()`, `succeeded()`, `failed()`, `timed_out()`
mod ast;
mod evaluator;
mod parser;
mod tokenizer;
pub use ast::{BinaryOp, Expr, UnaryOp};
pub use evaluator::{is_truthy, EvalContext, EvalError, EvalResult};
pub use parser::{ParseError, Parser};
pub use tokenizer::{Token, TokenKind, Tokenizer};
use serde_json::Value as JsonValue;
/// Parse and evaluate an expression string against the given context.
///
/// This is the main entry point for the expression engine. It tokenizes the
/// input, parses it into an AST, and evaluates it to produce a `JsonValue`.
pub fn eval_expression(input: &str, ctx: &dyn EvalContext) -> EvalResult<JsonValue> {
let tokens = Tokenizer::new(input).tokenize().map_err(|e| {
EvalError::ParseError(format!("{}", e))
})?;
let ast = Parser::new(&tokens).parse().map_err(|e| {
EvalError::ParseError(format!("{}", e))
})?;
evaluator::eval(&ast, ctx)
}
/// Parse an expression string into an AST without evaluating it.
///
/// Useful for validation or inspection.
pub fn parse_expression(input: &str) -> Result<Expr, ParseError> {
let tokens = Tokenizer::new(input).tokenize().map_err(|e| {
ParseError::TokenError(format!("{}", e))
})?;
Parser::new(&tokens).parse()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::collections::HashMap;
/// A minimal eval context for integration tests.
struct TestContext {
variables: HashMap<String, JsonValue>,
}
impl TestContext {
fn new() -> Self {
Self {
variables: HashMap::new(),
}
}
fn with_var(mut self, name: &str, value: JsonValue) -> Self {
self.variables.insert(name.to_string(), value);
self
}
}
impl EvalContext for TestContext {
fn resolve_variable(&self, name: &str) -> EvalResult<JsonValue> {
self.variables
.get(name)
.cloned()
.ok_or_else(|| EvalError::VariableNotFound(name.to_string()))
}
fn call_workflow_function(
&self,
_name: &str,
_args: &[JsonValue],
) -> EvalResult<Option<JsonValue>> {
Ok(None)
}
}
// ---------------------------------------------------------------
// Arithmetic
// ---------------------------------------------------------------
#[test]
fn test_integer_arithmetic() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2 + 3", &ctx).unwrap(), json!(5));
assert_eq!(eval_expression("10 - 4", &ctx).unwrap(), json!(6));
assert_eq!(eval_expression("3 * 7", &ctx).unwrap(), json!(21));
assert_eq!(eval_expression("15 / 5", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("17 % 5", &ctx).unwrap(), json!(2));
}
#[test]
fn test_float_arithmetic() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2.5 + 1.5", &ctx).unwrap(), json!(4.0));
assert_eq!(eval_expression("10.0 / 3.0", &ctx).unwrap(), json!(10.0 / 3.0));
}
#[test]
fn test_mixed_int_float() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2 + 1.5", &ctx).unwrap(), json!(3.5));
// Integer division yields float when not evenly divisible
assert_eq!(eval_expression("10 / 4", &ctx).unwrap(), json!(2.5));
assert_eq!(eval_expression("10 / 4.0", &ctx).unwrap(), json!(2.5));
// Evenly divisible integer division stays integer
assert_eq!(eval_expression("10 / 5", &ctx).unwrap(), json!(2));
}
#[test]
fn test_operator_precedence() {
let ctx = TestContext::new();
assert_eq!(eval_expression("2 + 3 * 4", &ctx).unwrap(), json!(14));
assert_eq!(eval_expression("(2 + 3) * 4", &ctx).unwrap(), json!(20));
assert_eq!(eval_expression("10 - 2 * 3 + 1", &ctx).unwrap(), json!(5));
}
#[test]
fn test_unary_negation() {
let ctx = TestContext::new();
assert_eq!(eval_expression("-5", &ctx).unwrap(), json!(-5));
assert_eq!(eval_expression("-2 + 3", &ctx).unwrap(), json!(1));
assert_eq!(eval_expression("-(2 + 3)", &ctx).unwrap(), json!(-5));
}
#[test]
fn test_string_concatenation() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("\"hello\" + \" \" + \"world\"", &ctx).unwrap(),
json!("hello world")
);
}
// ---------------------------------------------------------------
// Comparison
// ---------------------------------------------------------------
#[test]
fn test_number_comparison() {
let ctx = TestContext::new();
assert_eq!(eval_expression("3 == 3", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 != 4", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 > 2", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 < 2", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("3 >= 3", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 <= 4", &ctx).unwrap(), json!(true));
}
#[test]
fn test_int_float_equality() {
let ctx = TestContext::new();
assert_eq!(eval_expression("3 == 3.0", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3.0 == 3", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("3 != 3.1", &ctx).unwrap(), json!(true));
}
#[test]
fn test_string_comparison() {
let ctx = TestContext::new();
assert_eq!(eval_expression("\"abc\" == \"abc\"", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("\"abc\" < \"abd\"", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("\"abc\" > \"abb\"", &ctx).unwrap(), json!(true));
}
#[test]
fn test_null_equality() {
let ctx = TestContext::new();
assert_eq!(eval_expression("null == null", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("null != null", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("null == 0", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("null == false", &ctx).unwrap(), json!(false));
}
#[test]
fn test_array_equality() {
let ctx = TestContext::new()
.with_var("a", json!([1, 2, 3]))
.with_var("b", json!([1, 2, 3]))
.with_var("c", json!([1, 2, 4]));
assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true));
}
#[test]
fn test_object_equality() {
let ctx = TestContext::new()
.with_var("a", json!({"x": 1, "y": 2}))
.with_var("b", json!({"y": 2, "x": 1}))
.with_var("c", json!({"x": 1, "y": 3}));
assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true));
}
// ---------------------------------------------------------------
// Boolean / Logical
// ---------------------------------------------------------------
#[test]
fn test_boolean_operators() {
let ctx = TestContext::new();
assert_eq!(eval_expression("true and true", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("true and false", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("false or true", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("false or false", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("not true", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("not false", &ctx).unwrap(), json!(true));
}
#[test]
fn test_boolean_precedence() {
let ctx = TestContext::new();
// `and` binds tighter than `or`
assert_eq!(
eval_expression("true or false and false", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("(true or false) and false", &ctx).unwrap(),
json!(false)
);
}
// ---------------------------------------------------------------
// Membership & access
// ---------------------------------------------------------------
#[test]
fn test_dot_access() {
let ctx = TestContext::new()
.with_var("obj", json!({"a": {"b": 42}}));
assert_eq!(eval_expression("obj.a.b", &ctx).unwrap(), json!(42));
}
#[test]
fn test_bracket_access() {
let ctx = TestContext::new()
.with_var("arr", json!([10, 20, 30]))
.with_var("obj", json!({"key": "value"}));
assert_eq!(eval_expression("arr[1]", &ctx).unwrap(), json!(20));
assert_eq!(eval_expression("obj[\"key\"]", &ctx).unwrap(), json!("value"));
}
#[test]
fn test_in_operator() {
let ctx = TestContext::new()
.with_var("arr", json!([1, 2, 3]))
.with_var("obj", json!({"key": "val"}));
assert_eq!(eval_expression("2 in arr", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("5 in arr", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("\"key\" in obj", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("\"nope\" in obj", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("\"ell\" in \"hello\"", &ctx).unwrap(), json!(true));
}
// ---------------------------------------------------------------
// Built-in functions
// ---------------------------------------------------------------
#[test]
fn test_length() {
let ctx = TestContext::new()
.with_var("arr", json!([1, 2, 3]))
.with_var("obj", json!({"a": 1, "b": 2}));
assert_eq!(eval_expression("length(arr)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("length(\"hello\")", &ctx).unwrap(), json!(5));
assert_eq!(eval_expression("length(obj)", &ctx).unwrap(), json!(2));
}
#[test]
fn test_type_conversions() {
let ctx = TestContext::new();
assert_eq!(eval_expression("string(42)", &ctx).unwrap(), json!("42"));
assert_eq!(eval_expression("number(\"3.14\")", &ctx).unwrap(), json!(3.14));
assert_eq!(eval_expression("int(3.9)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("int(\"42\")", &ctx).unwrap(), json!(42));
assert_eq!(eval_expression("bool(1)", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("bool(0)", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("bool(\"\")", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("bool(\"x\")", &ctx).unwrap(), json!(true));
}
#[test]
fn test_type_of() {
let ctx = TestContext::new()
.with_var("arr", json!([1]))
.with_var("obj", json!({}));
assert_eq!(eval_expression("type_of(42)", &ctx).unwrap(), json!("number"));
assert_eq!(eval_expression("type_of(\"hi\")", &ctx).unwrap(), json!("string"));
assert_eq!(eval_expression("type_of(true)", &ctx).unwrap(), json!("bool"));
assert_eq!(eval_expression("type_of(null)", &ctx).unwrap(), json!("null"));
assert_eq!(eval_expression("type_of(arr)", &ctx).unwrap(), json!("array"));
assert_eq!(eval_expression("type_of(obj)", &ctx).unwrap(), json!("object"));
}
#[test]
fn test_keys_values() {
let ctx = TestContext::new()
.with_var("obj", json!({"b": 2, "a": 1}));
let keys = eval_expression("sort(keys(obj))", &ctx).unwrap();
assert_eq!(keys, json!(["a", "b"]));
let values = eval_expression("sort(values(obj))", &ctx).unwrap();
assert_eq!(values, json!([1, 2]));
}
#[test]
fn test_math_functions() {
let ctx = TestContext::new();
assert_eq!(eval_expression("abs(-5)", &ctx).unwrap(), json!(5));
assert_eq!(eval_expression("floor(3.7)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("ceil(3.2)", &ctx).unwrap(), json!(4));
assert_eq!(eval_expression("round(3.5)", &ctx).unwrap(), json!(4));
assert_eq!(eval_expression("min(3, 7)", &ctx).unwrap(), json!(3));
assert_eq!(eval_expression("max(3, 7)", &ctx).unwrap(), json!(7));
assert_eq!(eval_expression("sum([1, 2, 3, 4])", &ctx).unwrap(), json!(10));
}
#[test]
fn test_string_functions() {
let ctx = TestContext::new();
assert_eq!(eval_expression("lower(\"HELLO\")", &ctx).unwrap(), json!("hello"));
assert_eq!(eval_expression("upper(\"hello\")", &ctx).unwrap(), json!("HELLO"));
assert_eq!(eval_expression("trim(\" hi \")", &ctx).unwrap(), json!("hi"));
assert_eq!(
eval_expression("replace(\"hello world\", \"world\", \"rust\")", &ctx).unwrap(),
json!("hello rust")
);
assert_eq!(
eval_expression("starts_with(\"hello\", \"hel\")", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("ends_with(\"hello\", \"llo\")", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("split(\"a,b,c\", \",\")", &ctx).unwrap(),
json!(["a", "b", "c"])
);
assert_eq!(
eval_expression("join([\"a\", \"b\", \"c\"], \",\")", &ctx).unwrap(),
json!("a,b,c")
);
}
#[test]
fn test_regex_match() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("match(\"^hello\", \"hello world\")", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("match(\"^world\", \"hello world\")", &ctx).unwrap(),
json!(false)
);
}
#[test]
fn test_collection_functions() {
let ctx = TestContext::new()
.with_var("arr", json!([3, 1, 2]));
assert_eq!(eval_expression("sort(arr)", &ctx).unwrap(), json!([1, 2, 3]));
assert_eq!(eval_expression("reversed(arr)", &ctx).unwrap(), json!([2, 1, 3]));
assert_eq!(
eval_expression("unique([1, 2, 2, 3, 1])", &ctx).unwrap(),
json!([1, 2, 3])
);
assert_eq!(
eval_expression("flat([[1, 2], [3, 4]])", &ctx).unwrap(),
json!([1, 2, 3, 4])
);
assert_eq!(
eval_expression("zip([1, 2], [\"a\", \"b\"])", &ctx).unwrap(),
json!([[1, "a"], [2, "b"]])
);
}
#[test]
fn test_range() {
let ctx = TestContext::new();
assert_eq!(eval_expression("range(5)", &ctx).unwrap(), json!([0, 1, 2, 3, 4]));
assert_eq!(eval_expression("range(2, 5)", &ctx).unwrap(), json!([2, 3, 4]));
}
#[test]
fn test_reversed_string() {
let ctx = TestContext::new();
assert_eq!(eval_expression("reversed(\"abc\")", &ctx).unwrap(), json!("cba"));
}
#[test]
fn test_contains_function() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("contains([1, 2, 3], 2)", &ctx).unwrap(),
json!(true)
);
assert_eq!(
eval_expression("contains(\"hello\", \"ell\")", &ctx).unwrap(),
json!(true)
);
}
// ---------------------------------------------------------------
// Complex expressions
// ---------------------------------------------------------------
#[test]
fn test_complex_expression() {
let ctx = TestContext::new()
.with_var("items", json!([1, 2, 3, 4, 5]));
assert_eq!(
eval_expression("length(items) > 3 and 5 in items", &ctx).unwrap(),
json!(true)
);
}
#[test]
fn test_chained_access() {
let ctx = TestContext::new()
.with_var("data", json!({"users": [{"name": "Alice"}, {"name": "Bob"}]}));
assert_eq!(
eval_expression("data.users[1].name", &ctx).unwrap(),
json!("Bob")
);
}
#[test]
fn test_ternary_via_boolean() {
let ctx = TestContext::new()
.with_var("x", json!(10));
// No ternary operator, but boolean expressions work for conditions
assert_eq!(
eval_expression("x > 5 and x < 20", &ctx).unwrap(),
json!(true)
);
}
#[test]
fn test_no_implicit_type_coercion() {
let ctx = TestContext::new();
// String + number should error, not silently coerce
assert!(eval_expression("\"hello\" + 5", &ctx).is_err());
// Comparing different types (other than int/float) should return false for ==
assert_eq!(eval_expression("\"3\" == 3", &ctx).unwrap(), json!(false));
}
#[test]
fn test_division_by_zero() {
let ctx = TestContext::new();
assert!(eval_expression("5 / 0", &ctx).is_err());
assert!(eval_expression("5 % 0", &ctx).is_err());
}
#[test]
fn test_array_literal() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("[1, 2, 3]", &ctx).unwrap(),
json!([1, 2, 3])
);
assert_eq!(
eval_expression("[\"a\", \"b\"]", &ctx).unwrap(),
json!(["a", "b"])
);
}
#[test]
fn test_nested_function_calls() {
let ctx = TestContext::new();
assert_eq!(
eval_expression("length(split(\"a,b,c\", \",\"))", &ctx).unwrap(),
json!(3)
);
assert_eq!(
eval_expression("join(sort([\"c\", \"a\", \"b\"]), \"-\")", &ctx).unwrap(),
json!("a-b-c")
);
}
#[test]
fn test_boolean_literals() {
let ctx = TestContext::new();
assert_eq!(eval_expression("true", &ctx).unwrap(), json!(true));
assert_eq!(eval_expression("false", &ctx).unwrap(), json!(false));
assert_eq!(eval_expression("null", &ctx).unwrap(), json!(null));
}
}

View File

@@ -0,0 +1,520 @@
//! # Expression Parser
//!
//! Recursive-descent parser that transforms a token stream into an AST.
//!
//! ## Operator Precedence (lowest to highest)
//!
//! 1. `or`
//! 2. `and`
//! 3. `not` (unary)
//! 4. `==`, `!=`, `<`, `>`, `<=`, `>=`, `in`
//! 5. `+`, `-` (addition / subtraction)
//! 6. `*`, `/`, `%`
//! 7. Unary `-`
//! 8. Postfix: `.field`, `[index]`, `(args)`
use super::ast::{BinaryOp, Expr, UnaryOp};
use super::tokenizer::{Token, TokenKind};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ParseError {
#[error("Unexpected token {0} at position {1}")]
UnexpectedToken(String, usize),
#[error("Expected {0}, found {1} at position {2}")]
Expected(String, String, usize),
#[error("Unexpected end of expression")]
UnexpectedEof,
#[error("Token error: {0}")]
TokenError(String),
}
/// The parser state.
pub struct Parser<'a> {
tokens: &'a [Token],
pos: usize,
}
impl<'a> Parser<'a> {
pub fn new(tokens: &'a [Token]) -> Self {
Self { tokens, pos: 0 }
}
/// Parse the token stream into a single expression AST.
pub fn parse(&mut self) -> Result<Expr, ParseError> {
let expr = self.parse_or()?;
// We should be at EOF now
if !self.at_end() {
let tok = self.peek();
return Err(ParseError::UnexpectedToken(
format!("{}", tok.kind),
tok.span.0,
));
}
Ok(expr)
}
// ----- Helpers -----
fn peek(&self) -> &Token {
&self.tokens[self.pos.min(self.tokens.len() - 1)]
}
fn at_end(&self) -> bool {
self.peek().kind == TokenKind::Eof
}
fn advance(&mut self) -> &Token {
let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)];
if self.pos < self.tokens.len() {
self.pos += 1;
}
tok
}
fn expect(&mut self, expected: &TokenKind) -> Result<&Token, ParseError> {
let tok = self.peek();
if std::mem::discriminant(&tok.kind) == std::mem::discriminant(expected) {
Ok(self.advance())
} else {
Err(ParseError::Expected(
format!("{}", expected),
format!("{}", tok.kind),
tok.span.0,
))
}
}
fn check(&self, kind: &TokenKind) -> bool {
std::mem::discriminant(&self.peek().kind) == std::mem::discriminant(kind)
}
// ----- Grammar rules -----
// or_expr = and_expr ( "or" and_expr )*
fn parse_or(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_and()?;
while self.peek().kind == TokenKind::Or {
self.advance();
let right = self.parse_and()?;
left = Expr::BinaryOp {
op: BinaryOp::Or,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// and_expr = not_expr ( "and" not_expr )*
fn parse_and(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_not()?;
while self.peek().kind == TokenKind::And {
self.advance();
let right = self.parse_not()?;
left = Expr::BinaryOp {
op: BinaryOp::And,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// not_expr = "not" not_expr | comparison
fn parse_not(&mut self) -> Result<Expr, ParseError> {
if self.peek().kind == TokenKind::Not {
self.advance();
let operand = self.parse_not()?;
return Ok(Expr::UnaryOp {
op: UnaryOp::Not,
operand: Box::new(operand),
});
}
self.parse_comparison()
}
// comparison = addition ( ("==" | "!=" | "<" | ">" | "<=" | ">=" | "in") addition )*
fn parse_comparison(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_addition()?;
loop {
let op = match self.peek().kind {
TokenKind::EqEq => BinaryOp::Eq,
TokenKind::BangEq => BinaryOp::Ne,
TokenKind::Lt => BinaryOp::Lt,
TokenKind::Gt => BinaryOp::Gt,
TokenKind::LtEq => BinaryOp::Le,
TokenKind::GtEq => BinaryOp::Ge,
TokenKind::In => BinaryOp::In,
_ => break,
};
self.advance();
let right = self.parse_addition()?;
left = Expr::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// addition = multiplication ( ("+" | "-") multiplication )*
fn parse_addition(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_multiplication()?;
loop {
let op = match self.peek().kind {
TokenKind::Plus => BinaryOp::Add,
TokenKind::Minus => BinaryOp::Sub,
_ => break,
};
self.advance();
let right = self.parse_multiplication()?;
left = Expr::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// multiplication = unary ( ("*" | "/" | "%") unary )*
fn parse_multiplication(&mut self) -> Result<Expr, ParseError> {
let mut left = self.parse_unary()?;
loop {
let op = match self.peek().kind {
TokenKind::Star => BinaryOp::Mul,
TokenKind::Slash => BinaryOp::Div,
TokenKind::Percent => BinaryOp::Mod,
_ => break,
};
self.advance();
let right = self.parse_unary()?;
left = Expr::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
};
}
Ok(left)
}
// unary = "-" unary | postfix
fn parse_unary(&mut self) -> Result<Expr, ParseError> {
if self.peek().kind == TokenKind::Minus {
self.advance();
let operand = self.parse_unary()?;
return Ok(Expr::UnaryOp {
op: UnaryOp::Neg,
operand: Box::new(operand),
});
}
self.parse_postfix()
}
// postfix = primary ( "." IDENT | "[" expr "]" | "(" args ")" )*
fn parse_postfix(&mut self) -> Result<Expr, ParseError> {
let mut expr = self.parse_primary()?;
loop {
match self.peek().kind {
TokenKind::Dot => {
self.advance();
// The field after dot
let tok = self.advance().clone();
let field = match &tok.kind {
TokenKind::Ident(name) => name.clone(),
// Allow keywords as field names (e.g., obj.in, obj.and)
TokenKind::And => "and".to_string(),
TokenKind::Or => "or".to_string(),
TokenKind::Not => "not".to_string(),
TokenKind::In => "in".to_string(),
TokenKind::True => "true".to_string(),
TokenKind::False => "false".to_string(),
TokenKind::Null => "null".to_string(),
_ => {
return Err(ParseError::Expected(
"identifier".to_string(),
format!("{}", tok.kind),
tok.span.0,
));
}
};
expr = Expr::DotAccess {
object: Box::new(expr),
field,
};
}
TokenKind::LBracket => {
self.advance();
let index = self.parse_or()?;
self.expect(&TokenKind::RBracket)?;
expr = Expr::IndexAccess {
object: Box::new(expr),
index: Box::new(index),
};
}
TokenKind::LParen => {
// Only if the expression so far is an identifier (function name)
// or a dot-access chain (method-like call).
// For now we handle Ident -> FunctionCall transformation.
if let Expr::Ident(name) = expr {
self.advance();
let args = self.parse_args()?;
self.expect(&TokenKind::RParen)?;
expr = Expr::FunctionCall { name, args };
} else {
break;
}
}
_ => break,
}
}
Ok(expr)
}
// args = ( expr ( "," expr )* )?
fn parse_args(&mut self) -> Result<Vec<Expr>, ParseError> {
let mut args = Vec::new();
if self.check(&TokenKind::RParen) {
return Ok(args);
}
args.push(self.parse_or()?);
while self.peek().kind == TokenKind::Comma {
self.advance();
args.push(self.parse_or()?);
}
Ok(args)
}
// primary = INTEGER | FLOAT | STRING | "true" | "false" | "null"
// | IDENT | "(" expr ")" | "[" elements "]"
fn parse_primary(&mut self) -> Result<Expr, ParseError> {
let tok = self.peek().clone();
match &tok.kind {
TokenKind::Integer(n) => {
let n = *n;
self.advance();
Ok(Expr::Literal(serde_json::json!(n)))
}
TokenKind::Float(f) => {
let f = *f;
self.advance();
Ok(Expr::Literal(serde_json::json!(f)))
}
TokenKind::StringLit(s) => {
let s = s.clone();
self.advance();
Ok(Expr::Literal(serde_json::Value::String(s)))
}
TokenKind::True => {
self.advance();
Ok(Expr::Literal(serde_json::json!(true)))
}
TokenKind::False => {
self.advance();
Ok(Expr::Literal(serde_json::json!(false)))
}
TokenKind::Null => {
self.advance();
Ok(Expr::Literal(serde_json::json!(null)))
}
TokenKind::Ident(name) => {
let name = name.clone();
self.advance();
Ok(Expr::Ident(name))
}
TokenKind::LParen => {
self.advance();
let expr = self.parse_or()?;
self.expect(&TokenKind::RParen)?;
Ok(expr)
}
TokenKind::LBracket => {
self.advance();
let mut elements = Vec::new();
if !self.check(&TokenKind::RBracket) {
elements.push(self.parse_or()?);
while self.peek().kind == TokenKind::Comma {
self.advance();
// Allow trailing comma
if self.check(&TokenKind::RBracket) {
break;
}
elements.push(self.parse_or()?);
}
}
self.expect(&TokenKind::RBracket)?;
Ok(Expr::Array(elements))
}
_ => Err(ParseError::UnexpectedToken(
format!("{}", tok.kind),
tok.span.0,
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::tokenizer::Tokenizer;
fn parse(input: &str) -> Expr {
let tokens = Tokenizer::new(input).tokenize().unwrap();
Parser::new(&tokens).parse().unwrap()
}
#[test]
fn test_simple_add() {
let ast = parse("2 + 3");
assert_eq!(
ast,
Expr::BinaryOp {
op: BinaryOp::Add,
left: Box::new(Expr::Literal(serde_json::json!(2))),
right: Box::new(Expr::Literal(serde_json::json!(3))),
}
);
}
#[test]
fn test_precedence() {
// 2 + 3 * 4 should parse as 2 + (3 * 4)
let ast = parse("2 + 3 * 4");
match ast {
Expr::BinaryOp {
op: BinaryOp::Add,
right,
..
} => {
assert!(matches!(
*right,
Expr::BinaryOp {
op: BinaryOp::Mul,
..
}
));
}
_ => panic!("Expected Add at top level"),
}
}
#[test]
fn test_function_call() {
let ast = parse("length(arr)");
assert_eq!(
ast,
Expr::FunctionCall {
name: "length".to_string(),
args: vec![Expr::Ident("arr".to_string())],
}
);
}
#[test]
fn test_dot_access() {
let ast = parse("obj.field.sub");
assert_eq!(
ast,
Expr::DotAccess {
object: Box::new(Expr::DotAccess {
object: Box::new(Expr::Ident("obj".to_string())),
field: "field".to_string(),
}),
field: "sub".to_string(),
}
);
}
#[test]
fn test_array_literal() {
let ast = parse("[1, 2, 3]");
assert_eq!(
ast,
Expr::Array(vec![
Expr::Literal(serde_json::json!(1)),
Expr::Literal(serde_json::json!(2)),
Expr::Literal(serde_json::json!(3)),
])
);
}
#[test]
fn test_bracket_access() {
let ast = parse("arr[0]");
assert_eq!(
ast,
Expr::IndexAccess {
object: Box::new(Expr::Ident("arr".to_string())),
index: Box::new(Expr::Literal(serde_json::json!(0))),
}
);
}
#[test]
fn test_not_operator() {
let ast = parse("not true");
assert_eq!(
ast,
Expr::UnaryOp {
op: UnaryOp::Not,
operand: Box::new(Expr::Literal(serde_json::json!(true))),
}
);
}
#[test]
fn test_in_operator() {
let ast = parse("x in arr");
assert_eq!(
ast,
Expr::BinaryOp {
op: BinaryOp::In,
left: Box::new(Expr::Ident("x".to_string())),
right: Box::new(Expr::Ident("arr".to_string())),
}
);
}
#[test]
fn test_complex_expression() {
// Should parse without error
let _ast = parse("length(items) > 3 and 5 in items");
}
#[test]
fn test_chained_access() {
// data.users[1].name
let _ast = parse("data.users[1].name");
}
#[test]
fn test_nested_function() {
let _ast = parse("length(split(\"a,b,c\", \",\"))");
}
#[test]
fn test_trailing_comma_in_array() {
let ast = parse("[1, 2, 3,]");
assert_eq!(
ast,
Expr::Array(vec![
Expr::Literal(serde_json::json!(1)),
Expr::Literal(serde_json::json!(2)),
Expr::Literal(serde_json::json!(3)),
])
);
}
}

View File

@@ -0,0 +1,512 @@
//! # Expression Tokenizer (Lexer)
//!
//! Converts an expression string into a sequence of tokens.
use std::fmt;
use thiserror::Error;
/// A token produced by the lexer.
#[derive(Debug, Clone, PartialEq)]
pub struct Token {
pub kind: TokenKind,
pub span: (usize, usize),
}
impl Token {
pub fn new(kind: TokenKind, start: usize, end: usize) -> Self {
Self {
kind,
span: (start, end),
}
}
}
/// The kind of a token.
#[derive(Debug, Clone, PartialEq)]
pub enum TokenKind {
// Literals
Integer(i64),
Float(f64),
StringLit(String),
True,
False,
Null,
// Identifier
Ident(String),
// Keywords (also parsed as identifiers initially, then classified)
And,
Or,
Not,
In,
// Operators
Plus,
Minus,
Star,
Slash,
Percent,
EqEq,
BangEq,
Lt,
Gt,
LtEq,
GtEq,
// Delimiters
LParen,
RParen,
LBracket,
RBracket,
Comma,
Dot,
// End of input
Eof,
}
impl fmt::Display for TokenKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TokenKind::Integer(n) => write!(f, "{}", n),
TokenKind::Float(n) => write!(f, "{}", n),
TokenKind::StringLit(s) => write!(f, "\"{}\"", s),
TokenKind::True => write!(f, "true"),
TokenKind::False => write!(f, "false"),
TokenKind::Null => write!(f, "null"),
TokenKind::Ident(s) => write!(f, "{}", s),
TokenKind::And => write!(f, "and"),
TokenKind::Or => write!(f, "or"),
TokenKind::Not => write!(f, "not"),
TokenKind::In => write!(f, "in"),
TokenKind::Plus => write!(f, "+"),
TokenKind::Minus => write!(f, "-"),
TokenKind::Star => write!(f, "*"),
TokenKind::Slash => write!(f, "/"),
TokenKind::Percent => write!(f, "%"),
TokenKind::EqEq => write!(f, "=="),
TokenKind::BangEq => write!(f, "!="),
TokenKind::Lt => write!(f, "<"),
TokenKind::Gt => write!(f, ">"),
TokenKind::LtEq => write!(f, "<="),
TokenKind::GtEq => write!(f, ">="),
TokenKind::LParen => write!(f, "("),
TokenKind::RParen => write!(f, ")"),
TokenKind::LBracket => write!(f, "["),
TokenKind::RBracket => write!(f, "]"),
TokenKind::Comma => write!(f, ","),
TokenKind::Dot => write!(f, "."),
TokenKind::Eof => write!(f, "EOF"),
}
}
}
#[derive(Debug, Error)]
pub enum TokenError {
#[error("Unexpected character '{0}' at position {1}")]
UnexpectedChar(char, usize),
#[error("Unterminated string literal starting at position {0}")]
UnterminatedString(usize),
#[error("Invalid number literal at position {0}: {1}")]
InvalidNumber(usize, String),
}
/// The tokenizer / lexer.
pub struct Tokenizer {
chars: Vec<char>,
pos: usize,
}
impl Tokenizer {
pub fn new(input: &str) -> Self {
Self {
chars: input.chars().collect(),
pos: 0,
}
}
/// Tokenize the entire input and return a vector of tokens.
pub fn tokenize(&mut self) -> Result<Vec<Token>, TokenError> {
let mut tokens = Vec::new();
loop {
let tok = self.next_token()?;
if tok.kind == TokenKind::Eof {
tokens.push(tok);
break;
}
tokens.push(tok);
}
Ok(tokens)
}
fn peek(&self) -> Option<char> {
self.chars.get(self.pos).copied()
}
fn advance(&mut self) -> Option<char> {
let ch = self.chars.get(self.pos).copied();
if ch.is_some() {
self.pos += 1;
}
ch
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.peek() {
if ch.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn next_token(&mut self) -> Result<Token, TokenError> {
self.skip_whitespace();
let start = self.pos;
let ch = match self.peek() {
Some(ch) => ch,
None => return Ok(Token::new(TokenKind::Eof, start, start)),
};
// Single-char and multi-char operators/delimiters
match ch {
'+' => {
self.advance();
Ok(Token::new(TokenKind::Plus, start, self.pos))
}
'-' => {
self.advance();
Ok(Token::new(TokenKind::Minus, start, self.pos))
}
'*' => {
self.advance();
Ok(Token::new(TokenKind::Star, start, self.pos))
}
'/' => {
self.advance();
Ok(Token::new(TokenKind::Slash, start, self.pos))
}
'%' => {
self.advance();
Ok(Token::new(TokenKind::Percent, start, self.pos))
}
'(' => {
self.advance();
Ok(Token::new(TokenKind::LParen, start, self.pos))
}
')' => {
self.advance();
Ok(Token::new(TokenKind::RParen, start, self.pos))
}
'[' => {
self.advance();
Ok(Token::new(TokenKind::LBracket, start, self.pos))
}
']' => {
self.advance();
Ok(Token::new(TokenKind::RBracket, start, self.pos))
}
',' => {
self.advance();
Ok(Token::new(TokenKind::Comma, start, self.pos))
}
'.' => {
self.advance();
Ok(Token::new(TokenKind::Dot, start, self.pos))
}
'=' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::EqEq, start, self.pos))
} else {
Err(TokenError::UnexpectedChar('=', start))
}
}
'!' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::BangEq, start, self.pos))
} else {
Err(TokenError::UnexpectedChar('!', start))
}
}
'<' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::LtEq, start, self.pos))
} else {
Ok(Token::new(TokenKind::Lt, start, self.pos))
}
}
'>' => {
self.advance();
if self.peek() == Some('=') {
self.advance();
Ok(Token::new(TokenKind::GtEq, start, self.pos))
} else {
Ok(Token::new(TokenKind::Gt, start, self.pos))
}
}
'"' | '\'' => self.read_string(ch),
c if c.is_ascii_digit() => self.read_number(),
c if c.is_ascii_alphabetic() || c == '_' => self.read_ident(),
other => Err(TokenError::UnexpectedChar(other, start)),
}
}
fn read_string(&mut self, quote: char) -> Result<Token, TokenError> {
let start = self.pos;
self.advance(); // consume opening quote
let mut s = String::new();
loop {
match self.advance() {
Some('\\') => {
// Escape sequence
match self.advance() {
Some('n') => s.push('\n'),
Some('t') => s.push('\t'),
Some('r') => s.push('\r'),
Some('\\') => s.push('\\'),
Some(c) if c == quote => s.push(c),
Some(c) => {
s.push('\\');
s.push(c);
}
None => return Err(TokenError::UnterminatedString(start)),
}
}
Some(c) if c == quote => {
return Ok(Token::new(TokenKind::StringLit(s), start, self.pos));
}
Some(c) => s.push(c),
None => return Err(TokenError::UnterminatedString(start)),
}
}
}
fn read_number(&mut self) -> Result<Token, TokenError> {
let start = self.pos;
let mut num_str = String::new();
let mut is_float = false;
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
num_str.push(ch);
self.advance();
} else if ch == '.' && !is_float {
// Check if this is a decimal point or a method call dot
// Look ahead to see if next char is a digit
let next_pos = self.pos + 1;
if next_pos < self.chars.len() && self.chars[next_pos].is_ascii_digit() {
is_float = true;
num_str.push(ch);
self.advance();
} else {
// It's a dot access, stop number parsing here
break;
}
} else {
break;
}
}
if is_float {
let val: f64 = num_str.parse().map_err(|_| {
TokenError::InvalidNumber(start, num_str.clone())
})?;
Ok(Token::new(TokenKind::Float(val), start, self.pos))
} else {
let val: i64 = num_str.parse().map_err(|_| {
TokenError::InvalidNumber(start, num_str.clone())
})?;
Ok(Token::new(TokenKind::Integer(val), start, self.pos))
}
}
fn read_ident(&mut self) -> Result<Token, TokenError> {
let start = self.pos;
let mut ident = String::new();
while let Some(ch) = self.peek() {
if ch.is_ascii_alphanumeric() || ch == '_' {
ident.push(ch);
self.advance();
} else {
break;
}
}
let kind = match ident.as_str() {
"true" => TokenKind::True,
"false" => TokenKind::False,
"null" => TokenKind::Null,
"and" => TokenKind::And,
"or" => TokenKind::Or,
"not" => TokenKind::Not,
"in" => TokenKind::In,
_ => TokenKind::Ident(ident),
};
Ok(Token::new(kind, start, self.pos))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tokenize(input: &str) -> Vec<TokenKind> {
let mut t = Tokenizer::new(input);
t.tokenize()
.unwrap()
.into_iter()
.map(|t| t.kind)
.collect()
}
#[test]
fn test_simple_expression() {
let kinds = tokenize("2 + 3");
assert_eq!(
kinds,
vec![
TokenKind::Integer(2),
TokenKind::Plus,
TokenKind::Integer(3),
TokenKind::Eof,
]
);
}
#[test]
fn test_comparison() {
let kinds = tokenize("x >= 10");
assert_eq!(
kinds,
vec![
TokenKind::Ident("x".to_string()),
TokenKind::GtEq,
TokenKind::Integer(10),
TokenKind::Eof,
]
);
}
#[test]
fn test_keywords() {
let kinds = tokenize("true and not false or null in x");
assert_eq!(
kinds,
vec![
TokenKind::True,
TokenKind::And,
TokenKind::Not,
TokenKind::False,
TokenKind::Or,
TokenKind::Null,
TokenKind::In,
TokenKind::Ident("x".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_string_literals() {
let kinds = tokenize("\"hello\" + 'world'");
assert_eq!(
kinds,
vec![
TokenKind::StringLit("hello".to_string()),
TokenKind::Plus,
TokenKind::StringLit("world".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_float() {
let kinds = tokenize("3.14");
assert_eq!(kinds, vec![TokenKind::Float(3.14), TokenKind::Eof]);
}
#[test]
fn test_dot_access() {
let kinds = tokenize("obj.field");
assert_eq!(
kinds,
vec![
TokenKind::Ident("obj".to_string()),
TokenKind::Dot,
TokenKind::Ident("field".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_function_call() {
let kinds = tokenize("length(arr)");
assert_eq!(
kinds,
vec![
TokenKind::Ident("length".to_string()),
TokenKind::LParen,
TokenKind::Ident("arr".to_string()),
TokenKind::RParen,
TokenKind::Eof,
]
);
}
#[test]
fn test_bracket_access() {
let kinds = tokenize("arr[0]");
assert_eq!(
kinds,
vec![
TokenKind::Ident("arr".to_string()),
TokenKind::LBracket,
TokenKind::Integer(0),
TokenKind::RBracket,
TokenKind::Eof,
]
);
}
#[test]
fn test_escape_sequences() {
let kinds = tokenize(r#""hello\nworld""#);
assert_eq!(
kinds,
vec![
TokenKind::StringLit("hello\nworld".to_string()),
TokenKind::Eof,
]
);
}
#[test]
fn test_integer_followed_by_dot() {
// `42.field` - the 42 is an integer and `.field` is separate
let kinds = tokenize("42.field");
assert_eq!(
kinds,
vec![
TokenKind::Integer(42),
TokenKind::Dot,
TokenKind::Ident("field".to_string()),
TokenKind::Eof,
]
);
}
}

View File

@@ -0,0 +1,674 @@
//! # Workflow Expression Validator
//!
//! Static validation of `{{ }}` template expressions in workflow definitions.
//! Catches syntax errors and unresolved variable references **before** the
//! workflow is saved, so users get immediate feedback instead of opaque
//! runtime failures during execution.
//!
//! ## What is validated
//!
//! 1. **Syntax** — every `{{ expr }}` block must parse successfully.
//! 2. **Variable references** — top-level identifiers that are not a known
//! namespace (`parameters`, `workflow`, `task`, `config`, `keystore`,
//! `item`, `index`, `system`, …) must exist in either the workflow's
//! `vars` map or its `param_schema` keys (bare-name fallback targets).
//!
//! ## What is NOT validated
//!
//! - **Type correctness** — e.g. whether `range(parameters.n)` actually
//! receives an integer. That requires runtime values.
//! - **Deep property paths** — e.g. `task.fetch.result.data`. We validate
//! that `task` is a known namespace but not that `fetch` is a real task
//! name (it might not exist yet at save time if tasks are re-ordered).
//! - **Function arity** — built-in functions are not checked for argument
//! count here; the evaluator already reports those errors at runtime.
use std::collections::HashSet;
use serde_json::Value as JsonValue;
use super::expression::{parse_expression, Expr, ParseError};
use super::parser::{PublishDirective, WorkflowDefinition};
// ───────────────────────────────────────────────────────────────────────────
// Public API
// ───────────────────────────────────────────────────────────────────────────
/// A single validation diagnostic.
#[derive(Debug, Clone)]
pub struct ExpressionWarning {
/// Human-readable location (e.g. `task 'sleep_1' with_items`).
pub location: String,
/// The raw template string that was checked.
pub expression: String,
/// What went wrong.
pub message: String,
}
impl std::fmt::Display for ExpressionWarning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {} — `{}`", self.location, self.message, self.expression)
}
}
/// Validate all template expressions in a workflow definition.
///
/// Returns an empty vec on success, or one [`ExpressionWarning`] per problem
/// found. The caller decides whether warnings are fatal (block save) or
/// advisory.
///
/// `param_schema` is the *flat-format* schema (`{ "url": { "type": "string" }, … }`)
/// passed alongside the definition in the save request. Its top-level keys
/// are the declared parameter names.
pub fn validate_workflow_expressions(
workflow: &WorkflowDefinition,
param_schema: Option<&JsonValue>,
) -> Vec<ExpressionWarning> {
let known_names = build_known_names(workflow, param_schema);
let mut warnings = Vec::new();
for task in &workflow.tasks {
let task_loc = format!("task '{}'", task.name);
// ── with_items expression ────────────────────────────────────
if let Some(ref expr) = task.with_items {
validate_template(
expr,
&format!("{task_loc} with_items"),
&known_names,
&mut warnings,
);
}
// ── task-level when condition ────────────────────────────────
if let Some(ref expr) = task.when {
validate_template(
expr,
&format!("{task_loc} when"),
&known_names,
&mut warnings,
);
}
// ── input templates ──────────────────────────────────────────
for (key, value) in &task.input {
collect_json_templates(
value,
&format!("{task_loc} input.{key}"),
&known_names,
&mut warnings,
);
}
// ── next transitions ─────────────────────────────────────────
for (ti, transition) in task.next.iter().enumerate() {
if let Some(ref when_expr) = transition.when {
validate_template(
when_expr,
&format!("{task_loc} next[{ti}].when"),
&known_names,
&mut warnings,
);
}
for directive in &transition.publish {
match directive {
PublishDirective::Simple(map) => {
for (pk, pv) in map {
validate_template(
pv,
&format!("{task_loc} next[{ti}].publish.{pk}"),
&known_names,
&mut warnings,
);
}
}
PublishDirective::Key(_) => { /* nothing to validate */ }
}
}
}
// ── legacy task-level publish ────────────────────────────────
for directive in &task.publish {
if let PublishDirective::Simple(map) = directive {
for (pk, pv) in map {
validate_template(
pv,
&format!("{task_loc} publish.{pk}"),
&known_names,
&mut warnings,
);
}
}
}
}
warnings
}
// ───────────────────────────────────────────────────────────────────────────
// Internals
// ───────────────────────────────────────────────────────────────────────────
/// Canonical namespace identifiers that are always valid as top-level names
/// inside `{{ }}` expressions. These are resolved by `WorkflowContext` at
/// runtime and never need to exist in `vars` or `param_schema`.
const CANONICAL_NAMESPACES: &[&str] = &[
"parameters",
"workflow",
"vars",
"variables",
"task",
"tasks",
"config",
"keystore",
"item",
"index",
"system",
];
/// Built-in constants that are valid bare identifiers.
const BUILTIN_LITERALS: &[&str] = &["true", "false", "null"];
/// Build the set of bare names that are valid in expressions:
/// canonical namespaces + workflow var names + param_schema keys.
fn build_known_names(
workflow: &WorkflowDefinition,
param_schema: Option<&JsonValue>,
) -> HashSet<String> {
let mut names: HashSet<String> = CANONICAL_NAMESPACES
.iter()
.map(|s| (*s).to_string())
.collect();
for lit in BUILTIN_LITERALS {
names.insert((*lit).to_string());
}
// Workflow vars
for key in workflow.vars.keys() {
names.insert(key.clone());
}
// Parameter schema keys (flat format: top-level keys are param names)
if let Some(JsonValue::Object(map)) = param_schema {
for key in map.keys() {
names.insert(key.clone());
}
}
// Also accept the workflow-level `parameters` schema if present on the
// definition itself (some loaders put it there).
if let Some(JsonValue::Object(ref map)) = workflow.parameters {
for key in map.keys() {
names.insert(key.clone());
}
}
names
}
/// Extract `{{ … }}` blocks from a template string and validate each one.
fn validate_template(
template: &str,
location: &str,
known_names: &HashSet<String>,
warnings: &mut Vec<ExpressionWarning>,
) {
for raw_expr in extract_expressions(template) {
let trimmed = raw_expr.trim();
if trimmed.is_empty() {
continue;
}
// Phase 1: parse
match parse_expression(trimmed) {
Err(e) => {
warnings.push(ExpressionWarning {
location: location.to_string(),
expression: raw_expr.to_string(),
message: format!("syntax error: {e}"),
});
}
Ok(ast) => {
// Phase 2: check bare-name references
let mut bare_idents = Vec::new();
collect_bare_idents(&ast, &mut bare_idents);
for ident in bare_idents {
if !known_names.contains(&ident) {
warnings.push(ExpressionWarning {
location: location.to_string(),
expression: raw_expr.to_string(),
message: format!(
"unknown variable '{}'. Use 'parameters.{}' for input \
parameters, or define it in workflow vars",
ident, ident,
),
});
}
}
}
}
}
}
/// Recursively walk a JSON value looking for string leaves that contain
/// `{{ }}` templates.
fn collect_json_templates(
value: &JsonValue,
location: &str,
known_names: &HashSet<String>,
warnings: &mut Vec<ExpressionWarning>,
) {
match value {
JsonValue::String(s) => {
validate_template(s, location, known_names, warnings);
}
JsonValue::Array(arr) => {
for (i, item) in arr.iter().enumerate() {
collect_json_templates(
item,
&format!("{location}[{i}]"),
known_names,
warnings,
);
}
}
JsonValue::Object(map) => {
for (key, val) in map {
collect_json_templates(
val,
&format!("{location}.{key}"),
known_names,
warnings,
);
}
}
_ => { /* numbers, bools, null — nothing to validate */ }
}
}
/// Extract the inner expression strings from all `{{ … }}` blocks in a
/// template. Handles nested braces conservatively (takes everything between
/// the outermost `{{` and `}}`).
fn extract_expressions(template: &str) -> Vec<&str> {
let mut results = Vec::new();
let mut rest = template;
while let Some(start) = rest.find("{{") {
let after_open = start + 2;
if let Some(end) = rest[after_open..].find("}}") {
results.push(&rest[after_open..after_open + end]);
rest = &rest[after_open + end + 2..];
} else {
// Unclosed `{{` — skip
break;
}
}
results
}
/// Collect bare `Ident` nodes that appear at the *top level* of an
/// expression — i.e. identifiers that are not the right-hand side of a
/// `.field` access (those are field names, not variable references).
///
/// For `DotAccess { object: Ident("parameters"), field: "n" }` we collect
/// `"parameters"` but NOT `"n"`.
///
/// For `FunctionCall { name: "range", args: [Ident("n")] }` we collect
/// `"n"` (it's a bare variable reference used as a function argument).
fn collect_bare_idents(expr: &Expr, out: &mut Vec<String>) {
match expr {
Expr::Ident(name) => {
out.push(name.clone());
}
Expr::Literal(_) => {}
Expr::Array(items) => {
for item in items {
collect_bare_idents(item, out);
}
}
Expr::BinaryOp { left, right, .. } => {
collect_bare_idents(left, out);
collect_bare_idents(right, out);
}
Expr::UnaryOp { operand, .. } => {
collect_bare_idents(operand, out);
}
Expr::DotAccess { object, .. } => {
// Only recurse into the object side — the field name is not a
// variable reference.
collect_bare_idents(object, out);
}
Expr::IndexAccess { object, index } => {
collect_bare_idents(object, out);
collect_bare_idents(index, out);
}
Expr::FunctionCall { args, .. } => {
// Function name itself is not a variable reference.
for arg in args {
collect_bare_idents(arg, out);
}
}
}
}
// ───────────────────────────────────────────────────────────────────────────
// Tests
// ───────────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn minimal_workflow(tasks: Vec<super::super::parser::Task>) -> WorkflowDefinition {
WorkflowDefinition {
r#ref: "test.wf".to_string(),
label: "Test".to_string(),
description: None,
version: "1.0.0".to_string(),
parameters: None,
output: None,
vars: HashMap::new(),
tasks,
output_map: None,
tags: vec![],
}
}
fn action_task(name: &str) -> super::super::parser::Task {
super::super::parser::Task {
name: name.to_string(),
r#type: super::super::parser::TaskType::Action,
action: Some("core.echo".to_string()),
input: HashMap::new(),
when: None,
with_items: None,
batch_size: None,
concurrency: None,
retry: None,
timeout: None,
next: vec![],
on_success: None,
on_failure: None,
on_complete: None,
on_timeout: None,
decision: vec![],
publish: vec![],
join: None,
tasks: None,
chart_meta: None,
}
}
// ── extract_expressions ──────────────────────────────────────────
#[test]
fn test_extract_single() {
let exprs = extract_expressions("{{ parameters.n }}");
assert_eq!(exprs, vec![" parameters.n "]);
}
#[test]
fn test_extract_multiple() {
let exprs = extract_expressions("Hello {{ name }}, you have {{ count }} items");
assert_eq!(exprs.len(), 2);
assert_eq!(exprs[0].trim(), "name");
assert_eq!(exprs[1].trim(), "count");
}
#[test]
fn test_extract_no_expressions() {
let exprs = extract_expressions("plain text");
assert!(exprs.is_empty());
}
#[test]
fn test_extract_unclosed() {
let exprs = extract_expressions("{{ oops");
assert!(exprs.is_empty());
}
// ── collect_bare_idents ──────────────────────────────────────────
#[test]
fn test_bare_ident() {
let ast = parse_expression("n").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["n"]);
}
#[test]
fn test_dot_access_does_not_collect_field() {
let ast = parse_expression("parameters.n").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["parameters"]);
}
#[test]
fn test_function_arg_collected() {
let ast = parse_expression("range(n)").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["n"]);
}
#[test]
fn test_nested_dot_access() {
let ast = parse_expression("task.fetch.result.data").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["task"]);
}
#[test]
fn test_binary_op() {
let ast = parse_expression("parameters.x + workflow.y").unwrap();
let mut idents = Vec::new();
collect_bare_idents(&ast, &mut idents);
assert_eq!(idents, vec!["parameters", "workflow"]);
}
// ── validate_workflow_expressions ─────────────────────────────────
#[test]
fn test_valid_workflow_no_warnings() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(parameters.n) }}".to_string());
task.input.insert(
"message".to_string(),
serde_json::json!("Hello {{ item }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_bare_name_from_vars_ok() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(n) }}".to_string());
let mut wf = minimal_workflow(vec![task]);
wf.vars.insert("n".to_string(), serde_json::json!(5));
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_bare_name_from_param_schema_ok() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(n) }}".to_string());
let wf = minimal_workflow(vec![task]);
let schema = serde_json::json!({
"n": { "type": "integer", "required": true }
});
let warnings = validate_workflow_expressions(&wf, Some(&schema));
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_unknown_bare_name_warning() {
let mut task = action_task("greet");
task.with_items = Some("{{ range(n) }}".to_string());
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("unknown variable 'n'"));
assert!(warnings[0].message.contains("parameters.n"));
assert!(warnings[0].location.contains("with_items"));
}
#[test]
fn test_syntax_error_warning() {
let mut task = action_task("greet");
task.input.insert(
"msg".to_string(),
serde_json::json!("{{ +++ }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("syntax error"));
}
#[test]
fn test_transition_when_validated() {
let mut task = action_task("step1");
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ bad_var > 3 }}".to_string()),
publish: vec![],
r#do: Some(vec!["step2".to_string()]),
chart_meta: None,
}];
let wf = minimal_workflow(vec![task, action_task("step2")]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("unknown variable 'bad_var'"));
assert!(warnings[0].location.contains("next[0].when"));
}
#[test]
fn test_transition_publish_validated() {
let mut task = action_task("step1");
let mut publish_map = HashMap::new();
publish_map.insert("out".to_string(), "{{ unknown_thing }}".to_string());
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ succeeded() }}".to_string()),
publish: vec![PublishDirective::Simple(publish_map)],
r#do: Some(vec!["step2".to_string()]),
chart_meta: None,
}];
let wf = minimal_workflow(vec![task, action_task("step2")]);
let warnings = validate_workflow_expressions(&wf, None);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].message.contains("unknown variable 'unknown_thing'"));
assert!(warnings[0].location.contains("publish.out"));
}
#[test]
fn test_workflow_functions_no_warning() {
// succeeded(), failed(), result() etc. are function calls,
// not variable references — should not produce warnings.
let mut task = action_task("step1");
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ succeeded() and result().code == 200 }}".to_string()),
publish: vec![],
r#do: Some(vec!["step2".to_string()]),
chart_meta: None,
}];
let wf = minimal_workflow(vec![task, action_task("step2")]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn test_plain_text_no_warning() {
let mut task = action_task("step1");
task.input.insert(
"msg".to_string(),
serde_json::json!("just plain text"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty());
}
#[test]
fn test_multiple_errors_collected() {
let mut task = action_task("step1");
task.with_items = Some("{{ range(a) }}".to_string());
task.input.insert(
"x".to_string(),
serde_json::json!("{{ b + c }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
// a, b, c are all unknown
assert_eq!(warnings.len(), 3);
let names: HashSet<_> = warnings
.iter()
.flat_map(|w| {
// extract the variable name from "unknown variable 'X'"
w.message
.strip_prefix("unknown variable '")
.and_then(|s| s.split('\'').next())
.map(|s| s.to_string())
})
.collect();
assert!(names.contains("a"));
assert!(names.contains("b"));
assert!(names.contains("c"));
}
#[test]
fn test_index_access_validated() {
let mut task = action_task("step1");
task.input.insert(
"val".to_string(),
serde_json::json!("{{ items[idx] }}"),
);
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
// Both `items` and `idx` are bare unknowns
assert_eq!(warnings.len(), 2);
}
#[test]
fn test_builtin_literals_ok() {
let mut task = action_task("step1");
task.next = vec![super::super::parser::TaskTransition {
when: Some("{{ true and not false }}".to_string()),
publish: vec![],
r#do: None,
chart_meta: None,
}];
let wf = minimal_workflow(vec![task]);
let warnings = validate_workflow_expressions(&wf, None);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
}

View File

@@ -3,6 +3,7 @@
//! This module provides utilities for loading, parsing, validating, and registering
//! workflow definitions from YAML files.
pub mod expression;
pub mod loader;
pub mod pack_service;
pub mod parser;

View File

@@ -356,7 +356,7 @@ async fn test_update_execution_status() {
status: Some(ExecutionStatus::Running),
result: None,
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -401,7 +401,7 @@ async fn test_update_execution_result() {
status: Some(ExecutionStatus::Completed),
result: Some(result_data.clone()),
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -445,7 +445,7 @@ async fn test_update_execution_executor() {
status: Some(ExecutionStatus::Scheduled),
result: None,
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -492,7 +492,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Scheduling),
result: None,
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -507,7 +507,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Scheduled),
result: None,
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -522,7 +522,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Running),
result: None,
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -537,7 +537,7 @@ async fn test_update_execution_status_transitions() {
status: Some(ExecutionStatus::Completed),
result: Some(json!({"success": true})),
executor: None,
workflow_task: None,
..Default::default()
},
)
.await
@@ -578,7 +578,7 @@ async fn test_update_execution_failed_status() {
status: Some(ExecutionStatus::Failed),
result: Some(json!({"error": "Connection timeout"})),
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -984,7 +984,7 @@ async fn test_execution_timestamps() {
status: Some(ExecutionStatus::Running),
result: None,
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)
@@ -1095,7 +1095,7 @@ async fn test_execution_result_json() {
status: Some(ExecutionStatus::Completed),
result: Some(complex_result.clone()),
executor: None,
workflow_task: None,
..Default::default()
};
let updated = ExecutionRepository::update(&pool, created.id, update)