From e89b5991ec5dc58e3485e646fee893167e4af2e5 Mon Sep 17 00:00:00 2001 From: David Culbreth Date: Wed, 25 Feb 2026 14:16:56 -0600 Subject: [PATCH] concurrent action execution --- crates/executor/src/scheduler.rs | 53 +++++-- crates/worker/src/service.rs | 108 +++++++++++-- .../components/workflows/TaskInspector.tsx | 3 +- .../components/workflows/WorkflowEdges.tsx | 12 +- web/src/hooks/useExecutionStream.ts | 147 ++++++++++++++---- web/src/types/workflow.ts | 8 + 6 files changed, 257 insertions(+), 74 deletions(-) diff --git a/crates/executor/src/scheduler.rs b/crates/executor/src/scheduler.rs index 18eae2b..23a1ae1 100644 --- a/crates/executor/src/scheduler.rs +++ b/crates/executor/src/scheduler.rs @@ -22,6 +22,7 @@ use chrono::Utc; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use sqlx::PgPool; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; use tracing::{debug, error, info, warn}; @@ -40,6 +41,8 @@ pub struct ExecutionScheduler { pool: PgPool, publisher: Arc, consumer: Arc, + /// Round-robin counter for distributing executions across workers + round_robin_counter: AtomicUsize, } /// Default heartbeat interval in seconds (should match worker config default) @@ -56,6 +59,7 @@ impl ExecutionScheduler { pool, publisher, consumer, + round_robin_counter: AtomicUsize::new(0), } } @@ -65,6 +69,12 @@ impl ExecutionScheduler { let pool = self.pool.clone(); let publisher = self.publisher.clone(); + // Share the counter with the handler closure via Arc. + // We wrap &self's AtomicUsize in a new Arc by copying the + // current value so the closure is 'static. + let counter = Arc::new(AtomicUsize::new( + self.round_robin_counter.load(Ordering::Relaxed), + )); // Use the handler pattern to consume messages self.consumer @@ -72,10 +82,13 @@ impl ExecutionScheduler { move |envelope: MessageEnvelope| { let pool = pool.clone(); let publisher = publisher.clone(); + let counter = counter.clone(); async move { - if let Err(e) = - Self::process_execution_requested(&pool, &publisher, &envelope).await + if let Err(e) = Self::process_execution_requested( + &pool, &publisher, &counter, &envelope, + ) + .await { error!("Error scheduling execution: {}", e); // Return error to trigger nack with requeue @@ -94,6 +107,7 @@ impl ExecutionScheduler { async fn process_execution_requested( pool: &PgPool, publisher: &Publisher, + round_robin_counter: &AtomicUsize, envelope: &MessageEnvelope, ) -> Result<()> { debug!("Processing execution requested message: {:?}", envelope); @@ -110,8 +124,8 @@ impl ExecutionScheduler { // Fetch action to determine runtime requirements let action = Self::get_action_for_execution(pool, &execution).await?; - // Select appropriate worker - let worker = Self::select_worker(pool, &action).await?; + // Select appropriate worker (round-robin among compatible workers) + let worker = Self::select_worker(pool, &action, round_robin_counter).await?; info!( "Selected worker {} for execution {}", @@ -158,9 +172,13 @@ impl ExecutionScheduler { } /// Select an appropriate worker for the execution + /// + /// Uses round-robin selection among compatible, active, and healthy workers + /// to distribute load evenly across the worker pool. async fn select_worker( pool: &PgPool, action: &Action, + round_robin_counter: &AtomicUsize, ) -> Result { // Get runtime requirements for the action let runtime = if let Some(runtime_id) = action.runtime { @@ -219,17 +237,21 @@ impl ExecutionScheduler { )); } - // TODO: Implement intelligent worker selection: - // - Consider worker load/capacity - // - Consider worker affinity (same pack, same runtime) - // - Consider geographic locality - // - Round-robin or least-connections strategy - - // For now, just select the first available worker - Ok(fresh_workers + // Round-robin selection: distribute executions evenly across workers. + // Each call increments the counter and picks the next worker in the list. + let count = round_robin_counter.fetch_add(1, Ordering::Relaxed); + let index = count % fresh_workers.len(); + let selected = fresh_workers .into_iter() - .next() - .expect("Worker list should not be empty")) + .nth(index) + .expect("Worker list should not be empty"); + + info!( + "Selected worker {} (id={}) via round-robin (index {} of available workers)", + selected.name, selected.id, index + ); + + Ok(selected) } /// Check if a worker supports a given runtime @@ -291,7 +313,8 @@ impl ExecutionScheduler { let now = Utc::now(); let age = now.signed_duration_since(last_heartbeat); - let max_age = Duration::from_secs(DEFAULT_HEARTBEAT_INTERVAL * HEARTBEAT_STALENESS_MULTIPLIER); + let max_age = + Duration::from_secs(DEFAULT_HEARTBEAT_INTERVAL * HEARTBEAT_STALENESS_MULTIPLIER); let is_fresh = age.to_std().unwrap_or(Duration::MAX) <= max_age; diff --git a/crates/worker/src/service.rs b/crates/worker/src/service.rs index 131e4e7..de66cc0 100644 --- a/crates/worker/src/service.rs +++ b/crates/worker/src/service.rs @@ -19,8 +19,8 @@ use sqlx::PgPool; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use tokio::sync::RwLock; -use tokio::task::JoinHandle; +use tokio::sync::{Mutex, RwLock, Semaphore}; +use tokio::task::{JoinHandle, JoinSet}; use tracing::{debug, error, info, warn}; use crate::artifacts::ArtifactManager; @@ -67,6 +67,10 @@ pub struct WorkerService { packs_base_dir: PathBuf, /// Base directory for isolated runtime environments runtime_envs_dir: PathBuf, + /// Semaphore to limit concurrent executions + execution_semaphore: Arc, + /// Tracks in-flight execution tasks for graceful shutdown + in_flight_tasks: Arc>>, } impl WorkerService { @@ -292,6 +296,17 @@ impl WorkerService { // Capture the runtime filter for use in env setup let runtime_filter_for_service = runtime_filter.clone(); + // Read max concurrent tasks from config + let max_concurrent_tasks = config + .worker + .as_ref() + .map(|w| w.max_concurrent_tasks) + .unwrap_or(10); + info!( + "Worker configured for max {} concurrent executions", + max_concurrent_tasks + ); + Ok(Self { config, db_pool: pool, @@ -308,6 +323,8 @@ impl WorkerService { runtime_filter: runtime_filter_for_service, packs_base_dir, runtime_envs_dir, + execution_semaphore: Arc::new(Semaphore::new(max_concurrent_tasks)), + in_flight_tasks: Arc::new(Mutex::new(JoinSet::new())), }) } @@ -563,17 +580,24 @@ impl WorkerService { /// Wait for in-flight tasks to complete async fn wait_for_in_flight_tasks(&self) { - // Poll for active executions with short intervals loop { - // Check if executor has any active tasks - // Note: This is a simplified check. In a real implementation, - // we would track active execution count in the executor. - tokio::time::sleep(Duration::from_millis(500)).await; + let remaining = { + let mut tasks = self.in_flight_tasks.lock().await; + // Drain any already-completed tasks + while tasks.try_join_next().is_some() {} + tasks.len() + }; - // TODO: Add proper tracking of active executions in ActionExecutor - // For now, we just wait a reasonable amount of time - // This will be improved when we add execution tracking - break; + if remaining == 0 { + info!("All in-flight execution tasks have completed"); + break; + } + + info!( + "Waiting for {} in-flight execution task(s) to complete...", + remaining + ); + tokio::time::sleep(Duration::from_secs(1)).await; } } @@ -581,6 +605,10 @@ impl WorkerService { /// /// Spawns the consumer loop as a background task so that `start()` returns /// immediately, allowing the caller to set up signal handlers. + /// + /// Executions are spawned as concurrent background tasks, limited by the + /// `execution_semaphore`. The consumer blocks when the concurrency limit is + /// reached, providing natural backpressure via RabbitMQ. async fn start_execution_consumer(&mut self) -> Result<()> { let worker_id = self .worker_id @@ -589,7 +617,19 @@ impl WorkerService { // Queue name for this worker (already created in setup_worker_infrastructure) let queue_name = format!("worker.{}.executions", worker_id); - info!("Starting consumer for worker queue: {}", queue_name); + // Set prefetch slightly above max concurrent tasks to keep the pipeline filled + let max_concurrent = self + .config + .worker + .as_ref() + .map(|w| w.max_concurrent_tasks) + .unwrap_or(10); + let prefetch_count = (max_concurrent as u16).saturating_add(2); + + info!( + "Starting consumer for worker queue: {} (prefetch: {}, max_concurrent: {})", + queue_name, prefetch_count, max_concurrent + ); // Create consumer let consumer = Arc::new( @@ -598,7 +638,7 @@ impl WorkerService { ConsumerConfig { queue: queue_name.clone(), tag: format!("worker-{}", worker_id), - prefetch_count: 10, + prefetch_count, auto_ack: false, exclusive: false, }, @@ -615,6 +655,8 @@ impl WorkerService { let db_pool = self.db_pool.clone(); let consumer_for_task = consumer.clone(); let queue_name_for_log = queue_name.clone(); + let semaphore = self.execution_semaphore.clone(); + let in_flight = self.in_flight_tasks.clone(); // Spawn the consumer loop as a background task so start() can return let handle = tokio::spawn(async move { @@ -625,11 +667,47 @@ impl WorkerService { let executor = executor.clone(); let publisher = publisher.clone(); let db_pool = db_pool.clone(); + let semaphore = semaphore.clone(); + let in_flight = in_flight.clone(); async move { - Self::handle_execution_scheduled(executor, publisher, db_pool, envelope) + let execution_id = envelope.payload.execution_id; + + // Acquire a concurrency permit. This blocks if we're at the + // max concurrent execution limit, providing natural backpressure: + // the message won't be acked until we can actually start working, + // so RabbitMQ will stop delivering once prefetch is exhausted. + let permit = semaphore.clone().acquire_owned().await.map_err(|_| { + attune_common::mq::error::MqError::Channel( + "Execution semaphore closed".to_string(), + ) + })?; + + info!( + "Acquired execution permit for execution {} ({} permits remaining)", + execution_id, + semaphore.available_permits() + ); + + // Spawn the actual execution as a background task so this + // handler returns immediately, acking the message and freeing + // the consumer loop to process the next delivery. + let mut tasks = in_flight.lock().await; + tasks.spawn(async move { + // The permit is moved into this task and will be released + // when the task completes (on drop). + let _permit = permit; + + if let Err(e) = Self::handle_execution_scheduled( + executor, publisher, db_pool, envelope, + ) .await - .map_err(|e| format!("Execution handler error: {}", e).into()) + { + error!("Execution {} handler error: {}", execution_id, e); + } + }); + + Ok(()) } }, ) diff --git a/web/src/components/workflows/TaskInspector.tsx b/web/src/components/workflows/TaskInspector.tsx index a00b8d4..1fe03c1 100644 --- a/web/src/components/workflows/TaskInspector.tsx +++ b/web/src/components/workflows/TaskInspector.tsx @@ -22,6 +22,7 @@ import { PRESET_WHEN, PRESET_LABELS, PRESET_COLORS, + EDGE_TYPE_COLORS, classifyTransitionWhen, transitionLabel, } from "@/types/workflow"; @@ -581,7 +582,7 @@ export default function TaskInspector({ y2="1" stroke={ transition.color || - PRESET_COLORS[ + EDGE_TYPE_COLORS[ classifyTransitionWhen(transition.when) ] || "#6b7280" diff --git a/web/src/components/workflows/WorkflowEdges.tsx b/web/src/components/workflows/WorkflowEdges.tsx index 5e07b21..850ebfc 100644 --- a/web/src/components/workflows/WorkflowEdges.tsx +++ b/web/src/components/workflows/WorkflowEdges.tsx @@ -2,10 +2,9 @@ import { memo, useMemo, useState, useCallback, useRef, useEffect } from "react"; import type { WorkflowEdge, WorkflowTask, - EdgeType, NodePosition, } from "@/types/workflow"; -import { PRESET_COLORS } from "@/types/workflow"; +import { PRESET_COLORS, EDGE_TYPE_COLORS } from "@/types/workflow"; import type { TransitionPreset } from "./TaskNode"; import type { ScreenToCanvas } from "./WorkflowCanvas"; @@ -58,13 +57,8 @@ interface WorkflowEdgesProps { const NODE_WIDTH = 240; const NODE_HEIGHT = 96; -/** Color for each edge type */ -const EDGE_COLORS: Record = { - success: "#22c55e", // green-500 - failure: "#ef4444", // red-500 - complete: "#6b7280", // gray-500 (unconditional / always) - custom: "#8b5cf6", // violet-500 -}; +/** Color for each edge type (alias for shared constant) */ +const EDGE_COLORS = EDGE_TYPE_COLORS; /** SVG stroke-dasharray values for each user-facing line style */ import type { LineStyle } from "@/types/workflow"; diff --git a/web/src/hooks/useExecutionStream.ts b/web/src/hooks/useExecutionStream.ts index e907222..0b508bf 100644 --- a/web/src/hooks/useExecutionStream.ts +++ b/web/src/hooks/useExecutionStream.ts @@ -17,8 +17,33 @@ interface UseExecutionStreamOptions { } /** - * Check if an execution matches the given query parameters - * Only checks fields that are reliably present in WebSocket payloads + * Notification metadata fields that come from the PostgreSQL trigger payload + * but are NOT part of the ExecutionSummary API model. These are stripped + * before storing execution data in the React Query cache. + */ +const NOTIFICATION_META_FIELDS = [ + "entity_type", + "entity_id", + "old_status", + "action_id", +] as const; + +/** + * Strip notification-only metadata fields from the payload so cached data + * matches the shape returned by the API (ExecutionSummary / ExecutionResponse). + */ +function stripNotificationMeta(payload: any): any { + if (!payload || typeof payload !== "object") return payload; + const cleaned = { ...payload }; + for (const key of NOTIFICATION_META_FIELDS) { + delete cleaned[key]; + } + return cleaned; +} + +/** + * Check if an execution matches the given query parameters. + * Only checks fields that are reliably present in WebSocket payloads. */ function executionMatchesParams(execution: any, params: any): boolean { if (!params) return true; @@ -55,7 +80,7 @@ function executionMatchesParams(execution: any, params: any): boolean { } /** - * Check if query params include filters not present in WebSocket payloads + * Check if query params include filters not present in WebSocket payloads. */ function hasUnsupportedFilters(params: any): boolean { if (!params) return false; @@ -88,8 +113,11 @@ export function useExecutionStream(options: UseExecutionStreamOptions = {}) { return; } - // Extract execution data from notification payload (flat structure) - const executionData = notification.payload as any; + // Extract execution data from notification payload (flat structure). + // Keep raw payload for old_status inspection, but use cleaned data for cache. + const rawPayload = notification.payload as any; + const oldStatus: string | undefined = rawPayload?.old_status; + const executionData = stripNotificationMeta(rawPayload); // Update specific execution query if it exists queryClient.setQueryData( @@ -106,8 +134,8 @@ export function useExecutionStream(options: UseExecutionStreamOptions = {}) { }, ); - // Update execution list queries by modifying existing data - // We need to iterate manually to access query keys for filtering + // Update execution list queries by modifying existing data. + // We need to iterate manually to access query keys for filtering. const queries = queryClient .getQueriesData({ queryKey: ["executions"], exact: false }) .filter(([, data]) => data && Array.isArray((data as any)?.data)); @@ -123,45 +151,96 @@ export function useExecutionStream(options: UseExecutionStreamOptions = {}) { (exec: any) => exec.id === notification.entity_id, ); + // Merge the updated fields to determine if the execution matches the query + const mergedExecution = + existingIndex >= 0 + ? { ...old.data[existingIndex], ...executionData } + : executionData; + const matchesQuery = executionMatchesParams( + mergedExecution, + queryParams, + ); + let updatedData; + let totalItemsDelta = 0; + if (existingIndex >= 0) { - // Always update existing execution in the list - updatedData = [...old.data]; - updatedData[existingIndex] = { - ...updatedData[existingIndex], - ...executionData, - }; - - // Note: We don't remove executions from cache based on filters. - // The cache represents what the API query returned. - // Client-side filtering (in the page component) handles what's displayed. - } else { - // For new executions, be conservative with filters we can't verify - // If filters include rule_ref/trigger_ref, don't add new executions - // (these fields may not be in WebSocket payload) - if (hasUnsupportedFilters(queryParams)) { - // Don't add new execution when using filters we can't verify - return; - } - - // Only add new execution if it matches the query parameters - // (not the display filters - those are handled client-side) - if (executionMatchesParams(executionData, queryParams)) { - // Add to beginning and cap at 50 items to prevent performance issues - updatedData = [executionData, ...old.data].slice(0, 50); + // ── Execution IS in the local data array ── + if (matchesQuery) { + // Still matches — update in place, no total_items change + updatedData = [...old.data]; + updatedData[existingIndex] = mergedExecution; } else { - // Don't modify the list if the new execution doesn't match the query - return; + // No longer matches the query filter — remove it + updatedData = old.data.filter( + (_: any, i: number) => i !== existingIndex, + ); + totalItemsDelta = -1; + } + } else { + // ── Execution is NOT in the local data array ── + // This happens when the execution is beyond the fetched page boundary + // (e.g., running count query with pageSize=1) or was pushed out by + // the 50-item cap after many new executions were prepended. + + if (oldStatus) { + // This is a status-change notification (has old_status from the + // PostgreSQL trigger). Use old_status to detect whether the + // execution crossed a query filter boundary — even though it's + // not in our local data array, total_items must stay accurate. + const virtualOldExecution = { + ...mergedExecution, + status: oldStatus, + }; + const oldMatchedQuery = executionMatchesParams( + virtualOldExecution, + queryParams, + ); + + if (oldMatchedQuery && !matchesQuery) { + // Execution LEFT this query's result set (e.g., was running, + // now completed). Decrement total_items but don't touch the + // data array — the item was never in it. + updatedData = old.data; + totalItemsDelta = -1; + } else if (!oldMatchedQuery && matchesQuery) { + // Execution ENTERED this query's result set. + if (hasUnsupportedFilters(queryParams)) { + return; + } + updatedData = [executionData, ...old.data].slice(0, 50); + totalItemsDelta = 1; + } else { + // No boundary crossing: either both match (execution was + // already counted in total_items — don't double-count) or + // neither matches (irrelevant to this query). + return; + } + } else { + // No old_status: this is likely an execution_created notification + // (INSERT trigger). Use the standard add-if-matches logic. + if (hasUnsupportedFilters(queryParams)) { + return; + } + + if (matchesQuery) { + // Add to beginning and cap at 50 items to prevent unbounded growth + updatedData = [executionData, ...old.data].slice(0, 50); + totalItemsDelta = 1; + } else { + return; + } } } // Update the query with the new data + const newTotal = (old.pagination?.total_items || 0) + totalItemsDelta; queryClient.setQueryData(queryKey, { ...old, data: updatedData, pagination: { ...old.pagination, - total_items: (old.pagination?.total_items || 0) + 1, + total_items: Math.max(0, newTotal), }, }); }); diff --git a/web/src/types/workflow.ts b/web/src/types/workflow.ts index b37b370..5ca99a5 100644 --- a/web/src/types/workflow.ts +++ b/web/src/types/workflow.ts @@ -133,6 +133,14 @@ export const PRESET_STYLES: Record = { */ export type EdgeType = "success" | "failure" | "complete" | "custom"; +/** Default colors for each EdgeType (mirrors PRESET_COLORS but keyed by EdgeType). */ +export const EDGE_TYPE_COLORS: Record = { + success: "#22c55e", // green-500 + failure: "#ef4444", // red-500 + complete: "#6b7280", // gray-500 (unconditional / always) + custom: "#8b5cf6", // violet-500 +}; + export function classifyTransitionWhen(when?: string): EdgeType { if (!when) return "complete"; // unconditional const lower = when.toLowerCase().replace(/\s+/g, "");