From bbe94d75f89641f4c2356e02a23ba3c2fb695422 Mon Sep 17 00:00:00 2001 From: David Culbreth Date: Sun, 1 Mar 2026 20:43:48 -0600 Subject: [PATCH] proper sql filtering --- AGENTS.md | 94 +- crates/api/src/dto/event.rs | 4 + crates/api/src/dto/execution.rs | 35 + crates/api/src/routes/actions.rs | 50 +- crates/api/src/routes/events.rs | 103 +- crates/api/src/routes/executions.rs | 238 ++- crates/api/src/routes/inquiries.rs | 101 +- crates/api/src/routes/keys.rs | 38 +- crates/api/src/routes/rules.rs | 124 +- crates/api/src/routes/triggers.rs | 164 +- crates/api/src/routes/workflows.rs | 105 +- crates/common/src/models.rs | 5 + crates/common/src/repositories/action.rs | 106 ++ crates/common/src/repositories/event.rs | 190 +++ crates/common/src/repositories/execution.rs | 208 ++- crates/common/src/repositories/inquiry.rs | 81 + crates/common/src/repositories/key.rs | 77 + crates/common/src/repositories/rule.rs | 89 ++ crates/common/src/repositories/trigger.rs | 177 +++ crates/common/src/repositories/workflow.rs | 127 ++ crates/common/src/workflow/expression/ast.rs | 112 ++ .../src/workflow/expression/evaluator.rs | 1316 +++++++++++++++++ crates/common/src/workflow/expression/mod.rs | 545 +++++++ .../common/src/workflow/expression/parser.rs | 520 +++++++ .../src/workflow/expression/tokenizer.rs | 512 +++++++ .../src/workflow/expression_validator.rs | 674 +++++++++ crates/common/src/workflow/mod.rs | 1 + .../tests/execution_repository_tests.rs | 20 +- crates/executor/src/inquiry_handler.rs | 3 +- crates/executor/src/retry_manager.rs | 5 +- crates/executor/src/scheduler.rs | 73 +- crates/executor/src/workflow/context.rs | 823 ++++++++--- crates/worker/src/executor.rs | 17 +- docs/examples/complete-workflow.yaml | 62 +- docs/examples/simple-workflow.yaml | 22 +- ...0250101000005_execution_and_operations.sql | 1 + migrations/20250101000008_notify_triggers.sql | 2 + .../20250101000009_timescaledb_history.sql | 11 +- packs.external/nodejs_example | 2 +- packs/core/workflows/install_packs.yaml | 16 +- .../models/ApiResponse_ExecutionResponse.ts | 5 + web/src/api/models/ExecutionResponse.ts | 5 + web/src/api/models/ExecutionSummary.ts | 5 + .../PaginatedResponse_ExecutionSummary.ts | 5 + web/src/api/services/ExecutionsService.ts | 5 + .../components/common/WorkflowTasksPanel.tsx | 26 +- .../executions/ExecutionPreviewPanel.tsx | 75 +- .../executions/WorkflowExecutionTree.tsx | 2 + .../components/workflows/RunWorkflowModal.tsx | 195 +++ web/src/hooks/useExecutions.ts | 45 +- web/src/pages/actions/WorkflowBuilderPage.tsx | 132 +- web/src/pages/executions/ExecutionsPage.tsx | 61 + ...026-02-05-sql-side-filtering-pagination.md | 100 ++ work-summary/2026-02-28-expression-engine.md | 106 ++ 54 files changed, 6692 insertions(+), 928 deletions(-) create mode 100644 crates/common/src/workflow/expression/ast.rs create mode 100644 crates/common/src/workflow/expression/evaluator.rs create mode 100644 crates/common/src/workflow/expression/mod.rs create mode 100644 crates/common/src/workflow/expression/parser.rs create mode 100644 crates/common/src/workflow/expression/tokenizer.rs create mode 100644 crates/common/src/workflow/expression_validator.rs create mode 100644 web/src/components/workflows/RunWorkflowModal.tsx create mode 100644 work-summary/2026-02-05-sql-side-filtering-pagination.md create mode 100644 work-summary/2026-02-28-expression-engine.md diff --git a/AGENTS.md b/AGENTS.md index eb69f35..ee6cab6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -218,10 +218,10 @@ Completion listener advances workflow → Schedules successor tasks → Complete - **FK ON DELETE Policy**: Historical records (executions) use `ON DELETE SET NULL` so they survive entity deletion while preserving text ref fields (`action_ref`, `trigger_ref`, etc.) for auditing. The `event`, `enforcement`, and `execution` tables are TimescaleDB hypertables, so they **cannot be the target of FK constraints** — `enforcement.event`, `execution.enforcement`, `inquiry.execution`, `workflow_execution.execution`, `execution.parent`, and `execution.original_execution` are plain BIGINT columns (no FK) and may become dangling references if the referenced row is deleted. Pack-owned entities (actions, triggers, sensors, rules, runtimes) use `ON DELETE CASCADE` from pack. Workflow executions cascade-delete with their workflow definition. - **Event Table (TimescaleDB Hypertable)**: The `event` table is a TimescaleDB hypertable partitioned on `created` (1-day chunks). Events are **immutable after insert** — there is no `updated` column, no update trigger, and no `Update` repository impl. The `Event` model has no `updated` field. Compression is segmented by `trigger_ref` (after 7 days) and retention is 90 days. The `event_volume_hourly` continuous aggregate queries the `event` table directly. - **Enforcement Table (TimescaleDB Hypertable)**: The `enforcement` table is a TimescaleDB hypertable partitioned on `created` (1-day chunks). Enforcements are updated **exactly once** — the executor sets `status` from `created` to `processed` or `disabled` within ~1 second of creation, well before the 7-day compression window. The `resolved_at` column (nullable `TIMESTAMPTZ`) records when this transition occurred; it is `NULL` while status is `created`. There is no `updated` column. Compression is segmented by `rule_ref` (after 7 days) and retention is 90 days. The `enforcement_volume_hourly` continuous aggregate queries the `enforcement` table directly. -- **Execution Table (TimescaleDB Hypertable)**: The `execution` table is a TimescaleDB hypertable partitioned on `created` (1-day chunks). Executions are updated **~4 times** during their lifecycle (requested → scheduled → running → completed/failed), completing within at most ~1 day — well before the 7-day compression window. The `updated` column and its BEFORE UPDATE trigger are preserved (used by timeout monitor and UI). Compression is segmented by `action_ref` (after 7 days) and retention is 90 days. The `execution_volume_hourly` continuous aggregate queries the execution hypertable directly. The `execution_history` hypertable (field-level diffs) and its continuous aggregates (`execution_status_hourly`, `execution_throughput_hourly`) are preserved alongside — they serve complementary purposes (change tracking vs. volume monitoring). -- **Entity History Tracking (TimescaleDB)**: Append-only `_history` hypertables track field-level changes to `execution` and `worker` tables. Populated by PostgreSQL `AFTER INSERT OR UPDATE OR DELETE` triggers — no Rust code changes needed for recording. Uses JSONB diff format (`old_values`/`new_values`) with a `changed_fields TEXT[]` column for efficient filtering. Worker heartbeat-only updates are excluded. There are **no `event_history` or `enforcement_history` tables** — events are immutable and enforcements have a single deterministic status transition, so both tables are hypertables themselves. See `docs/plans/timescaledb-entity-history.md` for full design. +- **Execution Table (TimescaleDB Hypertable)**: The `execution` table is a TimescaleDB hypertable partitioned on `created` (1-day chunks). Executions are updated **~4 times** during their lifecycle (requested → scheduled → running → completed/failed), completing within at most ~1 day — well before the 7-day compression window. The `updated` column and its BEFORE UPDATE trigger are preserved (used by timeout monitor and UI). The `started_at` column (nullable `TIMESTAMPTZ`) records when the worker picked up the execution (status → `running`); it is `NULL` until then. **Duration** in the UI is computed as `updated - started_at` (not `updated - created`) so that queue/scheduling wait time is excluded. Compression is segmented by `action_ref` (after 7 days) and retention is 90 days. The `execution_volume_hourly` continuous aggregate queries the execution hypertable directly. The `execution_history` hypertable (field-level diffs) and its continuous aggregates (`execution_status_hourly`, `execution_throughput_hourly`) are preserved alongside — they serve complementary purposes (change tracking vs. volume monitoring). +- **Entity History Tracking (TimescaleDB)**: Append-only `
_history` hypertables track field-level changes to `execution` and `worker` tables. Populated by PostgreSQL `AFTER INSERT OR UPDATE OR DELETE` triggers — no Rust code changes needed for recording. Uses JSONB diff format (`old_values`/`new_values`) with a `changed_fields TEXT[]` column for efficient filtering. Worker heartbeat-only updates are excluded. There are **no `event_history` or `enforcement_history` tables** — events are immutable and enforcements have a single deterministic status transition, so both tables are hypertables themselves. See `docs/plans/timescaledb-entity-history.md` for full design. The execution history trigger tracks: `status`, `result`, `executor`, `workflow_task`, `env_vars`, `started_at`. - **History Large-Field Guardrails**: The `execution` history trigger stores a compact **digest summary** instead of the full value for the `result` column (which can be arbitrarily large). The digest is produced by the `_jsonb_digest_summary(JSONB)` helper function and has the shape `{"digest": "md5:", "size": , "type": ""}`. This preserves change-detection semantics while avoiding history table bloat. The full result is always available on the live `execution` row. When adding new large JSONB columns to history triggers, use `_jsonb_digest_summary()` instead of storing the raw value. -- **Nullable FK Fields**: `rule.action` and `rule.trigger` are nullable (`Option` in Rust) — a rule with NULL action/trigger is non-functional but preserved for traceability. `execution.action`, `execution.parent`, `execution.enforcement`, and `event.source` are also nullable. `enforcement.event` is nullable but has no FK constraint (event is a hypertable). `execution.enforcement` is nullable but has no FK constraint (enforcement is a hypertable). All FK columns on the execution table (`action`, `parent`, `original_execution`, `enforcement`, `executor`, `workflow_def`) have no FK constraints (execution is a hypertable). `inquiry.execution` and `workflow_execution.execution` also have no FK constraints. `enforcement.resolved_at` is nullable — `None` while status is `created`, set when resolved. +- **Nullable FK Fields**: `rule.action` and `rule.trigger` are nullable (`Option` in Rust) — a rule with NULL action/trigger is non-functional but preserved for traceability. `execution.action`, `execution.parent`, `execution.enforcement`, `execution.started_at`, and `event.source` are also nullable. `enforcement.event` is nullable but has no FK constraint (event is a hypertable). `execution.enforcement` is nullable but has no FK constraint (enforcement is a hypertable). All FK columns on the execution table (`action`, `parent`, `original_execution`, `enforcement`, `executor`, `workflow_def`) have no FK constraints (execution is a hypertable). `inquiry.execution` and `workflow_execution.execution` also have no FK constraints. `enforcement.resolved_at` is nullable — `None` while status is `created`, set when resolved. `execution.started_at` is nullable — `None` until the worker sets status to `running`. **Table Count**: 20 tables total in the schema (including `runtime_version`, 2 `*_history` hypertables, and the `event`, `enforcement`, + `execution` hypertables) **Migration Count**: 9 migrations (`000001` through `000009`) — see `migrations/` directory - **Pack Component Loading Order**: Runtimes → Triggers → Actions → Sensors (dependency order). Both `PackComponentLoader` (Rust) and `load_core_pack.py` (Python) follow this order. @@ -229,12 +229,12 @@ Completion listener advances workflow → Schedules successor tasks → Complete ### Workflow Execution Orchestration - **Detection**: The `ExecutionScheduler` checks `action.workflow_def.is_some()` before dispatching to a worker. Workflow actions are orchestrated by the executor, not sent to workers. - **Orchestration Flow**: Scheduler loads the `WorkflowDefinition`, builds a `TaskGraph`, creates a `workflow_execution` record, marks the parent execution as Running, builds an initial `WorkflowContext` from execution parameters and workflow vars, then dispatches entry-point tasks as child executions via MQ with rendered inputs. -- **Template Resolution**: Task inputs are rendered through `WorkflowContext.render_json()` before dispatching. Supports `{{ parameters.x }}`, `{{ item }}`, `{{ index }}`, `{{ number_list }}` (direct variable), `{{ task.task_name.field }}`, and function expressions. **Type-preserving**: pure template expressions like `"{{ item }}"` preserve the JSON type (integer `5` stays as `5`, not string `"5"`). Mixed expressions like `"Sleeping for {{ item }} seconds"` remain strings. +- **Template Resolution**: Task inputs are rendered through `WorkflowContext.render_json()` before dispatching. Uses the expression engine for full operator/function support inside `{{ }}`. Canonical namespaces: `parameters`, `workflow` (mutable vars), `task` (results), `config` (pack config), `keystore` (secrets), `item`, `index`, `system`. Backward-compat aliases: `vars`/`variables` → `workflow`, `tasks` → `task`, bare names → `workflow` fallback. **Type-preserving**: pure template expressions like `"{{ item }}"` preserve the JSON type (integer `5` stays as `5`, not string `"5"`). Mixed expressions like `"Sleeping for {{ item }} seconds"` remain strings. - **Function Expressions**: `{{ result() }}` returns the last completed task's result. `{{ result().field.subfield }}` navigates into it. `{{ succeeded() }}`, `{{ failed() }}`, `{{ timed_out() }}` return booleans. These are evaluated by `WorkflowContext.try_evaluate_function_call()`. -- **Publish Directives**: Transition `publish` directives (e.g., `number_list: "{{ result().data.items }}"`) are evaluated when a transition fires. Published variables are persisted to the `workflow_execution.variables` column and available to subsequent tasks. Uses type-preserving rendering so arrays/numbers/booleans retain their types. +- **Publish Directives**: Transition `publish` directives (e.g., `number_list: "{{ result().data.items }}"`) are evaluated when a transition fires. Published variables are persisted to the `workflow_execution.variables` column and available to subsequent tasks via the `workflow` namespace (e.g., `{{ workflow.number_list }}`). Uses type-preserving rendering so arrays/numbers/booleans retain their types. - **Child Task Dispatch**: Each workflow task becomes a child execution with the task's actual action ref (e.g., `core.echo`), `workflow_task` metadata linking it to the `workflow_execution` record, and a parent reference to the workflow execution. Child executions re-enter the normal scheduling pipeline, so nested workflows work recursively. - **with_items Expansion**: Tasks declaring `with_items: "{{ expr }}"` are expanded into child executions. The expression is resolved via the `WorkflowContext` to produce a JSON array, then each item gets its own child execution with `item`/`index` set on the context and `task_index` in `WorkflowTaskMetadata`. Completion tracking waits for ALL sibling items to finish before marking the task as completed/failed and advancing the workflow. -- **with_items Concurrency Limiting**: When a task declares `concurrency: N`, ALL child execution records are created in the database up front (with fully-rendered inputs), but only the first `N` are published to the message queue. The remaining children stay at `Requested` status in the DB. As each item completes, `advance_workflow` counts in-flight siblings (`scheduling`/`scheduled`/`running`), calculates free slots (`concurrency - in_flight`), and calls `publish_pending_with_items_children()` which queries for `Requested`-status siblings ordered by `task_index` and publishes them. The DB `status = 'requested'` query is the authoritative source of undispatched items — no auxiliary state in workflow variables needed. The task is only marked complete when all siblings reach a terminal state. Without a `concurrency` value, all items are dispatched at once (previous behavior). +- **with_items Concurrency Limiting**: ALL child execution records are created in the database up front (with fully-rendered inputs), but only the first `N` are published to the message queue where `N` is the task's `concurrency` value (**default: 1**, i.e. serial execution). The remaining children stay at `Requested` status in the DB. As each item completes, `advance_workflow` counts in-flight siblings (`scheduling`/`scheduled`/`running`), calculates free slots (`concurrency - in_flight`), and calls `publish_pending_with_items_children()` which queries for `Requested`-status siblings ordered by `task_index` and publishes them. The DB `status = 'requested'` query is the authoritative source of undispatched items — no auxiliary state in workflow variables needed. The task is only marked complete when all siblings reach a terminal state. To run all items in parallel, explicitly set `concurrency` to the list length or a suitably large number. - **Advancement**: The `CompletionListener` detects when a completed execution has `workflow_task` metadata and calls `ExecutionScheduler::advance_workflow()`. The scheduler rebuilds the `WorkflowContext` from persisted `workflow_execution.variables` plus all completed child execution results, sets `last_task_outcome`, evaluates transitions (succeeded/failed/always/timed_out/custom with context-based condition evaluation), processes publish directives, schedules successor tasks with rendered inputs, and completes the workflow when all tasks are done. - **Transition Evaluation**: `succeeded()`, `failed()`, `timed_out()`, and `always` (no condition) are supported. Custom conditions are evaluated via `WorkflowContext.evaluate_condition()` with fallback to fire-on-success if evaluation fails. - **Legacy Coordinator**: The prototype `WorkflowCoordinator` in `crates/executor/src/workflow/coordinator.rs` is bypassed — it has hardcoded schema prefixes and is not integrated with the MQ pipeline. @@ -316,7 +316,7 @@ Completion listener advances workflow → Schedules successor tasks → Complete - **Available at**: `http://localhost:8080` (dev), `/api-spec/openapi.json` for spec ### Common Library (`crates/common`) -- **Modules**: `models`, `repositories`, `db`, `config`, `error`, `mq`, `crypto`, `utils`, `workflow`, `pack_registry`, `template_resolver`, `version_matching`, `runtime_detection` +- **Modules**: `models`, `repositories`, `db`, `config`, `error`, `mq`, `crypto`, `utils`, `workflow` (includes `expression` sub-module), `pack_registry`, `template_resolver`, `version_matching`, `runtime_detection` - **Exports**: Commonly used types re-exported from `lib.rs` - **Repository Layer**: All DB access goes through repositories in `repositories/` - **Message Queue**: Abstractions in `mq/` for RabbitMQ communication @@ -338,6 +338,84 @@ Rule `action_params` support Jinja2-style `{{ source.path }}` templates resolved - **Integration**: `crates/executor/src/event_processor.rs` calls `resolve_templates()` in `create_enforcement()` - **IMPORTANT**: The old `trigger.payload.*` syntax was renamed to `event.payload.*` — the payload data comes from the Event, not the Trigger +### Workflow Expression Engine +Workflow templates (`{{ expr }}`) support a full expression language for evaluating conditions, computing values, and transforming data. The engine is in `crates/common/src/workflow/expression/` (tokenizer → parser → evaluator) and is integrated into `WorkflowContext` via the `EvalContext` trait. + +**Canonical Namespaces** — all data inside `{{ }}` expressions is organised into well-defined, non-overlapping namespaces: + +| Namespace | Example | Description | +|-----------|---------|-------------| +| `parameters` | `{{ parameters.url }}` | Immutable workflow input parameters | +| `workflow` | `{{ workflow.counter }}` | Mutable workflow-scoped variables (set via `publish`) | +| `task` | `{{ task.fetch.result.data }}` | Completed task results keyed by task name | +| `config` | `{{ config.api_token }}` | Pack configuration values (read-only) | +| `keystore` | `{{ keystore.secret_key }}` | Encrypted secrets from the key store (read-only) | +| `item` | `{{ item }}` / `{{ item.name }}` | Current element in a `with_items` loop | +| `index` | `{{ index }}` | Zero-based iteration index in a `with_items` loop | +| `system` | `{{ system.workflow_start }}` | System-provided variables | + +Backward-compatible aliases (kept for existing workflow definitions): +- `vars` / `variables` → same as `workflow` +- `tasks` → same as `task` +- Bare variable names (e.g. `{{ my_var }}`) resolve against the `workflow` variable store as a last-resort fallback. + +**IMPORTANT**: New workflow definitions should always use the canonical namespace names. The `config` and `keystore` namespaces are populated by the scheduler from the pack's `config` JSONB column and decrypted `key` table entries respectively. If not populated, they resolve to `null`. + +**Operators** (lowest to highest precedence): +1. `or` — logical OR (short-circuit) +2. `and` — logical AND (short-circuit) +3. `not` — logical NOT (unary) +4. `==`, `!=`, `<`, `>`, `<=`, `>=`, `in` — comparison & membership +5. `+`, `-` — addition/subtraction (also string/array concatenation for `+`) +6. `*`, `/`, `%` — multiplication, division, modulo +7. Unary `-` — negation +8. `.field`, `[index]`, `(args)` — postfix access & function calls + +**Type Rules**: +- **No implicit type coercion**: `"3" == 3` → `false`, `"hello" + 5` → error +- **Int/float cross-comparison allowed**: `3 == 3.0` → `true` +- **Integer preservation**: `2 + 3` → `5` (int), `2 + 1.5` → `3.5` (float), `10 / 4` → `2.5` (float), `10 / 5` → `2` (int) +- **Python-like truthiness**: `null`, `false`, `0`, `""`, `[]`, `{}` are falsy +- **Deep equality**: `==`/`!=` recursively compare objects and arrays +- **Negative indexing**: `arr[-1]` returns last element + +**Built-in Functions**: +- Type conversion: `string(v)`, `number(v)`, `int(v)`, `bool(v)` +- Introspection: `type_of(v)`, `length(v)`, `keys(obj)`, `values(obj)` +- Math: `abs(n)`, `floor(n)`, `ceil(n)`, `round(n)`, `min(a,b)`, `max(a,b)`, `sum(arr)` +- String: `lower(s)`, `upper(s)`, `trim(s)`, `split(s, sep)`, `join(arr, sep)`, `replace(s, old, new)`, `starts_with(s, prefix)`, `ends_with(s, suffix)`, `match(pattern, s)` (regex) +- Collection: `contains(haystack, needle)`, `reversed(v)`, `sort(arr)`, `unique(arr)`, `flat(arr)`, `zip(a, b)`, `range(n)` / `range(start, end)`, `slice(v, start, end)`, `index_of(haystack, needle)`, `count(haystack, needle)`, `merge(obj_a, obj_b)`, `chunks(arr, size)` +- Workflow: `result()`, `succeeded()`, `failed()`, `timed_out()` (resolved via `EvalContext` trait) + +**Usage in Conditions** (`when:` on transitions): +``` +when: "succeeded() and result().code == 200" +when: "length(workflow.items) > 3 and \"admin\" in workflow.roles" +when: "not failed()" +when: "result().status == \"ok\" or result().status == \"accepted\"" +when: "config.retries > 0" +``` + +**Usage in Templates** (`{{ expr }}`): +``` +input: + count: "{{ length(workflow.items) }}" + greeting: "{{ parameters.first + \" \" + parameters.last }}" + doubled: "{{ parameters.x * 2 }}" + names: "{{ join(sort(keys(workflow.data)), \", \") }}" + auth: "Bearer {{ keystore.api_key }}" + endpoint: "{{ config.base_url + \"/api/v1\" }}" + prev_output: "{{ task.fetch.result.data.id }}" +``` + +**Implementation Files**: +- `crates/common/src/workflow/expression/mod.rs` — module entry point, `eval_expression()`, `parse_expression()` +- `crates/common/src/workflow/expression/tokenizer.rs` — lexer +- `crates/common/src/workflow/expression/parser.rs` — recursive-descent parser +- `crates/common/src/workflow/expression/evaluator.rs` — AST evaluator, `EvalContext` trait, built-in functions +- `crates/common/src/workflow/expression/ast.rs` — AST node types (`Expr`, `BinaryOp`, `UnaryOp`) +- `crates/executor/src/workflow/context.rs` — `WorkflowContext` implements `EvalContext` + ### Web UI (`web/`) - **Generated Client**: OpenAPI client auto-generated from API spec - Run: `npm run generate:api` (requires API running on :8080) @@ -519,7 +597,7 @@ When reporting, ask: "Should I fix this first or continue with [original task]?" - **Web UI**: Static files served separately or via API service ## Current Development Status -- ✅ **Complete**: Database migrations (20 tables, 10 migration files), API service (most endpoints), common library, message queue infrastructure, repository layer, JWT auth, CLI tool, Web UI (basic + workflow builder), Executor service (core functionality + workflow orchestration), Worker service (shell/Python execution), Runtime version data model, constraint matching, worker version selection pipeline, version verification at startup, per-version environment isolation, TimescaleDB entity history tracking (execution, worker), Event, enforcement, and execution tables as TimescaleDB hypertables (time-series with retention/compression), History API endpoints (generic + entity-specific with pagination & filtering), History UI panels on entity detail pages (execution), TimescaleDB continuous aggregates (6 hourly rollup views with auto-refresh policies), Analytics API endpoints (7 endpoints under `/api/v1/analytics/` — dashboard, execution status/throughput/failure-rate, event volume, worker status, enforcement volume), Analytics dashboard widgets (bar charts, stacked status charts, failure rate ring gauge, time range selector), Workflow execution orchestration (scheduler detects workflow actions, creates child task executions, completion listener advances workflow via transitions), Workflow template resolution (type-preserving `{{ }}` rendering in task inputs), Workflow `with_items` expansion (parallel child executions per item), Workflow `with_items` concurrency limiting (sliding-window dispatch with pending items stored in workflow variables), Workflow `publish` directive processing (variable propagation between tasks), Workflow function expressions (`result()`, `succeeded()`, `failed()`, `timed_out()`) +- ✅ **Complete**: Database migrations (20 tables, 10 migration files), API service (most endpoints), common library, message queue infrastructure, repository layer, JWT auth, CLI tool, Web UI (basic + workflow builder), Executor service (core functionality + workflow orchestration), Worker service (shell/Python execution), Runtime version data model, constraint matching, worker version selection pipeline, version verification at startup, per-version environment isolation, TimescaleDB entity history tracking (execution, worker), Event, enforcement, and execution tables as TimescaleDB hypertables (time-series with retention/compression), History API endpoints (generic + entity-specific with pagination & filtering), History UI panels on entity detail pages (execution), TimescaleDB continuous aggregates (6 hourly rollup views with auto-refresh policies), Analytics API endpoints (7 endpoints under `/api/v1/analytics/` — dashboard, execution status/throughput/failure-rate, event volume, worker status, enforcement volume), Analytics dashboard widgets (bar charts, stacked status charts, failure rate ring gauge, time range selector), Workflow execution orchestration (scheduler detects workflow actions, creates child task executions, completion listener advances workflow via transitions), Workflow template resolution (type-preserving `{{ }}` rendering in task inputs), Workflow `with_items` expansion (parallel child executions per item), Workflow `with_items` concurrency limiting (sliding-window dispatch with pending items stored in workflow variables), Workflow `publish` directive processing (variable propagation between tasks), Workflow function expressions (`result()`, `succeeded()`, `failed()`, `timed_out()`), Workflow expression engine (full arithmetic/comparison/boolean/membership operators, 30+ built-in functions, recursive-descent parser), Canonical workflow namespaces (`parameters`, `workflow`, `task`, `config`, `keystore`, `item`, `index`, `system`) - 🔄 **In Progress**: Sensor service, advanced workflow features (nested workflow context propagation), Python runtime dependency management, API/UI endpoints for runtime version management - 📋 **Planned**: Notifier service, execution policies, monitoring, pack registry system, configurable retention periods via admin settings, export/archival to external storage diff --git a/crates/api/src/dto/event.rs b/crates/api/src/dto/event.rs index 085a2f7..ef405ce 100644 --- a/crates/api/src/dto/event.rs +++ b/crates/api/src/dto/event.rs @@ -319,6 +319,10 @@ pub struct EnforcementQueryParams { #[param(example = "core.webhook")] pub trigger_ref: Option, + /// Filter by rule reference + #[param(example = "core.on_webhook")] + pub rule_ref: Option, + /// Page number (1-indexed) #[serde(default = "default_page")] #[param(example = 1, minimum = 1)] diff --git a/crates/api/src/dto/execution.rs b/crates/api/src/dto/execution.rs index 2e0d030..a765858 100644 --- a/crates/api/src/dto/execution.rs +++ b/crates/api/src/dto/execution.rs @@ -7,6 +7,7 @@ use utoipa::{IntoParams, ToSchema}; use attune_common::models::enums::ExecutionStatus; use attune_common::models::execution::WorkflowTaskMetadata; +use attune_common::repositories::execution::ExecutionWithRefs; /// Request DTO for creating a manual execution #[derive(Debug, Clone, Deserialize, ToSchema)] @@ -63,6 +64,12 @@ pub struct ExecutionResponse { #[schema(value_type = Object, example = json!({"message_id": "1234567890.123456"}))] pub result: Option, + /// When the execution actually started running (worker picked it up). + /// Null if the execution hasn't started running yet. + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(example = "2024-01-13T10:31:00Z", nullable = true)] + pub started_at: Option>, + /// Workflow task metadata (only populated for workflow task executions) #[serde(skip_serializing_if = "Option::is_none")] #[schema(value_type = Option, nullable = true)] @@ -108,6 +115,12 @@ pub struct ExecutionSummary { #[schema(example = "core.timer")] pub trigger_ref: Option, + /// When the execution actually started running (worker picked it up). + /// Null if the execution hasn't started running yet. + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(example = "2024-01-13T10:31:00Z", nullable = true)] + pub started_at: Option>, + /// Workflow task metadata (only populated for workflow task executions) #[serde(skip_serializing_if = "Option::is_none")] #[schema(value_type = Option, nullable = true)] @@ -207,6 +220,7 @@ impl From for ExecutionResponse { result: execution .result .map(|r| serde_json::to_value(r).unwrap_or(JsonValue::Null)), + started_at: execution.started_at, workflow_task: execution.workflow_task, created: execution.created, updated: execution.updated, @@ -225,6 +239,7 @@ impl From for ExecutionSummary { enforcement: execution.enforcement, rule_ref: None, // Populated separately via enforcement lookup trigger_ref: None, // Populated separately via enforcement lookup + started_at: execution.started_at, workflow_task: execution.workflow_task, created: execution.created, updated: execution.updated, @@ -232,6 +247,26 @@ impl From for ExecutionSummary { } } +/// Convert from the joined query result (execution + enforcement refs). +/// `rule_ref` and `trigger_ref` are already populated from the SQL JOIN. +impl From for ExecutionSummary { + fn from(row: ExecutionWithRefs) -> Self { + Self { + id: row.id, + action_ref: row.action_ref, + status: row.status, + parent: row.parent, + enforcement: row.enforcement, + rule_ref: row.rule_ref, + trigger_ref: row.trigger_ref, + started_at: row.started_at, + workflow_task: row.workflow_task, + created: row.created, + updated: row.updated, + } + } +} + fn default_page() -> u32 { 1 } diff --git a/crates/api/src/routes/actions.rs b/crates/api/src/routes/actions.rs index f95d55b..d087557 100644 --- a/crates/api/src/routes/actions.rs +++ b/crates/api/src/routes/actions.rs @@ -11,10 +11,10 @@ use std::sync::Arc; use validator::Validate; use attune_common::repositories::{ - action::{ActionRepository, CreateActionInput, UpdateActionInput}, + action::{ActionRepository, ActionSearchFilters, CreateActionInput, UpdateActionInput}, pack::PackRepository, queue_stats::QueueStatsRepository, - Create, Delete, FindByRef, List, Update, + Create, Delete, FindByRef, Update, }; use crate::{ @@ -47,21 +47,20 @@ pub async fn list_actions( RequireAuth(_user): RequireAuth, Query(pagination): Query, ) -> ApiResult { - // Get all actions (we'll implement pagination in repository later) - let actions = ActionRepository::list(&state.db).await?; + // All filtering and pagination happen in a single SQL query. + let filters = ActionSearchFilters { + pack: None, + query: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = actions.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(actions.len()); + let result = ActionRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_actions: Vec = actions[start..end] - .iter() - .map(|a| ActionSummary::from(a.clone())) - .collect(); + let paginated_actions: Vec = + result.rows.into_iter().map(ActionSummary::from).collect(); - let response = PaginatedResponse::new(paginated_actions, &pagination, total); + let response = PaginatedResponse::new(paginated_actions, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -92,21 +91,20 @@ pub async fn list_actions_by_pack( .await? .ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?; - // Get actions for this pack - let actions = ActionRepository::find_by_pack(&state.db, pack.id).await?; + // All filtering and pagination happen in a single SQL query. + let filters = ActionSearchFilters { + pack: Some(pack.id), + query: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = actions.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(actions.len()); + let result = ActionRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_actions: Vec = actions[start..end] - .iter() - .map(|a| ActionSummary::from(a.clone())) - .collect(); + let paginated_actions: Vec = + result.rows.into_iter().map(ActionSummary::from).collect(); - let response = PaginatedResponse::new(paginated_actions, &pagination, total); + let response = PaginatedResponse::new(paginated_actions, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } diff --git a/crates/api/src/routes/events.rs b/crates/api/src/routes/events.rs index bf82b4b..a2bf704 100644 --- a/crates/api/src/routes/events.rs +++ b/crates/api/src/routes/events.rs @@ -16,9 +16,12 @@ use validator::Validate; use attune_common::{ mq::{EventCreatedPayload, MessageEnvelope, MessageType}, repositories::{ - event::{CreateEventInput, EnforcementRepository, EventRepository}, + event::{ + CreateEventInput, EnforcementRepository, EnforcementSearchFilters, EventRepository, + EventSearchFilters, + }, trigger::TriggerRepository, - Create, FindById, FindByRef, List, + Create, FindById, FindByRef, }, }; @@ -220,53 +223,27 @@ pub async fn list_events( State(state): State>, Query(query): Query, ) -> ApiResult { - // Get events based on filters - let events = if let Some(trigger_id) = query.trigger { - // Filter by trigger ID - EventRepository::find_by_trigger(&state.db, trigger_id).await? - } else if let Some(trigger_ref) = &query.trigger_ref { - // Filter by trigger reference - EventRepository::find_by_trigger_ref(&state.db, trigger_ref).await? - } else { - // Get all events - EventRepository::list(&state.db).await? + // All filtering and pagination happen in a single SQL query. + let filters = EventSearchFilters { + trigger: query.trigger, + trigger_ref: query.trigger_ref.clone(), + source: query.source, + rule_ref: query.rule_ref.clone(), + limit: query.limit(), + offset: query.offset(), }; - // Apply additional filters in memory - let mut filtered_events = events; + let result = EventRepository::search(&state.db, &filters).await?; - if let Some(source_id) = query.source { - filtered_events.retain(|e| e.source == Some(source_id)); - } + let paginated_events: Vec = + result.rows.into_iter().map(EventSummary::from).collect(); - if let Some(rule_ref) = &query.rule_ref { - let rule_ref_lower = rule_ref.to_lowercase(); - filtered_events.retain(|e| { - e.rule_ref - .as_ref() - .map(|r| r.to_lowercase().contains(&rule_ref_lower)) - .unwrap_or(false) - }); - } - - // Calculate pagination - let total = filtered_events.len() as u64; - let start = query.offset() as usize; - let end = (start + query.limit() as usize).min(filtered_events.len()); - - // Get paginated slice - let paginated_events: Vec = filtered_events[start..end] - .iter() - .map(|event| EventSummary::from(event.clone())) - .collect(); - - // Convert query params to pagination params for response let pagination_params = PaginationParams { page: query.page, page_size: query.per_page, }; - let response = PaginatedResponse::new(paginated_events, &pagination_params, total); + let response = PaginatedResponse::new(paginated_events, &pagination_params, result.total); Ok((StatusCode::OK, Json(response))) } @@ -319,46 +296,32 @@ pub async fn list_enforcements( State(state): State>, Query(query): Query, ) -> ApiResult { - // Get enforcements based on filters - let enforcements = if let Some(status) = query.status { - // Filter by status - EnforcementRepository::find_by_status(&state.db, status).await? - } else if let Some(rule_id) = query.rule { - // Filter by rule ID - EnforcementRepository::find_by_rule(&state.db, rule_id).await? - } else if let Some(event_id) = query.event { - // Filter by event ID - EnforcementRepository::find_by_event(&state.db, event_id).await? - } else { - // Get all enforcements - EnforcementRepository::list(&state.db).await? + // All filtering and pagination happen in a single SQL query. + // Filters are combinable (AND), not mutually exclusive. + let filters = EnforcementSearchFilters { + status: query.status, + rule: query.rule, + event: query.event, + trigger_ref: query.trigger_ref.clone(), + rule_ref: query.rule_ref.clone(), + limit: query.limit(), + offset: query.offset(), }; - // Apply additional filters in memory - let mut filtered_enforcements = enforcements; + let result = EnforcementRepository::search(&state.db, &filters).await?; - if let Some(trigger_ref) = &query.trigger_ref { - filtered_enforcements.retain(|e| e.trigger_ref == *trigger_ref); - } - - // Calculate pagination - let total = filtered_enforcements.len() as u64; - let start = query.offset() as usize; - let end = (start + query.limit() as usize).min(filtered_enforcements.len()); - - // Get paginated slice - let paginated_enforcements: Vec = filtered_enforcements[start..end] - .iter() - .map(|enforcement| EnforcementSummary::from(enforcement.clone())) + let paginated_enforcements: Vec = result + .rows + .into_iter() + .map(EnforcementSummary::from) .collect(); - // Convert query params to pagination params for response let pagination_params = PaginationParams { page: query.page, page_size: query.per_page, }; - let response = PaginatedResponse::new(paginated_enforcements, &pagination_params, total); + let response = PaginatedResponse::new(paginated_enforcements, &pagination_params, result.total); Ok((StatusCode::OK, Json(response))) } diff --git a/crates/api/src/routes/executions.rs b/crates/api/src/routes/executions.rs index 1b65b1e..85bc004 100644 --- a/crates/api/src/routes/executions.rs +++ b/crates/api/src/routes/executions.rs @@ -18,9 +18,10 @@ use attune_common::models::enums::ExecutionStatus; use attune_common::mq::{ExecutionRequestedPayload, MessageEnvelope, MessageType}; use attune_common::repositories::{ action::ActionRepository, - execution::{CreateExecutionInput, ExecutionRepository}, - Create, EnforcementRepository, FindById, FindByRef, List, + execution::{CreateExecutionInput, ExecutionRepository, ExecutionSearchFilters}, + Create, FindById, FindByRef, }; +use sqlx::Row; use crate::{ auth::middleware::RequireAuth, @@ -125,117 +126,37 @@ pub async fn list_executions( RequireAuth(_user): RequireAuth, Query(query): Query, ) -> ApiResult { - // Get executions based on filters - let executions = if let Some(status) = query.status { - // Filter by status - ExecutionRepository::find_by_status(&state.db, status).await? - } else if let Some(enforcement_id) = query.enforcement { - // Filter by enforcement - ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await? - } else { - // Get all executions - ExecutionRepository::list(&state.db).await? + // All filtering, pagination, and the enforcement JOIN happen in a single + // SQL query — no in-memory filtering or post-fetch lookups. + let filters = ExecutionSearchFilters { + status: query.status, + action_ref: query.action_ref.clone(), + pack_name: query.pack_name.clone(), + rule_ref: query.rule_ref.clone(), + trigger_ref: query.trigger_ref.clone(), + executor: query.executor, + result_contains: query.result_contains.clone(), + enforcement: query.enforcement, + parent: query.parent, + top_level_only: query.top_level_only == Some(true), + limit: query.limit(), + offset: query.offset(), }; - // Apply additional filters in memory (could be optimized with database queries) - let mut filtered_executions = executions; + let result = ExecutionRepository::search(&state.db, &filters).await?; - if let Some(action_ref) = &query.action_ref { - filtered_executions.retain(|e| e.action_ref == *action_ref); - } - - if let Some(pack_name) = &query.pack_name { - filtered_executions.retain(|e| { - // action_ref format is "pack.action" - e.action_ref.starts_with(&format!("{}.", pack_name)) - }); - } - - if let Some(result_search) = &query.result_contains { - let search_lower = result_search.to_lowercase(); - filtered_executions.retain(|e| { - if let Some(result) = &e.result { - // Convert result to JSON string and search case-insensitively - let result_str = serde_json::to_string(result).unwrap_or_default(); - result_str.to_lowercase().contains(&search_lower) - } else { - false - } - }); - } - - if let Some(parent_id) = query.parent { - filtered_executions.retain(|e| e.parent == Some(parent_id)); - } - - if query.top_level_only == Some(true) { - filtered_executions.retain(|e| e.parent.is_none()); - } - - if let Some(executor_id) = query.executor { - filtered_executions.retain(|e| e.executor == Some(executor_id)); - } - - // Fetch enforcements for all executions to populate rule_ref and trigger_ref - let enforcement_ids: Vec = filtered_executions - .iter() - .filter_map(|e| e.enforcement) + let paginated_executions: Vec = result + .rows + .into_iter() + .map(ExecutionSummary::from) .collect(); - let enforcement_map: std::collections::HashMap = if !enforcement_ids.is_empty() { - let enforcements = EnforcementRepository::list(&state.db).await?; - enforcements.into_iter().map(|enf| (enf.id, enf)).collect() - } else { - std::collections::HashMap::new() - }; - - // Filter by rule_ref if specified - if let Some(rule_ref) = &query.rule_ref { - filtered_executions.retain(|e| { - e.enforcement - .and_then(|enf_id| enforcement_map.get(&enf_id)) - .map(|enf| enf.rule_ref == *rule_ref) - .unwrap_or(false) - }); - } - - // Filter by trigger_ref if specified - if let Some(trigger_ref) = &query.trigger_ref { - filtered_executions.retain(|e| { - e.enforcement - .and_then(|enf_id| enforcement_map.get(&enf_id)) - .map(|enf| enf.trigger_ref == *trigger_ref) - .unwrap_or(false) - }); - } - - // Calculate pagination - let total = filtered_executions.len() as u64; - let start = query.offset() as usize; - let end = (start + query.limit() as usize).min(filtered_executions.len()); - - // Get paginated slice and populate rule_ref/trigger_ref from enforcements - let paginated_executions: Vec = filtered_executions[start..end] - .iter() - .map(|e| { - let mut summary = ExecutionSummary::from(e.clone()); - if let Some(enf_id) = e.enforcement { - if let Some(enforcement) = enforcement_map.get(&enf_id) { - summary.rule_ref = Some(enforcement.rule_ref.clone()); - summary.trigger_ref = Some(enforcement.trigger_ref.clone()); - } - } - summary - }) - .collect(); - - // Convert query params to pagination params for response let pagination_params = PaginationParams { page: query.page, page_size: query.per_page, }; - let response = PaginatedResponse::new(paginated_executions, &pagination_params, total); + let response = PaginatedResponse::new(paginated_executions, &pagination_params, result.total); Ok((StatusCode::OK, Json(response))) } @@ -310,21 +231,23 @@ pub async fn list_executions_by_status( } }; - // Get executions by status - let executions = ExecutionRepository::find_by_status(&state.db, status).await?; + // Use the search method for SQL-side filtering + pagination. + let filters = ExecutionSearchFilters { + status: Some(status), + limit: pagination.limit(), + offset: pagination.offset(), + ..Default::default() + }; - // Calculate pagination - let total = executions.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(executions.len()); + let result = ExecutionRepository::search(&state.db, &filters).await?; - // Get paginated slice - let paginated_executions: Vec = executions[start..end] - .iter() - .map(|e| ExecutionSummary::from(e.clone())) + let paginated_executions: Vec = result + .rows + .into_iter() + .map(ExecutionSummary::from) .collect(); - let response = PaginatedResponse::new(paginated_executions, &pagination, total); + let response = PaginatedResponse::new(paginated_executions, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -350,21 +273,23 @@ pub async fn list_executions_by_enforcement( Path(enforcement_id): Path, Query(pagination): Query, ) -> ApiResult { - // Get executions by enforcement - let executions = ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?; + // Use the search method for SQL-side filtering + pagination. + let filters = ExecutionSearchFilters { + enforcement: Some(enforcement_id), + limit: pagination.limit(), + offset: pagination.offset(), + ..Default::default() + }; - // Calculate pagination - let total = executions.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(executions.len()); + let result = ExecutionRepository::search(&state.db, &filters).await?; - // Get paginated slice - let paginated_executions: Vec = executions[start..end] - .iter() - .map(|e| ExecutionSummary::from(e.clone())) + let paginated_executions: Vec = result + .rows + .into_iter() + .map(ExecutionSummary::from) .collect(); - let response = PaginatedResponse::new(paginated_executions, &pagination, total); + let response = PaginatedResponse::new(paginated_executions, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -384,34 +309,37 @@ pub async fn get_execution_stats( State(state): State>, RequireAuth(_user): RequireAuth, ) -> ApiResult { - // Get all executions (limited by repository to 1000) - let executions = ExecutionRepository::list(&state.db).await?; + // Use a single SQL query with COUNT + GROUP BY instead of fetching all rows. + let rows = sqlx::query( + "SELECT status::text AS status, COUNT(*) AS cnt FROM execution GROUP BY status", + ) + .fetch_all(&state.db) + .await?; - // Calculate statistics - let total = executions.len(); - let completed = executions - .iter() - .filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Completed) - .count(); - let failed = executions - .iter() - .filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Failed) - .count(); - let running = executions - .iter() - .filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Running) - .count(); - let pending = executions - .iter() - .filter(|e| { - matches!( - e.status, - attune_common::models::enums::ExecutionStatus::Requested - | attune_common::models::enums::ExecutionStatus::Scheduling - | attune_common::models::enums::ExecutionStatus::Scheduled - ) - }) - .count(); + let mut completed: i64 = 0; + let mut failed: i64 = 0; + let mut running: i64 = 0; + let mut pending: i64 = 0; + let mut cancelled: i64 = 0; + let mut timeout: i64 = 0; + let mut abandoned: i64 = 0; + let mut total: i64 = 0; + + for row in &rows { + let status: &str = row.get("status"); + let cnt: i64 = row.get("cnt"); + total += cnt; + match status { + "completed" => completed = cnt, + "failed" => failed = cnt, + "running" => running = cnt, + "requested" | "scheduling" | "scheduled" => pending += cnt, + "cancelled" | "canceling" => cancelled += cnt, + "timeout" => timeout = cnt, + "abandoned" => abandoned = cnt, + _ => {} + } + } let stats = serde_json::json!({ "total": total, @@ -419,9 +347,9 @@ pub async fn get_execution_stats( "failed": failed, "running": running, "pending": pending, - "cancelled": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Cancelled).count(), - "timeout": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Timeout).count(), - "abandoned": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Abandoned).count(), + "cancelled": cancelled, + "timeout": timeout, + "abandoned": abandoned, }); let response = ApiResponse::new(stats); diff --git a/crates/api/src/routes/inquiries.rs b/crates/api/src/routes/inquiries.rs index e13b548..db9024c 100644 --- a/crates/api/src/routes/inquiries.rs +++ b/crates/api/src/routes/inquiries.rs @@ -14,8 +14,10 @@ use attune_common::{ mq::{InquiryRespondedPayload, MessageEnvelope, MessageType}, repositories::{ execution::ExecutionRepository, - inquiry::{CreateInquiryInput, InquiryRepository, UpdateInquiryInput}, - Create, Delete, FindById, List, Update, + inquiry::{ + CreateInquiryInput, InquiryRepository, InquirySearchFilters, UpdateInquiryInput, + }, + Create, Delete, FindById, Update, }, }; @@ -51,45 +53,30 @@ pub async fn list_inquiries( State(state): State>, Query(query): Query, ) -> ApiResult { - // Get inquiries based on filters - let inquiries = if let Some(status) = query.status { - // Filter by status - InquiryRepository::find_by_status(&state.db, status).await? - } else if let Some(execution_id) = query.execution { - // Filter by execution - InquiryRepository::find_by_execution(&state.db, execution_id).await? - } else { - // Get all inquiries - InquiryRepository::list(&state.db).await? + // All filtering and pagination happen in a single SQL query. + // Filters are combinable (AND), not mutually exclusive. + let limit = query.limit.unwrap_or(50).min(500) as u32; + let offset = query.offset.unwrap_or(0) as u32; + + let filters = InquirySearchFilters { + status: query.status, + execution: query.execution, + assigned_to: query.assigned_to, + limit, + offset, }; - // Apply additional filters in memory - let mut filtered_inquiries = inquiries; + let result = InquiryRepository::search(&state.db, &filters).await?; - if let Some(assigned_to) = query.assigned_to { - filtered_inquiries.retain(|i| i.assigned_to == Some(assigned_to)); - } + let paginated_inquiries: Vec = + result.rows.into_iter().map(InquirySummary::from).collect(); - // Calculate pagination - let total = filtered_inquiries.len() as u64; - let offset = query.offset.unwrap_or(0); - let limit = query.limit.unwrap_or(50).min(500); - let start = offset; - let end = (start + limit).min(filtered_inquiries.len()); - - // Get paginated slice - let paginated_inquiries: Vec = filtered_inquiries[start..end] - .iter() - .map(|inquiry| InquirySummary::from(inquiry.clone())) - .collect(); - - // Convert to pagination params for response let pagination_params = PaginationParams { - page: (offset / limit.max(1)) as u32 + 1, - page_size: limit as u32, + page: (offset / limit.max(1)) + 1, + page_size: limit, }; - let response = PaginatedResponse::new(paginated_inquiries, &pagination_params, total); + let response = PaginatedResponse::new(paginated_inquiries, &pagination_params, result.total); Ok((StatusCode::OK, Json(response))) } @@ -161,20 +148,21 @@ pub async fn list_inquiries_by_status( } }; - let inquiries = InquiryRepository::find_by_status(&state.db, status).await?; + // Use the search method for SQL-side filtering + pagination. + let filters = InquirySearchFilters { + status: Some(status), + execution: None, + assigned_to: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = inquiries.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(inquiries.len()); + let result = InquiryRepository::search(&state.db, &filters).await?; - // Get paginated slice - let paginated_inquiries: Vec = inquiries[start..end] - .iter() - .map(|inquiry| InquirySummary::from(inquiry.clone())) - .collect(); + let paginated_inquiries: Vec = + result.rows.into_iter().map(InquirySummary::from).collect(); - let response = PaginatedResponse::new(paginated_inquiries, &pagination, total); + let response = PaginatedResponse::new(paginated_inquiries, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -209,20 +197,21 @@ pub async fn list_inquiries_by_execution( ApiError::NotFound(format!("Execution with ID {} not found", execution_id)) })?; - let inquiries = InquiryRepository::find_by_execution(&state.db, execution_id).await?; + // Use the search method for SQL-side filtering + pagination. + let filters = InquirySearchFilters { + status: None, + execution: Some(execution_id), + assigned_to: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = inquiries.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(inquiries.len()); + let result = InquiryRepository::search(&state.db, &filters).await?; - // Get paginated slice - let paginated_inquiries: Vec = inquiries[start..end] - .iter() - .map(|inquiry| InquirySummary::from(inquiry.clone())) - .collect(); + let paginated_inquiries: Vec = + result.rows.into_iter().map(InquirySummary::from).collect(); - let response = PaginatedResponse::new(paginated_inquiries, &pagination, total); + let response = PaginatedResponse::new(paginated_inquiries, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } diff --git a/crates/api/src/routes/keys.rs b/crates/api/src/routes/keys.rs index 1d42e29..7163072 100644 --- a/crates/api/src/routes/keys.rs +++ b/crates/api/src/routes/keys.rs @@ -13,10 +13,10 @@ use validator::Validate; use attune_common::models::OwnerType; use attune_common::repositories::{ action::ActionRepository, - key::{CreateKeyInput, KeyRepository, UpdateKeyInput}, + key::{CreateKeyInput, KeyRepository, KeySearchFilters, UpdateKeyInput}, pack::PackRepository, trigger::SensorRepository, - Create, Delete, FindByRef, List, Update, + Create, Delete, FindByRef, Update, }; use crate::auth::RequireAuth; @@ -46,40 +46,24 @@ pub async fn list_keys( State(state): State>, Query(query): Query, ) -> ApiResult { - // Get keys based on filters - let keys = if let Some(owner_type) = query.owner_type { - // Filter by owner type - KeyRepository::find_by_owner_type(&state.db, owner_type).await? - } else { - // Get all keys - KeyRepository::list(&state.db).await? + // All filtering and pagination happen in a single SQL query. + let filters = KeySearchFilters { + owner_type: query.owner_type, + owner: query.owner.clone(), + limit: query.limit(), + offset: query.offset(), }; - // Apply additional filters in memory - let mut filtered_keys = keys; + let result = KeyRepository::search(&state.db, &filters).await?; - if let Some(owner) = &query.owner { - filtered_keys.retain(|k| k.owner.as_ref() == Some(owner)); - } + let paginated_keys: Vec = result.rows.into_iter().map(KeySummary::from).collect(); - // Calculate pagination - let total = filtered_keys.len() as u64; - let start = query.offset() as usize; - let end = (start + query.limit() as usize).min(filtered_keys.len()); - - // Get paginated slice (values redacted in summary) - let paginated_keys: Vec = filtered_keys[start..end] - .iter() - .map(|key| KeySummary::from(key.clone())) - .collect(); - - // Convert query params to pagination params for response let pagination_params = PaginationParams { page: query.page, page_size: query.per_page, }; - let response = PaginatedResponse::new(paginated_keys, &pagination_params, total); + let response = PaginatedResponse::new(paginated_keys, &pagination_params, result.total); Ok((StatusCode::OK, Json(response))) } diff --git a/crates/api/src/routes/rules.rs b/crates/api/src/routes/rules.rs index b56f5b0..ca1ed1b 100644 --- a/crates/api/src/routes/rules.rs +++ b/crates/api/src/routes/rules.rs @@ -17,9 +17,9 @@ use attune_common::mq::{ use attune_common::repositories::{ action::ActionRepository, pack::PackRepository, - rule::{CreateRuleInput, RuleRepository, UpdateRuleInput}, + rule::{CreateRuleInput, RuleRepository, RuleSearchFilters, UpdateRuleInput}, trigger::TriggerRepository, - Create, Delete, FindByRef, List, Update, + Create, Delete, FindByRef, Update, }; use crate::{ @@ -50,21 +50,21 @@ pub async fn list_rules( RequireAuth(_user): RequireAuth, Query(pagination): Query, ) -> ApiResult { - // Get all rules - let rules = RuleRepository::list(&state.db).await?; + let filters = RuleSearchFilters { + pack: None, + action: None, + trigger: None, + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = rules.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(rules.len()); + let result = RuleRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_rules: Vec = rules[start..end] - .iter() - .map(|r| RuleSummary::from(r.clone())) - .collect(); + let paginated_rules: Vec = + result.rows.into_iter().map(RuleSummary::from).collect(); - let response = PaginatedResponse::new(paginated_rules, &pagination, total); + let response = PaginatedResponse::new(paginated_rules, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -85,21 +85,21 @@ pub async fn list_enabled_rules( RequireAuth(_user): RequireAuth, Query(pagination): Query, ) -> ApiResult { - // Get enabled rules - let rules = RuleRepository::find_enabled(&state.db).await?; + let filters = RuleSearchFilters { + pack: None, + action: None, + trigger: None, + enabled: Some(true), + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = rules.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(rules.len()); + let result = RuleRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_rules: Vec = rules[start..end] - .iter() - .map(|r| RuleSummary::from(r.clone())) - .collect(); + let paginated_rules: Vec = + result.rows.into_iter().map(RuleSummary::from).collect(); - let response = PaginatedResponse::new(paginated_rules, &pagination, total); + let response = PaginatedResponse::new(paginated_rules, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -130,21 +130,21 @@ pub async fn list_rules_by_pack( .await? .ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?; - // Get rules for this pack - let rules = RuleRepository::find_by_pack(&state.db, pack.id).await?; + let filters = RuleSearchFilters { + pack: Some(pack.id), + action: None, + trigger: None, + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = rules.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(rules.len()); + let result = RuleRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_rules: Vec = rules[start..end] - .iter() - .map(|r| RuleSummary::from(r.clone())) - .collect(); + let paginated_rules: Vec = + result.rows.into_iter().map(RuleSummary::from).collect(); - let response = PaginatedResponse::new(paginated_rules, &pagination, total); + let response = PaginatedResponse::new(paginated_rules, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -175,21 +175,21 @@ pub async fn list_rules_by_action( .await? .ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?; - // Get rules for this action - let rules = RuleRepository::find_by_action(&state.db, action.id).await?; + let filters = RuleSearchFilters { + pack: None, + action: Some(action.id), + trigger: None, + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = rules.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(rules.len()); + let result = RuleRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_rules: Vec = rules[start..end] - .iter() - .map(|r| RuleSummary::from(r.clone())) - .collect(); + let paginated_rules: Vec = + result.rows.into_iter().map(RuleSummary::from).collect(); - let response = PaginatedResponse::new(paginated_rules, &pagination, total); + let response = PaginatedResponse::new(paginated_rules, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -220,21 +220,21 @@ pub async fn list_rules_by_trigger( .await? .ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?; - // Get rules for this trigger - let rules = RuleRepository::find_by_trigger(&state.db, trigger.id).await?; + let filters = RuleSearchFilters { + pack: None, + action: None, + trigger: Some(trigger.id), + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = rules.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(rules.len()); + let result = RuleRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_rules: Vec = rules[start..end] - .iter() - .map(|r| RuleSummary::from(r.clone())) - .collect(); + let paginated_rules: Vec = + result.rows.into_iter().map(RuleSummary::from).collect(); - let response = PaginatedResponse::new(paginated_rules, &pagination, total); + let response = PaginatedResponse::new(paginated_rules, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } diff --git a/crates/api/src/routes/triggers.rs b/crates/api/src/routes/triggers.rs index 2cd46d4..fe96590 100644 --- a/crates/api/src/routes/triggers.rs +++ b/crates/api/src/routes/triggers.rs @@ -14,10 +14,10 @@ use attune_common::repositories::{ pack::PackRepository, runtime::RuntimeRepository, trigger::{ - CreateSensorInput, CreateTriggerInput, SensorRepository, TriggerRepository, - UpdateSensorInput, UpdateTriggerInput, + CreateSensorInput, CreateTriggerInput, SensorRepository, SensorSearchFilters, + TriggerRepository, TriggerSearchFilters, UpdateSensorInput, UpdateTriggerInput, }, - Create, Delete, FindByRef, List, Update, + Create, Delete, FindByRef, Update, }; use crate::{ @@ -54,21 +54,19 @@ pub async fn list_triggers( RequireAuth(_user): RequireAuth, Query(pagination): Query, ) -> ApiResult { - // Get all triggers - let triggers = TriggerRepository::list(&state.db).await?; + let filters = TriggerSearchFilters { + pack: None, + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = triggers.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(triggers.len()); + let result = TriggerRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_triggers: Vec = triggers[start..end] - .iter() - .map(|t| TriggerSummary::from(t.clone())) - .collect(); + let paginated_triggers: Vec = + result.rows.into_iter().map(TriggerSummary::from).collect(); - let response = PaginatedResponse::new(paginated_triggers, &pagination, total); + let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -89,21 +87,19 @@ pub async fn list_enabled_triggers( RequireAuth(_user): RequireAuth, Query(pagination): Query, ) -> ApiResult { - // Get enabled triggers - let triggers = TriggerRepository::find_enabled(&state.db).await?; + let filters = TriggerSearchFilters { + pack: None, + enabled: Some(true), + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = triggers.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(triggers.len()); + let result = TriggerRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_triggers: Vec = triggers[start..end] - .iter() - .map(|t| TriggerSummary::from(t.clone())) - .collect(); + let paginated_triggers: Vec = + result.rows.into_iter().map(TriggerSummary::from).collect(); - let response = PaginatedResponse::new(paginated_triggers, &pagination, total); + let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -134,21 +130,19 @@ pub async fn list_triggers_by_pack( .await? .ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?; - // Get triggers for this pack - let triggers = TriggerRepository::find_by_pack(&state.db, pack.id).await?; + let filters = TriggerSearchFilters { + pack: Some(pack.id), + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = triggers.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(triggers.len()); + let result = TriggerRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_triggers: Vec = triggers[start..end] - .iter() - .map(|t| TriggerSummary::from(t.clone())) - .collect(); + let paginated_triggers: Vec = + result.rows.into_iter().map(TriggerSummary::from).collect(); - let response = PaginatedResponse::new(paginated_triggers, &pagination, total); + let response = PaginatedResponse::new(paginated_triggers, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -438,21 +432,20 @@ pub async fn list_sensors( RequireAuth(_user): RequireAuth, Query(pagination): Query, ) -> ApiResult { - // Get all sensors - let sensors = SensorRepository::list(&state.db).await?; + let filters = SensorSearchFilters { + pack: None, + trigger: None, + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = sensors.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(sensors.len()); + let result = SensorRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_sensors: Vec = sensors[start..end] - .iter() - .map(|s| SensorSummary::from(s.clone())) - .collect(); + let paginated_sensors: Vec = + result.rows.into_iter().map(SensorSummary::from).collect(); - let response = PaginatedResponse::new(paginated_sensors, &pagination, total); + let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -473,21 +466,20 @@ pub async fn list_enabled_sensors( RequireAuth(_user): RequireAuth, Query(pagination): Query, ) -> ApiResult { - // Get enabled sensors - let sensors = SensorRepository::find_enabled(&state.db).await?; + let filters = SensorSearchFilters { + pack: None, + trigger: None, + enabled: Some(true), + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = sensors.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(sensors.len()); + let result = SensorRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_sensors: Vec = sensors[start..end] - .iter() - .map(|s| SensorSummary::from(s.clone())) - .collect(); + let paginated_sensors: Vec = + result.rows.into_iter().map(SensorSummary::from).collect(); - let response = PaginatedResponse::new(paginated_sensors, &pagination, total); + let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -518,21 +510,20 @@ pub async fn list_sensors_by_pack( .await? .ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?; - // Get sensors for this pack - let sensors = SensorRepository::find_by_pack(&state.db, pack.id).await?; + let filters = SensorSearchFilters { + pack: Some(pack.id), + trigger: None, + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = sensors.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(sensors.len()); + let result = SensorRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_sensors: Vec = sensors[start..end] - .iter() - .map(|s| SensorSummary::from(s.clone())) - .collect(); + let paginated_sensors: Vec = + result.rows.into_iter().map(SensorSummary::from).collect(); - let response = PaginatedResponse::new(paginated_sensors, &pagination, total); + let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -563,21 +554,20 @@ pub async fn list_sensors_by_trigger( .await? .ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?; - // Get sensors for this trigger - let sensors = SensorRepository::find_by_trigger(&state.db, trigger.id).await?; + let filters = SensorSearchFilters { + pack: None, + trigger: Some(trigger.id), + enabled: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = sensors.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(sensors.len()); + let result = SensorRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_sensors: Vec = sensors[start..end] - .iter() - .map(|s| SensorSummary::from(s.clone())) - .collect(); + let paginated_sensors: Vec = + result.rows.into_iter().map(SensorSummary::from).collect(); - let response = PaginatedResponse::new(paginated_sensors, &pagination, total); + let response = PaginatedResponse::new(paginated_sensors, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } diff --git a/crates/api/src/routes/workflows.rs b/crates/api/src/routes/workflows.rs index c200e91..951a644 100644 --- a/crates/api/src/routes/workflows.rs +++ b/crates/api/src/routes/workflows.rs @@ -16,8 +16,9 @@ use attune_common::repositories::{ pack::PackRepository, workflow::{ CreateWorkflowDefinitionInput, UpdateWorkflowDefinitionInput, WorkflowDefinitionRepository, + WorkflowSearchFilters, }, - Create, Delete, FindByRef, List, Update, + Create, Delete, FindByRef, Update, }; use crate::{ @@ -54,64 +55,30 @@ pub async fn list_workflows( // Validate search params search_params.validate()?; - // Get workflows based on filters - let mut workflows = if let Some(tags_str) = &search_params.tags { - // Filter by tags - let tags: Vec<&str> = tags_str.split(',').map(|s| s.trim()).collect(); - let mut results = Vec::new(); - for tag in tags { - let mut tag_results = WorkflowDefinitionRepository::find_by_tag(&state.db, tag).await?; - results.append(&mut tag_results); - } - // Remove duplicates by ID - results.sort_by_key(|w| w.id); - results.dedup_by_key(|w| w.id); - results - } else if search_params.enabled == Some(true) { - // Filter by enabled status (only return enabled workflows) - WorkflowDefinitionRepository::find_enabled(&state.db).await? - } else { - // Get all workflows - WorkflowDefinitionRepository::list(&state.db).await? + // Parse comma-separated tags into a Vec if provided + let tags = search_params.tags.as_ref().map(|t| { + t.split(',') + .map(|s| s.trim().to_string()) + .collect::>() + }); + + // All filtering and pagination happen in a single SQL query. + let filters = WorkflowSearchFilters { + pack: None, + pack_ref: search_params.pack_ref.clone(), + enabled: search_params.enabled, + tags, + search: search_params.search.clone(), + limit: pagination.limit(), + offset: pagination.offset(), }; - // Apply enabled filter if specified and not already filtered by it - if let Some(enabled) = search_params.enabled { - if search_params.tags.is_some() { - // If we filtered by tags, also apply enabled filter - workflows.retain(|w| w.enabled == enabled); - } - } + let result = WorkflowDefinitionRepository::list_search(&state.db, &filters).await?; - // Apply search filter if provided - if let Some(search_term) = &search_params.search { - let search_lower = search_term.to_lowercase(); - workflows.retain(|w| { - w.label.to_lowercase().contains(&search_lower) - || w.description - .as_ref() - .map(|d| d.to_lowercase().contains(&search_lower)) - .unwrap_or(false) - }); - } + let paginated_workflows: Vec = + result.rows.into_iter().map(WorkflowSummary::from).collect(); - // Apply pack_ref filter if provided - if let Some(pack_ref) = &search_params.pack_ref { - workflows.retain(|w| w.pack_ref == *pack_ref); - } - - // Calculate pagination - let total = workflows.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(workflows.len()); - - // Get paginated slice - let paginated_workflows: Vec = workflows[start..end] - .iter() - .map(|w| WorkflowSummary::from(w.clone())) - .collect(); - - let response = PaginatedResponse::new(paginated_workflows, &pagination, total); + let response = PaginatedResponse::new(paginated_workflows, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } @@ -138,25 +105,27 @@ pub async fn list_workflows_by_pack( Query(pagination): Query, ) -> ApiResult { // Verify pack exists - let pack = PackRepository::find_by_ref(&state.db, &pack_ref) + let _pack = PackRepository::find_by_ref(&state.db, &pack_ref) .await? .ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?; - // Get workflows for this pack - let workflows = WorkflowDefinitionRepository::find_by_pack(&state.db, pack.id).await?; + // All filtering and pagination happen in a single SQL query. + let filters = WorkflowSearchFilters { + pack: None, + pack_ref: Some(pack_ref), + enabled: None, + tags: None, + search: None, + limit: pagination.limit(), + offset: pagination.offset(), + }; - // Calculate pagination - let total = workflows.len() as u64; - let start = ((pagination.page - 1) * pagination.limit()) as usize; - let end = (start + pagination.limit() as usize).min(workflows.len()); + let result = WorkflowDefinitionRepository::list_search(&state.db, &filters).await?; - // Get paginated slice - let paginated_workflows: Vec = workflows[start..end] - .iter() - .map(|w| WorkflowSummary::from(w.clone())) - .collect(); + let paginated_workflows: Vec = + result.rows.into_iter().map(WorkflowSummary::from).collect(); - let response = PaginatedResponse::new(paginated_workflows, &pagination, total); + let response = PaginatedResponse::new(paginated_workflows, &pagination, result.total); Ok((StatusCode::OK, Json(response))) } diff --git a/crates/common/src/models.rs b/crates/common/src/models.rs index 01cba9a..03a069d 100644 --- a/crates/common/src/models.rs +++ b/crates/common/src/models.rs @@ -1104,6 +1104,11 @@ pub mod execution { pub status: ExecutionStatus, pub result: Option, + /// When the execution actually started running (worker picked it up). + /// Set when status transitions to `Running`. Used to compute accurate + /// duration that excludes queue/scheduling wait time. + pub started_at: Option>, + /// Workflow task metadata (only populated for workflow task executions) /// /// Provides direct access to workflow orchestration state without JOINs. diff --git a/crates/common/src/repositories/action.rs b/crates/common/src/repositories/action.rs index 1d87881..21a170f 100644 --- a/crates/common/src/repositories/action.rs +++ b/crates/common/src/repositories/action.rs @@ -8,6 +8,26 @@ use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, FindByRef, List, Repository, Update}; +/// Filters for [`ActionRepository::list_search`]. +/// +/// All fields are optional and combinable (AND). Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct ActionSearchFilters { + /// Filter by pack ID + pub pack: Option, + /// Text search across ref, label, description (case-insensitive) + pub query: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`ActionRepository::list_search`]. +#[derive(Debug)] +pub struct ActionSearchResult { + pub rows: Vec, + pub total: u64, +} + /// Repository for Action operations pub struct ActionRepository; @@ -287,6 +307,92 @@ impl Delete for ActionRepository { } impl ActionRepository { + /// Search actions with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + pub async fn list_search<'e, E>( + db: E, + filters: &ActionSearchFilters, + ) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_version_constraint, param_schema, out_schema, workflow_def, is_adhoc, created, updated"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM action")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM action"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(pack_id) = filters.pack { + push_condition!("pack = ", pack_id); + } + if let Some(ref query) = filters.query { + let pattern = format!("%{}%", query.to_lowercase()); + // Search needs an OR across multiple columns, wrapped in parens + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push("(LOWER(ref) LIKE "); + qb.push_bind(pattern.clone()); + qb.push(" OR LOWER(label) LIKE "); + qb.push_bind(pattern.clone()); + qb.push(" OR LOWER(description) LIKE "); + qb.push_bind(pattern.clone()); + qb.push(")"); + + count_qb.push("(LOWER(ref) LIKE "); + count_qb.push_bind(pattern.clone()); + count_qb.push(" OR LOWER(label) LIKE "); + count_qb.push_bind(pattern.clone()); + count_qb.push(" OR LOWER(description) LIKE "); + count_qb.push_bind(pattern); + count_qb.push(")"); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY ref ASC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(ActionSearchResult { rows, total }) + } + /// Find actions by pack ID pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result> where diff --git a/crates/common/src/repositories/event.rs b/crates/common/src/repositories/event.rs index 20316ad..60a0706 100644 --- a/crates/common/src/repositories/event.rs +++ b/crates/common/src/repositories/event.rs @@ -15,6 +15,56 @@ use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, List, Repository, Update}; +// ============================================================================ +// Event Search +// ============================================================================ + +/// Filters for [`EventRepository::search`]. +/// +/// All fields are optional. When set, the corresponding WHERE clause is added. +/// Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct EventSearchFilters { + pub trigger: Option, + pub trigger_ref: Option, + pub source: Option, + pub rule_ref: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`EventRepository::search`]. +#[derive(Debug)] +pub struct EventSearchResult { + pub rows: Vec, + pub total: u64, +} + +// ============================================================================ +// Enforcement Search +// ============================================================================ + +/// Filters for [`EnforcementRepository::search`]. +/// +/// All fields are optional and combinable. Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct EnforcementSearchFilters { + pub rule: Option, + pub event: Option, + pub status: Option, + pub trigger_ref: Option, + pub rule_ref: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`EnforcementRepository::search`]. +#[derive(Debug)] +pub struct EnforcementSearchResult { + pub rows: Vec, + pub total: u64, +} + /// Repository for Event operations pub struct EventRepository; @@ -173,6 +223,75 @@ impl EventRepository { Ok(events) } + + /// Search events with all filters pushed into SQL. + /// + /// Builds a dynamic query so that every filter, pagination, and the total + /// count are handled in the database — no in-memory filtering or slicing. + pub async fn search<'e, E>(db: E, filters: &EventSearchFilters) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, trigger, trigger_ref, config, payload, source, source_ref, rule, rule_ref, created"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM event")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM event"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(trigger_id) = filters.trigger { + push_condition!("trigger = ", trigger_id); + } + if let Some(ref trigger_ref) = filters.trigger_ref { + push_condition!("trigger_ref = ", trigger_ref.clone()); + } + if let Some(source_id) = filters.source { + push_condition!("source = ", source_id); + } + if let Some(ref rule_ref) = filters.rule_ref { + push_condition!( + "LOWER(rule_ref) LIKE ", + format!("%{}%", rule_ref.to_lowercase()) + ); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY created DESC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(EventSearchResult { rows, total }) + } } // ============================================================================ @@ -425,4 +544,75 @@ impl EnforcementRepository { Ok(enforcements) } + + /// Search enforcements with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + pub async fn search<'e, E>( + db: E, + filters: &EnforcementSearchFilters, + ) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, rule, rule_ref, trigger_ref, config, event, status, payload, condition, conditions, created, resolved_at"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM enforcement")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM enforcement"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(status) = &filters.status { + push_condition!("status = ", status.clone()); + } + if let Some(rule_id) = filters.rule { + push_condition!("rule = ", rule_id); + } + if let Some(event_id) = filters.event { + push_condition!("event = ", event_id); + } + if let Some(ref trigger_ref) = filters.trigger_ref { + push_condition!("trigger_ref = ", trigger_ref.clone()); + } + if let Some(ref rule_ref) = filters.rule_ref { + push_condition!("rule_ref = ", rule_ref.clone()); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY created DESC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(EnforcementSearchResult { rows, total }) + } } diff --git a/crates/common/src/repositories/execution.rs b/crates/common/src/repositories/execution.rs index 64bab7f..224d6a3 100644 --- a/crates/common/src/repositories/execution.rs +++ b/crates/common/src/repositories/execution.rs @@ -1,11 +1,71 @@ //! Execution repository for database operations +use chrono::{DateTime, Utc}; + use crate::models::{enums::ExecutionStatus, execution::*, Id, JsonDict}; use crate::Result; use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, List, Repository, Update}; +/// Filters for the [`ExecutionRepository::search`] query-builder method. +/// +/// Every field is optional. When set, the corresponding `WHERE` clause is +/// appended to the query. Pagination (`limit`/`offset`) is always applied. +/// +/// Filters that involve the `enforcement` table (`rule_ref`, `trigger_ref`) +/// cause a `LEFT JOIN enforcement` to be added automatically. +#[derive(Debug, Clone, Default)] +pub struct ExecutionSearchFilters { + pub status: Option, + pub action_ref: Option, + pub pack_name: Option, + pub rule_ref: Option, + pub trigger_ref: Option, + pub executor: Option, + pub result_contains: Option, + pub enforcement: Option, + pub parent: Option, + pub top_level_only: bool, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`ExecutionRepository::search`]. +/// +/// Includes the matching rows *and* the total count (before LIMIT/OFFSET) +/// so the caller can build pagination metadata without a second round-trip. +#[derive(Debug)] +pub struct ExecutionSearchResult { + pub rows: Vec, + pub total: u64, +} + +/// An execution row with optional `rule_ref` / `trigger_ref` populated from +/// the joined `enforcement` table. This avoids a separate in-memory lookup. +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct ExecutionWithRefs { + // — execution columns (same order as SELECT_COLUMNS) — + pub id: Id, + pub action: Option, + pub action_ref: String, + pub config: Option, + pub env_vars: Option, + pub parent: Option, + pub enforcement: Option, + pub executor: Option, + pub status: ExecutionStatus, + pub result: Option, + pub started_at: Option>, + #[sqlx(json, default)] + pub workflow_task: Option, + pub created: DateTime, + pub updated: DateTime, + // — joined from enforcement — + pub rule_ref: Option, + pub trigger_ref: Option, +} + /// Column list for SELECT queries on the execution table. /// /// Defined once to avoid drift between queries and the `Execution` model. @@ -13,7 +73,7 @@ use super::{Create, Delete, FindById, List, Repository, Update}; /// are NOT in the Rust struct, so `SELECT *` must never be used. pub const SELECT_COLUMNS: &str = "\ id, action, action_ref, config, env_vars, parent, enforcement, \ - executor, status, result, workflow_task, created, updated"; + executor, status, result, started_at, workflow_task, created, updated"; pub struct ExecutionRepository; @@ -43,6 +103,7 @@ pub struct UpdateExecutionInput { pub status: Option, pub result: Option, pub executor: Option, + pub started_at: Option>, pub workflow_task: Option, } @@ -52,6 +113,7 @@ impl From for UpdateExecutionInput { status: Some(execution.status), result: execution.result, executor: execution.executor, + started_at: execution.started_at, workflow_task: execution.workflow_task, } } @@ -146,6 +208,13 @@ impl Update for ExecutionRepository { query.push("executor = ").push_bind(executor_id); has_updates = true; } + if let Some(started_at) = input.started_at { + if has_updates { + query.push(", "); + } + query.push("started_at = ").push_bind(started_at); + has_updates = true; + } if let Some(workflow_task) = &input.workflow_task { if has_updates { query.push(", "); @@ -239,4 +308,141 @@ impl ExecutionRepository { .await .map_err(Into::into) } + + /// Search executions with all filters pushed into SQL. + /// + /// Builds a dynamic query with only the WHERE clauses that apply, + /// a LEFT JOIN on `enforcement` when `rule_ref` or `trigger_ref` filters + /// are present (or always, to populate those columns on the result), + /// and proper LIMIT/OFFSET so pagination is server-side. + /// + /// Returns both the matching page of rows and the total count. + pub async fn search<'e, E>( + db: E, + filters: &ExecutionSearchFilters, + ) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + // We always LEFT JOIN enforcement so we can return rule_ref/trigger_ref + // on every row without a second round-trip. + let prefixed_select = SELECT_COLUMNS + .split(", ") + .map(|col| format!("e.{col}")) + .collect::>() + .join(", "); + + let select_clause = format!( + "{prefixed_select}, enf.rule_ref AS rule_ref, enf.trigger_ref AS trigger_ref" + ); + + let from_clause = "FROM execution e LEFT JOIN enforcement enf ON e.enforcement = enf.id"; + + // ── Build WHERE clauses ────────────────────────────────────────── + let mut conditions: Vec = Vec::new(); + + // We'll collect bind values to push into the QueryBuilder afterwards. + // Because QueryBuilder doesn't let us interleave raw SQL and binds in + // arbitrary order easily, we build the SQL string with numbered $N + // placeholders and then bind in order. + + // Track the next placeholder index ($1, $2, …). + // We can't use QueryBuilder's push_bind because we need the COUNT(*) + // query to share the same WHERE clause text. Instead we build the + // clause once and execute both queries with manual binds. + + // ── Use QueryBuilder for the data query ────────────────────────── + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_clause} {from_clause}")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT COUNT(*) AS total {from_clause}")); + + // Helper: append the same condition to both builders. + // We need a tiny state machine since push_bind moves the value. + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + let needs_where = conditions.is_empty(); + conditions.push(String::new()); // just to track count + if needs_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + macro_rules! push_raw_condition { + ($cond:expr) => {{ + let needs_where = conditions.is_empty(); + conditions.push(String::new()); + if needs_where { + qb.push(concat!(" WHERE ", $cond)); + count_qb.push(concat!(" WHERE ", $cond)); + } else { + qb.push(concat!(" AND ", $cond)); + count_qb.push(concat!(" AND ", $cond)); + } + }}; + } + + if let Some(status) = &filters.status { + push_condition!("e.status = ", status.clone()); + } + if let Some(action_ref) = &filters.action_ref { + push_condition!("e.action_ref = ", action_ref.clone()); + } + if let Some(pack_name) = &filters.pack_name { + let pattern = format!("{pack_name}.%"); + push_condition!("e.action_ref LIKE ", pattern); + } + if let Some(enforcement_id) = filters.enforcement { + push_condition!("e.enforcement = ", enforcement_id); + } + if let Some(parent_id) = filters.parent { + push_condition!("e.parent = ", parent_id); + } + if filters.top_level_only { + push_raw_condition!("e.parent IS NULL"); + } + if let Some(executor_id) = filters.executor { + push_condition!("e.executor = ", executor_id); + } + if let Some(rule_ref) = &filters.rule_ref { + push_condition!("enf.rule_ref = ", rule_ref.clone()); + } + if let Some(trigger_ref) = &filters.trigger_ref { + push_condition!("enf.trigger_ref = ", trigger_ref.clone()); + } + if let Some(search) = &filters.result_contains { + let pattern = format!("%{}%", search.to_lowercase()); + push_condition!("LOWER(e.result::text) LIKE ", pattern); + } + + // ── COUNT query ────────────────────────────────────────────────── + let total: i64 = count_qb + .build_query_scalar() + .fetch_one(db) + .await?; + let total = total.max(0) as u64; + + // ── Data query with ORDER BY + pagination ──────────────────────── + qb.push(" ORDER BY e.created DESC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb + .build_query_as() + .fetch_all(db) + .await?; + + Ok(ExecutionSearchResult { rows, total }) + } } diff --git a/crates/common/src/repositories/inquiry.rs b/crates/common/src/repositories/inquiry.rs index 972bbf9..87b7f74 100644 --- a/crates/common/src/repositories/inquiry.rs +++ b/crates/common/src/repositories/inquiry.rs @@ -7,6 +7,25 @@ use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, List, Repository, Update}; +/// Filters for [`InquiryRepository::search`]. +/// +/// All fields are optional and combinable (AND). Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct InquirySearchFilters { + pub status: Option, + pub execution: Option, + pub assigned_to: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`InquiryRepository::search`]. +#[derive(Debug)] +pub struct InquirySearchResult { + pub rows: Vec, + pub total: u64, +} + pub struct InquiryRepository; impl Repository for InquiryRepository { @@ -157,4 +176,66 @@ impl InquiryRepository { "SELECT id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated FROM inquiry WHERE execution = $1 ORDER BY created DESC" ).bind(execution_id).fetch_all(executor).await.map_err(Into::into) } + + /// Search inquiries with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + pub async fn search<'e, E>(db: E, filters: &InquirySearchFilters) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, execution, prompt, response_schema, assigned_to, status, response, timeout_at, responded_at, created, updated"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM inquiry")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM inquiry"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(status) = &filters.status { + push_condition!("status = ", status.clone()); + } + if let Some(execution_id) = filters.execution { + push_condition!("execution = ", execution_id); + } + if let Some(assigned_to) = filters.assigned_to { + push_condition!("assigned_to = ", assigned_to); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY created DESC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(InquirySearchResult { rows, total }) + } } diff --git a/crates/common/src/repositories/key.rs b/crates/common/src/repositories/key.rs index cd87597..ff88161 100644 --- a/crates/common/src/repositories/key.rs +++ b/crates/common/src/repositories/key.rs @@ -6,6 +6,24 @@ use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, List, Repository, Update}; +/// Filters for [`KeyRepository::search`]. +/// +/// All fields are optional and combinable (AND). Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct KeySearchFilters { + pub owner_type: Option, + pub owner: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`KeyRepository::search`]. +#[derive(Debug)] +pub struct KeySearchResult { + pub rows: Vec, + pub total: u64, +} + pub struct KeyRepository; impl Repository for KeyRepository { @@ -165,4 +183,63 @@ impl KeyRepository { "SELECT id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated FROM key WHERE owner_type = $1 ORDER BY ref ASC" ).bind(owner_type).fetch_all(executor).await.map_err(Into::into) } + + /// Search keys with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + pub async fn search<'e, E>(db: E, filters: &KeySearchFilters) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref, owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted, encryption_key_hash, value, created, updated"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM key")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM key"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(ref owner_type) = filters.owner_type { + push_condition!("owner_type = ", owner_type.clone()); + } + if let Some(ref owner) = filters.owner { + push_condition!("owner = ", owner.clone()); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY ref ASC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(KeySearchResult { rows, total }) + } } diff --git a/crates/common/src/repositories/rule.rs b/crates/common/src/repositories/rule.rs index 66dc2d2..42cf10d 100644 --- a/crates/common/src/repositories/rule.rs +++ b/crates/common/src/repositories/rule.rs @@ -8,6 +8,30 @@ use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, FindByRef, List, Repository, Update}; +/// Filters for [`RuleRepository::list_search`]. +/// +/// All fields are optional and combinable (AND). Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct RuleSearchFilters { + /// Filter by pack ID + pub pack: Option, + /// Filter by action ID + pub action: Option, + /// Filter by trigger ID + pub trigger: Option, + /// Filter by enabled status + pub enabled: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`RuleRepository::list_search`]. +#[derive(Debug)] +pub struct RuleSearchResult { + pub rows: Vec, + pub total: u64, +} + /// Input for restoring an ad-hoc rule during pack reinstallation. /// Unlike `CreateRuleInput`, action and trigger IDs are optional because /// the referenced entities may not exist yet or may have been removed. @@ -275,6 +299,71 @@ impl Delete for RuleRepository { } impl RuleRepository { + /// Search rules with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + pub async fn list_search<'e, E>(db: E, filters: &RuleSearchFilters) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, ref, pack, pack_ref, label, description, action, action_ref, trigger, trigger_ref, conditions, action_params, trigger_params, enabled, is_adhoc, created, updated"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM rule")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM rule"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(pack_id) = filters.pack { + push_condition!("pack = ", pack_id); + } + if let Some(action_id) = filters.action { + push_condition!("action = ", action_id); + } + if let Some(trigger_id) = filters.trigger { + push_condition!("trigger = ", trigger_id); + } + if let Some(enabled) = filters.enabled { + push_condition!("enabled = ", enabled); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY ref ASC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(RuleSearchResult { rows, total }) + } + /// Find rules by pack ID pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result> where diff --git a/crates/common/src/repositories/trigger.rs b/crates/common/src/repositories/trigger.rs index e11a4b6..20eb2ba 100644 --- a/crates/common/src/repositories/trigger.rs +++ b/crates/common/src/repositories/trigger.rs @@ -9,6 +9,56 @@ use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, FindByRef, List, Repository, Update}; +// ============================================================================ +// Trigger Search +// ============================================================================ + +/// Filters for [`TriggerRepository::list_search`]. +/// +/// All fields are optional and combinable (AND). Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct TriggerSearchFilters { + /// Filter by pack ID + pub pack: Option, + /// Filter by enabled status + pub enabled: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`TriggerRepository::list_search`]. +#[derive(Debug)] +pub struct TriggerSearchResult { + pub rows: Vec, + pub total: u64, +} + +// ============================================================================ +// Sensor Search +// ============================================================================ + +/// Filters for [`SensorRepository::list_search`]. +/// +/// All fields are optional and combinable (AND). Pagination is always applied. +#[derive(Debug, Clone, Default)] +pub struct SensorSearchFilters { + /// Filter by pack ID + pub pack: Option, + /// Filter by trigger ID + pub trigger: Option, + /// Filter by enabled status + pub enabled: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`SensorRepository::list_search`]. +#[derive(Debug)] +pub struct SensorSearchResult { + pub rows: Vec, + pub total: u64, +} + /// Repository for Trigger operations pub struct TriggerRepository; @@ -251,6 +301,68 @@ impl Delete for TriggerRepository { } impl TriggerRepository { + /// Search triggers with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + pub async fn list_search<'e, E>( + db: E, + filters: &TriggerSearchFilters, + ) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, ref, pack, pack_ref, label, description, enabled, param_schema, out_schema, webhook_enabled, webhook_key, webhook_config, is_adhoc, created, updated"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM trigger")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM trigger"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(pack_id) = filters.pack { + push_condition!("pack = ", pack_id); + } + if let Some(enabled) = filters.enabled { + push_condition!("enabled = ", enabled); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY ref ASC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(TriggerSearchResult { rows, total }) + } + /// Find triggers by pack ID pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result> where @@ -795,6 +907,71 @@ impl Delete for SensorRepository { } impl SensorRepository { + /// Search sensors with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + pub async fn list_search<'e, E>( + db: E, + filters: &SensorSearchFilters, + ) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, ref, pack, pack_ref, label, description, entrypoint, runtime, runtime_ref, runtime_version_constraint, trigger, trigger_ref, enabled, param_schema, config, created, updated"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM sensor")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM sensor"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(pack_id) = filters.pack { + push_condition!("pack = ", pack_id); + } + if let Some(trigger_id) = filters.trigger { + push_condition!("trigger = ", trigger_id); + } + if let Some(enabled) = filters.enabled { + push_condition!("enabled = ", enabled); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY ref ASC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(SensorSearchResult { rows, total }) + } + /// Find sensors by trigger ID pub async fn find_by_trigger<'e, E>(executor: E, trigger_id: Id) -> Result> where diff --git a/crates/common/src/repositories/workflow.rs b/crates/common/src/repositories/workflow.rs index 29ba26e..63374ae 100644 --- a/crates/common/src/repositories/workflow.rs +++ b/crates/common/src/repositories/workflow.rs @@ -6,6 +6,37 @@ use sqlx::{Executor, Postgres, QueryBuilder}; use super::{Create, Delete, FindById, FindByRef, List, Repository, Update}; +// ============================================================================ +// Workflow Definition Search +// ============================================================================ + +/// Filters for [`WorkflowDefinitionRepository::list_search`]. +/// +/// All fields are optional and combinable (AND). Pagination is always applied. +/// Tag filtering uses `ANY(tags)` for each tag (OR across tags, AND with other filters). +#[derive(Debug, Clone, Default)] +pub struct WorkflowSearchFilters { + /// Filter by pack ID + pub pack: Option, + /// Filter by pack reference + pub pack_ref: Option, + /// Filter by enabled status + pub enabled: Option, + /// Filter by tags (OR across tags — matches if any tag is present) + pub tags: Option>, + /// Text search across label and description (case-insensitive substring) + pub search: Option, + pub limit: u32, + pub offset: u32, +} + +/// Result of [`WorkflowDefinitionRepository::list_search`]. +#[derive(Debug)] +pub struct WorkflowSearchResult { + pub rows: Vec, + pub total: u64, +} + // ============================================================================ // WORKFLOW DEFINITION REPOSITORY // ============================================================================ @@ -226,6 +257,102 @@ impl Delete for WorkflowDefinitionRepository { } impl WorkflowDefinitionRepository { + /// Search workflow definitions with all filters pushed into SQL. + /// + /// All filter fields are combinable (AND). Pagination is server-side. + /// Tags use an OR match — a workflow matches if it contains ANY of the + /// requested tags (via `tags && ARRAY[...]`). + pub async fn list_search<'e, E>( + db: E, + filters: &WorkflowSearchFilters, + ) -> Result + where + E: Executor<'e, Database = Postgres> + Copy + 'e, + { + let select_cols = "id, ref, pack, pack_ref, label, description, version, param_schema, out_schema, definition, tags, enabled, created, updated"; + + let mut qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new(format!("SELECT {select_cols} FROM workflow_definition")); + let mut count_qb: QueryBuilder<'_, Postgres> = + QueryBuilder::new("SELECT COUNT(*) FROM workflow_definition"); + + let mut has_where = false; + + macro_rules! push_condition { + ($cond_prefix:expr, $value:expr) => {{ + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push($cond_prefix); + qb.push_bind($value.clone()); + count_qb.push($cond_prefix); + count_qb.push_bind($value); + }}; + } + + if let Some(pack_id) = filters.pack { + push_condition!("pack = ", pack_id); + } + if let Some(ref pack_ref) = filters.pack_ref { + push_condition!("pack_ref = ", pack_ref.clone()); + } + if let Some(enabled) = filters.enabled { + push_condition!("enabled = ", enabled); + } + if let Some(ref tags) = filters.tags { + if !tags.is_empty() { + // Use PostgreSQL array overlap operator: tags && ARRAY[...] + push_condition!("tags && ", tags.clone()); + } + } + if let Some(ref search) = filters.search { + let pattern = format!("%{}%", search.to_lowercase()); + // Search needs an OR across multiple columns, wrapped in parens + if !has_where { + qb.push(" WHERE "); + count_qb.push(" WHERE "); + has_where = true; + } else { + qb.push(" AND "); + count_qb.push(" AND "); + } + qb.push("(LOWER(label) LIKE "); + qb.push_bind(pattern.clone()); + qb.push(" OR LOWER(COALESCE(description, '')) LIKE "); + qb.push_bind(pattern.clone()); + qb.push(")"); + + count_qb.push("(LOWER(label) LIKE "); + count_qb.push_bind(pattern.clone()); + count_qb.push(" OR LOWER(COALESCE(description, '')) LIKE "); + count_qb.push_bind(pattern); + count_qb.push(")"); + } + + // Suppress unused-assignment warning from the macro's last expansion. + let _ = has_where; + + // Count + let total: i64 = count_qb.build_query_scalar().fetch_one(db).await?; + let total = total.max(0) as u64; + + // Data query + qb.push(" ORDER BY label ASC"); + qb.push(" LIMIT "); + qb.push_bind(filters.limit as i64); + qb.push(" OFFSET "); + qb.push_bind(filters.offset as i64); + + let rows: Vec = qb.build_query_as().fetch_all(db).await?; + + Ok(WorkflowSearchResult { rows, total }) + } + /// Find all workflows for a specific pack by pack ID pub async fn find_by_pack<'e, E>(executor: E, pack_id: Id) -> Result> where diff --git a/crates/common/src/workflow/expression/ast.rs b/crates/common/src/workflow/expression/ast.rs new file mode 100644 index 0000000..170c1e8 --- /dev/null +++ b/crates/common/src/workflow/expression/ast.rs @@ -0,0 +1,112 @@ +//! # Expression AST +//! +//! Defines the abstract syntax tree nodes produced by the parser and consumed +//! by the evaluator. + +use std::fmt; + +/// A binary operator connecting two sub-expressions. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinaryOp { + // Arithmetic + Add, + Sub, + Mul, + Div, + Mod, + // Comparison + Eq, + Ne, + Lt, + Gt, + Le, + Ge, + // Logical + And, + Or, + // Membership + In, +} + +impl fmt::Display for BinaryOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BinaryOp::Add => write!(f, "+"), + BinaryOp::Sub => write!(f, "-"), + BinaryOp::Mul => write!(f, "*"), + BinaryOp::Div => write!(f, "/"), + BinaryOp::Mod => write!(f, "%"), + BinaryOp::Eq => write!(f, "=="), + BinaryOp::Ne => write!(f, "!="), + BinaryOp::Lt => write!(f, "<"), + BinaryOp::Gt => write!(f, ">"), + BinaryOp::Le => write!(f, "<="), + BinaryOp::Ge => write!(f, ">="), + BinaryOp::And => write!(f, "and"), + BinaryOp::Or => write!(f, "or"), + BinaryOp::In => write!(f, "in"), + } + } +} + +/// A unary operator applied to a single sub-expression. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryOp { + /// Arithmetic negation: `-x` + Neg, + /// Logical negation: `not x` + Not, +} + +impl fmt::Display for UnaryOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UnaryOp::Neg => write!(f, "-"), + UnaryOp::Not => write!(f, "not"), + } + } +} + +/// An expression AST node. +#[derive(Debug, Clone, PartialEq)] +pub enum Expr { + /// A literal JSON value: number, string, bool, or null. + Literal(serde_json::Value), + + /// An array literal: `[expr, expr, ...]` + Array(Vec), + + /// A variable reference by name (e.g., `x`, `parameters`, `item`). + Ident(String), + + /// Binary operation: `left op right` + BinaryOp { + op: BinaryOp, + left: Box, + right: Box, + }, + + /// Unary operation: `op operand` + UnaryOp { + op: UnaryOp, + operand: Box, + }, + + /// Property access: `expr.field` + DotAccess { + object: Box, + field: String, + }, + + /// Index/bracket access: `expr[index_expr]` + IndexAccess { + object: Box, + index: Box, + }, + + /// Function call: `name(arg1, arg2, ...)` + FunctionCall { + name: String, + args: Vec, + }, +} diff --git a/crates/common/src/workflow/expression/evaluator.rs b/crates/common/src/workflow/expression/evaluator.rs new file mode 100644 index 0000000..406c1a7 --- /dev/null +++ b/crates/common/src/workflow/expression/evaluator.rs @@ -0,0 +1,1316 @@ +//! # Expression Evaluator +//! +//! Walks the AST and produces a `JsonValue` result. + +use super::ast::{BinaryOp, Expr, UnaryOp}; +use regex::Regex; +use serde_json::{json, Value as JsonValue}; +use thiserror::Error; + +/// Result type for evaluation operations. +pub type EvalResult = Result; + +/// Errors that can occur during expression evaluation. +#[derive(Debug, Error)] +pub enum EvalError { + #[error("Variable not found: {0}")] + VariableNotFound(String), + + #[error("Type error: {0}")] + TypeError(String), + + #[error("Division by zero")] + DivisionByZero, + + #[error("Index out of bounds: {0}")] + IndexOutOfBounds(String), + + #[error("Unknown function: {0}")] + UnknownFunction(String), + + #[error("Wrong number of arguments for {0}: expected {1}, got {2}")] + WrongArgCount(String, String, usize), + + #[error("Parse error: {0}")] + ParseError(String), + + #[error("Regex error: {0}")] + RegexError(String), +} + +/// Trait for resolving variables and workflow-specific functions from +/// the execution context. +pub trait EvalContext { + /// Resolve a top-level variable name to its JSON value. + fn resolve_variable(&self, name: &str) -> EvalResult; + + /// Try to call a workflow-specific function (e.g., `result()`, `succeeded()`). + /// Return `Ok(Some(value))` if handled, `Ok(None)` if not recognized. + fn call_workflow_function( + &self, + name: &str, + args: &[JsonValue], + ) -> EvalResult>; +} + +/// Evaluate an AST expression against the given context. +pub fn eval(expr: &Expr, ctx: &dyn EvalContext) -> EvalResult { + match expr { + Expr::Literal(v) => Ok(v.clone()), + + Expr::Array(elements) => { + let mut arr = Vec::with_capacity(elements.len()); + for elem in elements { + arr.push(eval(elem, ctx)?); + } + Ok(JsonValue::Array(arr)) + } + + Expr::Ident(name) => ctx.resolve_variable(name), + + Expr::BinaryOp { op, left, right } => { + // Short-circuit for `and` / `or` + if *op == BinaryOp::And { + let lv = eval(left, ctx)?; + if !is_truthy(&lv) { + return Ok(json!(false)); + } + let rv = eval(right, ctx)?; + return Ok(json!(is_truthy(&rv))); + } + if *op == BinaryOp::Or { + let lv = eval(left, ctx)?; + if is_truthy(&lv) { + return Ok(json!(true)); + } + let rv = eval(right, ctx)?; + return Ok(json!(is_truthy(&rv))); + } + + let lv = eval(left, ctx)?; + let rv = eval(right, ctx)?; + eval_binary_op(*op, &lv, &rv) + } + + Expr::UnaryOp { op, operand } => { + let v = eval(operand, ctx)?; + eval_unary_op(*op, &v) + } + + Expr::DotAccess { object, field } => { + let obj = eval(object, ctx)?; + dot_access(&obj, field) + } + + Expr::IndexAccess { object, index } => { + let obj = eval(object, ctx)?; + let idx = eval(index, ctx)?; + index_access(&obj, &idx) + } + + Expr::FunctionCall { name, args } => { + // First, try workflow-specific functions (result(), succeeded(), etc.) + // We evaluate args lazily for workflow fns that take 0 args. + let evaluated_args: Vec = args + .iter() + .map(|a| eval(a, ctx)) + .collect::>>()?; + + if let Some(val) = ctx.call_workflow_function(name, &evaluated_args)? { + return Ok(val); + } + + // Built-in functions + eval_builtin_function(name, &evaluated_args) + } + } +} + +// --------------------------------------------------------------- +// Truthiness +// --------------------------------------------------------------- + +/// Determine if a JSON value is "truthy" (Python-like semantics). +pub fn is_truthy(v: &JsonValue) -> bool { + match v { + JsonValue::Null => false, + JsonValue::Bool(b) => *b, + JsonValue::Number(n) => { + if let Some(i) = n.as_i64() { + i != 0 + } else if let Some(f) = n.as_f64() { + f != 0.0 + } else { + true + } + } + JsonValue::String(s) => !s.is_empty(), + JsonValue::Array(a) => !a.is_empty(), + JsonValue::Object(o) => !o.is_empty(), + } +} + +// --------------------------------------------------------------- +// Binary operations +// --------------------------------------------------------------- + +fn eval_binary_op(op: BinaryOp, left: &JsonValue, right: &JsonValue) -> EvalResult { + match op { + // Arithmetic + BinaryOp::Add => eval_add(left, right), + BinaryOp::Sub => eval_arithmetic(left, right, |a, b| a - b, |a, b| a - b, "-"), + BinaryOp::Mul => eval_arithmetic(left, right, |a, b| a * b, |a, b| a * b, "*"), + BinaryOp::Div => eval_div(left, right), + BinaryOp::Mod => eval_mod(left, right), + + // Comparison + BinaryOp::Eq => Ok(json!(json_eq(left, right))), + BinaryOp::Ne => Ok(json!(!json_eq(left, right))), + BinaryOp::Lt => eval_ordering(left, right, |o| o == std::cmp::Ordering::Less), + BinaryOp::Gt => eval_ordering(left, right, |o| o == std::cmp::Ordering::Greater), + BinaryOp::Le => eval_ordering(left, right, |o| o != std::cmp::Ordering::Greater), + BinaryOp::Ge => eval_ordering(left, right, |o| o != std::cmp::Ordering::Less), + + // Membership + BinaryOp::In => eval_in(left, right), + + // And/Or handled in eval() with short-circuit + BinaryOp::And | BinaryOp::Or => unreachable!(), + } +} + +fn eval_add(left: &JsonValue, right: &JsonValue) -> EvalResult { + // String concatenation + if left.is_string() && right.is_string() { + let l = left.as_str().unwrap(); + let r = right.as_str().unwrap(); + return Ok(json!(format!("{}{}", l, r))); + } + + // Array concatenation + if left.is_array() && right.is_array() { + let mut result = left.as_array().unwrap().clone(); + result.extend(right.as_array().unwrap().iter().cloned()); + return Ok(JsonValue::Array(result)); + } + + // Numeric addition + eval_arithmetic(left, right, |a, b| a + b, |a, b| a + b, "+") +} + +fn eval_arithmetic( + left: &JsonValue, + right: &JsonValue, + int_op: impl Fn(i64, i64) -> i64, + float_op: impl Fn(f64, f64) -> f64, + op_name: &str, +) -> EvalResult { + match (as_numeric(left), as_numeric(right)) { + (Some(NumericValue::Int(a)), Some(NumericValue::Int(b))) => Ok(json!(int_op(a, b))), + (Some(a), Some(b)) => Ok(json!(float_op(a.as_f64(), b.as_f64()))), + _ => Err(EvalError::TypeError(format!( + "Cannot apply '{}' to {} and {}", + op_name, + type_name(left), + type_name(right) + ))), + } +} + +fn eval_div(left: &JsonValue, right: &JsonValue) -> EvalResult { + match (as_numeric(left), as_numeric(right)) { + (Some(_), Some(b)) if b.as_f64() == 0.0 => Err(EvalError::DivisionByZero), + (Some(NumericValue::Int(a)), Some(NumericValue::Int(b))) => { + // Integer division stays integer if divisible + if a % b == 0 { + Ok(json!(a / b)) + } else { + Ok(json!(a as f64 / b as f64)) + } + } + (Some(a), Some(b)) => Ok(json!(a.as_f64() / b.as_f64())), + _ => Err(EvalError::TypeError(format!( + "Cannot apply '/' to {} and {}", + type_name(left), + type_name(right) + ))), + } +} + +fn eval_mod(left: &JsonValue, right: &JsonValue) -> EvalResult { + match (as_numeric(left), as_numeric(right)) { + (Some(_), Some(b)) if b.as_f64() == 0.0 => Err(EvalError::DivisionByZero), + (Some(NumericValue::Int(a)), Some(NumericValue::Int(b))) => Ok(json!(a % b)), + (Some(a), Some(b)) => Ok(json!(a.as_f64() % b.as_f64())), + _ => Err(EvalError::TypeError(format!( + "Cannot apply '%' to {} and {}", + type_name(left), + type_name(right) + ))), + } +} + +// --------------------------------------------------------------- +// Comparison helpers +// --------------------------------------------------------------- + +/// Deep equality that allows int/float cross-comparison. +fn json_eq(a: &JsonValue, b: &JsonValue) -> bool { + match (a, b) { + (JsonValue::Null, JsonValue::Null) => true, + (JsonValue::Bool(a), JsonValue::Bool(b)) => a == b, + (JsonValue::Number(_), JsonValue::Number(_)) => { + // Allow int/float comparison + match (as_numeric(a), as_numeric(b)) { + (Some(a), Some(b)) => a.as_f64() == b.as_f64(), + _ => false, + } + } + (JsonValue::String(a), JsonValue::String(b)) => a == b, + (JsonValue::Array(a), JsonValue::Array(b)) => { + if a.len() != b.len() { + return false; + } + a.iter().zip(b.iter()).all(|(x, y)| json_eq(x, y)) + } + (JsonValue::Object(a), JsonValue::Object(b)) => { + if a.len() != b.len() { + return false; + } + a.iter() + .all(|(k, v)| b.get(k).map_or(false, |bv| json_eq(v, bv))) + } + // Different types (other than number cross-compare) are never equal + _ => false, + } +} + +fn eval_ordering( + left: &JsonValue, + right: &JsonValue, + predicate: impl Fn(std::cmp::Ordering) -> bool, +) -> EvalResult { + // Number comparison (int/float cross-allowed) + if let (Some(a), Some(b)) = (as_numeric(left), as_numeric(right)) { + let af = a.as_f64(); + let bf = b.as_f64(); + let ord = af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal); + return Ok(json!(predicate(ord))); + } + + // String comparison + if let (Some(a), Some(b)) = (left.as_str(), right.as_str()) { + return Ok(json!(predicate(a.cmp(b)))); + } + + // List comparison (lexicographic) + if let (Some(a), Some(b)) = (left.as_array(), right.as_array()) { + let ord = compare_arrays(a, b)?; + return Ok(json!(predicate(ord))); + } + + Err(EvalError::TypeError(format!( + "Cannot compare {} and {} with ordering operators", + type_name(left), + type_name(right) + ))) +} + +fn compare_arrays(a: &[JsonValue], b: &[JsonValue]) -> EvalResult { + for (x, y) in a.iter().zip(b.iter()) { + if let (Some(xn), Some(yn)) = (as_numeric(x), as_numeric(y)) { + let ord = xn + .as_f64() + .partial_cmp(&yn.as_f64()) + .unwrap_or(std::cmp::Ordering::Equal); + if ord != std::cmp::Ordering::Equal { + return Ok(ord); + } + } else if let (Some(xs), Some(ys)) = (x.as_str(), y.as_str()) { + let ord = xs.cmp(ys); + if ord != std::cmp::Ordering::Equal { + return Ok(ord); + } + } else { + return Err(EvalError::TypeError( + "Cannot compare heterogeneous array elements for ordering".to_string(), + )); + } + } + Ok(a.len().cmp(&b.len())) +} + +fn eval_in(needle: &JsonValue, haystack: &JsonValue) -> EvalResult { + match haystack { + JsonValue::Array(arr) => Ok(json!(arr.iter().any(|item| json_eq(needle, item)))), + JsonValue::Object(obj) => { + if let Some(key) = needle.as_str() { + Ok(json!(obj.contains_key(key))) + } else { + Err(EvalError::TypeError( + "Only string keys can be tested for membership in objects".to_string(), + )) + } + } + JsonValue::String(s) => { + if let Some(sub) = needle.as_str() { + Ok(json!(s.contains(sub))) + } else { + Err(EvalError::TypeError( + "Only strings can be tested for substring membership".to_string(), + )) + } + } + _ => Err(EvalError::TypeError(format!( + "'in' requires array, object, or string on right side, got {}", + type_name(haystack) + ))), + } +} + +// --------------------------------------------------------------- +// Unary operations +// --------------------------------------------------------------- + +fn eval_unary_op(op: UnaryOp, val: &JsonValue) -> EvalResult { + match op { + UnaryOp::Neg => { + if let Some(n) = as_numeric(val) { + match n { + NumericValue::Int(i) => Ok(json!(-i)), + NumericValue::Float(f) => Ok(json!(-f)), + } + } else { + Err(EvalError::TypeError(format!( + "Cannot negate {}", + type_name(val) + ))) + } + } + UnaryOp::Not => Ok(json!(!is_truthy(val))), + } +} + +// --------------------------------------------------------------- +// Property / index access +// --------------------------------------------------------------- + +fn dot_access(obj: &JsonValue, field: &str) -> EvalResult { + match obj { + JsonValue::Object(map) => map + .get(field) + .cloned() + .ok_or_else(|| EvalError::VariableNotFound(format!("field '{}'", field))), + _ => Err(EvalError::TypeError(format!( + "Cannot access property '{}' on {}", + field, + type_name(obj) + ))), + } +} + +fn index_access(obj: &JsonValue, index: &JsonValue) -> EvalResult { + match obj { + JsonValue::Array(arr) => { + if let Some(i) = index.as_i64() { + let i = if i < 0 { + // Negative indexing + (arr.len() as i64 + i) as usize + } else { + i as usize + }; + arr.get(i) + .cloned() + .ok_or_else(|| EvalError::IndexOutOfBounds(format!("{}", i))) + } else { + Err(EvalError::TypeError( + "Array index must be an integer".to_string(), + )) + } + } + JsonValue::Object(map) => { + if let Some(key) = index.as_str() { + map.get(key) + .cloned() + .ok_or_else(|| EvalError::VariableNotFound(format!("key '{}'", key))) + } else { + Err(EvalError::TypeError( + "Object key must be a string".to_string(), + )) + } + } + JsonValue::String(s) => { + if let Some(i) = index.as_i64() { + let chars: Vec = s.chars().collect(); + let i = if i < 0 { + (chars.len() as i64 + i) as usize + } else { + i as usize + }; + chars + .get(i) + .map(|c| json!(c.to_string())) + .ok_or_else(|| EvalError::IndexOutOfBounds(format!("{}", i))) + } else { + Err(EvalError::TypeError( + "String index must be an integer".to_string(), + )) + } + } + _ => Err(EvalError::TypeError(format!( + "Cannot index into {}", + type_name(obj) + ))), + } +} + +// --------------------------------------------------------------- +// Built-in functions +// --------------------------------------------------------------- + +fn eval_builtin_function(name: &str, args: &[JsonValue]) -> EvalResult { + match name { + // -- Type conversion -- + "string" => { + expect_args(name, args, 1)?; + Ok(json!(value_to_string(&args[0]))) + } + "number" => { + expect_args(name, args, 1)?; + to_number(&args[0]) + } + "int" => { + expect_args(name, args, 1)?; + to_int(&args[0]) + } + "bool" => { + expect_args(name, args, 1)?; + Ok(json!(is_truthy(&args[0]))) + } + + // -- Introspection -- + "type_of" => { + expect_args(name, args, 1)?; + Ok(json!(type_name(&args[0]))) + } + "length" => { + expect_args(name, args, 1)?; + fn_length(&args[0]) + } + "keys" => { + expect_args(name, args, 1)?; + fn_keys(&args[0]) + } + "values" => { + expect_args(name, args, 1)?; + fn_values(&args[0]) + } + + // -- Math -- + "abs" => { + expect_args(name, args, 1)?; + fn_abs(&args[0]) + } + "floor" => { + expect_args(name, args, 1)?; + fn_floor(&args[0]) + } + "ceil" => { + expect_args(name, args, 1)?; + fn_ceil(&args[0]) + } + "round" => { + expect_args(name, args, 1)?; + fn_round(&args[0]) + } + "min" => { + expect_args(name, args, 2)?; + fn_min(&args[0], &args[1]) + } + "max" => { + expect_args(name, args, 2)?; + fn_max(&args[0], &args[1]) + } + "sum" => { + expect_args(name, args, 1)?; + fn_sum(&args[0]) + } + + // -- String -- + "lower" => { + expect_args(name, args, 1)?; + fn_lower(&args[0]) + } + "upper" => { + expect_args(name, args, 1)?; + fn_upper(&args[0]) + } + "trim" => { + expect_args(name, args, 1)?; + fn_trim(&args[0]) + } + "split" => { + expect_args(name, args, 2)?; + fn_split(&args[0], &args[1]) + } + "join" => { + expect_args(name, args, 2)?; + fn_join(&args[0], &args[1]) + } + "replace" => { + expect_args(name, args, 3)?; + fn_replace(&args[0], &args[1], &args[2]) + } + "starts_with" => { + expect_args(name, args, 2)?; + fn_starts_with(&args[0], &args[1]) + } + "ends_with" => { + expect_args(name, args, 2)?; + fn_ends_with(&args[0], &args[1]) + } + "match" => { + expect_args(name, args, 2)?; + fn_match(&args[0], &args[1]) + } + + // -- Collections -- + "contains" => { + expect_args(name, args, 2)?; + eval_in(&args[1], &args[0]) + } + "reversed" => { + expect_args(name, args, 1)?; + fn_reversed(&args[0]) + } + "sort" => { + expect_args(name, args, 1)?; + fn_sort(&args[0]) + } + "unique" => { + expect_args(name, args, 1)?; + fn_unique(&args[0]) + } + "flat" => { + expect_args(name, args, 1)?; + fn_flat(&args[0]) + } + "zip" => { + expect_args(name, args, 2)?; + fn_zip(&args[0], &args[1]) + } + "range" => { + if args.len() == 1 { + fn_range_1(&args[0]) + } else if args.len() == 2 { + fn_range_2(&args[0], &args[1]) + } else { + Err(EvalError::WrongArgCount( + name.to_string(), + "1 or 2".to_string(), + args.len(), + )) + } + } + "slice" => { + if args.len() == 2 { + fn_slice(&args[0], &args[1], &JsonValue::Null) + } else if args.len() == 3 { + fn_slice(&args[0], &args[1], &args[2]) + } else { + Err(EvalError::WrongArgCount( + name.to_string(), + "2 or 3".to_string(), + args.len(), + )) + } + } + "index_of" => { + expect_args(name, args, 2)?; + fn_index_of(&args[0], &args[1]) + } + "count" => { + expect_args(name, args, 2)?; + fn_count(&args[0], &args[1]) + } + "merge" => { + expect_args(name, args, 2)?; + fn_merge(&args[0], &args[1]) + } + "chunks" => { + expect_args(name, args, 2)?; + fn_chunks(&args[0], &args[1]) + } + + _ => Err(EvalError::UnknownFunction(name.to_string())), + } +} + +fn expect_args(name: &str, args: &[JsonValue], expected: usize) -> EvalResult<()> { + if args.len() != expected { + Err(EvalError::WrongArgCount( + name.to_string(), + expected.to_string(), + args.len(), + )) + } else { + Ok(()) + } +} + +// --------------------------------------------------------------- +// Numeric helpers +// --------------------------------------------------------------- + +#[derive(Debug, Clone, Copy)] +enum NumericValue { + Int(i64), + Float(f64), +} + +impl NumericValue { + fn as_f64(self) -> f64 { + match self { + NumericValue::Int(i) => i as f64, + NumericValue::Float(f) => f, + } + } +} + +fn as_numeric(v: &JsonValue) -> Option { + if let Some(i) = v.as_i64() { + Some(NumericValue::Int(i)) + } else if let Some(f) = v.as_f64() { + Some(NumericValue::Float(f)) + } else { + None + } +} + +fn type_name(v: &JsonValue) -> &'static str { + match v { + JsonValue::Null => "null", + JsonValue::Bool(_) => "bool", + JsonValue::Number(_) => "number", + JsonValue::String(_) => "string", + JsonValue::Array(_) => "array", + JsonValue::Object(_) => "object", + } +} + +fn value_to_string(v: &JsonValue) -> String { + match v { + JsonValue::String(s) => s.clone(), + JsonValue::Null => "null".to_string(), + JsonValue::Bool(b) => b.to_string(), + JsonValue::Number(n) => n.to_string(), + other => serde_json::to_string(other).unwrap_or_default(), + } +} + +// --------------------------------------------------------------- +// Type conversion functions +// --------------------------------------------------------------- + +fn to_number(v: &JsonValue) -> EvalResult { + match v { + JsonValue::Number(_) => Ok(v.clone()), + JsonValue::String(s) => { + if let Ok(f) = s.parse::() { + Ok(json!(f)) + } else { + Err(EvalError::TypeError(format!( + "Cannot convert string '{}' to number", + s + ))) + } + } + JsonValue::Bool(b) => Ok(json!(if *b { 1.0 } else { 0.0 })), + _ => Err(EvalError::TypeError(format!( + "Cannot convert {} to number", + type_name(v) + ))), + } +} + +fn to_int(v: &JsonValue) -> EvalResult { + match v { + JsonValue::Number(n) => { + if let Some(i) = n.as_i64() { + Ok(json!(i)) + } else if let Some(f) = n.as_f64() { + Ok(json!(f as i64)) + } else { + Err(EvalError::TypeError("Cannot convert number to int".to_string())) + } + } + JsonValue::String(s) => { + // Try integer first, then float truncation + if let Ok(i) = s.parse::() { + Ok(json!(i)) + } else if let Ok(f) = s.parse::() { + Ok(json!(f as i64)) + } else { + Err(EvalError::TypeError(format!( + "Cannot convert string '{}' to int", + s + ))) + } + } + JsonValue::Bool(b) => Ok(json!(if *b { 1 } else { 0 })), + _ => Err(EvalError::TypeError(format!( + "Cannot convert {} to int", + type_name(v) + ))), + } +} + +// --------------------------------------------------------------- +// Introspection functions +// --------------------------------------------------------------- + +fn fn_length(v: &JsonValue) -> EvalResult { + match v { + JsonValue::String(s) => Ok(json!(s.len())), + JsonValue::Array(a) => Ok(json!(a.len())), + JsonValue::Object(o) => Ok(json!(o.len())), + _ => Err(EvalError::TypeError(format!( + "length() requires string, array, or object, got {}", + type_name(v) + ))), + } +} + +fn fn_keys(v: &JsonValue) -> EvalResult { + match v { + JsonValue::Object(obj) => { + let keys: Vec = obj.keys().map(|k| json!(k)).collect(); + Ok(JsonValue::Array(keys)) + } + _ => Err(EvalError::TypeError(format!( + "keys() requires object, got {}", + type_name(v) + ))), + } +} + +fn fn_values(v: &JsonValue) -> EvalResult { + match v { + JsonValue::Object(obj) => { + let values: Vec = obj.values().cloned().collect(); + Ok(JsonValue::Array(values)) + } + _ => Err(EvalError::TypeError(format!( + "values() requires object, got {}", + type_name(v) + ))), + } +} + +// --------------------------------------------------------------- +// Math functions +// --------------------------------------------------------------- + +fn fn_abs(v: &JsonValue) -> EvalResult { + match as_numeric(v) { + Some(NumericValue::Int(i)) => Ok(json!(i.abs())), + Some(NumericValue::Float(f)) => Ok(json!(f.abs())), + None => Err(EvalError::TypeError(format!( + "abs() requires number, got {}", + type_name(v) + ))), + } +} + +fn fn_floor(v: &JsonValue) -> EvalResult { + match as_numeric(v) { + Some(NumericValue::Int(i)) => Ok(json!(i)), + Some(NumericValue::Float(f)) => Ok(json!(f.floor() as i64)), + None => Err(EvalError::TypeError(format!( + "floor() requires number, got {}", + type_name(v) + ))), + } +} + +fn fn_ceil(v: &JsonValue) -> EvalResult { + match as_numeric(v) { + Some(NumericValue::Int(i)) => Ok(json!(i)), + Some(NumericValue::Float(f)) => Ok(json!(f.ceil() as i64)), + None => Err(EvalError::TypeError(format!( + "ceil() requires number, got {}", + type_name(v) + ))), + } +} + +fn fn_round(v: &JsonValue) -> EvalResult { + match as_numeric(v) { + Some(NumericValue::Int(i)) => Ok(json!(i)), + Some(NumericValue::Float(f)) => Ok(json!(f.round() as i64)), + None => Err(EvalError::TypeError(format!( + "round() requires number, got {}", + type_name(v) + ))), + } +} + +fn fn_min(a: &JsonValue, b: &JsonValue) -> EvalResult { + match (as_numeric(a), as_numeric(b)) { + (Some(NumericValue::Int(x)), Some(NumericValue::Int(y))) => Ok(json!(x.min(y))), + (Some(x), Some(y)) => Ok(json!(x.as_f64().min(y.as_f64()))), + _ => { + // String min + if let (Some(sa), Some(sb)) = (a.as_str(), b.as_str()) { + Ok(json!(sa.min(sb))) + } else { + Err(EvalError::TypeError( + "min() requires two numbers or two strings".to_string(), + )) + } + } + } +} + +fn fn_max(a: &JsonValue, b: &JsonValue) -> EvalResult { + match (as_numeric(a), as_numeric(b)) { + (Some(NumericValue::Int(x)), Some(NumericValue::Int(y))) => Ok(json!(x.max(y))), + (Some(x), Some(y)) => Ok(json!(x.as_f64().max(y.as_f64()))), + _ => { + if let (Some(sa), Some(sb)) = (a.as_str(), b.as_str()) { + Ok(json!(sa.max(sb))) + } else { + Err(EvalError::TypeError( + "max() requires two numbers or two strings".to_string(), + )) + } + } + } +} + +fn fn_sum(v: &JsonValue) -> EvalResult { + match v { + JsonValue::Array(arr) => { + let mut has_float = false; + let mut int_sum: i64 = 0; + let mut float_sum: f64 = 0.0; + + for item in arr { + match as_numeric(item) { + Some(NumericValue::Int(i)) => { + int_sum += i; + float_sum += i as f64; + } + Some(NumericValue::Float(f)) => { + has_float = true; + float_sum += f; + } + None => { + return Err(EvalError::TypeError(format!( + "sum() requires array of numbers, found {}", + type_name(item) + ))); + } + } + } + + if has_float { + Ok(json!(float_sum)) + } else { + Ok(json!(int_sum)) + } + } + _ => Err(EvalError::TypeError(format!( + "sum() requires array, got {}", + type_name(v) + ))), + } +} + +// --------------------------------------------------------------- +// String functions +// --------------------------------------------------------------- + +fn fn_lower(v: &JsonValue) -> EvalResult { + require_string("lower", v).map(|s| json!(s.to_lowercase())) +} + +fn fn_upper(v: &JsonValue) -> EvalResult { + require_string("upper", v).map(|s| json!(s.to_uppercase())) +} + +fn fn_trim(v: &JsonValue) -> EvalResult { + require_string("trim", v).map(|s| json!(s.trim())) +} + +fn fn_split(s: &JsonValue, sep: &JsonValue) -> EvalResult { + let s = require_string("split", s)?; + let sep = require_string("split", sep)?; + let parts: Vec = s.split(sep).map(|p| json!(p)).collect(); + Ok(JsonValue::Array(parts)) +} + +fn fn_join(arr: &JsonValue, sep: &JsonValue) -> EvalResult { + let arr = arr.as_array().ok_or_else(|| { + EvalError::TypeError(format!( + "join() first argument must be array, got {}", + type_name(arr) + )) + })?; + let sep = require_string("join", sep)?; + let strings: Result, _> = arr.iter().map(|v| { + Ok(value_to_string(v)) + }).collect(); + Ok(json!(strings?.join(sep))) +} + +fn fn_replace(s: &JsonValue, old: &JsonValue, new: &JsonValue) -> EvalResult { + let s = require_string("replace", s)?; + let old = require_string("replace", old)?; + let new_s = require_string("replace", new)?; + Ok(json!(s.replace(old, new_s))) +} + +fn fn_starts_with(s: &JsonValue, prefix: &JsonValue) -> EvalResult { + let s = require_string("starts_with", s)?; + let prefix = require_string("starts_with", prefix)?; + Ok(json!(s.starts_with(prefix))) +} + +fn fn_ends_with(s: &JsonValue, suffix: &JsonValue) -> EvalResult { + let s = require_string("ends_with", s)?; + let suffix = require_string("ends_with", suffix)?; + Ok(json!(s.ends_with(suffix))) +} + +fn fn_match(pattern: &JsonValue, s: &JsonValue) -> EvalResult { + let pattern = require_string("match", pattern)?; + let s = require_string("match", s)?; + let re = Regex::new(pattern) + .map_err(|e| EvalError::RegexError(format!("{}", e)))?; + Ok(json!(re.is_match(s))) +} + +fn require_string<'a>(func: &str, v: &'a JsonValue) -> EvalResult<&'a str> { + v.as_str().ok_or_else(|| { + EvalError::TypeError(format!( + "{}() requires string argument, got {}", + func, + type_name(v) + )) + }) +} + +// --------------------------------------------------------------- +// Collection functions +// --------------------------------------------------------------- + +fn fn_reversed(v: &JsonValue) -> EvalResult { + match v { + JsonValue::Array(arr) => { + let mut rev = arr.clone(); + rev.reverse(); + Ok(JsonValue::Array(rev)) + } + JsonValue::String(s) => { + Ok(json!(s.chars().rev().collect::())) + } + _ => Err(EvalError::TypeError(format!( + "reversed() requires array or string, got {}", + type_name(v) + ))), + } +} + +fn fn_sort(v: &JsonValue) -> EvalResult { + let arr = v.as_array().ok_or_else(|| { + EvalError::TypeError(format!("sort() requires array, got {}", type_name(v))) + })?; + + let mut sorted = arr.clone(); + // Sort stably; numbers first, then strings + let mut err: Option = None; + sorted.sort_by(|a, b| { + if err.is_some() { + return std::cmp::Ordering::Equal; + } + match (as_numeric(a), as_numeric(b)) { + (Some(x), Some(y)) => x + .as_f64() + .partial_cmp(&y.as_f64()) + .unwrap_or(std::cmp::Ordering::Equal), + _ => { + if let (Some(sa), Some(sb)) = (a.as_str(), b.as_str()) { + sa.cmp(sb) + } else { + err = Some(EvalError::TypeError( + "sort() requires array of numbers or strings".to_string(), + )); + std::cmp::Ordering::Equal + } + } + } + }); + + if let Some(e) = err { + return Err(e); + } + + Ok(JsonValue::Array(sorted)) +} + +fn fn_unique(v: &JsonValue) -> EvalResult { + let arr = v.as_array().ok_or_else(|| { + EvalError::TypeError(format!("unique() requires array, got {}", type_name(v))) + })?; + + let mut seen = Vec::new(); + let mut result = Vec::new(); + for item in arr { + if !seen.iter().any(|s| json_eq(s, item)) { + seen.push(item.clone()); + result.push(item.clone()); + } + } + + Ok(JsonValue::Array(result)) +} + +fn fn_flat(v: &JsonValue) -> EvalResult { + let arr = v.as_array().ok_or_else(|| { + EvalError::TypeError(format!("flat() requires array, got {}", type_name(v))) + })?; + + let mut result = Vec::new(); + for item in arr { + if let JsonValue::Array(inner) = item { + result.extend(inner.iter().cloned()); + } else { + result.push(item.clone()); + } + } + + Ok(JsonValue::Array(result)) +} + +fn fn_zip(a: &JsonValue, b: &JsonValue) -> EvalResult { + let a_arr = a.as_array().ok_or_else(|| { + EvalError::TypeError(format!("zip() first argument must be array, got {}", type_name(a))) + })?; + let b_arr = b.as_array().ok_or_else(|| { + EvalError::TypeError(format!( + "zip() second argument must be array, got {}", + type_name(b) + )) + })?; + + let pairs: Vec = a_arr + .iter() + .zip(b_arr.iter()) + .map(|(x, y)| json!([x, y])) + .collect(); + + Ok(JsonValue::Array(pairs)) +} + +fn fn_range_1(end: &JsonValue) -> EvalResult { + let n = end.as_i64().ok_or_else(|| { + EvalError::TypeError("range() requires integer argument".to_string()) + })?; + let arr: Vec = (0..n).map(|i| json!(i)).collect(); + Ok(JsonValue::Array(arr)) +} + +fn fn_range_2(start: &JsonValue, end: &JsonValue) -> EvalResult { + let s = start.as_i64().ok_or_else(|| { + EvalError::TypeError("range() requires integer arguments".to_string()) + })?; + let e = end.as_i64().ok_or_else(|| { + EvalError::TypeError("range() requires integer arguments".to_string()) + })?; + let arr: Vec = (s..e).map(|i| json!(i)).collect(); + Ok(JsonValue::Array(arr)) +} + +fn fn_slice(v: &JsonValue, start: &JsonValue, end: &JsonValue) -> EvalResult { + let s = start.as_i64().ok_or_else(|| { + EvalError::TypeError("slice() start must be integer".to_string()) + })? as usize; + + match v { + JsonValue::Array(arr) => { + let e = if end.is_null() { + arr.len() + } else { + end.as_i64() + .ok_or_else(|| EvalError::TypeError("slice() end must be integer".to_string()))? + as usize + }; + let e = e.min(arr.len()); + let s = s.min(e); + Ok(JsonValue::Array(arr[s..e].to_vec())) + } + JsonValue::String(str_val) => { + let chars: Vec = str_val.chars().collect(); + let e = if end.is_null() { + chars.len() + } else { + end.as_i64() + .ok_or_else(|| EvalError::TypeError("slice() end must be integer".to_string()))? + as usize + }; + let e = e.min(chars.len()); + let s = s.min(e); + Ok(json!(chars[s..e].iter().collect::())) + } + _ => Err(EvalError::TypeError(format!( + "slice() requires array or string, got {}", + type_name(v) + ))), + } +} + +fn fn_index_of(haystack: &JsonValue, needle: &JsonValue) -> EvalResult { + match haystack { + JsonValue::Array(arr) => { + for (i, item) in arr.iter().enumerate() { + if json_eq(item, needle) { + return Ok(json!(i as i64)); + } + } + Ok(json!(-1)) + } + JsonValue::String(s) => { + let needle = needle.as_str().ok_or_else(|| { + EvalError::TypeError("index_of() needle must be string for string search".to_string()) + })?; + match s.find(needle) { + Some(pos) => Ok(json!(pos as i64)), + None => Ok(json!(-1)), + } + } + _ => Err(EvalError::TypeError(format!( + "index_of() requires array or string, got {}", + type_name(haystack) + ))), + } +} + +fn fn_count(haystack: &JsonValue, needle: &JsonValue) -> EvalResult { + match haystack { + JsonValue::Array(arr) => { + let count = arr.iter().filter(|item| json_eq(item, needle)).count(); + Ok(json!(count as i64)) + } + JsonValue::String(s) => { + let needle = needle.as_str().ok_or_else(|| { + EvalError::TypeError("count() needle must be string for string search".to_string()) + })?; + Ok(json!(s.matches(needle).count() as i64)) + } + _ => Err(EvalError::TypeError(format!( + "count() requires array or string, got {}", + type_name(haystack) + ))), + } +} + +fn fn_merge(a: &JsonValue, b: &JsonValue) -> EvalResult { + match (a, b) { + (JsonValue::Object(obj_a), JsonValue::Object(obj_b)) => { + let mut result = obj_a.clone(); + for (k, v) in obj_b { + result.insert(k.clone(), v.clone()); + } + Ok(JsonValue::Object(result)) + } + _ => Err(EvalError::TypeError(format!( + "merge() requires two objects, got {} and {}", + type_name(a), + type_name(b) + ))), + } +} + +fn fn_chunks(v: &JsonValue, size: &JsonValue) -> EvalResult { + let n = size.as_i64().ok_or_else(|| { + EvalError::TypeError("chunks() size must be a positive integer".to_string()) + })?; + if n <= 0 { + return Err(EvalError::TypeError( + "chunks() size must be a positive integer".to_string(), + )); + } + let n = n as usize; + + match v { + JsonValue::Array(arr) => { + let chunks: Vec = arr + .chunks(n) + .map(|c| JsonValue::Array(c.to_vec())) + .collect(); + Ok(JsonValue::Array(chunks)) + } + _ => Err(EvalError::TypeError(format!( + "chunks() requires array, got {}", + type_name(v) + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_is_truthy() { + assert!(!is_truthy(&json!(null))); + assert!(!is_truthy(&json!(false))); + assert!(!is_truthy(&json!(0))); + assert!(!is_truthy(&json!(""))); + assert!(!is_truthy(&json!([]))); + assert!(!is_truthy(&json!({}))); + + assert!(is_truthy(&json!(true))); + assert!(is_truthy(&json!(1))); + assert!(is_truthy(&json!("a"))); + assert!(is_truthy(&json!([1]))); + assert!(is_truthy(&json!({"a": 1}))); + } + + #[test] + fn test_json_eq_cross_numeric() { + assert!(json_eq(&json!(3), &json!(3.0))); + assert!(json_eq(&json!(3.0), &json!(3))); + assert!(!json_eq(&json!(3), &json!(3.1))); + } + + #[test] + fn test_json_eq_recursive() { + assert!(json_eq( + &json!({"a": [1, 2], "b": {"c": 3}}), + &json!({"b": {"c": 3}, "a": [1, 2]}) + )); + assert!(!json_eq( + &json!({"a": [1, 2]}), + &json!({"a": [1, 3]}) + )); + } + + #[test] + fn test_negative_indexing() { + let arr = json!([10, 20, 30]); + assert_eq!(index_access(&arr, &json!(-1)).unwrap(), json!(30)); + assert_eq!(index_access(&arr, &json!(-2)).unwrap(), json!(20)); + } + + #[test] + fn test_integer_division() { + // 10 / 5 = 2 (integer) + assert_eq!(eval_div(&json!(10), &json!(5)).unwrap(), json!(2)); + // 10 / 3 = 3.333... (float because not evenly divisible) + let result = eval_div(&json!(10), &json!(3)).unwrap(); + assert!(result.is_f64()); + } +} diff --git a/crates/common/src/workflow/expression/mod.rs b/crates/common/src/workflow/expression/mod.rs new file mode 100644 index 0000000..dc7af12 --- /dev/null +++ b/crates/common/src/workflow/expression/mod.rs @@ -0,0 +1,545 @@ +//! # Workflow Expression Engine +//! +//! A complete expression evaluator for workflow templates, supporting arithmetic, +//! comparison, boolean logic, member access, and built-in functions over JSON values. +//! +//! ## Architecture +//! +//! The engine is structured as a classic three-phase interpreter: +//! +//! 1. **Lexer** (`tokenizer.rs`) — converts expression strings into a stream of tokens +//! 2. **Parser** (`parser.rs`) — builds an AST from tokens using recursive descent +//! 3. **Evaluator** (`evaluator.rs`) — walks the AST and produces a `JsonValue` result +//! +//! ## Supported Operators +//! +//! ### Arithmetic +//! - `+` (addition for numbers, concatenation for strings) +//! - `-` (subtraction, unary negation) +//! - `*`, `/`, `%` (multiplication, division, modulo) +//! +//! ### Comparison +//! - `==`, `!=` (equality — works on all types, recursive for objects/arrays) +//! - `>`, `<`, `>=`, `<=` (ordering — numbers and strings only) +//! - Float/int comparisons allowed: `3 == 3.0` → true +//! +//! ### Boolean / Logical +//! - `and`, `or`, `not` +//! +//! ### Membership & Access +//! - `.` — object property access +//! - `[n]` — array index / object bracket access +//! - `in` — membership test (item in list, key in object, substring in string) +//! +//! ## Built-in Functions +//! +//! ### Type conversion +//! - `string(v)`, `number(v)`, `int(v)`, `bool(v)` +//! +//! ### Introspection +//! - `type_of(v)`, `length(v)`, `keys(obj)`, `values(obj)` +//! +//! ### Math +//! - `abs(n)`, `floor(n)`, `ceil(n)`, `round(n)`, `min(a,b)`, `max(a,b)`, `sum(arr)` +//! +//! ### String +//! - `lower(s)`, `upper(s)`, `trim(s)`, `split(s, sep)`, `join(arr, sep)` +//! - `replace(s, old, new)`, `starts_with(s, prefix)`, `ends_with(s, suffix)` +//! - `match(pattern, s)` — regex match +//! +//! ### Collection +//! - `contains(haystack, needle)`, `reversed(v)`, `sort(arr)`, `unique(arr)` +//! - `flat(arr)`, `zip(a, b)`, `range(n)` / `range(start, end)` +//! +//! ### Workflow-specific +//! - `result()`, `succeeded()`, `failed()`, `timed_out()` + +mod ast; +mod evaluator; +mod parser; +mod tokenizer; + +pub use ast::{BinaryOp, Expr, UnaryOp}; +pub use evaluator::{is_truthy, EvalContext, EvalError, EvalResult}; +pub use parser::{ParseError, Parser}; +pub use tokenizer::{Token, TokenKind, Tokenizer}; + +use serde_json::Value as JsonValue; + +/// Parse and evaluate an expression string against the given context. +/// +/// This is the main entry point for the expression engine. It tokenizes the +/// input, parses it into an AST, and evaluates it to produce a `JsonValue`. +pub fn eval_expression(input: &str, ctx: &dyn EvalContext) -> EvalResult { + let tokens = Tokenizer::new(input).tokenize().map_err(|e| { + EvalError::ParseError(format!("{}", e)) + })?; + let ast = Parser::new(&tokens).parse().map_err(|e| { + EvalError::ParseError(format!("{}", e)) + })?; + evaluator::eval(&ast, ctx) +} + +/// Parse an expression string into an AST without evaluating it. +/// +/// Useful for validation or inspection. +pub fn parse_expression(input: &str) -> Result { + let tokens = Tokenizer::new(input).tokenize().map_err(|e| { + ParseError::TokenError(format!("{}", e)) + })?; + Parser::new(&tokens).parse() +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::collections::HashMap; + + /// A minimal eval context for integration tests. + struct TestContext { + variables: HashMap, + } + + impl TestContext { + fn new() -> Self { + Self { + variables: HashMap::new(), + } + } + + fn with_var(mut self, name: &str, value: JsonValue) -> Self { + self.variables.insert(name.to_string(), value); + self + } + } + + impl EvalContext for TestContext { + fn resolve_variable(&self, name: &str) -> EvalResult { + self.variables + .get(name) + .cloned() + .ok_or_else(|| EvalError::VariableNotFound(name.to_string())) + } + + fn call_workflow_function( + &self, + _name: &str, + _args: &[JsonValue], + ) -> EvalResult> { + Ok(None) + } + } + + // --------------------------------------------------------------- + // Arithmetic + // --------------------------------------------------------------- + + #[test] + fn test_integer_arithmetic() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("2 + 3", &ctx).unwrap(), json!(5)); + assert_eq!(eval_expression("10 - 4", &ctx).unwrap(), json!(6)); + assert_eq!(eval_expression("3 * 7", &ctx).unwrap(), json!(21)); + assert_eq!(eval_expression("15 / 5", &ctx).unwrap(), json!(3)); + assert_eq!(eval_expression("17 % 5", &ctx).unwrap(), json!(2)); + } + + #[test] + fn test_float_arithmetic() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("2.5 + 1.5", &ctx).unwrap(), json!(4.0)); + assert_eq!(eval_expression("10.0 / 3.0", &ctx).unwrap(), json!(10.0 / 3.0)); + } + + #[test] + fn test_mixed_int_float() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("2 + 1.5", &ctx).unwrap(), json!(3.5)); + // Integer division yields float when not evenly divisible + assert_eq!(eval_expression("10 / 4", &ctx).unwrap(), json!(2.5)); + assert_eq!(eval_expression("10 / 4.0", &ctx).unwrap(), json!(2.5)); + // Evenly divisible integer division stays integer + assert_eq!(eval_expression("10 / 5", &ctx).unwrap(), json!(2)); + } + + #[test] + fn test_operator_precedence() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("2 + 3 * 4", &ctx).unwrap(), json!(14)); + assert_eq!(eval_expression("(2 + 3) * 4", &ctx).unwrap(), json!(20)); + assert_eq!(eval_expression("10 - 2 * 3 + 1", &ctx).unwrap(), json!(5)); + } + + #[test] + fn test_unary_negation() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("-5", &ctx).unwrap(), json!(-5)); + assert_eq!(eval_expression("-2 + 3", &ctx).unwrap(), json!(1)); + assert_eq!(eval_expression("-(2 + 3)", &ctx).unwrap(), json!(-5)); + } + + #[test] + fn test_string_concatenation() { + let ctx = TestContext::new(); + assert_eq!( + eval_expression("\"hello\" + \" \" + \"world\"", &ctx).unwrap(), + json!("hello world") + ); + } + + // --------------------------------------------------------------- + // Comparison + // --------------------------------------------------------------- + + #[test] + fn test_number_comparison() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("3 == 3", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("3 != 4", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("3 > 2", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("3 < 2", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("3 >= 3", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("3 <= 4", &ctx).unwrap(), json!(true)); + } + + #[test] + fn test_int_float_equality() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("3 == 3.0", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("3.0 == 3", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("3 != 3.1", &ctx).unwrap(), json!(true)); + } + + #[test] + fn test_string_comparison() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("\"abc\" == \"abc\"", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("\"abc\" < \"abd\"", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("\"abc\" > \"abb\"", &ctx).unwrap(), json!(true)); + } + + #[test] + fn test_null_equality() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("null == null", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("null != null", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("null == 0", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("null == false", &ctx).unwrap(), json!(false)); + } + + #[test] + fn test_array_equality() { + let ctx = TestContext::new() + .with_var("a", json!([1, 2, 3])) + .with_var("b", json!([1, 2, 3])) + .with_var("c", json!([1, 2, 4])); + assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true)); + } + + #[test] + fn test_object_equality() { + let ctx = TestContext::new() + .with_var("a", json!({"x": 1, "y": 2})) + .with_var("b", json!({"y": 2, "x": 1})) + .with_var("c", json!({"x": 1, "y": 3})); + assert_eq!(eval_expression("a == b", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("a != c", &ctx).unwrap(), json!(true)); + } + + // --------------------------------------------------------------- + // Boolean / Logical + // --------------------------------------------------------------- + + #[test] + fn test_boolean_operators() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("true and true", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("true and false", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("false or true", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("false or false", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("not true", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("not false", &ctx).unwrap(), json!(true)); + } + + #[test] + fn test_boolean_precedence() { + let ctx = TestContext::new(); + // `and` binds tighter than `or` + assert_eq!( + eval_expression("true or false and false", &ctx).unwrap(), + json!(true) + ); + assert_eq!( + eval_expression("(true or false) and false", &ctx).unwrap(), + json!(false) + ); + } + + // --------------------------------------------------------------- + // Membership & access + // --------------------------------------------------------------- + + #[test] + fn test_dot_access() { + let ctx = TestContext::new() + .with_var("obj", json!({"a": {"b": 42}})); + assert_eq!(eval_expression("obj.a.b", &ctx).unwrap(), json!(42)); + } + + #[test] + fn test_bracket_access() { + let ctx = TestContext::new() + .with_var("arr", json!([10, 20, 30])) + .with_var("obj", json!({"key": "value"})); + assert_eq!(eval_expression("arr[1]", &ctx).unwrap(), json!(20)); + assert_eq!(eval_expression("obj[\"key\"]", &ctx).unwrap(), json!("value")); + } + + #[test] + fn test_in_operator() { + let ctx = TestContext::new() + .with_var("arr", json!([1, 2, 3])) + .with_var("obj", json!({"key": "val"})); + assert_eq!(eval_expression("2 in arr", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("5 in arr", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("\"key\" in obj", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("\"nope\" in obj", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("\"ell\" in \"hello\"", &ctx).unwrap(), json!(true)); + } + + // --------------------------------------------------------------- + // Built-in functions + // --------------------------------------------------------------- + + #[test] + fn test_length() { + let ctx = TestContext::new() + .with_var("arr", json!([1, 2, 3])) + .with_var("obj", json!({"a": 1, "b": 2})); + assert_eq!(eval_expression("length(arr)", &ctx).unwrap(), json!(3)); + assert_eq!(eval_expression("length(\"hello\")", &ctx).unwrap(), json!(5)); + assert_eq!(eval_expression("length(obj)", &ctx).unwrap(), json!(2)); + } + + #[test] + fn test_type_conversions() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("string(42)", &ctx).unwrap(), json!("42")); + assert_eq!(eval_expression("number(\"3.14\")", &ctx).unwrap(), json!(3.14)); + assert_eq!(eval_expression("int(3.9)", &ctx).unwrap(), json!(3)); + assert_eq!(eval_expression("int(\"42\")", &ctx).unwrap(), json!(42)); + assert_eq!(eval_expression("bool(1)", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("bool(0)", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("bool(\"\")", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("bool(\"x\")", &ctx).unwrap(), json!(true)); + } + + #[test] + fn test_type_of() { + let ctx = TestContext::new() + .with_var("arr", json!([1])) + .with_var("obj", json!({})); + assert_eq!(eval_expression("type_of(42)", &ctx).unwrap(), json!("number")); + assert_eq!(eval_expression("type_of(\"hi\")", &ctx).unwrap(), json!("string")); + assert_eq!(eval_expression("type_of(true)", &ctx).unwrap(), json!("bool")); + assert_eq!(eval_expression("type_of(null)", &ctx).unwrap(), json!("null")); + assert_eq!(eval_expression("type_of(arr)", &ctx).unwrap(), json!("array")); + assert_eq!(eval_expression("type_of(obj)", &ctx).unwrap(), json!("object")); + } + + #[test] + fn test_keys_values() { + let ctx = TestContext::new() + .with_var("obj", json!({"b": 2, "a": 1})); + let keys = eval_expression("sort(keys(obj))", &ctx).unwrap(); + assert_eq!(keys, json!(["a", "b"])); + let values = eval_expression("sort(values(obj))", &ctx).unwrap(); + assert_eq!(values, json!([1, 2])); + } + + #[test] + fn test_math_functions() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("abs(-5)", &ctx).unwrap(), json!(5)); + assert_eq!(eval_expression("floor(3.7)", &ctx).unwrap(), json!(3)); + assert_eq!(eval_expression("ceil(3.2)", &ctx).unwrap(), json!(4)); + assert_eq!(eval_expression("round(3.5)", &ctx).unwrap(), json!(4)); + assert_eq!(eval_expression("min(3, 7)", &ctx).unwrap(), json!(3)); + assert_eq!(eval_expression("max(3, 7)", &ctx).unwrap(), json!(7)); + assert_eq!(eval_expression("sum([1, 2, 3, 4])", &ctx).unwrap(), json!(10)); + } + + #[test] + fn test_string_functions() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("lower(\"HELLO\")", &ctx).unwrap(), json!("hello")); + assert_eq!(eval_expression("upper(\"hello\")", &ctx).unwrap(), json!("HELLO")); + assert_eq!(eval_expression("trim(\" hi \")", &ctx).unwrap(), json!("hi")); + assert_eq!( + eval_expression("replace(\"hello world\", \"world\", \"rust\")", &ctx).unwrap(), + json!("hello rust") + ); + assert_eq!( + eval_expression("starts_with(\"hello\", \"hel\")", &ctx).unwrap(), + json!(true) + ); + assert_eq!( + eval_expression("ends_with(\"hello\", \"llo\")", &ctx).unwrap(), + json!(true) + ); + assert_eq!( + eval_expression("split(\"a,b,c\", \",\")", &ctx).unwrap(), + json!(["a", "b", "c"]) + ); + assert_eq!( + eval_expression("join([\"a\", \"b\", \"c\"], \",\")", &ctx).unwrap(), + json!("a,b,c") + ); + } + + #[test] + fn test_regex_match() { + let ctx = TestContext::new(); + assert_eq!( + eval_expression("match(\"^hello\", \"hello world\")", &ctx).unwrap(), + json!(true) + ); + assert_eq!( + eval_expression("match(\"^world\", \"hello world\")", &ctx).unwrap(), + json!(false) + ); + } + + #[test] + fn test_collection_functions() { + let ctx = TestContext::new() + .with_var("arr", json!([3, 1, 2])); + assert_eq!(eval_expression("sort(arr)", &ctx).unwrap(), json!([1, 2, 3])); + assert_eq!(eval_expression("reversed(arr)", &ctx).unwrap(), json!([2, 1, 3])); + assert_eq!( + eval_expression("unique([1, 2, 2, 3, 1])", &ctx).unwrap(), + json!([1, 2, 3]) + ); + assert_eq!( + eval_expression("flat([[1, 2], [3, 4]])", &ctx).unwrap(), + json!([1, 2, 3, 4]) + ); + assert_eq!( + eval_expression("zip([1, 2], [\"a\", \"b\"])", &ctx).unwrap(), + json!([[1, "a"], [2, "b"]]) + ); + } + + #[test] + fn test_range() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("range(5)", &ctx).unwrap(), json!([0, 1, 2, 3, 4])); + assert_eq!(eval_expression("range(2, 5)", &ctx).unwrap(), json!([2, 3, 4])); + } + + #[test] + fn test_reversed_string() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("reversed(\"abc\")", &ctx).unwrap(), json!("cba")); + } + + #[test] + fn test_contains_function() { + let ctx = TestContext::new(); + assert_eq!( + eval_expression("contains([1, 2, 3], 2)", &ctx).unwrap(), + json!(true) + ); + assert_eq!( + eval_expression("contains(\"hello\", \"ell\")", &ctx).unwrap(), + json!(true) + ); + } + + // --------------------------------------------------------------- + // Complex expressions + // --------------------------------------------------------------- + + #[test] + fn test_complex_expression() { + let ctx = TestContext::new() + .with_var("items", json!([1, 2, 3, 4, 5])); + assert_eq!( + eval_expression("length(items) > 3 and 5 in items", &ctx).unwrap(), + json!(true) + ); + } + + #[test] + fn test_chained_access() { + let ctx = TestContext::new() + .with_var("data", json!({"users": [{"name": "Alice"}, {"name": "Bob"}]})); + assert_eq!( + eval_expression("data.users[1].name", &ctx).unwrap(), + json!("Bob") + ); + } + + #[test] + fn test_ternary_via_boolean() { + let ctx = TestContext::new() + .with_var("x", json!(10)); + // No ternary operator, but boolean expressions work for conditions + assert_eq!( + eval_expression("x > 5 and x < 20", &ctx).unwrap(), + json!(true) + ); + } + + #[test] + fn test_no_implicit_type_coercion() { + let ctx = TestContext::new(); + // String + number should error, not silently coerce + assert!(eval_expression("\"hello\" + 5", &ctx).is_err()); + // Comparing different types (other than int/float) should return false for == + assert_eq!(eval_expression("\"3\" == 3", &ctx).unwrap(), json!(false)); + } + + #[test] + fn test_division_by_zero() { + let ctx = TestContext::new(); + assert!(eval_expression("5 / 0", &ctx).is_err()); + assert!(eval_expression("5 % 0", &ctx).is_err()); + } + + #[test] + fn test_array_literal() { + let ctx = TestContext::new(); + assert_eq!( + eval_expression("[1, 2, 3]", &ctx).unwrap(), + json!([1, 2, 3]) + ); + assert_eq!( + eval_expression("[\"a\", \"b\"]", &ctx).unwrap(), + json!(["a", "b"]) + ); + } + + #[test] + fn test_nested_function_calls() { + let ctx = TestContext::new(); + assert_eq!( + eval_expression("length(split(\"a,b,c\", \",\"))", &ctx).unwrap(), + json!(3) + ); + assert_eq!( + eval_expression("join(sort([\"c\", \"a\", \"b\"]), \"-\")", &ctx).unwrap(), + json!("a-b-c") + ); + } + + #[test] + fn test_boolean_literals() { + let ctx = TestContext::new(); + assert_eq!(eval_expression("true", &ctx).unwrap(), json!(true)); + assert_eq!(eval_expression("false", &ctx).unwrap(), json!(false)); + assert_eq!(eval_expression("null", &ctx).unwrap(), json!(null)); + } +} diff --git a/crates/common/src/workflow/expression/parser.rs b/crates/common/src/workflow/expression/parser.rs new file mode 100644 index 0000000..1e9cf67 --- /dev/null +++ b/crates/common/src/workflow/expression/parser.rs @@ -0,0 +1,520 @@ +//! # Expression Parser +//! +//! Recursive-descent parser that transforms a token stream into an AST. +//! +//! ## Operator Precedence (lowest to highest) +//! +//! 1. `or` +//! 2. `and` +//! 3. `not` (unary) +//! 4. `==`, `!=`, `<`, `>`, `<=`, `>=`, `in` +//! 5. `+`, `-` (addition / subtraction) +//! 6. `*`, `/`, `%` +//! 7. Unary `-` +//! 8. Postfix: `.field`, `[index]`, `(args)` + +use super::ast::{BinaryOp, Expr, UnaryOp}; +use super::tokenizer::{Token, TokenKind}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ParseError { + #[error("Unexpected token {0} at position {1}")] + UnexpectedToken(String, usize), + + #[error("Expected {0}, found {1} at position {2}")] + Expected(String, String, usize), + + #[error("Unexpected end of expression")] + UnexpectedEof, + + #[error("Token error: {0}")] + TokenError(String), +} + +/// The parser state. +pub struct Parser<'a> { + tokens: &'a [Token], + pos: usize, +} + +impl<'a> Parser<'a> { + pub fn new(tokens: &'a [Token]) -> Self { + Self { tokens, pos: 0 } + } + + /// Parse the token stream into a single expression AST. + pub fn parse(&mut self) -> Result { + let expr = self.parse_or()?; + // We should be at EOF now + if !self.at_end() { + let tok = self.peek(); + return Err(ParseError::UnexpectedToken( + format!("{}", tok.kind), + tok.span.0, + )); + } + Ok(expr) + } + + // ----- Helpers ----- + + fn peek(&self) -> &Token { + &self.tokens[self.pos.min(self.tokens.len() - 1)] + } + + fn at_end(&self) -> bool { + self.peek().kind == TokenKind::Eof + } + + fn advance(&mut self) -> &Token { + let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)]; + if self.pos < self.tokens.len() { + self.pos += 1; + } + tok + } + + fn expect(&mut self, expected: &TokenKind) -> Result<&Token, ParseError> { + let tok = self.peek(); + if std::mem::discriminant(&tok.kind) == std::mem::discriminant(expected) { + Ok(self.advance()) + } else { + Err(ParseError::Expected( + format!("{}", expected), + format!("{}", tok.kind), + tok.span.0, + )) + } + } + + fn check(&self, kind: &TokenKind) -> bool { + std::mem::discriminant(&self.peek().kind) == std::mem::discriminant(kind) + } + + // ----- Grammar rules ----- + + // or_expr = and_expr ( "or" and_expr )* + fn parse_or(&mut self) -> Result { + let mut left = self.parse_and()?; + while self.peek().kind == TokenKind::Or { + self.advance(); + let right = self.parse_and()?; + left = Expr::BinaryOp { + op: BinaryOp::Or, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + // and_expr = not_expr ( "and" not_expr )* + fn parse_and(&mut self) -> Result { + let mut left = self.parse_not()?; + while self.peek().kind == TokenKind::And { + self.advance(); + let right = self.parse_not()?; + left = Expr::BinaryOp { + op: BinaryOp::And, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + // not_expr = "not" not_expr | comparison + fn parse_not(&mut self) -> Result { + if self.peek().kind == TokenKind::Not { + self.advance(); + let operand = self.parse_not()?; + return Ok(Expr::UnaryOp { + op: UnaryOp::Not, + operand: Box::new(operand), + }); + } + self.parse_comparison() + } + + // comparison = addition ( ("==" | "!=" | "<" | ">" | "<=" | ">=" | "in") addition )* + fn parse_comparison(&mut self) -> Result { + let mut left = self.parse_addition()?; + + loop { + let op = match self.peek().kind { + TokenKind::EqEq => BinaryOp::Eq, + TokenKind::BangEq => BinaryOp::Ne, + TokenKind::Lt => BinaryOp::Lt, + TokenKind::Gt => BinaryOp::Gt, + TokenKind::LtEq => BinaryOp::Le, + TokenKind::GtEq => BinaryOp::Ge, + TokenKind::In => BinaryOp::In, + _ => break, + }; + self.advance(); + let right = self.parse_addition()?; + left = Expr::BinaryOp { + op, + left: Box::new(left), + right: Box::new(right), + }; + } + + Ok(left) + } + + // addition = multiplication ( ("+" | "-") multiplication )* + fn parse_addition(&mut self) -> Result { + let mut left = self.parse_multiplication()?; + + loop { + let op = match self.peek().kind { + TokenKind::Plus => BinaryOp::Add, + TokenKind::Minus => BinaryOp::Sub, + _ => break, + }; + self.advance(); + let right = self.parse_multiplication()?; + left = Expr::BinaryOp { + op, + left: Box::new(left), + right: Box::new(right), + }; + } + + Ok(left) + } + + // multiplication = unary ( ("*" | "/" | "%") unary )* + fn parse_multiplication(&mut self) -> Result { + let mut left = self.parse_unary()?; + + loop { + let op = match self.peek().kind { + TokenKind::Star => BinaryOp::Mul, + TokenKind::Slash => BinaryOp::Div, + TokenKind::Percent => BinaryOp::Mod, + _ => break, + }; + self.advance(); + let right = self.parse_unary()?; + left = Expr::BinaryOp { + op, + left: Box::new(left), + right: Box::new(right), + }; + } + + Ok(left) + } + + // unary = "-" unary | postfix + fn parse_unary(&mut self) -> Result { + if self.peek().kind == TokenKind::Minus { + self.advance(); + let operand = self.parse_unary()?; + return Ok(Expr::UnaryOp { + op: UnaryOp::Neg, + operand: Box::new(operand), + }); + } + self.parse_postfix() + } + + // postfix = primary ( "." IDENT | "[" expr "]" | "(" args ")" )* + fn parse_postfix(&mut self) -> Result { + let mut expr = self.parse_primary()?; + + loop { + match self.peek().kind { + TokenKind::Dot => { + self.advance(); + // The field after dot + let tok = self.advance().clone(); + let field = match &tok.kind { + TokenKind::Ident(name) => name.clone(), + // Allow keywords as field names (e.g., obj.in, obj.and) + TokenKind::And => "and".to_string(), + TokenKind::Or => "or".to_string(), + TokenKind::Not => "not".to_string(), + TokenKind::In => "in".to_string(), + TokenKind::True => "true".to_string(), + TokenKind::False => "false".to_string(), + TokenKind::Null => "null".to_string(), + _ => { + return Err(ParseError::Expected( + "identifier".to_string(), + format!("{}", tok.kind), + tok.span.0, + )); + } + }; + expr = Expr::DotAccess { + object: Box::new(expr), + field, + }; + } + TokenKind::LBracket => { + self.advance(); + let index = self.parse_or()?; + self.expect(&TokenKind::RBracket)?; + expr = Expr::IndexAccess { + object: Box::new(expr), + index: Box::new(index), + }; + } + TokenKind::LParen => { + // Only if the expression so far is an identifier (function name) + // or a dot-access chain (method-like call). + // For now we handle Ident -> FunctionCall transformation. + if let Expr::Ident(name) = expr { + self.advance(); + let args = self.parse_args()?; + self.expect(&TokenKind::RParen)?; + expr = Expr::FunctionCall { name, args }; + } else { + break; + } + } + _ => break, + } + } + + Ok(expr) + } + + // args = ( expr ( "," expr )* )? + fn parse_args(&mut self) -> Result, ParseError> { + let mut args = Vec::new(); + if self.check(&TokenKind::RParen) { + return Ok(args); + } + args.push(self.parse_or()?); + while self.peek().kind == TokenKind::Comma { + self.advance(); + args.push(self.parse_or()?); + } + Ok(args) + } + + // primary = INTEGER | FLOAT | STRING | "true" | "false" | "null" + // | IDENT | "(" expr ")" | "[" elements "]" + fn parse_primary(&mut self) -> Result { + let tok = self.peek().clone(); + match &tok.kind { + TokenKind::Integer(n) => { + let n = *n; + self.advance(); + Ok(Expr::Literal(serde_json::json!(n))) + } + TokenKind::Float(f) => { + let f = *f; + self.advance(); + Ok(Expr::Literal(serde_json::json!(f))) + } + TokenKind::StringLit(s) => { + let s = s.clone(); + self.advance(); + Ok(Expr::Literal(serde_json::Value::String(s))) + } + TokenKind::True => { + self.advance(); + Ok(Expr::Literal(serde_json::json!(true))) + } + TokenKind::False => { + self.advance(); + Ok(Expr::Literal(serde_json::json!(false))) + } + TokenKind::Null => { + self.advance(); + Ok(Expr::Literal(serde_json::json!(null))) + } + TokenKind::Ident(name) => { + let name = name.clone(); + self.advance(); + Ok(Expr::Ident(name)) + } + TokenKind::LParen => { + self.advance(); + let expr = self.parse_or()?; + self.expect(&TokenKind::RParen)?; + Ok(expr) + } + TokenKind::LBracket => { + self.advance(); + let mut elements = Vec::new(); + if !self.check(&TokenKind::RBracket) { + elements.push(self.parse_or()?); + while self.peek().kind == TokenKind::Comma { + self.advance(); + // Allow trailing comma + if self.check(&TokenKind::RBracket) { + break; + } + elements.push(self.parse_or()?); + } + } + self.expect(&TokenKind::RBracket)?; + Ok(Expr::Array(elements)) + } + _ => Err(ParseError::UnexpectedToken( + format!("{}", tok.kind), + tok.span.0, + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use super::super::tokenizer::Tokenizer; + + fn parse(input: &str) -> Expr { + let tokens = Tokenizer::new(input).tokenize().unwrap(); + Parser::new(&tokens).parse().unwrap() + } + + #[test] + fn test_simple_add() { + let ast = parse("2 + 3"); + assert_eq!( + ast, + Expr::BinaryOp { + op: BinaryOp::Add, + left: Box::new(Expr::Literal(serde_json::json!(2))), + right: Box::new(Expr::Literal(serde_json::json!(3))), + } + ); + } + + #[test] + fn test_precedence() { + // 2 + 3 * 4 should parse as 2 + (3 * 4) + let ast = parse("2 + 3 * 4"); + match ast { + Expr::BinaryOp { + op: BinaryOp::Add, + right, + .. + } => { + assert!(matches!( + *right, + Expr::BinaryOp { + op: BinaryOp::Mul, + .. + } + )); + } + _ => panic!("Expected Add at top level"), + } + } + + #[test] + fn test_function_call() { + let ast = parse("length(arr)"); + assert_eq!( + ast, + Expr::FunctionCall { + name: "length".to_string(), + args: vec![Expr::Ident("arr".to_string())], + } + ); + } + + #[test] + fn test_dot_access() { + let ast = parse("obj.field.sub"); + assert_eq!( + ast, + Expr::DotAccess { + object: Box::new(Expr::DotAccess { + object: Box::new(Expr::Ident("obj".to_string())), + field: "field".to_string(), + }), + field: "sub".to_string(), + } + ); + } + + #[test] + fn test_array_literal() { + let ast = parse("[1, 2, 3]"); + assert_eq!( + ast, + Expr::Array(vec![ + Expr::Literal(serde_json::json!(1)), + Expr::Literal(serde_json::json!(2)), + Expr::Literal(serde_json::json!(3)), + ]) + ); + } + + #[test] + fn test_bracket_access() { + let ast = parse("arr[0]"); + assert_eq!( + ast, + Expr::IndexAccess { + object: Box::new(Expr::Ident("arr".to_string())), + index: Box::new(Expr::Literal(serde_json::json!(0))), + } + ); + } + + #[test] + fn test_not_operator() { + let ast = parse("not true"); + assert_eq!( + ast, + Expr::UnaryOp { + op: UnaryOp::Not, + operand: Box::new(Expr::Literal(serde_json::json!(true))), + } + ); + } + + #[test] + fn test_in_operator() { + let ast = parse("x in arr"); + assert_eq!( + ast, + Expr::BinaryOp { + op: BinaryOp::In, + left: Box::new(Expr::Ident("x".to_string())), + right: Box::new(Expr::Ident("arr".to_string())), + } + ); + } + + #[test] + fn test_complex_expression() { + // Should parse without error + let _ast = parse("length(items) > 3 and 5 in items"); + } + + #[test] + fn test_chained_access() { + // data.users[1].name + let _ast = parse("data.users[1].name"); + } + + #[test] + fn test_nested_function() { + let _ast = parse("length(split(\"a,b,c\", \",\"))"); + } + + #[test] + fn test_trailing_comma_in_array() { + let ast = parse("[1, 2, 3,]"); + assert_eq!( + ast, + Expr::Array(vec![ + Expr::Literal(serde_json::json!(1)), + Expr::Literal(serde_json::json!(2)), + Expr::Literal(serde_json::json!(3)), + ]) + ); + } +} diff --git a/crates/common/src/workflow/expression/tokenizer.rs b/crates/common/src/workflow/expression/tokenizer.rs new file mode 100644 index 0000000..7feda1f --- /dev/null +++ b/crates/common/src/workflow/expression/tokenizer.rs @@ -0,0 +1,512 @@ +//! # Expression Tokenizer (Lexer) +//! +//! Converts an expression string into a sequence of tokens. + +use std::fmt; +use thiserror::Error; + +/// A token produced by the lexer. +#[derive(Debug, Clone, PartialEq)] +pub struct Token { + pub kind: TokenKind, + pub span: (usize, usize), +} + +impl Token { + pub fn new(kind: TokenKind, start: usize, end: usize) -> Self { + Self { + kind, + span: (start, end), + } + } +} + +/// The kind of a token. +#[derive(Debug, Clone, PartialEq)] +pub enum TokenKind { + // Literals + Integer(i64), + Float(f64), + StringLit(String), + True, + False, + Null, + + // Identifier + Ident(String), + + // Keywords (also parsed as identifiers initially, then classified) + And, + Or, + Not, + In, + + // Operators + Plus, + Minus, + Star, + Slash, + Percent, + EqEq, + BangEq, + Lt, + Gt, + LtEq, + GtEq, + + // Delimiters + LParen, + RParen, + LBracket, + RBracket, + Comma, + Dot, + + // End of input + Eof, +} + +impl fmt::Display for TokenKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TokenKind::Integer(n) => write!(f, "{}", n), + TokenKind::Float(n) => write!(f, "{}", n), + TokenKind::StringLit(s) => write!(f, "\"{}\"", s), + TokenKind::True => write!(f, "true"), + TokenKind::False => write!(f, "false"), + TokenKind::Null => write!(f, "null"), + TokenKind::Ident(s) => write!(f, "{}", s), + TokenKind::And => write!(f, "and"), + TokenKind::Or => write!(f, "or"), + TokenKind::Not => write!(f, "not"), + TokenKind::In => write!(f, "in"), + TokenKind::Plus => write!(f, "+"), + TokenKind::Minus => write!(f, "-"), + TokenKind::Star => write!(f, "*"), + TokenKind::Slash => write!(f, "/"), + TokenKind::Percent => write!(f, "%"), + TokenKind::EqEq => write!(f, "=="), + TokenKind::BangEq => write!(f, "!="), + TokenKind::Lt => write!(f, "<"), + TokenKind::Gt => write!(f, ">"), + TokenKind::LtEq => write!(f, "<="), + TokenKind::GtEq => write!(f, ">="), + TokenKind::LParen => write!(f, "("), + TokenKind::RParen => write!(f, ")"), + TokenKind::LBracket => write!(f, "["), + TokenKind::RBracket => write!(f, "]"), + TokenKind::Comma => write!(f, ","), + TokenKind::Dot => write!(f, "."), + TokenKind::Eof => write!(f, "EOF"), + } + } +} + +#[derive(Debug, Error)] +pub enum TokenError { + #[error("Unexpected character '{0}' at position {1}")] + UnexpectedChar(char, usize), + + #[error("Unterminated string literal starting at position {0}")] + UnterminatedString(usize), + + #[error("Invalid number literal at position {0}: {1}")] + InvalidNumber(usize, String), +} + +/// The tokenizer / lexer. +pub struct Tokenizer { + chars: Vec, + pos: usize, +} + +impl Tokenizer { + pub fn new(input: &str) -> Self { + Self { + chars: input.chars().collect(), + pos: 0, + } + } + + /// Tokenize the entire input and return a vector of tokens. + pub fn tokenize(&mut self) -> Result, TokenError> { + let mut tokens = Vec::new(); + loop { + let tok = self.next_token()?; + if tok.kind == TokenKind::Eof { + tokens.push(tok); + break; + } + tokens.push(tok); + } + Ok(tokens) + } + + fn peek(&self) -> Option { + self.chars.get(self.pos).copied() + } + + fn advance(&mut self) -> Option { + let ch = self.chars.get(self.pos).copied(); + if ch.is_some() { + self.pos += 1; + } + ch + } + + fn skip_whitespace(&mut self) { + while let Some(ch) = self.peek() { + if ch.is_whitespace() { + self.advance(); + } else { + break; + } + } + } + + fn next_token(&mut self) -> Result { + self.skip_whitespace(); + + let start = self.pos; + + let ch = match self.peek() { + Some(ch) => ch, + None => return Ok(Token::new(TokenKind::Eof, start, start)), + }; + + // Single-char and multi-char operators/delimiters + match ch { + '+' => { + self.advance(); + Ok(Token::new(TokenKind::Plus, start, self.pos)) + } + '-' => { + self.advance(); + Ok(Token::new(TokenKind::Minus, start, self.pos)) + } + '*' => { + self.advance(); + Ok(Token::new(TokenKind::Star, start, self.pos)) + } + '/' => { + self.advance(); + Ok(Token::new(TokenKind::Slash, start, self.pos)) + } + '%' => { + self.advance(); + Ok(Token::new(TokenKind::Percent, start, self.pos)) + } + '(' => { + self.advance(); + Ok(Token::new(TokenKind::LParen, start, self.pos)) + } + ')' => { + self.advance(); + Ok(Token::new(TokenKind::RParen, start, self.pos)) + } + '[' => { + self.advance(); + Ok(Token::new(TokenKind::LBracket, start, self.pos)) + } + ']' => { + self.advance(); + Ok(Token::new(TokenKind::RBracket, start, self.pos)) + } + ',' => { + self.advance(); + Ok(Token::new(TokenKind::Comma, start, self.pos)) + } + '.' => { + self.advance(); + Ok(Token::new(TokenKind::Dot, start, self.pos)) + } + '=' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Ok(Token::new(TokenKind::EqEq, start, self.pos)) + } else { + Err(TokenError::UnexpectedChar('=', start)) + } + } + '!' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Ok(Token::new(TokenKind::BangEq, start, self.pos)) + } else { + Err(TokenError::UnexpectedChar('!', start)) + } + } + '<' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Ok(Token::new(TokenKind::LtEq, start, self.pos)) + } else { + Ok(Token::new(TokenKind::Lt, start, self.pos)) + } + } + '>' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Ok(Token::new(TokenKind::GtEq, start, self.pos)) + } else { + Ok(Token::new(TokenKind::Gt, start, self.pos)) + } + } + '"' | '\'' => self.read_string(ch), + c if c.is_ascii_digit() => self.read_number(), + c if c.is_ascii_alphabetic() || c == '_' => self.read_ident(), + other => Err(TokenError::UnexpectedChar(other, start)), + } + } + + fn read_string(&mut self, quote: char) -> Result { + let start = self.pos; + self.advance(); // consume opening quote + let mut s = String::new(); + loop { + match self.advance() { + Some('\\') => { + // Escape sequence + match self.advance() { + Some('n') => s.push('\n'), + Some('t') => s.push('\t'), + Some('r') => s.push('\r'), + Some('\\') => s.push('\\'), + Some(c) if c == quote => s.push(c), + Some(c) => { + s.push('\\'); + s.push(c); + } + None => return Err(TokenError::UnterminatedString(start)), + } + } + Some(c) if c == quote => { + return Ok(Token::new(TokenKind::StringLit(s), start, self.pos)); + } + Some(c) => s.push(c), + None => return Err(TokenError::UnterminatedString(start)), + } + } + } + + fn read_number(&mut self) -> Result { + let start = self.pos; + let mut num_str = String::new(); + let mut is_float = false; + + while let Some(ch) = self.peek() { + if ch.is_ascii_digit() { + num_str.push(ch); + self.advance(); + } else if ch == '.' && !is_float { + // Check if this is a decimal point or a method call dot + // Look ahead to see if next char is a digit + let next_pos = self.pos + 1; + if next_pos < self.chars.len() && self.chars[next_pos].is_ascii_digit() { + is_float = true; + num_str.push(ch); + self.advance(); + } else { + // It's a dot access, stop number parsing here + break; + } + } else { + break; + } + } + + if is_float { + let val: f64 = num_str.parse().map_err(|_| { + TokenError::InvalidNumber(start, num_str.clone()) + })?; + Ok(Token::new(TokenKind::Float(val), start, self.pos)) + } else { + let val: i64 = num_str.parse().map_err(|_| { + TokenError::InvalidNumber(start, num_str.clone()) + })?; + Ok(Token::new(TokenKind::Integer(val), start, self.pos)) + } + } + + fn read_ident(&mut self) -> Result { + let start = self.pos; + let mut ident = String::new(); + while let Some(ch) = self.peek() { + if ch.is_ascii_alphanumeric() || ch == '_' { + ident.push(ch); + self.advance(); + } else { + break; + } + } + + let kind = match ident.as_str() { + "true" => TokenKind::True, + "false" => TokenKind::False, + "null" => TokenKind::Null, + "and" => TokenKind::And, + "or" => TokenKind::Or, + "not" => TokenKind::Not, + "in" => TokenKind::In, + _ => TokenKind::Ident(ident), + }; + + Ok(Token::new(kind, start, self.pos)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tokenize(input: &str) -> Vec { + let mut t = Tokenizer::new(input); + t.tokenize() + .unwrap() + .into_iter() + .map(|t| t.kind) + .collect() + } + + #[test] + fn test_simple_expression() { + let kinds = tokenize("2 + 3"); + assert_eq!( + kinds, + vec![ + TokenKind::Integer(2), + TokenKind::Plus, + TokenKind::Integer(3), + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_comparison() { + let kinds = tokenize("x >= 10"); + assert_eq!( + kinds, + vec![ + TokenKind::Ident("x".to_string()), + TokenKind::GtEq, + TokenKind::Integer(10), + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_keywords() { + let kinds = tokenize("true and not false or null in x"); + assert_eq!( + kinds, + vec![ + TokenKind::True, + TokenKind::And, + TokenKind::Not, + TokenKind::False, + TokenKind::Or, + TokenKind::Null, + TokenKind::In, + TokenKind::Ident("x".to_string()), + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_string_literals() { + let kinds = tokenize("\"hello\" + 'world'"); + assert_eq!( + kinds, + vec![ + TokenKind::StringLit("hello".to_string()), + TokenKind::Plus, + TokenKind::StringLit("world".to_string()), + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_float() { + let kinds = tokenize("3.14"); + assert_eq!(kinds, vec![TokenKind::Float(3.14), TokenKind::Eof]); + } + + #[test] + fn test_dot_access() { + let kinds = tokenize("obj.field"); + assert_eq!( + kinds, + vec![ + TokenKind::Ident("obj".to_string()), + TokenKind::Dot, + TokenKind::Ident("field".to_string()), + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_function_call() { + let kinds = tokenize("length(arr)"); + assert_eq!( + kinds, + vec![ + TokenKind::Ident("length".to_string()), + TokenKind::LParen, + TokenKind::Ident("arr".to_string()), + TokenKind::RParen, + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_bracket_access() { + let kinds = tokenize("arr[0]"); + assert_eq!( + kinds, + vec![ + TokenKind::Ident("arr".to_string()), + TokenKind::LBracket, + TokenKind::Integer(0), + TokenKind::RBracket, + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_escape_sequences() { + let kinds = tokenize(r#""hello\nworld""#); + assert_eq!( + kinds, + vec![ + TokenKind::StringLit("hello\nworld".to_string()), + TokenKind::Eof, + ] + ); + } + + #[test] + fn test_integer_followed_by_dot() { + // `42.field` - the 42 is an integer and `.field` is separate + let kinds = tokenize("42.field"); + assert_eq!( + kinds, + vec![ + TokenKind::Integer(42), + TokenKind::Dot, + TokenKind::Ident("field".to_string()), + TokenKind::Eof, + ] + ); + } +} diff --git a/crates/common/src/workflow/expression_validator.rs b/crates/common/src/workflow/expression_validator.rs new file mode 100644 index 0000000..6f228ca --- /dev/null +++ b/crates/common/src/workflow/expression_validator.rs @@ -0,0 +1,674 @@ +//! # Workflow Expression Validator +//! +//! Static validation of `{{ }}` template expressions in workflow definitions. +//! Catches syntax errors and unresolved variable references **before** the +//! workflow is saved, so users get immediate feedback instead of opaque +//! runtime failures during execution. +//! +//! ## What is validated +//! +//! 1. **Syntax** — every `{{ expr }}` block must parse successfully. +//! 2. **Variable references** — top-level identifiers that are not a known +//! namespace (`parameters`, `workflow`, `task`, `config`, `keystore`, +//! `item`, `index`, `system`, …) must exist in either the workflow's +//! `vars` map or its `param_schema` keys (bare-name fallback targets). +//! +//! ## What is NOT validated +//! +//! - **Type correctness** — e.g. whether `range(parameters.n)` actually +//! receives an integer. That requires runtime values. +//! - **Deep property paths** — e.g. `task.fetch.result.data`. We validate +//! that `task` is a known namespace but not that `fetch` is a real task +//! name (it might not exist yet at save time if tasks are re-ordered). +//! - **Function arity** — built-in functions are not checked for argument +//! count here; the evaluator already reports those errors at runtime. + +use std::collections::HashSet; + +use serde_json::Value as JsonValue; + +use super::expression::{parse_expression, Expr, ParseError}; +use super::parser::{PublishDirective, WorkflowDefinition}; + +// ─────────────────────────────────────────────────────────────────────────── +// Public API +// ─────────────────────────────────────────────────────────────────────────── + +/// A single validation diagnostic. +#[derive(Debug, Clone)] +pub struct ExpressionWarning { + /// Human-readable location (e.g. `task 'sleep_1' with_items`). + pub location: String, + /// The raw template string that was checked. + pub expression: String, + /// What went wrong. + pub message: String, +} + +impl std::fmt::Display for ExpressionWarning { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {} — `{}`", self.location, self.message, self.expression) + } +} + +/// Validate all template expressions in a workflow definition. +/// +/// Returns an empty vec on success, or one [`ExpressionWarning`] per problem +/// found. The caller decides whether warnings are fatal (block save) or +/// advisory. +/// +/// `param_schema` is the *flat-format* schema (`{ "url": { "type": "string" }, … }`) +/// passed alongside the definition in the save request. Its top-level keys +/// are the declared parameter names. +pub fn validate_workflow_expressions( + workflow: &WorkflowDefinition, + param_schema: Option<&JsonValue>, +) -> Vec { + let known_names = build_known_names(workflow, param_schema); + let mut warnings = Vec::new(); + + for task in &workflow.tasks { + let task_loc = format!("task '{}'", task.name); + + // ── with_items expression ──────────────────────────────────── + if let Some(ref expr) = task.with_items { + validate_template( + expr, + &format!("{task_loc} with_items"), + &known_names, + &mut warnings, + ); + } + + // ── task-level when condition ──────────────────────────────── + if let Some(ref expr) = task.when { + validate_template( + expr, + &format!("{task_loc} when"), + &known_names, + &mut warnings, + ); + } + + // ── input templates ────────────────────────────────────────── + for (key, value) in &task.input { + collect_json_templates( + value, + &format!("{task_loc} input.{key}"), + &known_names, + &mut warnings, + ); + } + + // ── next transitions ───────────────────────────────────────── + for (ti, transition) in task.next.iter().enumerate() { + if let Some(ref when_expr) = transition.when { + validate_template( + when_expr, + &format!("{task_loc} next[{ti}].when"), + &known_names, + &mut warnings, + ); + } + + for directive in &transition.publish { + match directive { + PublishDirective::Simple(map) => { + for (pk, pv) in map { + validate_template( + pv, + &format!("{task_loc} next[{ti}].publish.{pk}"), + &known_names, + &mut warnings, + ); + } + } + PublishDirective::Key(_) => { /* nothing to validate */ } + } + } + } + + // ── legacy task-level publish ──────────────────────────────── + for directive in &task.publish { + if let PublishDirective::Simple(map) = directive { + for (pk, pv) in map { + validate_template( + pv, + &format!("{task_loc} publish.{pk}"), + &known_names, + &mut warnings, + ); + } + } + } + } + + warnings +} + +// ─────────────────────────────────────────────────────────────────────────── +// Internals +// ─────────────────────────────────────────────────────────────────────────── + +/// Canonical namespace identifiers that are always valid as top-level names +/// inside `{{ }}` expressions. These are resolved by `WorkflowContext` at +/// runtime and never need to exist in `vars` or `param_schema`. +const CANONICAL_NAMESPACES: &[&str] = &[ + "parameters", + "workflow", + "vars", + "variables", + "task", + "tasks", + "config", + "keystore", + "item", + "index", + "system", +]; + +/// Built-in constants that are valid bare identifiers. +const BUILTIN_LITERALS: &[&str] = &["true", "false", "null"]; + +/// Build the set of bare names that are valid in expressions: +/// canonical namespaces + workflow var names + param_schema keys. +fn build_known_names( + workflow: &WorkflowDefinition, + param_schema: Option<&JsonValue>, +) -> HashSet { + let mut names: HashSet = CANONICAL_NAMESPACES + .iter() + .map(|s| (*s).to_string()) + .collect(); + + for lit in BUILTIN_LITERALS { + names.insert((*lit).to_string()); + } + + // Workflow vars + for key in workflow.vars.keys() { + names.insert(key.clone()); + } + + // Parameter schema keys (flat format: top-level keys are param names) + if let Some(JsonValue::Object(map)) = param_schema { + for key in map.keys() { + names.insert(key.clone()); + } + } + + // Also accept the workflow-level `parameters` schema if present on the + // definition itself (some loaders put it there). + if let Some(JsonValue::Object(ref map)) = workflow.parameters { + for key in map.keys() { + names.insert(key.clone()); + } + } + + names +} + +/// Extract `{{ … }}` blocks from a template string and validate each one. +fn validate_template( + template: &str, + location: &str, + known_names: &HashSet, + warnings: &mut Vec, +) { + for raw_expr in extract_expressions(template) { + let trimmed = raw_expr.trim(); + if trimmed.is_empty() { + continue; + } + + // Phase 1: parse + match parse_expression(trimmed) { + Err(e) => { + warnings.push(ExpressionWarning { + location: location.to_string(), + expression: raw_expr.to_string(), + message: format!("syntax error: {e}"), + }); + } + Ok(ast) => { + // Phase 2: check bare-name references + let mut bare_idents = Vec::new(); + collect_bare_idents(&ast, &mut bare_idents); + + for ident in bare_idents { + if !known_names.contains(&ident) { + warnings.push(ExpressionWarning { + location: location.to_string(), + expression: raw_expr.to_string(), + message: format!( + "unknown variable '{}'. Use 'parameters.{}' for input \ + parameters, or define it in workflow vars", + ident, ident, + ), + }); + } + } + } + } + } +} + +/// Recursively walk a JSON value looking for string leaves that contain +/// `{{ }}` templates. +fn collect_json_templates( + value: &JsonValue, + location: &str, + known_names: &HashSet, + warnings: &mut Vec, +) { + match value { + JsonValue::String(s) => { + validate_template(s, location, known_names, warnings); + } + JsonValue::Array(arr) => { + for (i, item) in arr.iter().enumerate() { + collect_json_templates( + item, + &format!("{location}[{i}]"), + known_names, + warnings, + ); + } + } + JsonValue::Object(map) => { + for (key, val) in map { + collect_json_templates( + val, + &format!("{location}.{key}"), + known_names, + warnings, + ); + } + } + _ => { /* numbers, bools, null — nothing to validate */ } + } +} + +/// Extract the inner expression strings from all `{{ … }}` blocks in a +/// template. Handles nested braces conservatively (takes everything between +/// the outermost `{{` and `}}`). +fn extract_expressions(template: &str) -> Vec<&str> { + let mut results = Vec::new(); + let mut rest = template; + + while let Some(start) = rest.find("{{") { + let after_open = start + 2; + if let Some(end) = rest[after_open..].find("}}") { + results.push(&rest[after_open..after_open + end]); + rest = &rest[after_open + end + 2..]; + } else { + // Unclosed `{{` — skip + break; + } + } + + results +} + +/// Collect bare `Ident` nodes that appear at the *top level* of an +/// expression — i.e. identifiers that are not the right-hand side of a +/// `.field` access (those are field names, not variable references). +/// +/// For `DotAccess { object: Ident("parameters"), field: "n" }` we collect +/// `"parameters"` but NOT `"n"`. +/// +/// For `FunctionCall { name: "range", args: [Ident("n")] }` we collect +/// `"n"` (it's a bare variable reference used as a function argument). +fn collect_bare_idents(expr: &Expr, out: &mut Vec) { + match expr { + Expr::Ident(name) => { + out.push(name.clone()); + } + Expr::Literal(_) => {} + Expr::Array(items) => { + for item in items { + collect_bare_idents(item, out); + } + } + Expr::BinaryOp { left, right, .. } => { + collect_bare_idents(left, out); + collect_bare_idents(right, out); + } + Expr::UnaryOp { operand, .. } => { + collect_bare_idents(operand, out); + } + Expr::DotAccess { object, .. } => { + // Only recurse into the object side — the field name is not a + // variable reference. + collect_bare_idents(object, out); + } + Expr::IndexAccess { object, index } => { + collect_bare_idents(object, out); + collect_bare_idents(index, out); + } + Expr::FunctionCall { args, .. } => { + // Function name itself is not a variable reference. + for arg in args { + collect_bare_idents(arg, out); + } + } + } +} + +// ─────────────────────────────────────────────────────────────────────────── +// Tests +// ─────────────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + fn minimal_workflow(tasks: Vec) -> WorkflowDefinition { + WorkflowDefinition { + r#ref: "test.wf".to_string(), + label: "Test".to_string(), + description: None, + version: "1.0.0".to_string(), + parameters: None, + output: None, + vars: HashMap::new(), + tasks, + output_map: None, + tags: vec![], + } + } + + fn action_task(name: &str) -> super::super::parser::Task { + super::super::parser::Task { + name: name.to_string(), + r#type: super::super::parser::TaskType::Action, + action: Some("core.echo".to_string()), + input: HashMap::new(), + when: None, + with_items: None, + batch_size: None, + concurrency: None, + retry: None, + timeout: None, + next: vec![], + on_success: None, + on_failure: None, + on_complete: None, + on_timeout: None, + decision: vec![], + publish: vec![], + join: None, + tasks: None, + chart_meta: None, + } + } + + // ── extract_expressions ────────────────────────────────────────── + + #[test] + fn test_extract_single() { + let exprs = extract_expressions("{{ parameters.n }}"); + assert_eq!(exprs, vec![" parameters.n "]); + } + + #[test] + fn test_extract_multiple() { + let exprs = extract_expressions("Hello {{ name }}, you have {{ count }} items"); + assert_eq!(exprs.len(), 2); + assert_eq!(exprs[0].trim(), "name"); + assert_eq!(exprs[1].trim(), "count"); + } + + #[test] + fn test_extract_no_expressions() { + let exprs = extract_expressions("plain text"); + assert!(exprs.is_empty()); + } + + #[test] + fn test_extract_unclosed() { + let exprs = extract_expressions("{{ oops"); + assert!(exprs.is_empty()); + } + + // ── collect_bare_idents ────────────────────────────────────────── + + #[test] + fn test_bare_ident() { + let ast = parse_expression("n").unwrap(); + let mut idents = Vec::new(); + collect_bare_idents(&ast, &mut idents); + assert_eq!(idents, vec!["n"]); + } + + #[test] + fn test_dot_access_does_not_collect_field() { + let ast = parse_expression("parameters.n").unwrap(); + let mut idents = Vec::new(); + collect_bare_idents(&ast, &mut idents); + assert_eq!(idents, vec!["parameters"]); + } + + #[test] + fn test_function_arg_collected() { + let ast = parse_expression("range(n)").unwrap(); + let mut idents = Vec::new(); + collect_bare_idents(&ast, &mut idents); + assert_eq!(idents, vec!["n"]); + } + + #[test] + fn test_nested_dot_access() { + let ast = parse_expression("task.fetch.result.data").unwrap(); + let mut idents = Vec::new(); + collect_bare_idents(&ast, &mut idents); + assert_eq!(idents, vec!["task"]); + } + + #[test] + fn test_binary_op() { + let ast = parse_expression("parameters.x + workflow.y").unwrap(); + let mut idents = Vec::new(); + collect_bare_idents(&ast, &mut idents); + assert_eq!(idents, vec!["parameters", "workflow"]); + } + + // ── validate_workflow_expressions ───────────────────────────────── + + #[test] + fn test_valid_workflow_no_warnings() { + let mut task = action_task("greet"); + task.with_items = Some("{{ range(parameters.n) }}".to_string()); + task.input.insert( + "message".to_string(), + serde_json::json!("Hello {{ item }}"), + ); + + let wf = minimal_workflow(vec![task]); + let warnings = validate_workflow_expressions(&wf, None); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + } + + #[test] + fn test_bare_name_from_vars_ok() { + let mut task = action_task("greet"); + task.with_items = Some("{{ range(n) }}".to_string()); + + let mut wf = minimal_workflow(vec![task]); + wf.vars.insert("n".to_string(), serde_json::json!(5)); + + let warnings = validate_workflow_expressions(&wf, None); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + } + + #[test] + fn test_bare_name_from_param_schema_ok() { + let mut task = action_task("greet"); + task.with_items = Some("{{ range(n) }}".to_string()); + + let wf = minimal_workflow(vec![task]); + let schema = serde_json::json!({ + "n": { "type": "integer", "required": true } + }); + + let warnings = validate_workflow_expressions(&wf, Some(&schema)); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + } + + #[test] + fn test_unknown_bare_name_warning() { + let mut task = action_task("greet"); + task.with_items = Some("{{ range(n) }}".to_string()); + + let wf = minimal_workflow(vec![task]); + let warnings = validate_workflow_expressions(&wf, None); + + assert_eq!(warnings.len(), 1); + assert!(warnings[0].message.contains("unknown variable 'n'")); + assert!(warnings[0].message.contains("parameters.n")); + assert!(warnings[0].location.contains("with_items")); + } + + #[test] + fn test_syntax_error_warning() { + let mut task = action_task("greet"); + task.input.insert( + "msg".to_string(), + serde_json::json!("{{ +++ }}"), + ); + + let wf = minimal_workflow(vec![task]); + let warnings = validate_workflow_expressions(&wf, None); + + assert_eq!(warnings.len(), 1); + assert!(warnings[0].message.contains("syntax error")); + } + + #[test] + fn test_transition_when_validated() { + let mut task = action_task("step1"); + task.next = vec![super::super::parser::TaskTransition { + when: Some("{{ bad_var > 3 }}".to_string()), + publish: vec![], + r#do: Some(vec!["step2".to_string()]), + chart_meta: None, + }]; + + let wf = minimal_workflow(vec![task, action_task("step2")]); + let warnings = validate_workflow_expressions(&wf, None); + + assert_eq!(warnings.len(), 1); + assert!(warnings[0].message.contains("unknown variable 'bad_var'")); + assert!(warnings[0].location.contains("next[0].when")); + } + + #[test] + fn test_transition_publish_validated() { + let mut task = action_task("step1"); + let mut publish_map = HashMap::new(); + publish_map.insert("out".to_string(), "{{ unknown_thing }}".to_string()); + task.next = vec![super::super::parser::TaskTransition { + when: Some("{{ succeeded() }}".to_string()), + publish: vec![PublishDirective::Simple(publish_map)], + r#do: Some(vec!["step2".to_string()]), + chart_meta: None, + }]; + + let wf = minimal_workflow(vec![task, action_task("step2")]); + let warnings = validate_workflow_expressions(&wf, None); + + assert_eq!(warnings.len(), 1); + assert!(warnings[0].message.contains("unknown variable 'unknown_thing'")); + assert!(warnings[0].location.contains("publish.out")); + } + + #[test] + fn test_workflow_functions_no_warning() { + // succeeded(), failed(), result() etc. are function calls, + // not variable references — should not produce warnings. + let mut task = action_task("step1"); + task.next = vec![super::super::parser::TaskTransition { + when: Some("{{ succeeded() and result().code == 200 }}".to_string()), + publish: vec![], + r#do: Some(vec!["step2".to_string()]), + chart_meta: None, + }]; + + let wf = minimal_workflow(vec![task, action_task("step2")]); + let warnings = validate_workflow_expressions(&wf, None); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + } + + #[test] + fn test_plain_text_no_warning() { + let mut task = action_task("step1"); + task.input.insert( + "msg".to_string(), + serde_json::json!("just plain text"), + ); + + let wf = minimal_workflow(vec![task]); + let warnings = validate_workflow_expressions(&wf, None); + assert!(warnings.is_empty()); + } + + #[test] + fn test_multiple_errors_collected() { + let mut task = action_task("step1"); + task.with_items = Some("{{ range(a) }}".to_string()); + task.input.insert( + "x".to_string(), + serde_json::json!("{{ b + c }}"), + ); + + let wf = minimal_workflow(vec![task]); + let warnings = validate_workflow_expressions(&wf, None); + + // a, b, c are all unknown + assert_eq!(warnings.len(), 3); + let names: HashSet<_> = warnings + .iter() + .flat_map(|w| { + // extract the variable name from "unknown variable 'X'" + w.message + .strip_prefix("unknown variable '") + .and_then(|s| s.split('\'').next()) + .map(|s| s.to_string()) + }) + .collect(); + assert!(names.contains("a")); + assert!(names.contains("b")); + assert!(names.contains("c")); + } + + #[test] + fn test_index_access_validated() { + let mut task = action_task("step1"); + task.input.insert( + "val".to_string(), + serde_json::json!("{{ items[idx] }}"), + ); + + let wf = minimal_workflow(vec![task]); + let warnings = validate_workflow_expressions(&wf, None); + + // Both `items` and `idx` are bare unknowns + assert_eq!(warnings.len(), 2); + } + + #[test] + fn test_builtin_literals_ok() { + let mut task = action_task("step1"); + task.next = vec![super::super::parser::TaskTransition { + when: Some("{{ true and not false }}".to_string()), + publish: vec![], + r#do: None, + chart_meta: None, + }]; + + let wf = minimal_workflow(vec![task]); + let warnings = validate_workflow_expressions(&wf, None); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + } +} diff --git a/crates/common/src/workflow/mod.rs b/crates/common/src/workflow/mod.rs index 1192f44..ba6dd4c 100644 --- a/crates/common/src/workflow/mod.rs +++ b/crates/common/src/workflow/mod.rs @@ -3,6 +3,7 @@ //! This module provides utilities for loading, parsing, validating, and registering //! workflow definitions from YAML files. +pub mod expression; pub mod loader; pub mod pack_service; pub mod parser; diff --git a/crates/common/tests/execution_repository_tests.rs b/crates/common/tests/execution_repository_tests.rs index 4cb9186..01f0425 100644 --- a/crates/common/tests/execution_repository_tests.rs +++ b/crates/common/tests/execution_repository_tests.rs @@ -356,7 +356,7 @@ async fn test_update_execution_status() { status: Some(ExecutionStatus::Running), result: None, executor: None, - workflow_task: None, + ..Default::default() }; let updated = ExecutionRepository::update(&pool, created.id, update) @@ -401,7 +401,7 @@ async fn test_update_execution_result() { status: Some(ExecutionStatus::Completed), result: Some(result_data.clone()), executor: None, - workflow_task: None, + ..Default::default() }; let updated = ExecutionRepository::update(&pool, created.id, update) @@ -445,7 +445,7 @@ async fn test_update_execution_executor() { status: Some(ExecutionStatus::Scheduled), result: None, executor: None, - workflow_task: None, + ..Default::default() }; let updated = ExecutionRepository::update(&pool, created.id, update) @@ -492,7 +492,7 @@ async fn test_update_execution_status_transitions() { status: Some(ExecutionStatus::Scheduling), result: None, executor: None, - workflow_task: None, + ..Default::default() }, ) .await @@ -507,7 +507,7 @@ async fn test_update_execution_status_transitions() { status: Some(ExecutionStatus::Scheduled), result: None, executor: None, - workflow_task: None, + ..Default::default() }, ) .await @@ -522,7 +522,7 @@ async fn test_update_execution_status_transitions() { status: Some(ExecutionStatus::Running), result: None, executor: None, - workflow_task: None, + ..Default::default() }, ) .await @@ -537,7 +537,7 @@ async fn test_update_execution_status_transitions() { status: Some(ExecutionStatus::Completed), result: Some(json!({"success": true})), executor: None, - workflow_task: None, + ..Default::default() }, ) .await @@ -578,7 +578,7 @@ async fn test_update_execution_failed_status() { status: Some(ExecutionStatus::Failed), result: Some(json!({"error": "Connection timeout"})), executor: None, - workflow_task: None, + ..Default::default() }; let updated = ExecutionRepository::update(&pool, created.id, update) @@ -984,7 +984,7 @@ async fn test_execution_timestamps() { status: Some(ExecutionStatus::Running), result: None, executor: None, - workflow_task: None, + ..Default::default() }; let updated = ExecutionRepository::update(&pool, created.id, update) @@ -1095,7 +1095,7 @@ async fn test_execution_result_json() { status: Some(ExecutionStatus::Completed), result: Some(complex_result.clone()), executor: None, - workflow_task: None, + ..Default::default() }; let updated = ExecutionRepository::update(&pool, created.id, update) diff --git a/crates/executor/src/inquiry_handler.rs b/crates/executor/src/inquiry_handler.rs index 5206e2e..9f05e09 100644 --- a/crates/executor/src/inquiry_handler.rs +++ b/crates/executor/src/inquiry_handler.rs @@ -244,8 +244,7 @@ impl InquiryHandler { let update_input = UpdateExecutionInput { status: None, // Keep current status, let worker handle completion result: Some(updated_result), - executor: None, - workflow_task: None, // Not updating workflow metadata + ..Default::default() }; ExecutionRepository::update(pool, execution.id, update_input).await?; diff --git a/crates/executor/src/retry_manager.rs b/crates/executor/src/retry_manager.rs index dd2deca..414e922 100644 --- a/crates/executor/src/retry_manager.rs +++ b/crates/executor/src/retry_manager.rs @@ -381,10 +381,7 @@ impl RetryManager { &self.pool, execution_id, UpdateExecutionInput { - status: None, - result: None, - executor: None, - workflow_task: None, + ..Default::default() }, ) .await?; diff --git a/crates/executor/src/scheduler.rs b/crates/executor/src/scheduler.rs index 9c21fd4..39f2ba2 100644 --- a/crates/executor/src/scheduler.rs +++ b/crates/executor/src/scheduler.rs @@ -66,6 +66,53 @@ fn extract_workflow_params(config: &Option) -> JsonValue { } } +/// Apply default values from a workflow's `param_schema` to the provided +/// parameters. +/// +/// The param_schema uses the flat format where each key maps to an object +/// that may contain a `"default"` field: +/// +/// ```json +/// { "n": { "type": "integer", "default": 10 } } +/// ``` +/// +/// Any parameter that has a default in the schema but is missing (or `null`) +/// in the supplied `params` will be filled in. Parameters already provided +/// by the caller are never overwritten. +fn apply_param_defaults(params: JsonValue, param_schema: &Option) -> JsonValue { + let schema = match param_schema { + Some(s) if s.is_object() => s, + _ => return params, + }; + + let mut obj = match params { + JsonValue::Object(m) => m, + _ => return params, + }; + + if let Some(schema_obj) = schema.as_object() { + for (key, prop) in schema_obj { + // Only fill in missing / null parameters + let needs_default = match obj.get(key) { + None => true, + Some(JsonValue::Null) => true, + _ => false, + }; + if needs_default { + if let Some(default_val) = prop.get("default") { + debug!( + "Applying default for parameter '{}': {}", + key, default_val + ); + obj.insert(key.clone(), default_val.clone()); + } + } + } + } + + JsonValue::Object(obj) +} + /// Payload for execution scheduled messages #[derive(Debug, Clone, Serialize, Deserialize)] struct ExecutionScheduledPayload { @@ -316,7 +363,10 @@ impl ExecutionScheduler { // Build initial workflow context from execution parameters and // workflow-level vars so that entry-point task inputs are rendered. + // Apply defaults from the workflow's param_schema for any parameters + // that were not supplied by the caller. let workflow_params = extract_workflow_params(&execution.config); + let workflow_params = apply_param_defaults(workflow_params, &workflow_def.param_schema); let wf_ctx = WorkflowContext::new( workflow_params, definition @@ -563,7 +613,7 @@ impl ExecutionScheduler { }; let total = items.len(); - let concurrency_limit = task_node.concurrency.unwrap_or(total); + let concurrency_limit = task_node.concurrency.unwrap_or(1); let dispatch_count = total.min(concurrency_limit); info!( @@ -842,6 +892,18 @@ impl ExecutionScheduler { return Ok(()); } + // Load the workflow definition so we can apply param_schema defaults + let workflow_def = + WorkflowDefinitionRepository::find_by_id(pool, workflow_execution.workflow_def) + .await? + .ok_or_else(|| { + anyhow::anyhow!( + "Workflow definition {} not found for workflow_execution {}", + workflow_execution.workflow_def, + workflow_execution_id + ) + })?; + // Rebuild the task graph from the stored JSON let graph: TaskGraph = serde_json::from_value(workflow_execution.task_graph.clone()) .map_err(|e| { @@ -897,7 +959,7 @@ impl ExecutionScheduler { let concurrency_limit = graph .get_task(task_name) .and_then(|n| n.concurrency) - .unwrap_or(usize::MAX); + .unwrap_or(1); let free_slots = concurrency_limit.saturating_sub(in_flight_count.0 as usize); @@ -995,6 +1057,7 @@ impl ExecutionScheduler { // results so that successor task inputs can be rendered. // ----------------------------------------------------------------- let workflow_params = extract_workflow_params(&parent_execution.config); + let workflow_params = apply_param_defaults(workflow_params, &workflow_def.param_schema); // Collect results from all completed children of this workflow let child_executions = @@ -1619,11 +1682,11 @@ mod tests { let dispatch_count = total.min(concurrency_limit); assert_eq!(dispatch_count, 3); - // No concurrency limit → dispatch all + // No concurrency limit → default to serial (1 at a time) let concurrency: Option = None; - let concurrency_limit = concurrency.unwrap_or(total); + let concurrency_limit = concurrency.unwrap_or(1); let dispatch_count = total.min(concurrency_limit); - assert_eq!(dispatch_count, 20); + assert_eq!(dispatch_count, 1); // Concurrency exceeds total → dispatch all let concurrency: Option = Some(50); diff --git a/crates/executor/src/workflow/context.rs b/crates/executor/src/workflow/context.rs index 32b1ed1..6c9b806 100644 --- a/crates/executor/src/workflow/context.rs +++ b/crates/executor/src/workflow/context.rs @@ -3,6 +3,33 @@ //! This module manages workflow execution context, including variables, //! template rendering, and data flow between tasks. //! +//! ## Canonical Namespaces +//! +//! All data accessible inside `{{ }}` template expressions is organised into +//! well-defined, non-overlapping namespaces: +//! +//! | Namespace | Example | Description | +//! |-----------|---------|-------------| +//! | `parameters` | `{{ parameters.url }}` | Immutable workflow input parameters | +//! | `workflow` | `{{ workflow.counter }}` | Mutable workflow-scoped variables (set via `publish`) | +//! | `task` | `{{ task.fetch.result.data }}` | Completed task results keyed by task name | +//! | `config` | `{{ config.api_token }}` | Pack configuration values (read-only) | +//! | `keystore` | `{{ keystore.secret_key }}` | Encrypted secrets from the key store (read-only) | +//! | `item` | `{{ item }}` or `{{ item.name }}` | Current element in a `with_items` loop | +//! | `index` | `{{ index }}` | Zero-based iteration index in a `with_items` loop | +//! | `system` | `{{ system.workflow_start }}` | System-provided variables | +//! +//! ### Backward-compatible aliases +//! +//! The following aliases resolve to the same data as their canonical form and +//! are kept for backward compatibility with existing workflow definitions: +//! +//! - `vars` / `variables` → same as `workflow` +//! - `tasks` → same as `task` +//! +//! Bare variable names (e.g. `{{ my_var }}`) also resolve against the +//! `workflow` variable store as a last-resort fallback. +//! //! ## Function-call expressions //! //! Templates support Orquesta-style function calls: @@ -19,6 +46,9 @@ //! expression instead of stringifying it. This means `"{{ item }}"` resolving //! to integer `5` stays as `5`, not the string `"5"`. +use attune_common::workflow::expression::{ + self, is_truthy, EvalContext, EvalError, EvalResult as ExprResult, +}; use dashmap::DashMap; use serde_json::{json, Value as JsonValue}; use std::collections::HashMap; @@ -63,18 +93,25 @@ pub enum TaskOutcome { /// not the underlying data, making it O(1) instead of O(context_size). #[derive(Debug, Clone)] pub struct WorkflowContext { - /// Workflow-level variables (shared via Arc) + /// Mutable workflow-scoped variables. Canonical namespace: `workflow`. + /// Also accessible as `vars`, `variables`, or bare names (fallback). variables: Arc>, - /// Workflow input parameters (shared via Arc) + /// Immutable workflow input parameters. Canonical namespace: `parameters`. parameters: Arc, - /// Task results (shared via Arc, keyed by task name) + /// Completed task results keyed by task name. Canonical namespace: `task`. task_results: Arc>, - /// System variables (shared via Arc) + /// System-provided variables. Canonical namespace: `system`. system: Arc>, + /// Pack configuration values (read-only). Canonical namespace: `config`. + pack_config: Arc, + + /// Encrypted keystore values (read-only). Canonical namespace: `keystore`. + keystore: Arc, + /// Current item (for with-items iteration) - per-item data current_item: Option, @@ -89,7 +126,11 @@ pub struct WorkflowContext { } impl WorkflowContext { - /// Create a new workflow context + /// Create a new workflow context. + /// + /// `parameters` — the immutable input parameters for this workflow run. + /// `initial_vars` — initial workflow-scoped variables (from the workflow + /// definition's `vars` section). pub fn new(parameters: JsonValue, initial_vars: HashMap) -> Self { let system = DashMap::new(); system.insert("workflow_start".to_string(), json!(chrono::Utc::now())); @@ -104,6 +145,8 @@ impl WorkflowContext { parameters: Arc::new(parameters), task_results: Arc::new(DashMap::new()), system: Arc::new(system), + pack_config: Arc::new(JsonValue::Null), + keystore: Arc::new(JsonValue::Null), current_item: None, current_index: None, last_task_result: None, @@ -142,6 +185,8 @@ impl WorkflowContext { parameters: Arc::new(parameters), task_results: Arc::new(results), system: Arc::new(system), + pack_config: Arc::new(JsonValue::Null), + keystore: Arc::new(JsonValue::Null), current_item: None, current_index: None, last_task_result: None, @@ -149,28 +194,38 @@ impl WorkflowContext { } } - /// Set a variable + /// Set a workflow-scoped variable (accessible as `workflow.`). pub fn set_var(&mut self, name: &str, value: JsonValue) { self.variables.insert(name.to_string(), value); } - /// Get a variable + /// Get a workflow-scoped variable by name. pub fn get_var(&self, name: &str) -> Option { self.variables.get(name).map(|entry| entry.value().clone()) } - /// Store a task result + /// Store a completed task's result (accessible as `task..*`). pub fn set_task_result(&mut self, task_name: &str, result: JsonValue) { self.task_results.insert(task_name.to_string(), result); } - /// Get a task result + /// Get a task result by task name. pub fn get_task_result(&self, task_name: &str) -> Option { self.task_results .get(task_name) .map(|entry| entry.value().clone()) } + /// Set the pack configuration (accessible as `config.`). + pub fn set_pack_config(&mut self, config: JsonValue) { + self.pack_config = Arc::new(config); + } + + /// Set the keystore secrets (accessible as `keystore.`). + pub fn set_keystore(&mut self, secrets: JsonValue) { + self.keystore = Arc::new(secrets); + } + /// Set current item for iteration pub fn set_current_item(&mut self, item: JsonValue, index: usize) { self.current_item = Some(item); @@ -299,220 +354,55 @@ impl WorkflowContext { } } - /// Evaluate a template expression - fn evaluate_expression(&self, expr: &str) -> ContextResult { - // --------------------------------------------------------------- - // Function-call expressions: result(), succeeded(), failed(), timed_out() - // --------------------------------------------------------------- - // We handle these *before* splitting on `.` because the function - // name contains parentheses which would confuse the dot-split. - // - // Supported patterns: - // result() → last task result - // result().foo.bar → nested access into result - // result().data.items → nested access into result - // succeeded() → boolean - // failed() → boolean - // timed_out() → boolean - // --------------------------------------------------------------- - - if let Some(result_val) = self.try_evaluate_function_call(expr)? { - return Ok(result_val); - } - - // --------------------------------------------------------------- - // Dot-path expressions - // --------------------------------------------------------------- - let parts: Vec<&str> = expr.split('.').collect(); - - if parts.is_empty() { - return Err(ContextError::InvalidExpression(expr.to_string())); - } - - match parts[0] { - "parameters" => self.get_nested_value(&self.parameters, &parts[1..]), - "vars" | "variables" => { - if parts.len() < 2 { - return Err(ContextError::InvalidExpression(expr.to_string())); - } - let var_name = parts[1]; - if let Some(entry) = self.variables.get(var_name) { - let value = entry.value().clone(); - drop(entry); - if parts.len() > 2 { - self.get_nested_value(&value, &parts[2..]) - } else { - Ok(value) - } - } else { - Err(ContextError::VariableNotFound(var_name.to_string())) - } - } - "task" | "tasks" => { - if parts.len() < 2 { - return Err(ContextError::InvalidExpression(expr.to_string())); - } - let task_name = parts[1]; - if let Some(entry) = self.task_results.get(task_name) { - let result = entry.value().clone(); - drop(entry); - if parts.len() > 2 { - self.get_nested_value(&result, &parts[2..]) - } else { - Ok(result) - } - } else { - Err(ContextError::VariableNotFound(format!( - "task.{}", - task_name - ))) - } - } - "item" => { - if let Some(ref item) = self.current_item { - if parts.len() > 1 { - self.get_nested_value(item, &parts[1..]) - } else { - Ok(item.clone()) - } - } else { - Err(ContextError::VariableNotFound("item".to_string())) - } - } - "index" => { - if let Some(index) = self.current_index { - Ok(json!(index)) - } else { - Err(ContextError::VariableNotFound("index".to_string())) - } - } - "system" => { - if parts.len() < 2 { - return Err(ContextError::InvalidExpression(expr.to_string())); - } - let key = parts[1]; - if let Some(entry) = self.system.get(key) { - Ok(entry.value().clone()) - } else { - Err(ContextError::VariableNotFound(format!("system.{}", key))) - } - } - // Direct variable reference (e.g., `number_list` published by a - // previous task's transition) - var_name => { - if let Some(entry) = self.variables.get(var_name) { - let value = entry.value().clone(); - drop(entry); - if parts.len() > 1 { - self.get_nested_value(&value, &parts[1..]) - } else { - Ok(value) - } - } else { - Err(ContextError::VariableNotFound(var_name.to_string())) - } - } - } - } - - /// Try to evaluate `expr` as a function-call expression. + /// Evaluate a template expression using the expression engine. /// - /// Returns `Ok(Some(value))` if the expression starts with a recognised - /// function call, `Ok(None)` if it does not match, or `Err` on failure. - fn try_evaluate_function_call(&self, expr: &str) -> ContextResult> { - // succeeded() - if expr == "succeeded()" { - let val = self - .last_task_outcome - .map(|o| o == TaskOutcome::Succeeded) - .unwrap_or(false); - return Ok(Some(json!(val))); - } - - // failed() - if expr == "failed()" { - let val = self - .last_task_outcome - .map(|o| o == TaskOutcome::Failed) - .unwrap_or(false); - return Ok(Some(json!(val))); - } - - // timed_out() - if expr == "timed_out()" { - let val = self - .last_task_outcome - .map(|o| o == TaskOutcome::TimedOut) - .unwrap_or(false); - return Ok(Some(json!(val))); - } - - // result() or result().path.to.field - if expr == "result()" || expr.starts_with("result().") { - let base = self.last_task_result.clone().unwrap_or(JsonValue::Null); - - if expr == "result()" { - return Ok(Some(base)); - } - - // Strip "result()." prefix and navigate the remaining path - let rest = &expr["result().".len()..]; - let path_parts: Vec<&str> = rest.split('.').collect(); - let val = self.get_nested_value(&base, &path_parts)?; - return Ok(Some(val)); - } - - Ok(None) + /// Supports the full expression language including arithmetic, comparison, + /// boolean logic, member access, and built-in functions. Falls back to + /// legacy dot-path resolution for simple variable references when the + /// expression engine cannot parse the input. + fn evaluate_expression(&self, expr: &str) -> ContextResult { + // Use the expression engine for all expressions. It handles: + // - Dot-path access: parameters.config.port + // - Bracket access: arr[0], obj["key"] + // - Arithmetic: 2 + 3, length(items) * 2 + // - Comparison: x > 5, status == "ok" + // - Boolean logic: x > 0 and x < 10 + // - Function calls: length(arr), result(), succeeded() + // - Membership: "key" in obj, 5 in arr + expression::eval_expression(expr, self).map_err(|e| match e { + EvalError::VariableNotFound(name) => ContextError::VariableNotFound(name), + EvalError::TypeError(msg) => ContextError::TypeConversion(msg), + EvalError::ParseError(msg) => ContextError::InvalidExpression(msg), + other => ContextError::InvalidExpression(format!("{}", other)), + }) } - /// Get nested value from JSON - fn get_nested_value(&self, value: &JsonValue, path: &[&str]) -> ContextResult { - let mut current = value; - - for key in path { - match current { - JsonValue::Object(obj) => { - current = obj - .get(*key) - .ok_or_else(|| ContextError::VariableNotFound(key.to_string()))?; - } - JsonValue::Array(arr) => { - let index: usize = key.parse().map_err(|_| { - ContextError::InvalidExpression(format!("Invalid array index: {}", key)) - })?; - current = arr.get(index).ok_or_else(|| { - ContextError::InvalidExpression(format!( - "Array index out of bounds: {}", - index - )) - })?; - } - _ => { - return Err(ContextError::InvalidExpression(format!( - "Cannot access property '{}' on non-object/array value", - key - ))); - } - } - } - - Ok(current.clone()) - } - - /// Evaluate a conditional expression (for 'when' clauses) + /// Evaluate a conditional expression (for 'when' clauses). + /// + /// Uses the full expression engine so conditions can contain comparisons, + /// boolean operators, function calls, and arithmetic. For example: + /// + /// ```text + /// succeeded() + /// result().status == "ok" + /// length(items) > 3 and "admin" in roles + /// not failed() + /// ``` pub fn evaluate_condition(&self, condition: &str) -> ContextResult { - // For now, simple boolean evaluation - // TODO: Support more complex expressions (comparisons, logical operators) - - let rendered = self.render_template(condition)?; - - // Try to parse as boolean - match rendered.trim().to_lowercase().as_str() { - "true" | "1" | "yes" => Ok(true), - "false" | "0" | "no" | "" => Ok(false), - other => { - // Try to evaluate as truthy/falsy - Ok(!other.is_empty()) + // Try the expression engine first — it handles complex conditions + // like `result().code == 200 and succeeded()`. + match expression::eval_expression(condition, self) { + Ok(val) => Ok(is_truthy(&val)), + Err(_) => { + // Fall back to template rendering for backward compat with + // simple template conditions like `{{ succeeded() }}` (though + // bare expressions are preferred going forward). + let rendered = self.render_template(condition)?; + match rendered.trim().to_lowercase().as_str() { + "true" | "1" | "yes" => Ok(true), + "false" | "0" | "no" | "" => Ok(false), + _ => Ok(!rendered.trim().is_empty()), + } } } } @@ -574,6 +464,8 @@ impl WorkflowContext { "parameters": self.parameters.as_ref(), "task_results": task_results, "system": system, + "pack_config": self.pack_config.as_ref(), + "keystore": self.keystore.as_ref(), }) } @@ -602,11 +494,16 @@ impl WorkflowContext { } } + let pack_config = data["pack_config"].clone(); + let keystore = data["keystore"].clone(); + Ok(Self { variables: Arc::new(variables), parameters: Arc::new(parameters), task_results: Arc::new(task_results), system: Arc::new(system), + pack_config: Arc::new(pack_config), + keystore: Arc::new(keystore), current_item: None, current_index: None, last_task_result: None, @@ -626,10 +523,122 @@ fn value_to_string(value: &JsonValue) -> String { } } +// --------------------------------------------------------------- +// EvalContext implementation — bridges the expression engine into +// the WorkflowContext's variable resolution and workflow functions. +// --------------------------------------------------------------- + +impl EvalContext for WorkflowContext { + fn resolve_variable(&self, name: &str) -> ExprResult { + match name { + // ── Canonical namespaces ────────────────────────────── + "parameters" => Ok(self.parameters.as_ref().clone()), + + // `workflow` is the canonical name for mutable vars. + // `vars` and `variables` are backward-compatible aliases. + "workflow" | "vars" | "variables" => { + let map: serde_json::Map = self + .variables + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + Ok(JsonValue::Object(map)) + } + + // `task` (alias: `tasks`) — completed task results. + "task" | "tasks" => { + let map: serde_json::Map = self + .task_results + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + Ok(JsonValue::Object(map)) + } + + // `config` — pack configuration (read-only). + "config" => Ok(self.pack_config.as_ref().clone()), + + // `keystore` — encrypted secrets (read-only). + "keystore" => Ok(self.keystore.as_ref().clone()), + + // ── Iteration context ──────────────────────────────── + "item" => self + .current_item + .clone() + .ok_or_else(|| EvalError::VariableNotFound("item".to_string())), + "index" => self + .current_index + .map(|i| json!(i)) + .ok_or_else(|| EvalError::VariableNotFound("index".to_string())), + + // ── System variables ────────────────────────────────── + "system" => { + let map: serde_json::Map = self + .system + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + Ok(JsonValue::Object(map)) + } + + // ── Bare-name fallback ─────────────────────────────── + // Resolve against workflow variables last so that + // `{{ my_var }}` still works as shorthand for + // `{{ workflow.my_var }}`. + _ => { + if let Some(entry) = self.variables.get(name) { + Ok(entry.value().clone()) + } else { + Err(EvalError::VariableNotFound(name.to_string())) + } + } + } + } + + fn call_workflow_function( + &self, + name: &str, + _args: &[JsonValue], + ) -> ExprResult> { + match name { + "succeeded" => { + let val = self + .last_task_outcome + .map(|o| o == TaskOutcome::Succeeded) + .unwrap_or(false); + Ok(Some(json!(val))) + } + "failed" => { + let val = self + .last_task_outcome + .map(|o| o == TaskOutcome::Failed) + .unwrap_or(false); + Ok(Some(json!(val))) + } + "timed_out" => { + let val = self + .last_task_outcome + .map(|o| o == TaskOutcome::TimedOut) + .unwrap_or(false); + Ok(Some(json!(val))) + } + "result" => { + let base = self.last_task_result.clone().unwrap_or(JsonValue::Null); + Ok(Some(base)) + } + _ => Ok(None), + } + } +} + #[cfg(test)] mod tests { use super::*; + // --------------------------------------------------------------- + // parameters namespace + // --------------------------------------------------------------- + #[test] fn test_basic_template_rendering() { let params = json!({ @@ -641,28 +650,6 @@ mod tests { assert_eq!(result, "Hello World!"); } - #[test] - fn test_variable_access() { - let mut vars = HashMap::new(); - vars.insert("greeting".to_string(), json!("Hello")); - - let ctx = WorkflowContext::new(json!({}), vars); - - let result = ctx.render_template("{{ greeting }} World").unwrap(); - assert_eq!(result, "Hello World"); - } - - #[test] - fn test_task_result_access() { - let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); - ctx.set_task_result("task1", json!({"status": "success"})); - - let result = ctx - .render_template("Status: {{ task.task1.status }}") - .unwrap(); - assert_eq!(result, "Status: success"); - } - #[test] fn test_nested_value_access() { let params = json!({ @@ -680,6 +667,143 @@ mod tests { assert_eq!(result, "Port: 8080"); } + // --------------------------------------------------------------- + // workflow namespace (canonical) + vars/variables aliases + // --------------------------------------------------------------- + + #[test] + fn test_workflow_namespace_canonical() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_var("greeting", json!("Hello")); + + // Canonical: workflow. + let result = ctx.render_template("{{ workflow.greeting }} World").unwrap(); + assert_eq!(result, "Hello World"); + } + + #[test] + fn test_workflow_namespace_vars_alias() { + let mut vars = HashMap::new(); + vars.insert("greeting".to_string(), json!("Hello")); + let ctx = WorkflowContext::new(json!({}), vars); + + // Backward-compat alias: vars. + let result = ctx.render_template("{{ vars.greeting }} World").unwrap(); + assert_eq!(result, "Hello World"); + } + + #[test] + fn test_workflow_namespace_variables_alias() { + let mut vars = HashMap::new(); + vars.insert("greeting".to_string(), json!("Hello")); + let ctx = WorkflowContext::new(json!({}), vars); + + // Backward-compat alias: variables. + let result = ctx.render_template("{{ variables.greeting }} World").unwrap(); + assert_eq!(result, "Hello World"); + } + + #[test] + fn test_variable_access_bare_name_fallback() { + let mut vars = HashMap::new(); + vars.insert("greeting".to_string(), json!("Hello")); + + let ctx = WorkflowContext::new(json!({}), vars); + + // Bare name falls back to workflow variables + let result = ctx.render_template("{{ greeting }} World").unwrap(); + assert_eq!(result, "Hello World"); + } + + // --------------------------------------------------------------- + // task namespace + // --------------------------------------------------------------- + + #[test] + fn test_task_result_access() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_task_result("task1", json!({"status": "success"})); + + let result = ctx + .render_template("Status: {{ task.task1.status }}") + .unwrap(); + assert_eq!(result, "Status: success"); + } + + #[test] + fn test_task_result_deep_access() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_task_result("fetch", json!({"result": {"data": {"id": 42}}})); + + let val = ctx.evaluate_expression("task.fetch.result.data.id").unwrap(); + assert_eq!(val, json!(42)); + } + + #[test] + fn test_task_result_stdout() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_task_result("run_cmd", json!({"result": {"stdout": "hello world"}})); + + let val = ctx.evaluate_expression("task.run_cmd.result.stdout").unwrap(); + assert_eq!(val, json!("hello world")); + } + + // --------------------------------------------------------------- + // config namespace (pack configuration) + // --------------------------------------------------------------- + + #[test] + fn test_config_namespace() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_pack_config(json!({"api_token": "tok_abc123", "base_url": "https://api.example.com"})); + + let val = ctx.evaluate_expression("config.api_token").unwrap(); + assert_eq!(val, json!("tok_abc123")); + + let result = ctx + .render_template("URL: {{ config.base_url }}") + .unwrap(); + assert_eq!(result, "URL: https://api.example.com"); + } + + #[test] + fn test_config_namespace_nested() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_pack_config(json!({"slack": {"webhook_url": "https://hooks.slack.com/xxx"}})); + + let val = ctx.evaluate_expression("config.slack.webhook_url").unwrap(); + assert_eq!(val, json!("https://hooks.slack.com/xxx")); + } + + // --------------------------------------------------------------- + // keystore namespace (encrypted secrets) + // --------------------------------------------------------------- + + #[test] + fn test_keystore_namespace() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_keystore(json!({"secret_key": "s3cr3t", "db_password": "hunter2"})); + + let val = ctx.evaluate_expression("keystore.secret_key").unwrap(); + assert_eq!(val, json!("s3cr3t")); + + let val = ctx.evaluate_expression("keystore.db_password").unwrap(); + assert_eq!(val, json!("hunter2")); + } + + #[test] + fn test_keystore_bracket_access() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_keystore(json!({"My Secret Key": "value123"})); + + let val = ctx.evaluate_expression("keystore[\"My Secret Key\"]").unwrap(); + assert_eq!(val, json!("value123")); + } + + // --------------------------------------------------------------- + // item / index (with_items iteration) + // --------------------------------------------------------------- + #[test] fn test_item_context() { let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); @@ -691,6 +815,10 @@ mod tests { assert_eq!(result, "Item: item1, Index: 0"); } + // --------------------------------------------------------------- + // Condition evaluation + // --------------------------------------------------------------- + #[test] fn test_condition_evaluation() { let params = json!({"enabled": true}); @@ -700,6 +828,133 @@ mod tests { assert!(!ctx.evaluate_condition("false").unwrap()); } + #[test] + fn test_condition_with_comparison() { + let ctx = WorkflowContext::new(json!({"count": 10}), HashMap::new()); + assert!(ctx.evaluate_condition("parameters.count > 5").unwrap()); + assert!(!ctx.evaluate_condition("parameters.count < 5").unwrap()); + assert!(ctx.evaluate_condition("parameters.count == 10").unwrap()); + assert!(ctx.evaluate_condition("parameters.count >= 10").unwrap()); + assert!(ctx.evaluate_condition("parameters.count != 99").unwrap()); + } + + #[test] + fn test_condition_with_boolean_operators() { + let ctx = WorkflowContext::new(json!({"x": 10, "y": 20}), HashMap::new()); + assert!(ctx + .evaluate_condition("parameters.x > 5 and parameters.y > 15") + .unwrap()); + assert!(!ctx + .evaluate_condition("parameters.x > 5 and parameters.y > 25") + .unwrap()); + assert!(ctx + .evaluate_condition("parameters.x > 50 or parameters.y > 15") + .unwrap()); + assert!(ctx + .evaluate_condition("not parameters.x > 50") + .unwrap()); + } + + #[test] + fn test_condition_with_in_operator() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_var("roles", json!(["admin", "user"])); + // Via bare-name fallback + assert!(ctx.evaluate_condition("\"admin\" in roles").unwrap()); + assert!(!ctx.evaluate_condition("\"root\" in roles").unwrap()); + // Via canonical workflow namespace + assert!(ctx.evaluate_condition("\"admin\" in workflow.roles").unwrap()); + } + + #[test] + fn test_condition_with_function_calls() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_last_task_outcome( + json!({"status": "ok", "code": 200}), + TaskOutcome::Succeeded, + ); + assert!(ctx.evaluate_condition("succeeded()").unwrap()); + assert!(!ctx.evaluate_condition("failed()").unwrap()); + assert!(ctx + .evaluate_condition("succeeded() and result().code == 200") + .unwrap()); + assert!(!ctx + .evaluate_condition("succeeded() and result().code == 404") + .unwrap()); + } + + #[test] + fn test_condition_with_length() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_var("items", json!([1, 2, 3, 4, 5])); + assert!(ctx.evaluate_condition("length(items) > 3").unwrap()); + assert!(!ctx.evaluate_condition("length(items) > 10").unwrap()); + assert!(ctx + .evaluate_condition("length(items) == 5") + .unwrap()); + } + + #[test] + fn test_condition_with_config() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_pack_config(json!({"retries": 3})); + assert!(ctx.evaluate_condition("config.retries > 0").unwrap()); + assert!(ctx.evaluate_condition("config.retries == 3").unwrap()); + } + + // --------------------------------------------------------------- + // Expression engine in templates + // --------------------------------------------------------------- + + #[test] + fn test_expression_arithmetic() { + let ctx = WorkflowContext::new(json!({"x": 10}), HashMap::new()); + let input = json!({"result": "{{ parameters.x + 5 }}"}); + let result = ctx.render_json(&input).unwrap(); + assert_eq!(result["result"], json!(15)); + } + + #[test] + fn test_expression_string_concat() { + let ctx = WorkflowContext::new( + json!({"first": "Hello", "second": "World"}), + HashMap::new(), + ); + let input = json!({"msg": "{{ parameters.first + \" \" + parameters.second }}"}); + let result = ctx.render_json(&input).unwrap(); + assert_eq!(result["msg"], json!("Hello World")); + } + + #[test] + fn test_expression_nested_functions() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_var("data", json!("a,b,c")); + let input = json!({"count": "{{ length(split(data, \",\")) }}"}); + let result = ctx.render_json(&input).unwrap(); + assert_eq!(result["count"], json!(3)); + } + + #[test] + fn test_expression_bracket_access() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_var("arr", json!([10, 20, 30])); + let input = json!({"second": "{{ arr[1] }}"}); + let result = ctx.render_json(&input).unwrap(); + assert_eq!(result["second"], json!(20)); + } + + #[test] + fn test_expression_type_conversion() { + let ctx = WorkflowContext::new(json!({}), HashMap::new()); + let input = json!({"val": "{{ int(3.9) }}"}); + let result = ctx.render_json(&input).unwrap(); + assert_eq!(result["val"], json!(3)); + } + + // --------------------------------------------------------------- + // render_json type-preserving behaviour + // --------------------------------------------------------------- + #[test] fn test_render_json() { let params = json!({"name": "test"}); @@ -769,6 +1024,10 @@ mod tests { assert!(result["ok"].is_boolean()); } + // --------------------------------------------------------------- + // result() / succeeded() / failed() / timed_out() + // --------------------------------------------------------------- + #[test] fn test_result_function() { let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); @@ -813,6 +1072,10 @@ mod tests { assert_eq!(ctx.evaluate_expression("timed_out()").unwrap(), json!(true)); } + // --------------------------------------------------------------- + // Publish + // --------------------------------------------------------------- + #[test] fn test_publish_with_result_function() { let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); @@ -846,6 +1109,28 @@ mod tests { assert_eq!(ctx.get_var("my_var").unwrap(), result); } + #[test] + fn test_published_var_accessible_via_workflow_namespace() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_var("counter", json!(42)); + + // Via canonical namespace + let val = ctx.evaluate_expression("workflow.counter").unwrap(); + assert_eq!(val, json!(42)); + + // Via backward-compat alias + let val = ctx.evaluate_expression("vars.counter").unwrap(); + assert_eq!(val, json!(42)); + + // Via bare-name fallback + let val = ctx.evaluate_expression("counter").unwrap(); + assert_eq!(val, json!(42)); + } + + // --------------------------------------------------------------- + // Rebuild / Export / Import round-trip + // --------------------------------------------------------------- + #[test] fn test_rebuild_context() { let stored_vars = json!({"number_list": [0, 1, 2]}); @@ -868,17 +1153,31 @@ mod tests { let mut ctx = WorkflowContext::new(json!({"key": "value"}), HashMap::new()); ctx.set_var("test", json!("data")); ctx.set_task_result("task1", json!({"result": "ok"})); + ctx.set_pack_config(json!({"setting": "val"})); + ctx.set_keystore(json!({"secret": "hidden"})); let exported = ctx.export(); - let _imported = WorkflowContext::import(exported).unwrap(); + let imported = WorkflowContext::import(exported).unwrap(); - assert_eq!(ctx.get_var("test").unwrap(), json!("data")); + assert_eq!(imported.get_var("test").unwrap(), json!("data")); assert_eq!( - ctx.get_task_result("task1").unwrap(), + imported.get_task_result("task1").unwrap(), json!({"result": "ok"}) ); + assert_eq!( + imported.evaluate_expression("config.setting").unwrap(), + json!("val") + ); + assert_eq!( + imported.evaluate_expression("keystore.secret").unwrap(), + json!("hidden") + ); } + // --------------------------------------------------------------- + // with_items type preservation + // --------------------------------------------------------------- + #[test] fn test_with_items_integer_type_preservation() { // Simulates the sleep_2 task from the hello_workflow: @@ -902,4 +1201,40 @@ mod tests { assert_eq!(rendered["message"], json!("Sleeping for 3 seconds ")); assert!(rendered["message"].is_string()); } + + // --------------------------------------------------------------- + // Cross-namespace expressions + // --------------------------------------------------------------- + + #[test] + fn test_cross_namespace_expression() { + let mut ctx = WorkflowContext::new(json!({"limit": 5}), HashMap::new()); + ctx.set_var("items", json!([1, 2, 3])); + ctx.set_pack_config(json!({"multiplier": 2})); + + assert!(ctx + .evaluate_condition("length(workflow.items) < parameters.limit") + .unwrap()); + let val = ctx + .evaluate_expression("parameters.limit * config.multiplier") + .unwrap(); + assert_eq!(val, json!(10)); + } + + #[test] + fn test_keystore_in_template() { + let mut ctx = WorkflowContext::new(json!({}), HashMap::new()); + ctx.set_keystore(json!({"api_key": "abc-123"})); + + let input = json!({"auth": "Bearer {{ keystore.api_key }}"}); + let result = ctx.render_json(&input).unwrap(); + assert_eq!(result["auth"], json!("Bearer abc-123")); + } + + #[test] + fn test_config_null_when_not_set() { + let ctx = WorkflowContext::new(json!({}), HashMap::new()); + let val = ctx.evaluate_expression("config").unwrap(); + assert_eq!(val, json!(null)); + } } diff --git a/crates/worker/src/executor.rs b/crates/worker/src/executor.rs index 4869516..2db978b 100644 --- a/crates/worker/src/executor.rs +++ b/crates/worker/src/executor.rs @@ -654,8 +654,7 @@ impl ActionExecutor { let input = UpdateExecutionInput { status: Some(ExecutionStatus::Completed), result: Some(result_data), - executor: None, - workflow_task: None, // Not updating workflow metadata + ..Default::default() }; ExecutionRepository::update(&self.pool, execution_id, input).await?; @@ -755,8 +754,7 @@ impl ActionExecutor { let input = UpdateExecutionInput { status: Some(ExecutionStatus::Failed), result: Some(result_data), - executor: None, - workflow_task: None, // Not updating workflow metadata + ..Default::default() }; ExecutionRepository::update(&self.pool, execution_id, input).await?; @@ -775,11 +773,16 @@ impl ActionExecutor { execution_id, status ); + let started_at = if status == ExecutionStatus::Running { + Some(chrono::Utc::now()) + } else { + None + }; + let input = UpdateExecutionInput { status: Some(status), - result: None, - executor: None, - workflow_task: None, // Not updating workflow metadata + started_at, + ..Default::default() }; ExecutionRepository::update(&self.pool, execution_id, input).await?; diff --git a/docs/examples/complete-workflow.yaml b/docs/examples/complete-workflow.yaml index b3479f6..de9b133 100644 --- a/docs/examples/complete-workflow.yaml +++ b/docs/examples/complete-workflow.yaml @@ -134,7 +134,7 @@ tasks: publish: - approval_granted: "{{ task.require_production_approval.result.approved }}" decision: - - when: "{{ vars.approval_granted == true }}" + - when: "{{ workflow.approval_granted == true }}" next: create_deployment_record - default: cancel_deployment on_timeout: deployment_approval_timeout @@ -244,7 +244,7 @@ tasks: publish: - canary_passed: "{{ task.monitor_canary.result.success }}" decision: - - when: "{{ vars.canary_passed == true }}" + - when: "{{ workflow.canary_passed == true }}" next: promote_canary - default: rollback_canary @@ -305,7 +305,7 @@ tasks: - name: parallel_health_checks action: http.get - with_items: "{{ vars.health_check_urls }}" + with_items: "{{ workflow.health_check_urls }}" batch_size: 5 # Check 5 URLs at a time input: url: "{{ item }}" @@ -323,7 +323,7 @@ tasks: input: app_name: "{{ parameters.app_name }}" environment: "{{ parameters.environment }}" - base_urls: "{{ vars.health_check_urls }}" + base_urls: "{{ workflow.health_check_urls }}" timeout: 600 on_success: verify_metrics on_failure: handle_smoke_test_failures @@ -361,12 +361,12 @@ tasks: - name: finalize_deployment action: deployments.update_status input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" status: "success" metadata: version: "{{ parameters.version }}" regions: "{{ parameters.regions }}" - duration: "{{ system.timestamp - vars.start_time }}" + duration: "{{ system.timestamp - workflow.start_time }}" publish: - successful_regions: "{{ parameters.regions }}" on_success: post_deployment_tasks @@ -392,7 +392,7 @@ tasks: input: app_name: "{{ parameters.app_name }}" version: "{{ parameters.version }}" - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" on_complete: notify_deployment_success # ============================================================================ @@ -402,7 +402,7 @@ tasks: - name: handle_deployment_failures action: deployments.analyze_failures input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" failed_tasks: "{{ task.deploy_to_all_regions.failed_items }}" publish: - failed_regions: "{{ task.handle_deployment_failures.result.failed_regions }}" @@ -414,7 +414,7 @@ tasks: - name: handle_health_check_failures action: diagnostics.analyze_health_failures input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" failed_urls: "{{ task.parallel_health_checks.failed_items }}" decision: - when: "{{ parameters.rollback_on_failure == true }}" @@ -424,7 +424,7 @@ tasks: - name: handle_smoke_test_failures action: testing.capture_smoke_test_results input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" results: "{{ task.run_smoke_tests.result }}" decision: - when: "{{ parameters.rollback_on_failure == true }}" @@ -434,7 +434,7 @@ tasks: - name: handle_metrics_failures action: monitoring.capture_metric_violations input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" violations: "{{ task.verify_metrics.result.violations }}" decision: - when: "{{ parameters.rollback_on_failure == true }}" @@ -467,11 +467,11 @@ tasks: - name: update_deployment_rolled_back action: deployments.update_status input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" status: "rolled_back" metadata: reason: "deployment_failure" - failed_regions: "{{ vars.failed_regions }}" + failed_regions: "{{ workflow.failed_regions }}" on_complete: notify_deployment_rolled_back # ============================================================================ @@ -481,7 +481,7 @@ tasks: - name: cancel_deployment action: deployments.update_status input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" status: "cancelled" metadata: reason: "approval_denied" @@ -490,7 +490,7 @@ tasks: - name: deployment_approval_timeout action: deployments.update_status input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" status: "cancelled" metadata: reason: "approval_timeout" @@ -499,7 +499,7 @@ tasks: - name: cleanup_failed_deployment action: deployments.cleanup_resources input: - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" reason: "pre_deployment_checks_failed" on_complete: notify_deployment_failed @@ -516,12 +516,12 @@ tasks: Application: {{ parameters.app_name }} Version: {{ parameters.version }} Environment: {{ parameters.environment }} - Regions: {{ vars.successful_regions | join(', ') }} - Duration: {{ system.timestamp - vars.start_time }}s - Deployment ID: {{ vars.deployment_id }} + Regions: {{ workflow.successful_regions | join(', ') }} + Duration: {{ system.timestamp - workflow.start_time }}s + Deployment ID: {{ workflow.deployment_id }} metadata: severity: "info" - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" - name: notify_deployment_failed action: notifications.send_multi_channel @@ -532,11 +532,11 @@ tasks: Application: {{ parameters.app_name }} Version: {{ parameters.version }} Environment: {{ parameters.environment }} - Failed Regions: {{ vars.failed_regions | join(', ') }} - Deployment ID: {{ vars.deployment_id }} + Failed Regions: {{ workflow.failed_regions | join(', ') }} + Deployment ID: {{ workflow.deployment_id }} metadata: severity: "error" - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" - name: notify_deployment_rolled_back action: notifications.send_multi_channel @@ -548,10 +548,10 @@ tasks: Version: {{ parameters.version }} Environment: {{ parameters.environment }} Rollback completed for all regions - Deployment ID: {{ vars.deployment_id }}" + Deployment ID: {{ workflow.deployment_id }}" metadata: severity: "warning" - deployment_id: "{{ vars.deployment_id }}" + deployment_id: "{{ workflow.deployment_id }}" - name: notify_deployment_cancelled action: notifications.send_multi_channel @@ -648,8 +648,8 @@ tasks: ⚠️ Partial Deployment Application: {{ parameters.app_name }} Version: {{ parameters.version }} - Successful Regions: {{ vars.successful_regions | join(', ') }} - Failed Regions: {{ vars.failed_regions | join(', ') }} + Successful Regions: {{ workflow.successful_regions | join(', ') }} + Failed Regions: {{ workflow.failed_regions | join(', ') }} metadata: severity: "error" @@ -703,8 +703,8 @@ tasks: # Workflow output mapping output_map: - deployment_id: "{{ vars.deployment_id }}" - status: "{{ vars.rollback_initiated ? 'rolled_back' : 'success' }}" + deployment_id: "{{ workflow.deployment_id }}" + status: "{{ workflow.rollback_initiated ? 'rolled_back' : 'success' }}" deployed_version: "{{ parameters.version }}" - deployment_urls: "{{ vars.health_check_urls }}" - duration_seconds: "{{ system.timestamp - vars.start_time }}" + deployment_urls: "{{ workflow.health_check_urls }}" + duration_seconds: "{{ system.timestamp - workflow.start_time }}" diff --git a/docs/examples/simple-workflow.yaml b/docs/examples/simple-workflow.yaml index c580e66..75ce0e4 100644 --- a/docs/examples/simple-workflow.yaml +++ b/docs/examples/simple-workflow.yaml @@ -41,9 +41,12 @@ tasks: action: core.echo input: message: "Starting workflow with: {{ parameters.message }}" - publish: - - timestamp: "{{ system.timestamp }}" - on_success: process_message + next: + - when: "{{ succeeded() }}" + publish: + - timestamp: "{{ system.workflow_start }}" + do: + - process_message # Task 2: Process the message - name: process_message @@ -51,18 +54,21 @@ tasks: input: text: "{{ parameters.message }}" uppercase: "{{ parameters.uppercase }}" - publish: - - processed_message: "{{ task.process_message.result.text }}" - on_success: finalize + next: + - when: "{{ succeeded() }}" + publish: + - processed_message: "{{ task.process_message.result.text }}" + do: + - finalize # Task 3: Finalize and log result - name: finalize action: core.echo input: - message: "Workflow complete. Result: {{ vars.processed_message }}" + message: "Workflow complete. Result: {{ workflow.processed_message }}" # Map workflow outputs output_map: original: "{{ parameters.message }}" - processed: "{{ vars.processed_message }}" + processed: "{{ workflow.processed_message }}" final: "{{ task.finalize.result.message }}" diff --git a/migrations/20250101000005_execution_and_operations.sql b/migrations/20250101000005_execution_and_operations.sql index e5f3108..6146f77 100644 --- a/migrations/20250101000005_execution_and_operations.sql +++ b/migrations/20250101000005_execution_and_operations.sql @@ -28,6 +28,7 @@ CREATE TABLE execution ( executor BIGINT, -- references identity(id); no FK because execution becomes a hypertable status execution_status_enum NOT NULL DEFAULT 'requested', result JSONB, + started_at TIMESTAMPTZ, -- set when execution transitions to 'running' created TIMESTAMPTZ NOT NULL DEFAULT NOW(), is_workflow BOOLEAN DEFAULT false NOT NULL, workflow_def BIGINT, -- references workflow_definition(id); no FK because execution becomes a hypertable diff --git a/migrations/20250101000008_notify_triggers.sql b/migrations/20250101000008_notify_triggers.sql index a068337..127958d 100644 --- a/migrations/20250101000008_notify_triggers.sql +++ b/migrations/20250101000008_notify_triggers.sql @@ -34,6 +34,7 @@ BEGIN 'trigger_ref', enforcement_trigger_ref, 'parent', NEW.parent, 'result', NEW.result, + 'started_at', NEW.started_at, 'created', NEW.created, 'updated', NEW.updated ); @@ -75,6 +76,7 @@ BEGIN 'trigger_ref', enforcement_trigger_ref, 'parent', NEW.parent, 'result', NEW.result, + 'started_at', NEW.started_at, 'created', NEW.created, 'updated', NEW.updated ); diff --git a/migrations/20250101000009_timescaledb_history.sql b/migrations/20250101000009_timescaledb_history.sql index 02d832f..bdeae86 100644 --- a/migrations/20250101000009_timescaledb_history.sql +++ b/migrations/20250101000009_timescaledb_history.sql @@ -196,7 +196,7 @@ COMMENT ON TABLE execution IS 'Executions represent action runs with workflow su -- ---------------------------------------------------------------------------- -- execution history trigger --- Tracked fields: status, result, executor, workflow_task, env_vars +-- Tracked fields: status, result, executor, workflow_task, env_vars, started_at -- Note: result uses _jsonb_digest_summary() to avoid storing large payloads -- ---------------------------------------------------------------------------- @@ -215,7 +215,8 @@ BEGIN 'action_ref', NEW.action_ref, 'executor', NEW.executor, 'parent', NEW.parent, - 'enforcement', NEW.enforcement + 'enforcement', NEW.enforcement, + 'started_at', NEW.started_at )); RETURN NEW; END IF; @@ -260,6 +261,12 @@ BEGIN new_vals := new_vals || jsonb_build_object('env_vars', NEW.env_vars); END IF; + IF OLD.started_at IS DISTINCT FROM NEW.started_at THEN + changed := array_append(changed, 'started_at'); + old_vals := old_vals || jsonb_build_object('started_at', OLD.started_at); + new_vals := new_vals || jsonb_build_object('started_at', NEW.started_at); + END IF; + -- Only record if something actually changed IF array_length(changed, 1) > 0 THEN INSERT INTO execution_history (time, operation, entity_id, entity_ref, changed_fields, old_values, new_values) diff --git a/packs.external/nodejs_example b/packs.external/nodejs_example index 62c42b3..44b1181 160000 --- a/packs.external/nodejs_example +++ b/packs.external/nodejs_example @@ -1 +1 @@ -Subproject commit 62c42b399668e63486d99250db5cdfcb256352b0 +Subproject commit 44b1181d206b95cb1ea738cdb24fae6da4150783 diff --git a/packs/core/workflows/install_packs.yaml b/packs/core/workflows/install_packs.yaml index 7cd389b..e6abc53 100644 --- a/packs/core/workflows/install_packs.yaml +++ b/packs/core/workflows/install_packs.yaml @@ -80,7 +80,7 @@ tasks: action: core.download_packs input: packs: "{{ parameters.packs }}" - destination_dir: "{{ vars.temp_dir }}" + destination_dir: "{{ workflow.temp_dir }}" registry_url: "{{ parameters.registry_url }}" ref_spec: "{{ parameters.ref_spec }}" api_url: "{{ parameters.api_url }}" @@ -109,7 +109,7 @@ tasks: - name: get_dependencies action: core.get_pack_dependencies input: - pack_paths: "{{ vars.downloaded_packs | map(attribute='pack_path') | list }}" + pack_paths: "{{ workflow.downloaded_packs | map(attribute='pack_path') | list }}" api_url: "{{ parameters.api_url }}" skip_validation: false publish: @@ -125,7 +125,7 @@ tasks: - name: install_dependencies action: core.install_packs input: - packs: "{{ vars.missing_dependencies | map(attribute='pack_ref') | list }}" + packs: "{{ workflow.missing_dependencies | map(attribute='pack_ref') | list }}" skip_dependencies: false skip_tests: "{{ parameters.skip_tests }}" skip_env_build: "{{ parameters.skip_env_build }}" @@ -147,7 +147,7 @@ tasks: - name: build_environments action: core.build_pack_envs input: - pack_paths: "{{ vars.downloaded_packs | map(attribute='pack_path') | list }}" + pack_paths: "{{ workflow.downloaded_packs | map(attribute='pack_path') | list }}" packs_base_dir: "{{ parameters.packs_base_dir }}" python_version: "3.11" nodejs_version: "20" @@ -174,7 +174,7 @@ tasks: - name: run_tests action: core.run_pack_tests input: - pack_paths: "{{ vars.downloaded_packs | map(attribute='pack_path') | list }}" + pack_paths: "{{ workflow.downloaded_packs | map(attribute='pack_path') | list }}" timeout: 300 fail_on_error: false on_success: register_packs @@ -188,7 +188,7 @@ tasks: - name: register_packs action: core.register_packs input: - pack_paths: "{{ vars.downloaded_packs | map(attribute='pack_path') | list }}" + pack_paths: "{{ workflow.downloaded_packs | map(attribute='pack_path') | list }}" packs_base_dir: "{{ parameters.packs_base_dir }}" skip_validation: false skip_tests: "{{ parameters.skip_tests }}" @@ -201,7 +201,7 @@ tasks: - name: cleanup_success action: core.noop input: - message: "Pack installation completed successfully. Cleaning up temporary directory: {{ vars.temp_dir }}" + message: "Pack installation completed successfully. Cleaning up temporary directory: {{ workflow.temp_dir }}" publish: - cleanup_status: "success" @@ -209,7 +209,7 @@ tasks: - name: cleanup_on_failure action: core.noop input: - message: "Pack installation failed. Cleaning up temporary directory: {{ vars.temp_dir }}" + message: "Pack installation failed. Cleaning up temporary directory: {{ workflow.temp_dir }}" publish: - cleanup_status: "failed" diff --git a/web/src/api/models/ApiResponse_ExecutionResponse.ts b/web/src/api/models/ApiResponse_ExecutionResponse.ts index 77e9f57..8c5320c 100644 --- a/web/src/api/models/ApiResponse_ExecutionResponse.ts +++ b/web/src/api/models/ApiResponse_ExecutionResponse.ts @@ -47,6 +47,11 @@ export type ApiResponse_ExecutionResponse = { * Execution result/output */ result: Record; + /** + * When the execution actually started running (worker picked it up). + * Null if the execution hasn't started running yet. + */ + started_at?: string | null; /** * Execution status */ diff --git a/web/src/api/models/ExecutionResponse.ts b/web/src/api/models/ExecutionResponse.ts index 763dd36..f5e4e10 100644 --- a/web/src/api/models/ExecutionResponse.ts +++ b/web/src/api/models/ExecutionResponse.ts @@ -43,6 +43,11 @@ export type ExecutionResponse = { * Execution result/output */ result: Record; + /** + * When the execution actually started running (worker picked it up). + * Null if the execution hasn't started running yet. + */ + started_at?: string | null; /** * Execution status */ diff --git a/web/src/api/models/ExecutionSummary.ts b/web/src/api/models/ExecutionSummary.ts index b46dbd6..b373c7c 100644 --- a/web/src/api/models/ExecutionSummary.ts +++ b/web/src/api/models/ExecutionSummary.ts @@ -35,6 +35,11 @@ export type ExecutionSummary = { * Execution status */ status: ExecutionStatus; + /** + * When the execution actually started running (worker picked it up). + * Null if the execution hasn't started running yet. + */ + started_at?: string | null; /** * Trigger reference (if triggered by a trigger) */ diff --git a/web/src/api/models/PaginatedResponse_ExecutionSummary.ts b/web/src/api/models/PaginatedResponse_ExecutionSummary.ts index f0a12de..1a7c442 100644 --- a/web/src/api/models/PaginatedResponse_ExecutionSummary.ts +++ b/web/src/api/models/PaginatedResponse_ExecutionSummary.ts @@ -40,6 +40,11 @@ export type PaginatedResponse_ExecutionSummary = { * Execution status */ status: ExecutionStatus; + /** + * When the execution actually started running (worker picked it up). + * Null if the execution hasn't started running yet. + */ + started_at?: string | null; /** * Trigger reference (if triggered by a trigger) */ diff --git a/web/src/api/services/ExecutionsService.ts b/web/src/api/services/ExecutionsService.ts index 78bf7f2..5f3cd50 100644 --- a/web/src/api/services/ExecutionsService.ts +++ b/web/src/api/services/ExecutionsService.ts @@ -239,6 +239,11 @@ export class ExecutionsService { * Execution result/output */ result: Record; + /** + * When the execution actually started running (worker picked it up). + * Null if the execution hasn't started running yet. + */ + started_at?: string | null; /** * Execution status */ diff --git a/web/src/components/common/WorkflowTasksPanel.tsx b/web/src/components/common/WorkflowTasksPanel.tsx index 0d16941..da9d1bd 100644 --- a/web/src/components/common/WorkflowTasksPanel.tsx +++ b/web/src/components/common/WorkflowTasksPanel.tsx @@ -15,6 +15,7 @@ import { RotateCcw, } from "lucide-react"; import { useChildExecutions } from "@/hooks/useExecutions"; +import { useExecutionStream } from "@/hooks/useExecutionStream"; interface WorkflowTasksPanelProps { /** The parent (workflow) execution ID */ @@ -95,6 +96,11 @@ export default function WorkflowTasksPanel({ const [isCollapsed, setIsCollapsed] = useState(defaultCollapsed); const { data, isLoading, error } = useChildExecutions(parentExecutionId); + // Subscribe to the unfiltered execution stream so that child execution + // WebSocket notifications update the ["executions", { parent }] query cache + // in real-time (the detail page only subscribes filtered by its own ID). + useExecutionStream({ enabled: true }); + const tasks = useMemo(() => { if (!data?.data) return []; return data.data; @@ -211,15 +217,20 @@ export default function WorkflowTasksPanel({ const maxRetries = wt?.max_retries ?? 0; const timedOut = wt?.timed_out ?? false; - // Compute duration from created → updated (best available) + // Compute duration from started_at → updated (actual run time) + const startedAt = task.started_at + ? new Date(task.started_at) + : null; const created = new Date(task.created); const updated = new Date(task.updated); + const isTerminal = + task.status === "completed" || + task.status === "failed" || + task.status === "timeout"; const durationMs = wt?.duration_ms ?? - (task.status === "completed" || - task.status === "failed" || - task.status === "timeout" - ? updated.getTime() - created.getTime() + (isTerminal && startedAt + ? updated.getTime() - startedAt.getTime() : null); return ( @@ -277,7 +288,10 @@ export default function WorkflowTasksPanel({
{task.status === "running" ? ( - {formatDistanceToNow(created, { addSuffix: false })}… + {formatDistanceToNow(startedAt ?? created, { + addSuffix: false, + })} + … ) : durationMs != null && durationMs > 0 ? ( formatDuration(durationMs) diff --git a/web/src/components/executions/ExecutionPreviewPanel.tsx b/web/src/components/executions/ExecutionPreviewPanel.tsx index 25532ff..0fa5316 100644 --- a/web/src/components/executions/ExecutionPreviewPanel.tsx +++ b/web/src/components/executions/ExecutionPreviewPanel.tsx @@ -70,11 +70,14 @@ const ExecutionPreviewPanel = memo(function ExecutionPreviewPanel({ execution?.status === "scheduled" || execution?.status === "requested"; + const startedAt = execution?.started_at + ? new Date(execution.started_at) + : null; const created = execution ? new Date(execution.created) : null; const updated = execution ? new Date(execution.updated) : null; const durationMs = - created && updated && !isRunning - ? updated.getTime() - created.getTime() + startedAt && updated && !isRunning + ? updated.getTime() - startedAt.getTime() : null; return ( @@ -175,9 +178,9 @@ const ExecutionPreviewPanel = memo(function ExecutionPreviewPanel({
Elapsed
-
+
- {formatDistanceToNow(created!)} + {formatDistanceToNow(startedAt ?? created!)}
)} @@ -240,41 +243,39 @@ const ExecutionPreviewPanel = memo(function ExecutionPreviewPanel({ {/* Config / Parameters */} - {execution.config && - Object.keys(execution.config).length > 0 && ( -
-
- Parameters -
-
-
-                      {JSON.stringify(execution.config, null, 2)}
-                    
-
-
- )} + {execution.config && Object.keys(execution.config).length > 0 && ( +
+
+ Parameters +
+
+
+                    {JSON.stringify(execution.config, null, 2)}
+                  
+
+
+ )} {/* Result */} - {execution.result && - Object.keys(execution.result).length > 0 && ( -
-
- Result -
-
-
-                      {JSON.stringify(execution.result, null, 2)}
-                    
-
-
- )} + {execution.result && Object.keys(execution.result).length > 0 && ( +
+
+ Result +
+
+
+                    {JSON.stringify(execution.result, null, 2)}
+                  
+
+
+ )} )} diff --git a/web/src/components/executions/WorkflowExecutionTree.tsx b/web/src/components/executions/WorkflowExecutionTree.tsx index 1adda5f..39f1214 100644 --- a/web/src/components/executions/WorkflowExecutionTree.tsx +++ b/web/src/components/executions/WorkflowExecutionTree.tsx @@ -128,6 +128,7 @@ const ChildExecutionRow = memo(function ChildExecutionRow({ return ( <>
{/* Main execution row */} Promise; + /** Called when the modal is closed (cancel or after successful execution) */ + onClose: () => void; + /** Optional label for display */ + label?: string; +} + +/** + * Modal for running a workflow with optional parameter overrides. + * + * Shown from the workflow builder's "Run" button when the workflow has + * parameters defined. Displays a ParamSchemaForm pre-populated with + * default values, saves the workflow first, then creates an execution + * and opens the execution detail page in a new tab. + */ +export default function RunWorkflowModal({ + actionRef, + paramSchema, + onSave, + onClose, + label, +}: RunWorkflowModalProps) { + const requestExecution = useRequestExecution(); + + const paramProperties = extractProperties(paramSchema); + + // Build initial values from schema defaults + const buildInitialValues = (): Record => { + const values: Record = {}; + for (const [key, prop] of Object.entries(paramProperties)) { + if (prop?.default !== undefined) { + values[key] = prop.default; + } + } + return values; + }; + + const [parameters, setParameters] = + useState>(buildInitialValues); + const [paramErrors, setParamErrors] = useState>({}); + const [error, setError] = useState(null); + const [phase, setPhase] = useState<"idle" | "saving" | "executing">("idle"); + + const isSubmitting = phase !== "idle"; + + const handleExecute = useCallback(async () => { + // Validate parameters against schema + const errors = validateParamSchema(paramSchema, parameters); + setParamErrors(errors); + if (Object.keys(errors).length > 0) return; + + setError(null); + + // Phase 1: Save the workflow + setPhase("saving"); + try { + const saved = await onSave(); + if (!saved) { + setPhase("idle"); + return; // save failed — error shown by parent + } + } catch { + setError("Failed to save workflow"); + setPhase("idle"); + return; + } + + // Phase 2: Execute + setPhase("executing"); + try { + // Strip out empty-string values so the backend applies schema defaults + // for parameters the user left blank. + const cleanedParams: Record = {}; + for (const [key, value] of Object.entries(parameters)) { + if (value !== "" && value !== undefined) { + cleanedParams[key] = value; + } + } + + const response = await requestExecution.mutateAsync({ + actionRef, + parameters: cleanedParams, + }); + const executionId = response.data.id; + + // Open execution in new tab and close the modal + window.open(`/executions/${executionId}`, "_blank"); + onClose(); + } catch (err: unknown) { + const e = err as { body?: { message?: string }; message?: string }; + const message = + e?.body?.message || e?.message || "Failed to start execution"; + setError(message); + setPhase("idle"); + } + }, [paramSchema, parameters, onSave, actionRef, requestExecution, onClose]); + + return ( +
+
+ {/* Header */} +
+
+

+ Run Workflow +

+

+ {label || actionRef} +

+
+ +
+ + {/* Body */} +
+ {error && ( +
+ {error} +
+ )} + +
+

+ Parameters +

+

+ Override default values or leave as-is to use the schema defaults. +

+ +
+
+ + {/* Footer */} +
+ + +
+
+
+ ); +} diff --git a/web/src/hooks/useExecutions.ts b/web/src/hooks/useExecutions.ts index f417729..96e20a8 100644 --- a/web/src/hooks/useExecutions.ts +++ b/web/src/hooks/useExecutions.ts @@ -1,6 +1,13 @@ -import { useQuery, keepPreviousData } from "@tanstack/react-query"; +import { + useQuery, + useMutation, + useQueryClient, + keepPreviousData, +} from "@tanstack/react-query"; import { ExecutionsService } from "@/api"; import type { ExecutionStatus } from "@/api"; +import { OpenAPI } from "@/api/core/OpenAPI"; +import { request as __request } from "@/api/core/request"; interface ExecutionsQueryParams { page?: number; @@ -69,6 +76,42 @@ export function useExecution(id: number) { * Enabled only when `parentId` is provided. Polls every 5 seconds while any * child execution is still in a running/pending state so the UI stays current. */ +/** + * Request a manual execution of an action (or workflow). + * + * Calls POST /api/v1/executions/execute and returns the created execution, + * including its `id` which callers can use to navigate to the detail page. + */ +export function useRequestExecution() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async ({ + actionRef, + parameters, + }: { + actionRef: string; + parameters?: Record; + }) => { + const response = await __request(OpenAPI, { + method: "POST", + url: "/api/v1/executions/execute", + body: { + action_ref: actionRef, + parameters: parameters ?? null, + }, + mediaType: "application/json", + }); + return response as { + data: { id: number; status: string; action_ref: string }; + }; + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ["executions"] }); + }, + }); +} + export function useChildExecutions(parentId: number | undefined) { return useQuery({ queryKey: ["executions", { parent: parentId }], diff --git a/web/src/pages/actions/WorkflowBuilderPage.tsx b/web/src/pages/actions/WorkflowBuilderPage.tsx index c9312b7..b5f2fa9 100644 --- a/web/src/pages/actions/WorkflowBuilderPage.tsx +++ b/web/src/pages/actions/WorkflowBuilderPage.tsx @@ -4,6 +4,7 @@ import { useQueries } from "@tanstack/react-query"; import { ArrowLeft, Save, + Play, AlertTriangle, FileCode, Code, @@ -11,6 +12,9 @@ import { X, Zap, Settings2, + ExternalLink, + Copy, + Check, } from "lucide-react"; import SearchableSelect from "@/components/common/SearchableSelect"; import yaml from "js-yaml"; @@ -23,6 +27,9 @@ import TaskInspector from "@/components/workflows/TaskInspector"; import { useActions } from "@/hooks/useActions"; import { ActionsService } from "@/api"; import { usePacks } from "@/hooks/usePacks"; +import { useRequestExecution } from "@/hooks/useExecutions"; +import RunWorkflowModal from "@/components/workflows/RunWorkflowModal"; +import type { ParamSchema } from "@/components/common/ParamSchemaForm"; import { useWorkflow } from "@/hooks/useWorkflows"; import { useSaveWorkflowFile, @@ -77,6 +84,7 @@ export default function WorkflowBuilderPage() { // Mutations const saveWorkflowFile = useSaveWorkflowFile(); const updateWorkflowFile = useUpdateWorkflowFile(); + const requestExecution = useRequestExecution(); // Builder state const [state, setState] = useState(INITIAL_STATE); @@ -85,6 +93,9 @@ export default function WorkflowBuilderPage() { const [showErrors, setShowErrors] = useState(false); const [saveError, setSaveError] = useState(null); const [saveSuccess, setSaveSuccess] = useState(false); + const [runError, setRunError] = useState(null); + const [showRunModal, setShowRunModal] = useState(false); + const [yamlCopied, setYamlCopied] = useState(false); const [initialized, setInitialized] = useState(false); const [showYamlPreview, setShowYamlPreview] = useState(false); const [sidebarTab, setSidebarTab] = useState<"actions" | "inputs">("actions"); @@ -428,7 +439,7 @@ export default function WorkflowBuilderPage() { if (errors.length > 0) { setShowErrors(true); - return; + return false; } const definition = builderStateToDefinition(state, actionSchemaMap); @@ -495,16 +506,19 @@ export default function WorkflowBuilderPage() { if (!isEditing) { const newRef = `${state.packRef}.${state.name}`; navigate(`/actions/workflows/${newRef}/edit`, { replace: true }); - return; + return true; } setSaveSuccess(true); setTimeout(() => setSaveSuccess(false), 3000); + + return true; // indicate success } catch (err: unknown) { const error = err as { body?: { message?: string }; message?: string }; const message = error?.body?.message || error?.message || "Failed to save workflow"; setSaveError(message); + return false; // indicate failure } }, [ state, @@ -516,6 +530,49 @@ export default function WorkflowBuilderPage() { navigate, ]); + // Check whether the workflow has any parameters defined + const hasParameters = useMemo( + () => Object.keys(state.parameters).length > 0, + [state.parameters], + ); + + const handleRun = useCallback(async () => { + setRunError(null); + + if (hasParameters) { + // Open the modal so the user can review / override parameter values + setShowRunModal(true); + return; + } + + // No parameters — save and execute immediately + const saved = await doSave(); + if (!saved) return; // save failed — error already shown + + const actionRef = editRef || `${state.packRef}.${state.name}`; + + try { + const response = await requestExecution.mutateAsync({ + actionRef, + parameters: {}, + }); + const executionId = response.data.id; + window.open(`/executions/${executionId}`, "_blank"); + } catch (err: unknown) { + const error = err as { body?: { message?: string }; message?: string }; + const message = + error?.body?.message || error?.message || "Failed to start execution"; + setRunError(message); + } + }, [ + hasParameters, + doSave, + editRef, + state.packRef, + state.name, + requestExecution, + ]); + const handleSave = useCallback(() => { // If there's a start-node problem, show the toast immediately and // require confirmation before saving @@ -547,6 +604,7 @@ export default function WorkflowBuilderPage() { }, [state, showYamlPreview, actionSchemaMap]); const isSaving = saveWorkflowFile.isPending || updateWorkflowFile.isPending; + const isExecuting = requestExecution.isPending; if (isEditing && workflowLoading) { return ( @@ -684,6 +742,16 @@ export default function WorkflowBuilderPage() { )} + {/* Run error indicator */} + {runError && ( + + ✗ {runError} + + )} + {/* Save button */} + + {/* Run button */} + @@ -771,6 +864,30 @@ export default function WorkflowBuilderPage() { (read-only preview of the generated YAML) +
+ +
               {yamlPreview}
@@ -965,6 +1082,17 @@ export default function WorkflowBuilderPage() {
         
       )}
 
+      {/* Run workflow modal (shown when workflow has parameters) */}
+      {showRunModal && (
+         setShowRunModal(false)}
+        />
+      )}
+
       {/* Inline style for fade-in animation */}