//! Worker Service Module //! //! Main service orchestration for the Attune Worker Service. //! Manages worker registration, heartbeat, message consumption, and action execution. //! //! ## Startup Sequence //! //! 1. Connect to database and message queue //! 2. Load runtimes from database → create `ProcessRuntime` instances //! 3. Register worker and set up MQ infrastructure //! 4. **Verify runtime versions** — run verification commands for each registered //! `RuntimeVersion` to determine which are available on this host/container //! 5. **Set up runtime environments** — create per-version environments for packs //! 6. Start heartbeat, execution consumer, pack registration consumer, and cancel consumer use attune_common::config::Config; use attune_common::db::Database; use attune_common::error::{Error, Result}; use attune_common::models::ExecutionStatus; use attune_common::mq::{ config::MessageQueueConfig as MqConfig, Connection, Consumer, ConsumerConfig, ExecutionCancelRequestedPayload, ExecutionCompletedPayload, ExecutionStatusChangedPayload, MessageEnvelope, MessageType, PackRegisteredPayload, Publisher, PublisherConfig, }; use attune_common::repositories::{execution::ExecutionRepository, FindById}; use attune_common::runtime_detection::runtime_in_filter; use chrono::Utc; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use tokio::sync::{Mutex, RwLock, Semaphore}; use tokio::task::{JoinHandle, JoinSet}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; use crate::artifacts::ArtifactManager; use crate::env_setup; use crate::executor::ActionExecutor; use crate::heartbeat::HeartbeatManager; use crate::registration::WorkerRegistration; use crate::runtime::local::LocalRuntime; use crate::runtime::native::NativeRuntime; use crate::runtime::process::ProcessRuntime; use crate::runtime::shell::ShellRuntime; use crate::runtime::RuntimeRegistry; use crate::secrets::SecretManager; use crate::version_verify; use attune_common::repositories::runtime::RuntimeRepository; use attune_common::repositories::List; /// Message payload for execution.scheduled events #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionScheduledPayload { pub execution_id: i64, pub action_ref: String, pub worker_id: i64, } /// Worker service that manages execution lifecycle pub struct WorkerService { #[allow(dead_code)] config: Config, db_pool: PgPool, registration: Arc>, heartbeat: Arc, executor: Arc, mq_connection: Arc, publisher: Arc, consumer: Option>, consumer_handle: Option>, pack_consumer: Option>, pack_consumer_handle: Option>, cancel_consumer: Option>, cancel_consumer_handle: Option>, worker_id: Option, /// Runtime filter derived from ATTUNE_WORKER_RUNTIMES runtime_filter: Option>, /// Base directory for pack files packs_base_dir: PathBuf, /// Base directory for isolated runtime environments runtime_envs_dir: PathBuf, /// Semaphore to limit concurrent executions execution_semaphore: Arc, /// Tracks in-flight execution tasks for graceful shutdown in_flight_tasks: Arc>>, /// Maps execution ID → CancellationToken for running processes. /// When a cancel request arrives, the token is triggered, causing /// the process executor to send SIGINT → SIGTERM → SIGKILL. cancel_tokens: Arc>>, } impl WorkerService { /// Create a new worker service pub async fn new(config: Config) -> Result { info!("Initializing Worker Service"); // Initialize database let db = Database::new(&config.database).await?; let pool = db.pool().clone(); info!("Database connection established"); // Initialize message queue connection let mq_url = config .message_queue .as_ref() .ok_or_else(|| Error::Internal("Message queue configuration is required".to_string()))? .url .as_str(); let mq_connection = Connection::connect(mq_url) .await .map_err(|e| Error::Internal(format!("Failed to connect to message queue: {}", e)))?; info!("Message queue connection established"); // Setup common message queue infrastructure (exchanges and DLX) let mq_config = MqConfig::default(); match mq_connection.setup_common_infrastructure(&mq_config).await { Ok(_) => info!("Common message queue infrastructure setup completed"), Err(e) => { warn!( "Failed to setup common MQ infrastructure (may already exist): {}", e ); } } // Initialize message queue publisher let publisher = Publisher::new( &mq_connection, PublisherConfig { confirm_publish: true, timeout_secs: 30, exchange: "attune.executions".to_string(), }, ) .await .map_err(|e| Error::Internal(format!("Failed to create publisher: {}", e)))?; info!("Message queue publisher initialized"); // Initialize worker registration let registration = Arc::new(RwLock::new(WorkerRegistration::new(pool.clone(), &config))); // Initialize artifact manager (legacy, for stdout/stderr log storage) let artifact_base_dir = std::path::PathBuf::from( config .worker .as_ref() .and_then(|w| w.name.clone()) .map(|name| format!("/tmp/attune/artifacts/{}", name)) .unwrap_or_else(|| "/tmp/attune/artifacts".to_string()), ); let artifact_manager = ArtifactManager::new(artifact_base_dir); artifact_manager.initialize().await?; // Initialize artifacts directory for file-backed artifact storage (shared volume). // Execution processes write artifact files here; the API serves them from the same path. let artifacts_dir = std::path::PathBuf::from(&config.artifacts_dir); if let Err(e) = tokio::fs::create_dir_all(&artifacts_dir).await { warn!( "Failed to create artifacts directory '{}': {}. File-backed artifacts may not work.", artifacts_dir.display(), e, ); } else { info!( "Artifacts directory initialized at: {}", artifacts_dir.display() ); } let packs_base_dir = std::path::PathBuf::from(&config.packs_base_dir); let runtime_envs_dir = std::path::PathBuf::from(&config.runtime_envs_dir); // Determine which runtimes to register based on configuration // ATTUNE_WORKER_RUNTIMES env var filters which runtimes this worker handles. // If not set, all action runtimes from the database are loaded. let runtime_filter: Option> = std::env::var("ATTUNE_WORKER_RUNTIMES").ok().map(|env_val| { info!( "Filtering runtimes from ATTUNE_WORKER_RUNTIMES: {}", env_val ); env_val .split(',') .map(|s| s.trim().to_lowercase()) .filter(|s| !s.is_empty()) .collect() }); // Initialize runtime registry let mut runtime_registry = RuntimeRegistry::new(); // Load runtimes from the database and create ProcessRuntime instances. // Each runtime row's `execution_config` JSONB drives how the ProcessRuntime // invokes interpreters, manages environments, and installs dependencies. // We skip runtimes with empty execution_config (e.g., core.native) since // they execute binaries directly and don't need a ProcessRuntime wrapper. match RuntimeRepository::list(&pool).await { Ok(db_runtimes) => { let executable_runtimes: Vec<_> = db_runtimes .into_iter() .filter(|r| { let config = r.parsed_execution_config(); // A runtime is executable if it has a non-default interpreter // (the default is "/bin/sh" from InterpreterConfig::default, // but runtimes with no execution_config at all will have an // empty JSON object that deserializes to defaults with no // file_extension — those are not real process runtimes). config.interpreter.file_extension.is_some() || r.execution_config != serde_json::json!({}) }) .collect(); info!( "Found {} executable runtime(s) in database", executable_runtimes.len() ); for rt in executable_runtimes { let rt_name = rt.name.to_lowercase(); // Apply filter if ATTUNE_WORKER_RUNTIMES is set. // Uses alias-aware matching so that e.g. filter "node" // matches DB runtime name "Node.js" (lowercased to "node.js"). if let Some(ref filter) = runtime_filter { if !runtime_in_filter(&rt_name, filter) { debug!( "Skipping runtime '{}' (not in ATTUNE_WORKER_RUNTIMES filter)", rt_name ); continue; } } let exec_config = rt.parsed_execution_config(); let process_runtime = ProcessRuntime::new( rt_name.clone(), exec_config, packs_base_dir.clone(), runtime_envs_dir.clone(), ); runtime_registry.register(Box::new(process_runtime)); info!( "Registered ProcessRuntime '{}' from database (ref: {})", rt_name, rt.r#ref ); } } Err(e) => { warn!( "Failed to load runtimes from database: {}. \ Falling back to built-in defaults.", e ); } } // If no runtimes were loaded from the DB, register built-in defaults if runtime_registry.list_runtimes().is_empty() { info!("No runtimes loaded from database, registering built-in defaults"); // Shell runtime (always available) runtime_registry.register(Box::new(ShellRuntime::new())); info!("Registered built-in Shell runtime"); // Native runtime (for compiled binaries) runtime_registry.register(Box::new(NativeRuntime::new())); info!("Registered built-in Native runtime"); // Local runtime as catch-all fallback let local_runtime = LocalRuntime::new(); runtime_registry.register(Box::new(local_runtime)); info!("Registered Local runtime (fallback)"); } // Validate all registered runtimes runtime_registry .validate_all() .await .map_err(|e| Error::Internal(format!("Failed to validate runtimes: {}", e)))?; info!( "Successfully validated runtimes: {:?}", runtime_registry.list_runtimes() ); // Initialize secret manager let encryption_key = config.security.encryption_key.clone(); let secret_manager = SecretManager::new(pool.clone(), encryption_key)?; info!("Secret manager initialized"); // Initialize action executor let max_stdout_bytes = config .worker .as_ref() .map(|w| w.max_stdout_bytes) .unwrap_or(10 * 1024 * 1024); let max_stderr_bytes = config .worker .as_ref() .map(|w| w.max_stderr_bytes) .unwrap_or(10 * 1024 * 1024); // Get API URL from environment or construct from server config let api_url = std::env::var("ATTUNE_API_URL") .unwrap_or_else(|_| format!("http://{}:{}", config.server.host, config.server.port)); // Build JWT config for generating execution-scoped tokens let jwt_config = attune_common::auth::jwt::JwtConfig { secret: config .security .jwt_secret .clone() .unwrap_or_else(|| "insecure_default_secret_change_in_production".to_string()), access_token_expiration: config.security.jwt_access_expiration as i64, refresh_token_expiration: config.security.jwt_refresh_expiration as i64, }; let executor = Arc::new(ActionExecutor::new( pool.clone(), runtime_registry, artifact_manager, secret_manager, max_stdout_bytes, max_stderr_bytes, packs_base_dir.clone(), artifacts_dir, api_url, jwt_config, )); // Initialize heartbeat manager let heartbeat_interval = config .worker .as_ref() .map(|w| w.heartbeat_interval) .unwrap_or(30); let heartbeat = Arc::new(HeartbeatManager::new( registration.clone(), heartbeat_interval, )); // Capture the runtime filter for use in env setup let runtime_filter_for_service = runtime_filter.clone(); // Read max concurrent tasks from config let max_concurrent_tasks = config .worker .as_ref() .map(|w| w.max_concurrent_tasks) .unwrap_or(10); info!( "Worker configured for max {} concurrent executions", max_concurrent_tasks ); Ok(Self { config, db_pool: pool, registration, heartbeat, executor, mq_connection: Arc::new(mq_connection), publisher: Arc::new(publisher), consumer: None, consumer_handle: None, pack_consumer: None, pack_consumer_handle: None, cancel_consumer: None, cancel_consumer_handle: None, worker_id: None, runtime_filter: runtime_filter_for_service, packs_base_dir, runtime_envs_dir, execution_semaphore: Arc::new(Semaphore::new(max_concurrent_tasks)), in_flight_tasks: Arc::new(Mutex::new(JoinSet::new())), cancel_tokens: Arc::new(Mutex::new(HashMap::new())), }) } /// Start the worker service pub async fn start(&mut self) -> Result<()> { info!("Starting Worker Service"); // Detect runtime capabilities and register worker let worker_id = { let mut reg = self.registration.write().await; reg.detect_capabilities(&self.config).await?; reg.register().await? }; self.worker_id = Some(worker_id); info!("Worker registered with ID: {}", worker_id); // Setup worker-specific message queue infrastructure // (includes per-worker execution queue AND pack registration queue) let mq_config = MqConfig::default(); self.mq_connection .setup_worker_infrastructure(worker_id, &mq_config) .await .map_err(|e| { Error::Internal(format!("Failed to setup worker MQ infrastructure: {}", e)) })?; info!("Worker-specific message queue infrastructure setup completed"); // Verify which runtime versions are available on this system. // This updates the `available` flag in the database so that // `select_best_version()` only considers genuinely present versions. self.verify_runtime_versions().await; // Proactively set up runtime environments for all registered packs. // This runs before we start consuming execution messages so that // environments are ready by the time the first execution arrives. // Now version-aware: creates per-version environments where needed. self.scan_and_setup_environments().await; // Start heartbeat self.heartbeat.start().await?; // Start consuming execution messages self.start_execution_consumer().await?; // Start consuming pack registration events self.start_pack_consumer().await?; // Start consuming cancel requests self.start_cancel_consumer().await?; info!("Worker Service started successfully"); Ok(()) } /// Stop the worker service gracefully /// /// Shutdown order (mirrors sensor service pattern): /// 1. Deregister worker (mark inactive to stop receiving new work) /// 2. Stop heartbeat /// 3. Wait for in-flight tasks with timeout /// 4. Close MQ connection /// 5. Close DB connection /// /// Verify which runtime versions are available on this host/container. /// /// Runs each version's verification commands (from `distributions` JSONB) /// and updates the `available` flag in the database. This ensures that /// `select_best_version()` only considers versions whose interpreters /// are genuinely present. async fn verify_runtime_versions(&self) { let filter_refs: Option> = self.runtime_filter.clone(); let filter_slice: Option<&[String]> = filter_refs.as_deref(); let result = version_verify::verify_all_runtime_versions(&self.db_pool, filter_slice).await; if !result.errors.is_empty() { warn!( "Runtime version verification completed with {} error(s): {:?}", result.errors.len(), result.errors, ); } else { info!( "Runtime version verification complete: {} checked, \ {} available, {} unavailable", result.total_checked, result.available, result.unavailable, ); } } /// Scan all registered packs and create missing runtime environments. async fn scan_and_setup_environments(&self) { let filter_refs: Option> = self.runtime_filter.clone(); let filter_slice: Option<&[String]> = filter_refs.as_deref(); let result = env_setup::scan_and_setup_all_environments( &self.db_pool, filter_slice, &self.packs_base_dir, &self.runtime_envs_dir, ) .await; if !result.errors.is_empty() { warn!( "Environment startup scan completed with {} error(s): {:?}", result.errors.len(), result.errors, ); } else { info!( "Environment startup scan completed: {} pack(s) scanned, \ {} environment(s) ensured, {} skipped", result.packs_scanned, result.environments_created, result.environments_skipped, ); } } /// Start consuming pack.registered events from the per-worker packs queue. async fn start_pack_consumer(&mut self) -> Result<()> { let worker_id = self .worker_id .ok_or_else(|| Error::Internal("Worker not registered".to_string()))?; let queue_name = format!("worker.{}.packs", worker_id); info!( "Starting pack registration consumer for queue: {}", queue_name ); let consumer = Arc::new( Consumer::new( &self.mq_connection, ConsumerConfig { queue: queue_name.clone(), tag: format!("worker-{}-packs", worker_id), prefetch_count: 5, auto_ack: false, exclusive: false, }, ) .await .map_err(|e| Error::Internal(format!("Failed to create pack consumer: {}", e)))?, ); let db_pool = self.db_pool.clone(); let consumer_for_task = consumer.clone(); let queue_name_for_log = queue_name.clone(); let runtime_filter = self.runtime_filter.clone(); let packs_base_dir = self.packs_base_dir.clone(); let runtime_envs_dir = self.runtime_envs_dir.clone(); let handle = tokio::spawn(async move { info!( "Pack consumer loop started for queue '{}'", queue_name_for_log ); let result = consumer_for_task .consume_with_handler(move |envelope: MessageEnvelope| { let db_pool = db_pool.clone(); let runtime_filter = runtime_filter.clone(); let packs_base_dir = packs_base_dir.clone(); let runtime_envs_dir = runtime_envs_dir.clone(); async move { info!( "Received pack.registered event for pack '{}' (version {})", envelope.payload.pack_ref, envelope.payload.version, ); let filter_slice: Option> = runtime_filter; let filter_ref: Option<&[String]> = filter_slice.as_deref(); let pack_result = env_setup::setup_environments_for_registered_pack( &db_pool, &envelope.payload, filter_ref, &packs_base_dir, &runtime_envs_dir, ) .await; if !pack_result.errors.is_empty() { warn!( "Pack '{}' environment setup had {} error(s): {:?}", pack_result.pack_ref, pack_result.errors.len(), pack_result.errors, ); } else if !pack_result.environments_created.is_empty() { info!( "Pack '{}' environments set up: {:?}", pack_result.pack_ref, pack_result.environments_created, ); } Ok(()) } }) .await; match result { Ok(()) => info!( "Pack consumer loop for queue '{}' ended", queue_name_for_log ), Err(e) => error!( "Pack consumer loop for queue '{}' failed: {}", queue_name_for_log, e ), } }); self.pack_consumer = Some(consumer); self.pack_consumer_handle = Some(handle); info!("Pack registration consumer initialized"); Ok(()) } pub async fn stop(&mut self) -> Result<()> { info!("Stopping Worker Service - initiating graceful shutdown"); // 1. Mark worker as inactive first to stop receiving new tasks // Use if-let instead of ? so shutdown continues even if DB call fails { let reg = self.registration.read().await; info!("Marking worker as inactive to stop receiving new tasks"); if let Err(e) = reg.deregister().await { error!("Failed to deregister worker: {}", e); } } // 2. Stop heartbeat info!("Stopping heartbeat updates"); self.heartbeat.stop().await; // Wait a bit for heartbeat loop to notice the flag tokio::time::sleep(Duration::from_millis(100)).await; // 3. Wait for in-flight tasks to complete (with timeout) let shutdown_timeout = self .config .worker .as_ref() .and_then(|w| w.shutdown_timeout) .unwrap_or(30); // Default: 30 seconds info!( "Waiting up to {} seconds for in-flight tasks to complete", shutdown_timeout ); let timeout_duration = Duration::from_secs(shutdown_timeout); match tokio::time::timeout(timeout_duration, self.wait_for_in_flight_tasks()).await { Ok(_) => info!("All in-flight tasks completed"), Err(_) => warn!("Shutdown timeout reached - some tasks may have been interrupted"), } // 4. Abort consumer tasks and close message queue connection if let Some(handle) = self.consumer_handle.take() { info!("Stopping execution consumer task..."); handle.abort(); // Wait briefly for the task to finish let _ = handle.await; } if let Some(handle) = self.pack_consumer_handle.take() { info!("Stopping pack consumer task..."); handle.abort(); let _ = handle.await; } if let Some(handle) = self.cancel_consumer_handle.take() { info!("Stopping cancel consumer task..."); handle.abort(); let _ = handle.await; } info!("Closing message queue connection..."); if let Err(e) = self.mq_connection.close().await { warn!("Error closing message queue: {}", e); } // 5. Close database connection info!("Closing database connection..."); self.db_pool.close().await; info!("Worker Service stopped"); Ok(()) } /// Wait for in-flight tasks to complete async fn wait_for_in_flight_tasks(&self) { loop { let remaining = { let mut tasks = self.in_flight_tasks.lock().await; // Drain any already-completed tasks while tasks.try_join_next().is_some() {} tasks.len() }; if remaining == 0 { info!("All in-flight execution tasks have completed"); break; } info!( "Waiting for {} in-flight execution task(s) to complete...", remaining ); tokio::time::sleep(Duration::from_secs(1)).await; } } /// Start consuming execution.scheduled messages /// /// Spawns the consumer loop as a background task so that `start()` returns /// immediately, allowing the caller to set up signal handlers. /// /// Executions are spawned as concurrent background tasks, limited by the /// `execution_semaphore`. The consumer blocks when the concurrency limit is /// reached, providing natural backpressure via RabbitMQ. async fn start_execution_consumer(&mut self) -> Result<()> { let worker_id = self .worker_id .ok_or_else(|| Error::Internal("Worker not registered".to_string()))?; // Queue name for this worker (already created in setup_worker_infrastructure) let queue_name = format!("worker.{}.executions", worker_id); // Set prefetch slightly above max concurrent tasks to keep the pipeline filled let max_concurrent = self .config .worker .as_ref() .map(|w| w.max_concurrent_tasks) .unwrap_or(10); let prefetch_count = (max_concurrent as u16).saturating_add(2); info!( "Starting consumer for worker queue: {} (prefetch: {}, max_concurrent: {})", queue_name, prefetch_count, max_concurrent ); // Create consumer let consumer = Arc::new( Consumer::new( &self.mq_connection, ConsumerConfig { queue: queue_name.clone(), tag: format!("worker-{}", worker_id), prefetch_count, auto_ack: false, exclusive: false, }, ) .await .map_err(|e| Error::Internal(format!("Failed to create consumer: {}", e)))?, ); info!("Consumer created for queue: {}", queue_name); // Clone Arc references for the spawned task let executor = self.executor.clone(); let publisher = self.publisher.clone(); let db_pool = self.db_pool.clone(); let consumer_for_task = consumer.clone(); let queue_name_for_log = queue_name.clone(); let semaphore = self.execution_semaphore.clone(); let in_flight = self.in_flight_tasks.clone(); let cancel_tokens = self.cancel_tokens.clone(); // Spawn the consumer loop as a background task so start() can return let handle = tokio::spawn(async move { info!("Consumer loop started for queue '{}'", queue_name_for_log); let result = consumer_for_task .consume_with_handler( move |envelope: MessageEnvelope| { let executor = executor.clone(); let publisher = publisher.clone(); let db_pool = db_pool.clone(); let semaphore = semaphore.clone(); let in_flight = in_flight.clone(); let cancel_tokens = cancel_tokens.clone(); async move { let execution_id = envelope.payload.execution_id; // Acquire a concurrency permit. This blocks if we're at the // max concurrent execution limit, providing natural backpressure: // the message won't be acked until we can actually start working, // so RabbitMQ will stop delivering once prefetch is exhausted. let permit = semaphore.clone().acquire_owned().await.map_err(|_| { attune_common::mq::error::MqError::Channel( "Execution semaphore closed".to_string(), ) })?; info!( "Acquired execution permit for execution {} ({} permits remaining)", execution_id, semaphore.available_permits() ); // Create a cancellation token for this execution let cancel_token = CancellationToken::new(); { let mut tokens = cancel_tokens.lock().await; tokens.insert(execution_id, cancel_token.clone()); } // Spawn the actual execution as a background task so this // handler returns immediately, acking the message and freeing // the consumer loop to process the next delivery. let mut tasks = in_flight.lock().await; tasks.spawn(async move { // The permit is moved into this task and will be released // when the task completes (on drop). let _permit = permit; if let Err(e) = Self::handle_execution_scheduled( executor, publisher, db_pool, envelope, cancel_token, ) .await { error!("Execution {} handler error: {}", execution_id, e); } // Remove the cancel token now that execution is done let mut tokens = cancel_tokens.lock().await; tokens.remove(&execution_id); }); Ok(()) } }, ) .await; match result { Ok(()) => info!("Consumer loop for queue '{}' ended", queue_name_for_log), Err(e) => error!( "Consumer loop for queue '{}' failed: {}", queue_name_for_log, e ), } }); // Store consumer reference and task handle self.consumer = Some(consumer); self.consumer_handle = Some(handle); info!("Message queue consumer initialized"); Ok(()) } /// Handle execution.scheduled message async fn handle_execution_scheduled( executor: Arc, publisher: Arc, db_pool: PgPool, envelope: MessageEnvelope, cancel_token: CancellationToken, ) -> Result<()> { let execution_id = envelope.payload.execution_id; info!( "Processing execution.scheduled for execution: {}", execution_id ); // Check if the execution was already cancelled before we started // (e.g. pre-running cancellation via the API). { if let Ok(Some(exec)) = ExecutionRepository::find_by_id(&db_pool, execution_id).await { if matches!( exec.status, ExecutionStatus::Cancelled | ExecutionStatus::Canceling ) { info!( "Execution {} already in {:?} state, skipping", execution_id, exec.status ); // If it was Canceling, finalize to Cancelled if exec.status == ExecutionStatus::Canceling { let _ = Self::publish_status_update( &db_pool, &publisher, execution_id, ExecutionStatus::Cancelled, None, Some("Cancelled before execution started".to_string()), ) .await; let _ = Self::publish_completion_notification( &db_pool, &publisher, execution_id, ExecutionStatus::Cancelled, ) .await; } return Ok(()); } } } // Publish status: running if let Err(e) = Self::publish_status_update( &db_pool, &publisher, execution_id, ExecutionStatus::Running, None, None, ) .await { error!("Failed to publish running status: {}", e); // Continue anyway - we'll update the database directly } // Execute the action (with cancellation support) match executor .execute_with_cancel(execution_id, cancel_token.clone()) .await { Ok(result) => { // Check if this was a cancellation let was_cancelled = cancel_token.is_cancelled() || result .error .as_deref() .is_some_and(|e| e.contains("cancelled")); if was_cancelled { info!( "Execution {} was cancelled in {}ms", execution_id, result.duration_ms ); // Publish status: cancelled if let Err(e) = Self::publish_status_update( &db_pool, &publisher, execution_id, ExecutionStatus::Cancelled, None, Some("Cancelled by user".to_string()), ) .await { error!("Failed to publish cancelled status: {}", e); } // Publish completion notification for queue management if let Err(e) = Self::publish_completion_notification( &db_pool, &publisher, execution_id, ExecutionStatus::Cancelled, ) .await { error!( "Failed to publish completion notification for cancelled execution {}: {}", execution_id, e ); } } else { info!( "Execution {} completed successfully in {}ms", execution_id, result.duration_ms ); // Publish status: completed if let Err(e) = Self::publish_status_update( &db_pool, &publisher, execution_id, ExecutionStatus::Completed, result.result.clone(), None, ) .await { error!("Failed to publish success status: {}", e); } // Publish completion notification for queue management if let Err(e) = Self::publish_completion_notification( &db_pool, &publisher, execution_id, ExecutionStatus::Completed, ) .await { error!( "Failed to publish completion notification for execution {}: {}", execution_id, e ); // Continue - this is important for queue management but not fatal } } } Err(e) => { error!("Execution {} failed: {}", execution_id, e); // Publish status: failed if let Err(e) = Self::publish_status_update( &db_pool, &publisher, execution_id, ExecutionStatus::Failed, None, Some(e.to_string()), ) .await { error!("Failed to publish failure status: {}", e); } // Publish completion notification for queue management if let Err(e) = Self::publish_completion_notification( &db_pool, &publisher, execution_id, ExecutionStatus::Failed, ) .await { error!( "Failed to publish completion notification for execution {}: {}", execution_id, e ); // Continue - this is important for queue management but not fatal } } } Ok(()) } /// Start consuming execution cancel requests from the per-worker cancel queue. async fn start_cancel_consumer(&mut self) -> Result<()> { let worker_id = self .worker_id .ok_or_else(|| Error::Internal("Worker not registered".to_string()))?; let queue_name = format!("worker.{}.cancel", worker_id); info!("Starting cancel consumer for queue: {}", queue_name); let consumer = Arc::new( Consumer::new( &self.mq_connection, ConsumerConfig { queue: queue_name.clone(), tag: format!("worker-{}-cancel", worker_id), prefetch_count: 10, auto_ack: false, exclusive: false, }, ) .await .map_err(|e| Error::Internal(format!("Failed to create cancel consumer: {}", e)))?, ); let consumer_for_task = consumer.clone(); let cancel_tokens = self.cancel_tokens.clone(); let queue_name_for_log = queue_name.clone(); let handle = tokio::spawn(async move { info!( "Cancel consumer loop started for queue '{}'", queue_name_for_log ); let result = consumer_for_task .consume_with_handler( move |envelope: MessageEnvelope| { let cancel_tokens = cancel_tokens.clone(); async move { let execution_id = envelope.payload.execution_id; info!("Received cancel request for execution {}", execution_id); let tokens = cancel_tokens.lock().await; if let Some(token) = tokens.get(&execution_id) { info!("Triggering cancellation for execution {}", execution_id); token.cancel(); } else { warn!( "No cancel token found for execution {} \ (may have already completed or not yet started)", execution_id ); } Ok(()) } }, ) .await; match result { Ok(()) => info!( "Cancel consumer loop for queue '{}' ended", queue_name_for_log ), Err(e) => error!( "Cancel consumer loop for queue '{}' failed: {}", queue_name_for_log, e ), } }); self.cancel_consumer = Some(consumer); self.cancel_consumer_handle = Some(handle); info!("Cancel consumer initialized for queue: {}", queue_name); Ok(()) } /// Publish execution status update async fn publish_status_update( db_pool: &PgPool, publisher: &Publisher, execution_id: i64, status: ExecutionStatus, _result: Option, _error: Option, ) -> Result<()> { // Fetch execution to get action_ref and previous status let execution = ExecutionRepository::find_by_id(db_pool, execution_id) .await? .ok_or_else(|| { Error::Internal(format!( "Execution {} not found for status update", execution_id )) })?; let new_status_str = match status { ExecutionStatus::Running => "running", ExecutionStatus::Completed => "completed", ExecutionStatus::Failed => "failed", ExecutionStatus::Canceling => "canceling", ExecutionStatus::Cancelled => "cancelled", ExecutionStatus::Timeout => "timeout", _ => "unknown", }; let previous_status_str = format!("{:?}", execution.status).to_lowercase(); let payload = ExecutionStatusChangedPayload { execution_id, action_ref: execution.action_ref, previous_status: previous_status_str, new_status: new_status_str.to_string(), changed_at: Utc::now(), }; let message_type = MessageType::ExecutionStatusChanged; let envelope = MessageEnvelope::new(message_type, payload).with_source("worker"); publisher .publish_envelope(&envelope) .await .map_err(|e| Error::Internal(format!("Failed to publish status update: {}", e)))?; Ok(()) } /// Publish execution completion notification for queue management async fn publish_completion_notification( db_pool: &PgPool, publisher: &Publisher, execution_id: i64, final_status: ExecutionStatus, ) -> Result<()> { // Fetch execution to get action_id and other required fields let execution = ExecutionRepository::find_by_id(db_pool, execution_id) .await? .ok_or_else(|| { Error::Internal(format!( "Execution {} not found after completion", execution_id )) })?; // Extract action_id - it should always be present for valid executions let action_id = execution.action.ok_or_else(|| { Error::Internal(format!( "Execution {} has no associated action", execution_id )) })?; info!( "Publishing completion notification for execution {} (action_id: {})", execution_id, action_id ); let payload = ExecutionCompletedPayload { execution_id: execution.id, action_id, action_ref: execution.action_ref.clone(), status: format!("{:?}", final_status), result: execution.result.clone(), completed_at: Utc::now(), }; let envelope = MessageEnvelope::new(MessageType::ExecutionCompleted, payload).with_source("worker"); publisher.publish_envelope(&envelope).await.map_err(|e| { Error::Internal(format!("Failed to publish completion notification: {}", e)) })?; info!( "Completion notification published for execution {}", execution_id ); Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_queue_name_format() { let worker_id = 42; let queue_name = format!("worker.{}.executions", worker_id); assert_eq!(queue_name, "worker.42.executions"); } #[test] fn test_status_string_conversion() { let status = ExecutionStatus::Running; let status_str = match status { ExecutionStatus::Running => "running", _ => "unknown", }; assert_eq!(status_str, "running"); } #[test] fn test_execution_completed_payload_structure() { let payload = ExecutionCompletedPayload { execution_id: 123, action_id: 456, action_ref: "test.action".to_string(), status: "Completed".to_string(), result: Some(serde_json::json!({"output": "test"})), completed_at: Utc::now(), }; assert_eq!(payload.execution_id, 123); assert_eq!(payload.action_id, 456); assert_eq!(payload.action_ref, "test.action"); assert_eq!(payload.status, "Completed"); assert!(payload.result.is_some()); } // Test removed - ExecutionStatusPayload struct doesn't exist // #[test] // fn test_execution_status_payload_structure() { // ... // } #[test] fn test_execution_scheduled_payload_structure() { let payload = ExecutionScheduledPayload { execution_id: 111, action_ref: "core.test".to_string(), worker_id: 222, }; assert_eq!(payload.execution_id, 111); assert_eq!(payload.action_ref, "core.test"); assert_eq!(payload.worker_id, 222); } #[test] fn test_status_format_for_completion() { let status = ExecutionStatus::Completed; let status_str = format!("{:?}", status); assert_eq!(status_str, "Completed"); let status = ExecutionStatus::Failed; let status_str = format!("{:?}", status); assert_eq!(status_str, "Failed"); let status = ExecutionStatus::Timeout; let status_str = format!("{:?}", status); assert_eq!(status_str, "Timeout"); let status = ExecutionStatus::Cancelled; let status_str = format!("{:?}", status); assert_eq!(status_str, "Cancelled"); } }