From a4c303ec84e1237d80eb9776cb7e4d80b128e6e6 Mon Sep 17 00:00:00 2001 From: David Culbreth Date: Wed, 1 Apr 2026 20:38:18 -0500 Subject: [PATCH] merging semgrep-scan --- crates/common/src/mq/error.rs | 7 +- crates/common/src/repositories/action.rs | 83 ++- crates/executor/src/completion_listener.rs | 113 ++- crates/executor/src/enforcement_processor.rs | 29 +- crates/executor/src/policy_enforcer.rs | 364 +++++++++- crates/executor/src/queue_manager.rs | 652 +++++++++++++++--- crates/executor/src/scheduler.rs | 393 ++++++++++- crates/executor/src/service.rs | 1 + .../tests/fifo_ordering_integration_test.rs | 26 +- .../executor/tests/policy_enforcer_tests.rs | 14 +- 10 files changed, 1434 insertions(+), 248 deletions(-) diff --git a/crates/common/src/mq/error.rs b/crates/common/src/mq/error.rs index 318ed1a..e13c3f2 100644 --- a/crates/common/src/mq/error.rs +++ b/crates/common/src/mq/error.rs @@ -102,7 +102,12 @@ impl MqError { pub fn is_retriable(&self) -> bool { matches!( self, - MqError::Connection(_) | MqError::Channel(_) | MqError::Timeout(_) | MqError::Pool(_) + MqError::Connection(_) + | MqError::Channel(_) + | MqError::Publish(_) + | MqError::Timeout(_) + | MqError::Pool(_) + | MqError::Lapin(_) ) } diff --git a/crates/common/src/repositories/action.rs b/crates/common/src/repositories/action.rs index 76a3d38..c265e31 100644 --- a/crates/common/src/repositories/action.rs +++ b/crates/common/src/repositories/action.rs @@ -571,7 +571,7 @@ impl Repository for PolicyRepository { type Entity = Policy; fn table_name() -> &'static str { - "policies" + "policy" } } @@ -612,7 +612,7 @@ impl FindById for PolicyRepository { r#" SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, threshold, name, description, tags, created, updated - FROM policies + FROM policy WHERE id = $1 "#, ) @@ -634,7 +634,7 @@ impl FindByRef for PolicyRepository { r#" SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, threshold, name, description, tags, created, updated - FROM policies + FROM policy WHERE ref = $1 "#, ) @@ -656,7 +656,7 @@ impl List for PolicyRepository { r#" SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, threshold, name, description, tags, created, updated - FROM policies + FROM policy ORDER BY ref ASC "#, ) @@ -678,7 +678,7 @@ impl Create for PolicyRepository { // Try to insert - database will enforce uniqueness constraint let policy = sqlx::query_as::<_, Policy>( r#" - INSERT INTO policies (ref, pack, pack_ref, action, action_ref, parameters, + INSERT INTO policy (ref, pack, pack_ref, action, action_ref, parameters, method, threshold, name, description, tags) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, ref, pack, pack_ref, action, action_ref, parameters, method, @@ -720,7 +720,7 @@ impl Update for PolicyRepository { where E: Executor<'e, Database = Postgres> + 'e, { - let mut query = QueryBuilder::new("UPDATE policies SET "); + let mut query = QueryBuilder::new("UPDATE policy SET "); let mut has_updates = false; if let Some(parameters) = &input.parameters { @@ -798,7 +798,7 @@ impl Delete for PolicyRepository { where E: Executor<'e, Database = Postgres> + 'e, { - let result = sqlx::query("DELETE FROM policies WHERE id = $1") + let result = sqlx::query("DELETE FROM policy WHERE id = $1") .bind(id) .execute(executor) .await?; @@ -817,7 +817,7 @@ impl PolicyRepository { r#" SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, threshold, name, description, tags, created, updated - FROM policies + FROM policy WHERE action = $1 ORDER BY ref ASC "#, @@ -838,7 +838,7 @@ impl PolicyRepository { r#" SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, threshold, name, description, tags, created, updated - FROM policies + FROM policy WHERE $1 = ANY(tags) ORDER BY ref ASC "#, @@ -849,4 +849,69 @@ impl PolicyRepository { Ok(policies) } + + /// Find the most recent action-specific policy. + pub async fn find_latest_by_action<'e, E>(executor: E, action_id: Id) -> Result> + where + E: Executor<'e, Database = Postgres> + 'e, + { + let policy = sqlx::query_as::<_, Policy>( + r#" + SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, + threshold, name, description, tags, created, updated + FROM policy + WHERE action = $1 + ORDER BY created DESC + LIMIT 1 + "#, + ) + .bind(action_id) + .fetch_optional(executor) + .await?; + + Ok(policy) + } + + /// Find the most recent pack-specific policy. + pub async fn find_latest_by_pack<'e, E>(executor: E, pack_id: Id) -> Result> + where + E: Executor<'e, Database = Postgres> + 'e, + { + let policy = sqlx::query_as::<_, Policy>( + r#" + SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, + threshold, name, description, tags, created, updated + FROM policy + WHERE pack = $1 AND action IS NULL + ORDER BY created DESC + LIMIT 1 + "#, + ) + .bind(pack_id) + .fetch_optional(executor) + .await?; + + Ok(policy) + } + + /// Find the most recent global policy. + pub async fn find_latest_global<'e, E>(executor: E) -> Result> + where + E: Executor<'e, Database = Postgres> + 'e, + { + let policy = sqlx::query_as::<_, Policy>( + r#" + SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, + threshold, name, description, tags, created, updated + FROM policy + WHERE pack IS NULL AND action IS NULL + ORDER BY created DESC + LIMIT 1 + "#, + ) + .fetch_optional(executor) + .await?; + + Ok(policy) + } } diff --git a/crates/executor/src/completion_listener.rs b/crates/executor/src/completion_listener.rs index 45bbbb5..92da0a6 100644 --- a/crates/executor/src/completion_listener.rs +++ b/crates/executor/src/completion_listener.rs @@ -11,7 +11,10 @@ use anyhow::Result; use attune_common::{ - mq::{Consumer, ExecutionCompletedPayload, MessageEnvelope, Publisher}, + mq::{ + Consumer, ExecutionCompletedPayload, ExecutionRequestedPayload, MessageEnvelope, + MessageType, MqError, Publisher, + }, repositories::{execution::ExecutionRepository, FindById}, }; use sqlx::PgPool; @@ -36,6 +39,19 @@ pub struct CompletionListener { } impl CompletionListener { + fn retryable_mq_error(error: &anyhow::Error) -> Option { + let mq_error = error.downcast_ref::()?; + Some(match mq_error { + MqError::Connection(msg) => MqError::Connection(msg.clone()), + MqError::Channel(msg) => MqError::Channel(msg.clone()), + MqError::Publish(msg) => MqError::Publish(msg.clone()), + MqError::Timeout(msg) => MqError::Timeout(msg.clone()), + MqError::Pool(msg) => MqError::Pool(msg.clone()), + MqError::Lapin(err) => MqError::Connection(err.to_string()), + _ => return None, + }) + } + /// Create a new completion listener pub fn new( pool: PgPool, @@ -82,6 +98,9 @@ impl CompletionListener { { error!("Error processing execution completion: {}", e); // Return error to trigger nack with requeue + if let Some(mq_err) = Self::retryable_mq_error(&e) { + return Err(mq_err); + } return Err( format!("Failed to process execution completion: {}", e).into() ); @@ -187,17 +206,37 @@ impl CompletionListener { action_id, execution_id ); - match queue_manager.notify_completion(action_id).await { - Ok(notified) => { - if notified { - info!( - "Queue slot released for action {}, next execution notified", - action_id - ); + match queue_manager.release_active_slot(execution_id).await { + Ok(release) => { + if let Some(release) = release { + if let Some(next_execution_id) = release.next_execution_id { + info!( + "Queue slot released for action {}, next execution {} can proceed", + action_id, next_execution_id + ); + if let Err(republish_err) = Self::publish_execution_requested( + pool, + publisher, + action_id, + next_execution_id, + ) + .await + { + queue_manager + .restore_active_slot(execution_id, &release) + .await?; + return Err(republish_err); + } + } else { + debug!( + "Queue slot released for action {}, no executions waiting", + action_id + ); + } } else { debug!( - "Queue slot released for action {}, no executions waiting", - action_id + "Execution {} had no active queue slot to release", + execution_id ); } } @@ -225,6 +264,38 @@ impl CompletionListener { Ok(()) } + + async fn publish_execution_requested( + pool: &PgPool, + publisher: &Publisher, + action_id: i64, + execution_id: i64, + ) -> Result<()> { + let execution = ExecutionRepository::find_by_id(pool, execution_id) + .await? + .ok_or_else(|| anyhow::anyhow!("Execution {} not found", execution_id))?; + + let payload = ExecutionRequestedPayload { + execution_id, + action_id: Some(action_id), + action_ref: execution.action_ref.clone(), + parent_id: execution.parent, + enforcement_id: execution.enforcement, + config: execution.config.clone(), + }; + + let envelope = MessageEnvelope::new(MessageType::ExecutionRequested, payload) + .with_source("executor-completion-listener"); + + publisher.publish_envelope(&envelope).await?; + + debug!( + "Republished deferred ExecutionRequested for execution {}", + execution_id + ); + + Ok(()) + } } #[cfg(test)] @@ -239,7 +310,7 @@ mod tests { // Simulate acquiring a slot queue_manager - .enqueue_and_wait(action_id, 100, 1) + .enqueue_and_wait(action_id, 100, 1, None) .await .unwrap(); @@ -249,7 +320,7 @@ mod tests { assert_eq!(stats.queue_length, 0); // Simulate completion notification - let notified = queue_manager.notify_completion(action_id).await.unwrap(); + let notified = queue_manager.notify_completion(100).await.unwrap(); assert!(!notified); // No one waiting // Verify slot is released @@ -264,7 +335,7 @@ mod tests { // Fill capacity queue_manager - .enqueue_and_wait(action_id, 100, 1) + .enqueue_and_wait(action_id, 100, 1, None) .await .unwrap(); @@ -272,7 +343,7 @@ mod tests { let queue_manager_clone = queue_manager.clone(); let handle = tokio::spawn(async move { queue_manager_clone - .enqueue_and_wait(action_id, 101, 1) + .enqueue_and_wait(action_id, 101, 1, None) .await .unwrap(); }); @@ -286,7 +357,7 @@ mod tests { assert_eq!(stats.queue_length, 1); // Notify completion - let notified = queue_manager.notify_completion(action_id).await.unwrap(); + let notified = queue_manager.notify_completion(100).await.unwrap(); assert!(notified); // Should wake the waiting execution // Wait for queued execution to proceed @@ -306,7 +377,7 @@ mod tests { // Fill capacity queue_manager - .enqueue_and_wait(action_id, 100, 1) + .enqueue_and_wait(action_id, 100, 1, None) .await .unwrap(); @@ -320,7 +391,7 @@ mod tests { let handle = tokio::spawn(async move { queue_manager - .enqueue_and_wait(action_id, exec_id, 1) + .enqueue_and_wait(action_id, exec_id, 1, None) .await .unwrap(); order.lock().await.push(exec_id); @@ -333,9 +404,9 @@ mod tests { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // Release them one by one - for _ in 0..3 { + for execution_id in 100..103 { tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - queue_manager.notify_completion(action_id).await.unwrap(); + queue_manager.notify_completion(execution_id).await.unwrap(); } // Wait for all to complete @@ -351,10 +422,10 @@ mod tests { #[tokio::test] async fn test_completion_with_no_queue() { let queue_manager = Arc::new(ExecutionQueueManager::with_defaults()); - let action_id = 999; // Non-existent action + let execution_id = 999; // Non-existent execution // Should succeed but not notify anyone - let result = queue_manager.notify_completion(action_id).await; + let result = queue_manager.notify_completion(execution_id).await; assert!(result.is_ok()); assert!(!result.unwrap()); } diff --git a/crates/executor/src/enforcement_processor.rs b/crates/executor/src/enforcement_processor.rs index 328ad4a..7b5cfb6 100644 --- a/crates/executor/src/enforcement_processor.rs +++ b/crates/executor/src/enforcement_processor.rs @@ -230,7 +230,7 @@ impl EnforcementProcessor { async fn create_execution( pool: &PgPool, publisher: &Publisher, - policy_enforcer: &PolicyEnforcer, + _policy_enforcer: &PolicyEnforcer, _queue_manager: &ExecutionQueueManager, enforcement: &Enforcement, rule: &Rule, @@ -257,33 +257,10 @@ impl EnforcementProcessor { enforcement.id, rule.id, action_id ); - let pack_id = rule.pack; let action_ref = &rule.action_ref; - // Enforce policies and wait for queue slot if needed - info!( - "Enforcing policies for action {} (enforcement: {})", - action_id, enforcement.id - ); - - // Use enforcement ID for queue tracking (execution doesn't exist yet) - if let Err(e) = policy_enforcer - .enforce_and_wait(action_id, Some(pack_id), enforcement.id) - .await - { - error!( - "Policy enforcement failed for enforcement {}: {}", - enforcement.id, e - ); - return Err(e); - } - - info!( - "Policy check passed and queue slot obtained for enforcement: {}", - enforcement.id - ); - - // Now create execution in database (we have a queue slot) + // Create the execution row first; scheduler-side policy enforcement + // now handles both rule-triggered and manual executions uniformly. let execution_input = CreateExecutionInput { action: Some(action_id), action_ref: action_ref.clone(), diff --git a/crates/executor/src/policy_enforcer.rs b/crates/executor/src/policy_enforcer.rs index d11a935..dc4b82f 100644 --- a/crates/executor/src/policy_enforcer.rs +++ b/crates/executor/src/policy_enforcer.rs @@ -10,14 +10,23 @@ use anyhow::Result; use chrono::{DateTime, Duration, Utc}; use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; use sqlx::PgPool; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use tracing::{debug, info, warn}; -use attune_common::models::{enums::ExecutionStatus, Id}; +use attune_common::{ + models::{ + enums::{ExecutionStatus, PolicyMethod}, + Id, Policy, + }, + repositories::action::PolicyRepository, +}; -use crate::queue_manager::ExecutionQueueManager; +use crate::queue_manager::{ + ExecutionQueueManager, QueuedRemovalOutcome, SlotEnqueueOutcome, SlotReleaseOutcome, +}; /// Policy violation type #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -79,16 +88,38 @@ impl std::fmt::Display for PolicyViolation { } /// Execution policy configuration -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionPolicy { /// Rate limit: maximum executions per time window pub rate_limit: Option, /// Concurrency limit: maximum concurrent executions pub concurrency_limit: Option, + /// How a concurrency violation should be handled. + pub concurrency_method: PolicyMethod, + /// Parameter paths used to scope concurrency grouping. + pub concurrency_parameters: Vec, /// Resource quotas pub quotas: Option>, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingPolicyOutcome { + Ready, + Queued, +} + +impl Default for ExecutionPolicy { + fn default() -> Self { + Self { + rate_limit: None, + concurrency_limit: None, + concurrency_method: PolicyMethod::Enqueue, + concurrency_parameters: Vec::new(), + quotas: None, + } + } +} + /// Rate limit configuration #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RateLimit { @@ -98,6 +129,25 @@ pub struct RateLimit { pub window_seconds: u32, } +#[derive(Debug, Clone)] +struct ResolvedConcurrencyPolicy { + limit: u32, + method: PolicyMethod, + parameters: Vec, +} + +impl From for ExecutionPolicy { + fn from(policy: Policy) -> Self { + Self { + rate_limit: None, + concurrency_limit: Some(policy.threshold as u32), + concurrency_method: policy.method, + concurrency_parameters: policy.parameters, + quotas: None, + } + } +} + /// Policy enforcement scope #[derive(Debug, Clone, PartialEq, Eq)] #[allow(dead_code)] // Used in tests @@ -185,6 +235,174 @@ impl PolicyEnforcer { self.action_policies.insert(action_id, policy); } + /// Best-effort release for a slot acquired during scheduling when the + /// execution never reaches the worker/completion path. + pub async fn release_execution_slot( + &self, + execution_id: Id, + ) -> Result> { + match &self.queue_manager { + Some(queue_manager) => queue_manager.release_active_slot(execution_id).await, + None => Ok(None), + } + } + + pub async fn restore_execution_slot( + &self, + execution_id: Id, + outcome: &SlotReleaseOutcome, + ) -> Result<()> { + match &self.queue_manager { + Some(queue_manager) => { + queue_manager + .restore_active_slot(execution_id, outcome) + .await + } + None => Ok(()), + } + } + + pub async fn remove_queued_execution( + &self, + execution_id: Id, + ) -> Result> { + match &self.queue_manager { + Some(queue_manager) => queue_manager.remove_queued_execution(execution_id).await, + None => Ok(None), + } + } + + pub async fn restore_queued_execution(&self, outcome: &QueuedRemovalOutcome) -> Result<()> { + match &self.queue_manager { + Some(queue_manager) => queue_manager.restore_queued_execution(outcome).await, + None => Ok(()), + } + } + + pub async fn enforce_for_scheduling( + &self, + action_id: Id, + pack_id: Option, + execution_id: Id, + config: Option<&JsonValue>, + ) -> Result { + if let Some(violation) = self + .check_policies_except_concurrency(action_id, pack_id) + .await? + { + warn!("Policy violation for action {}: {}", action_id, violation); + return Err(anyhow::anyhow!("Policy violation: {}", violation)); + } + + if let Some(concurrency) = self.resolve_concurrency_policy(action_id, pack_id).await? { + let group_key = self.build_parameter_group_key(&concurrency.parameters, config); + + if let Some(queue_manager) = &self.queue_manager { + match concurrency.method { + PolicyMethod::Enqueue => { + return match queue_manager + .enqueue(action_id, execution_id, concurrency.limit, group_key) + .await? + { + SlotEnqueueOutcome::Acquired => Ok(SchedulingPolicyOutcome::Ready), + SlotEnqueueOutcome::Enqueued => Ok(SchedulingPolicyOutcome::Queued), + }; + } + PolicyMethod::Cancel => { + let outcome = queue_manager + .try_acquire( + action_id, + execution_id, + concurrency.limit, + group_key.clone(), + ) + .await?; + + if !outcome.acquired { + let violation = PolicyViolation::ConcurrencyLimitExceeded { + limit: concurrency.limit, + current_count: outcome.current_count, + }; + warn!("Policy violation for action {}: {}", action_id, violation); + return Err(anyhow::anyhow!("Policy violation: {}", violation)); + } + } + } + } else { + let scope = PolicyScope::Action(action_id); + if let Some(violation) = self + .check_concurrency_limit(concurrency.limit, &scope) + .await? + { + return Err(anyhow::anyhow!("Policy violation: {}", violation)); + } + } + } + + Ok(SchedulingPolicyOutcome::Ready) + } + + async fn resolve_policy(&self, action_id: Id, pack_id: Option) -> Result { + if let Some(policy) = self.action_policies.get(&action_id) { + return Ok(policy.clone()); + } + + if let Some(policy) = PolicyRepository::find_latest_by_action(&self.pool, action_id).await? + { + return Ok(policy.into()); + } + + if let Some(pack_id) = pack_id { + if let Some(policy) = self.pack_policies.get(&pack_id) { + return Ok(policy.clone()); + } + + if let Some(policy) = PolicyRepository::find_latest_by_pack(&self.pool, pack_id).await? + { + return Ok(policy.into()); + } + } + + if let Some(policy) = PolicyRepository::find_latest_global(&self.pool).await? { + return Ok(policy.into()); + } + + Ok(self.global_policy.clone()) + } + + async fn resolve_concurrency_policy( + &self, + action_id: Id, + pack_id: Option, + ) -> Result> { + let policy = self.resolve_policy(action_id, pack_id).await?; + + Ok(policy + .concurrency_limit + .map(|limit| ResolvedConcurrencyPolicy { + limit, + method: policy.concurrency_method, + parameters: policy.concurrency_parameters, + })) + } + + fn build_parameter_group_key( + &self, + parameter_paths: &[String], + config: Option<&JsonValue>, + ) -> Option { + if parameter_paths.is_empty() { + return None; + } + + let values: BTreeMap = parameter_paths + .iter() + .map(|path| (path.clone(), extract_parameter_value(config, path))) + .collect(); + + serde_json::to_string(&values).ok() + } + /// Get the concurrency limit for a specific action /// /// Returns the most specific concurrency limit found: @@ -192,6 +410,7 @@ impl PolicyEnforcer { /// 2. Pack policy /// 3. Global policy /// 4. None (unlimited) + #[allow(dead_code)] pub fn get_concurrency_limit(&self, action_id: Id, pack_id: Option) -> Option { // Check action-specific policy first if let Some(policy) = self.action_policies.get(&action_id) { @@ -229,11 +448,13 @@ impl PolicyEnforcer { /// * `Ok(())` - Policy allows execution and queue slot obtained /// * `Err(PolicyViolation)` - Policy prevents execution /// * `Err(QueueError)` - Queue timeout or other queue error + #[allow(dead_code)] pub async fn enforce_and_wait( &self, action_id: Id, pack_id: Option, execution_id: Id, + config: Option<&JsonValue>, ) -> Result<()> { // First, check for policy violations (rate limit, quotas, etc.) // Note: We skip concurrency check here since queue manages that @@ -246,36 +467,61 @@ impl PolicyEnforcer { } // If queue manager is available, use it for concurrency control - if let Some(queue_manager) = &self.queue_manager { - let concurrency_limit = self - .get_concurrency_limit(action_id, pack_id) - .unwrap_or(u32::MAX); // Default to unlimited if no policy + if let Some(concurrency) = self.resolve_concurrency_policy(action_id, pack_id).await? { + let group_key = self.build_parameter_group_key(&concurrency.parameters, config); - debug!( - "Enqueuing execution {} for action {} with concurrency limit {}", - execution_id, action_id, concurrency_limit - ); + if let Some(queue_manager) = &self.queue_manager { + debug!( + "Applying concurrency policy to execution {} for action {} (limit: {}, method: {:?}, group: {:?})", + execution_id, action_id, concurrency.limit, concurrency.method, group_key + ); - queue_manager - .enqueue_and_wait(action_id, execution_id, concurrency_limit) - .await?; + match concurrency.method { + PolicyMethod::Enqueue => { + queue_manager + .enqueue_and_wait( + action_id, + execution_id, + concurrency.limit, + group_key.clone(), + ) + .await?; + } + PolicyMethod::Cancel => { + let outcome = queue_manager + .try_acquire( + action_id, + execution_id, + concurrency.limit, + group_key.clone(), + ) + .await?; - info!( - "Execution {} obtained queue slot for action {}", - execution_id, action_id - ); - } else { - // No queue manager - use legacy polling behavior - debug!( - "No queue manager configured, using legacy policy wait for action {}", - action_id - ); + if !outcome.acquired { + let violation = PolicyViolation::ConcurrencyLimitExceeded { + limit: concurrency.limit, + current_count: outcome.current_count, + }; + warn!("Policy violation for action {}: {}", action_id, violation); + return Err(anyhow::anyhow!("Policy violation: {}", violation)); + } + } + } + + info!( + "Execution {} obtained queue slot for action {} (group: {:?})", + execution_id, action_id, group_key + ); + } else { + // No queue manager - use legacy polling behavior + debug!( + "No queue manager configured, using legacy policy wait for action {}", + action_id + ); - if let Some(concurrency_limit) = self.get_concurrency_limit(action_id, pack_id) { - // Check concurrency with old method let scope = PolicyScope::Action(action_id); if let Some(violation) = self - .check_concurrency_limit(concurrency_limit, &scope) + .check_concurrency_limit(concurrency.limit, &scope) .await? { return Err(anyhow::anyhow!("Policy violation: {}", violation)); @@ -631,6 +877,25 @@ impl PolicyEnforcer { } } +fn extract_parameter_value(config: Option<&JsonValue>, path: &str) -> JsonValue { + let mut current = match config { + Some(value) => value, + None => return JsonValue::Null, + }; + + for segment in path.split('.') { + match current { + JsonValue::Object(map) => match map.get(segment) { + Some(next) => current = next, + None => return JsonValue::Null, + }, + _ => return JsonValue::Null, + } + } + + current.clone() +} + #[cfg(test)] mod tests { use super::*; @@ -665,6 +930,8 @@ mod tests { let policy = ExecutionPolicy::default(); assert!(policy.rate_limit.is_none()); assert!(policy.concurrency_limit.is_none()); + assert_eq!(policy.concurrency_method, PolicyMethod::Enqueue); + assert!(policy.concurrency_parameters.is_empty()); assert!(policy.quotas.is_none()); } @@ -784,7 +1051,7 @@ mod tests { ); // First execution should proceed immediately - let result = enforcer.enforce_and_wait(1, None, 100).await; + let result = enforcer.enforce_and_wait(1, None, 100, None).await; assert!(result.is_ok()); // Check queue stats @@ -809,7 +1076,7 @@ mod tests { let enforcer = Arc::new(enforcer); // First execution - let result = enforcer.enforce_and_wait(1, None, 100).await; + let result = enforcer.enforce_and_wait(1, None, 100, None).await; assert!(result.is_ok()); // Queue multiple executions @@ -822,11 +1089,14 @@ mod tests { let order = execution_order.clone(); let handle = tokio::spawn(async move { - enforcer.enforce_and_wait(1, None, exec_id).await.unwrap(); + enforcer + .enforce_and_wait(1, None, exec_id, None) + .await + .unwrap(); order.lock().await.push(exec_id); // Simulate work sleep(Duration::from_millis(10)).await; - queue_manager.notify_completion(1).await.unwrap(); + queue_manager.notify_completion(exec_id).await.unwrap(); }); handles.push(handle); @@ -836,7 +1106,7 @@ mod tests { sleep(Duration::from_millis(100)).await; // Release first execution - queue_manager.notify_completion(1).await.unwrap(); + queue_manager.notify_completion(100).await.unwrap(); // Wait for all for handle in handles { @@ -863,7 +1133,7 @@ mod tests { ); // Should work without queue manager (legacy behavior) - let result = enforcer.enforce_and_wait(1, None, 100).await; + let result = enforcer.enforce_and_wait(1, None, 100, None).await; assert!(result.is_ok()); } @@ -889,14 +1159,36 @@ mod tests { ); // First execution proceeds - enforcer.enforce_and_wait(1, None, 100).await.unwrap(); + enforcer.enforce_and_wait(1, None, 100, None).await.unwrap(); // Second execution should timeout - let result = enforcer.enforce_and_wait(1, None, 101).await; + let result = enforcer.enforce_and_wait(1, None, 101, None).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("timeout")); } + #[test] + fn test_build_parameter_group_key_uses_exact_values() { + let pool = sqlx::PgPool::connect_lazy("postgresql://localhost/test").unwrap(); + let enforcer = PolicyEnforcer::new(pool); + let config = serde_json::json!({ + "environment": "prod", + "target": { + "region": "us-east-1" + } + }); + + let group_key = enforcer.build_parameter_group_key( + &["target.region".to_string(), "environment".to_string()], + Some(&config), + ); + + assert_eq!( + group_key.as_deref(), + Some("{\"environment\":\"prod\",\"target.region\":\"us-east-1\"}") + ); + } + // Integration tests would require database setup // Those should be in a separate integration test file } diff --git a/crates/executor/src/queue_manager.rs b/crates/executor/src/queue_manager.rs index 41393d8..741b390 100644 --- a/crates/executor/src/queue_manager.rs +++ b/crates/executor/src/queue_manager.rs @@ -47,7 +47,7 @@ impl Default for QueueConfig { } /// Entry in the execution queue -#[derive(Debug)] +#[derive(Debug, Clone)] struct QueueEntry { /// Execution or enforcement ID being queued execution_id: Id, @@ -57,7 +57,13 @@ struct QueueEntry { notifier: Arc, } -/// Queue state for a single action +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct QueueKey { + action_id: Id, + group_key: Option, +} + +/// Queue state for a single action/group pair struct ActionQueue { /// FIFO queue of waiting executions queue: VecDeque, @@ -93,6 +99,32 @@ impl ActionQueue { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SlotAcquireOutcome { + pub acquired: bool, + pub current_count: u32, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SlotEnqueueOutcome { + Acquired, + Enqueued, +} + +#[derive(Debug, Clone)] +pub struct SlotReleaseOutcome { + pub next_execution_id: Option, + queue_key: QueueKey, +} + +#[derive(Debug, Clone)] +pub struct QueuedRemovalOutcome { + pub next_execution_id: Option, + queue_key: QueueKey, + removed_entry: QueueEntry, + removed_index: usize, +} + /// Statistics about a queue #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueueStats { @@ -114,8 +146,10 @@ pub struct QueueStats { /// Manages execution queues with FIFO ordering guarantees pub struct ExecutionQueueManager { - /// Per-action queues (key: action_id) - queues: DashMap>>, + /// Per-action/per-group queues. + queues: DashMap>>, + /// Tracks which queue key currently owns an active execution slot. + active_execution_keys: DashMap, /// Configuration config: QueueConfig, /// Database connection pool (optional for stats persistence) @@ -128,6 +162,7 @@ impl ExecutionQueueManager { pub fn new(config: QueueConfig) -> Self { Self { queues: DashMap::new(), + active_execution_keys: DashMap::new(), config, db_pool: None, } @@ -137,6 +172,7 @@ impl ExecutionQueueManager { pub fn with_db_pool(config: QueueConfig, db_pool: PgPool) -> Self { Self { queues: DashMap::new(), + active_execution_keys: DashMap::new(), config, db_pool: Some(db_pool), } @@ -148,6 +184,24 @@ impl ExecutionQueueManager { Self::new(QueueConfig::default()) } + fn queue_key(&self, action_id: Id, group_key: Option) -> QueueKey { + QueueKey { + action_id, + group_key, + } + } + + async fn get_or_create_queue( + &self, + queue_key: QueueKey, + max_concurrent: u32, + ) -> Arc> { + self.queues + .entry(queue_key) + .or_insert_with(|| Arc::new(Mutex::new(ActionQueue::new(max_concurrent)))) + .clone() + } + /// Enqueue an execution and wait until it can proceed /// /// This method will: @@ -164,23 +218,31 @@ impl ExecutionQueueManager { /// # Returns /// * `Ok(())` - Execution can proceed /// * `Err(_)` - Queue full or timeout + #[allow(dead_code)] pub async fn enqueue_and_wait( &self, action_id: Id, execution_id: Id, max_concurrent: u32, + group_key: Option, ) -> Result<()> { + if self.active_execution_keys.contains_key(&execution_id) { + debug!( + "Execution {} already owns an active slot, skipping queue wait", + execution_id + ); + return Ok(()); + } + debug!( - "Enqueuing execution {} for action {} (max_concurrent: {})", - execution_id, action_id, max_concurrent + "Enqueuing execution {} for action {} (max_concurrent: {}, group: {:?})", + execution_id, action_id, max_concurrent, group_key ); - // Get or create queue for this action + let queue_key = self.queue_key(action_id, group_key); let queue_arc = self - .queues - .entry(action_id) - .or_insert_with(|| Arc::new(Mutex::new(ActionQueue::new(max_concurrent)))) - .clone(); + .get_or_create_queue(queue_key.clone(), max_concurrent) + .await; // Create notifier for this execution let notifier = Arc::new(Notify::new()); @@ -192,14 +254,41 @@ impl ExecutionQueueManager { // Update max_concurrent if it changed queue.max_concurrent = max_concurrent; + let queued_index = queue + .queue + .iter() + .position(|entry| entry.execution_id == execution_id); + if let Some(queued_index) = queued_index { + if queued_index == 0 && queue.has_capacity() { + let entry = queue.queue.pop_front().expect("front entry just checked"); + queue.active_count += 1; + self.active_execution_keys + .insert(entry.execution_id, queue_key.clone()); + drop(queue); + self.persist_queue_stats(action_id).await; + return Ok(()); + } + debug!( + "Execution {} is already queued for action {} (group: {:?})", + execution_id, action_id, queue_key.group_key + ); + return Ok(()); + } + // Check if we can run immediately if queue.has_capacity() { debug!( - "Execution {} can run immediately (active: {}/{})", - execution_id, queue.active_count, queue.max_concurrent + "Execution {} can run immediately for action {} (active: {}/{}, group: {:?})", + execution_id, + action_id, + queue.active_count, + queue.max_concurrent, + queue_key.group_key ); queue.active_count += 1; queue.total_enqueued += 1; + self.active_execution_keys + .insert(execution_id, queue_key.clone()); // Persist stats to database if available drop(queue); @@ -211,8 +300,9 @@ impl ExecutionQueueManager { // Check if queue is full if queue.is_full(self.config.max_queue_length) { warn!( - "Queue full for action {}: {} entries (limit: {})", + "Queue full for action {} group {:?}: {} entries (limit: {})", action_id, + queue_key.group_key, queue.queue.len(), self.config.max_queue_length ); @@ -234,12 +324,13 @@ impl ExecutionQueueManager { queue.total_enqueued += 1; info!( - "Execution {} queued for action {} at position {} (active: {}/{})", + "Execution {} queued for action {} at position {} (active: {}/{}, group: {:?})", execution_id, action_id, queue.queue.len() - 1, queue.active_count, - queue.max_concurrent + queue.max_concurrent, + queue_key.group_key ); } @@ -273,6 +364,142 @@ impl ExecutionQueueManager { } } + /// Acquire a slot immediately or enqueue without blocking the caller. + pub async fn enqueue( + &self, + action_id: Id, + execution_id: Id, + max_concurrent: u32, + group_key: Option, + ) -> Result { + if self.active_execution_keys.contains_key(&execution_id) { + debug!( + "Execution {} already owns an active slot, treating as acquired", + execution_id + ); + return Ok(SlotEnqueueOutcome::Acquired); + } + + debug!( + "Enqueuing execution {} for action {} without waiting (max_concurrent: {}, group: {:?})", + execution_id, action_id, max_concurrent, group_key + ); + + let queue_key = self.queue_key(action_id, group_key); + let queue_arc = self + .get_or_create_queue(queue_key.clone(), max_concurrent) + .await; + + { + let mut queue = queue_arc.lock().await; + queue.max_concurrent = max_concurrent; + + let queued_index = queue + .queue + .iter() + .position(|entry| entry.execution_id == execution_id); + if let Some(queued_index) = queued_index { + if queued_index == 0 && queue.has_capacity() { + let entry = queue.queue.pop_front().expect("front entry just checked"); + queue.active_count += 1; + self.active_execution_keys + .insert(entry.execution_id, queue_key.clone()); + drop(queue); + self.persist_queue_stats(action_id).await; + return Ok(SlotEnqueueOutcome::Acquired); + } + debug!( + "Execution {} is already queued for action {} (group: {:?})", + execution_id, action_id, queue_key.group_key + ); + return Ok(SlotEnqueueOutcome::Enqueued); + } + + if queue.has_capacity() { + queue.active_count += 1; + queue.total_enqueued += 1; + self.active_execution_keys + .insert(execution_id, queue_key.clone()); + + drop(queue); + self.persist_queue_stats(action_id).await; + return Ok(SlotEnqueueOutcome::Acquired); + } + + if queue.is_full(self.config.max_queue_length) { + warn!( + "Queue full for action {} group {:?}: {} entries (limit: {})", + action_id, + queue_key.group_key, + queue.queue.len(), + self.config.max_queue_length + ); + return Err(anyhow::anyhow!( + "Queue full for action {}: maximum {} entries", + action_id, + self.config.max_queue_length + )); + } + + queue.queue.push_back(QueueEntry { + execution_id, + enqueued_at: Utc::now(), + notifier: Arc::new(Notify::new()), + }); + queue.total_enqueued += 1; + } + + self.persist_queue_stats(action_id).await; + Ok(SlotEnqueueOutcome::Enqueued) + } + + /// Try to acquire a slot immediately without queueing. + pub async fn try_acquire( + &self, + action_id: Id, + execution_id: Id, + max_concurrent: u32, + group_key: Option, + ) -> Result { + let queue_key = self.queue_key(action_id, group_key); + let queue_arc = self + .get_or_create_queue(queue_key.clone(), max_concurrent) + .await; + let mut queue = queue_arc.lock().await; + + queue.max_concurrent = max_concurrent; + let current_count = queue.active_count; + + if self.active_execution_keys.contains_key(&execution_id) { + debug!( + "Execution {} already owns a slot for action {} (group: {:?})", + execution_id, action_id, queue_key.group_key + ); + return Ok(SlotAcquireOutcome { + acquired: true, + current_count, + }); + } + + if queue.has_capacity() { + queue.active_count += 1; + queue.total_enqueued += 1; + self.active_execution_keys + .insert(execution_id, queue_key.clone()); + drop(queue); + self.persist_queue_stats(action_id).await; + return Ok(SlotAcquireOutcome { + acquired: true, + current_count, + }); + } + + Ok(SlotAcquireOutcome { + acquired: false, + current_count, + }) + } + /// Notify that an execution has completed, releasing a queue slot /// /// This method will: @@ -282,27 +509,64 @@ impl ExecutionQueueManager { /// 4. Increment active count for the notified execution /// /// # Arguments - /// * `action_id` - The action that completed + /// * `execution_id` - The execution that completed /// /// # Returns /// * `Ok(true)` - A queued execution was notified /// * `Ok(false)` - No executions were waiting /// * `Err(_)` - Error accessing queue - pub async fn notify_completion(&self, action_id: Id) -> Result { + pub async fn notify_completion(&self, execution_id: Id) -> Result { + Ok(self + .notify_completion_with_next(execution_id) + .await? + .is_some()) + } + + pub async fn notify_completion_with_next(&self, execution_id: Id) -> Result> { + let release = match self.release_active_slot(execution_id).await? { + Some(release) => release, + None => return Ok(None), + }; + + let Some(next_execution_id) = release.next_execution_id else { + return Ok(None); + }; + + if self.activate_queued_execution(next_execution_id).await? { + Ok(Some(next_execution_id)) + } else { + self.restore_active_slot(execution_id, &release).await?; + Ok(None) + } + } + + pub async fn release_active_slot( + &self, + execution_id: Id, + ) -> Result> { + let Some((_, queue_key)) = self.active_execution_keys.remove(&execution_id) else { + debug!( + "No active queue slot found for execution {} (queue may have been cleared)", + execution_id + ); + return Ok(None); + }; + let action_id = queue_key.action_id; + debug!( - "Processing completion notification for action {}", - action_id + "Processing completion notification for execution {} on action {} (group: {:?})", + execution_id, action_id, queue_key.group_key ); - // Get queue for this action - let queue_arc = match self.queues.get(&action_id) { + // Get queue for this action/group + let queue_arc = match self.queues.get(&queue_key) { Some(q) => q.clone(), None => { debug!( - "No queue found for action {} (no executions queued)", - action_id + "No queue found for action {} group {:?}", + action_id, queue_key.group_key ); - return Ok(false); + return Ok(None); } }; @@ -313,49 +577,162 @@ impl ExecutionQueueManager { queue.active_count -= 1; queue.total_completed += 1; debug!( - "Decremented active count for action {} to {}", - action_id, queue.active_count + "Decremented active count for action {} group {:?} to {}", + action_id, queue_key.group_key, queue.active_count ); } else { warn!( - "Completion notification for action {} but active_count is 0", - action_id + "Completion notification for action {} group {:?} but active_count is 0", + action_id, queue_key.group_key ); } // Check if there are queued executions if queue.queue.is_empty() { debug!( - "No executions queued for action {} after completion", - action_id + "No executions queued for action {} group {:?} after completion", + action_id, queue_key.group_key ); - return Ok(false); + drop(queue); + self.persist_queue_stats(action_id).await; + return Ok(Some(SlotReleaseOutcome { + next_execution_id: None, + queue_key, + })); } - // Pop the first (oldest) entry from queue - if let Some(entry) = queue.queue.pop_front() { + let next_execution_id = queue.queue.front().map(|entry| entry.execution_id); + if let Some(next_execution_id) = next_execution_id { info!( - "Notifying execution {} for action {} (was queued for {:?})", + "Execution {} is next for action {} group {:?}", + next_execution_id, action_id, queue_key.group_key + ); + } + + drop(queue); + self.persist_queue_stats(action_id).await; + + Ok(Some(SlotReleaseOutcome { + next_execution_id, + queue_key, + })) + } + + pub async fn restore_active_slot( + &self, + execution_id: Id, + outcome: &SlotReleaseOutcome, + ) -> Result<()> { + let action_id = outcome.queue_key.action_id; + let queue_arc = self.get_or_create_queue(outcome.queue_key.clone(), 1).await; + let mut queue = queue_arc.lock().await; + + queue.active_count += 1; + if queue.total_completed > 0 { + queue.total_completed -= 1; + } + self.active_execution_keys + .insert(execution_id, outcome.queue_key.clone()); + + drop(queue); + self.persist_queue_stats(action_id).await; + Ok(()) + } + + pub async fn activate_queued_execution(&self, execution_id: Id) -> Result { + for entry in self.queues.iter() { + let queue_key = entry.key().clone(); + let queue_arc = entry.value().clone(); + let mut queue = queue_arc.lock().await; + + let Some(front) = queue.queue.front() else { + continue; + }; + + if front.execution_id != execution_id { + continue; + } + + if !queue.has_capacity() { + return Ok(false); + } + + let entry = queue.queue.pop_front().expect("front entry just checked"); + info!( + "Activating queued execution {} for action {} group {:?} (queued for {:?})", entry.execution_id, - action_id, + queue_key.action_id, + queue_key.group_key, Utc::now() - entry.enqueued_at ); - - // Increment active count for the execution we're about to notify queue.active_count += 1; + self.active_execution_keys + .insert(entry.execution_id, queue_key.clone()); - // Notify the waiter (after releasing lock) drop(queue); entry.notifier.notify_one(); + self.persist_queue_stats(queue_key.action_id).await; + return Ok(true); + } - // Persist stats to database if available + Ok(false) + } + + pub async fn remove_queued_execution( + &self, + execution_id: Id, + ) -> Result> { + for entry in self.queues.iter() { + let queue_key = entry.key().clone(); + let queue_arc = entry.value().clone(); + let mut queue = queue_arc.lock().await; + + let Some(index) = queue + .queue + .iter() + .position(|queued| queued.execution_id == execution_id) + else { + continue; + }; + + let removed_entry = queue.queue.remove(index).expect("queue index just checked"); + let next_execution_id = if index == 0 { + queue.queue.front().map(|queued| queued.execution_id) + } else { + None + }; + let action_id = queue_key.action_id; + + drop(queue); self.persist_queue_stats(action_id).await; - Ok(true) - } else { - // Race condition check - queue was empty after all - Ok(false) + return Ok(Some(QueuedRemovalOutcome { + next_execution_id, + queue_key, + removed_entry, + removed_index: index, + })); } + + Ok(None) + } + + pub async fn restore_queued_execution(&self, outcome: &QueuedRemovalOutcome) -> Result<()> { + let action_id = outcome.queue_key.action_id; + let queue_arc = self.get_or_create_queue(outcome.queue_key.clone(), 1).await; + let mut queue = queue_arc.lock().await; + + if outcome.removed_index <= queue.queue.len() { + queue + .queue + .insert(outcome.removed_index, outcome.removed_entry.clone()); + } else { + queue.queue.push_back(outcome.removed_entry.clone()); + } + + drop(queue); + self.persist_queue_stats(action_id).await; + Ok(()) } /// Persist queue statistics to database (if database pool is available) @@ -384,19 +761,48 @@ impl ExecutionQueueManager { /// Get statistics for a specific action's queue pub async fn get_queue_stats(&self, action_id: Id) -> Option { - let queue_arc = self.queues.get(&action_id)?.clone(); - let queue = queue_arc.lock().await; + let queue_arcs: Vec>> = self + .queues + .iter() + .filter(|entry| entry.key().action_id == action_id) + .map(|entry| entry.value().clone()) + .collect(); - let oldest_enqueued_at = queue.queue.front().map(|e| e.enqueued_at); + if queue_arcs.is_empty() { + return None; + } + + let mut queue_length = 0usize; + let mut active_count = 0u32; + let mut max_concurrent = 0u32; + let mut oldest_enqueued_at: Option> = None; + let mut total_enqueued = 0u64; + let mut total_completed = 0u64; + + for queue_arc in queue_arcs { + let queue = queue_arc.lock().await; + queue_length += queue.queue.len(); + active_count += queue.active_count; + max_concurrent += queue.max_concurrent; + total_enqueued += queue.total_enqueued; + total_completed += queue.total_completed; + + if let Some(candidate) = queue.queue.front().map(|e| e.enqueued_at) { + oldest_enqueued_at = Some(match oldest_enqueued_at { + Some(current) => current.min(candidate), + None => candidate, + }); + } + } Some(QueueStats { action_id, - queue_length: queue.queue.len(), - active_count: queue.active_count, - max_concurrent: queue.max_concurrent, + queue_length, + active_count, + max_concurrent, oldest_enqueued_at, - total_enqueued: queue.total_enqueued, - total_completed: queue.total_completed, + total_enqueued, + total_completed, }) } @@ -405,22 +811,15 @@ impl ExecutionQueueManager { pub async fn get_all_queue_stats(&self) -> Vec { let mut stats = Vec::new(); + let mut action_ids = std::collections::BTreeSet::new(); for entry in self.queues.iter() { - let action_id = *entry.key(); - let queue_arc = entry.value().clone(); - let queue = queue_arc.lock().await; + action_ids.insert(entry.key().action_id); + } - let oldest_enqueued_at = queue.queue.front().map(|e| e.enqueued_at); - - stats.push(QueueStats { - action_id, - queue_length: queue.queue.len(), - active_count: queue.active_count, - max_concurrent: queue.max_concurrent, - oldest_enqueued_at, - total_enqueued: queue.total_enqueued, - total_completed: queue.total_completed, - }); + for action_id in action_ids { + if let Some(action_stats) = self.get_queue_stats(action_id).await { + stats.push(action_stats); + } } stats @@ -445,27 +844,29 @@ impl ExecutionQueueManager { execution_id, action_id ); - let queue_arc = match self.queues.get(&action_id) { - Some(q) => q.clone(), - None => return Ok(false), - }; + let queue_arcs: Vec>> = self + .queues + .iter() + .filter(|entry| entry.key().action_id == action_id) + .map(|entry| entry.value().clone()) + .collect(); - let mut queue = queue_arc.lock().await; - - let initial_len = queue.queue.len(); - queue.queue.retain(|e| e.execution_id != execution_id); - let removed = initial_len != queue.queue.len(); - - if removed { - info!("Cancelled execution {} from queue", execution_id); - } else { - debug!( - "Execution {} not found in queue (may be running)", - execution_id - ); + for queue_arc in queue_arcs { + let mut queue = queue_arc.lock().await; + let initial_len = queue.queue.len(); + queue.queue.retain(|e| e.execution_id != execution_id); + if initial_len != queue.queue.len() { + info!("Cancelled execution {} from queue", execution_id); + return Ok(true); + } } - Ok(removed) + debug!( + "Execution {} not found in queue (may be running)", + execution_id + ); + + Ok(false) } /// Clear all queues (for testing or emergency situations) @@ -479,12 +880,17 @@ impl ExecutionQueueManager { queue.queue.clear(); queue.active_count = 0; } + self.active_execution_keys.clear(); } /// Get the number of actions with active queues #[allow(dead_code)] pub fn active_queue_count(&self) -> usize { - self.queues.len() + self.queues + .iter() + .map(|entry| entry.key().action_id) + .collect::>() + .len() } } @@ -504,7 +910,7 @@ mod tests { let manager = ExecutionQueueManager::with_defaults(); // Should execute immediately when there's capacity - let result = manager.enqueue_and_wait(1, 100, 2).await; + let result = manager.enqueue_and_wait(1, 100, 2, None).await; assert!(result.is_ok()); // Check stats @@ -521,7 +927,7 @@ mod tests { // First execution should run immediately let result = manager - .enqueue_and_wait(action_id, 100, max_concurrent) + .enqueue_and_wait(action_id, 100, max_concurrent, None) .await; assert!(result.is_ok()); @@ -535,7 +941,7 @@ mod tests { let handle = tokio::spawn(async move { manager - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .unwrap(); order.lock().await.push(exec_id); @@ -553,9 +959,9 @@ mod tests { assert_eq!(stats.active_count, 1); // Release them one by one - for _ in 0..3 { + for execution_id in 100..103 { sleep(Duration::from_millis(50)).await; - manager.notify_completion(action_id).await.unwrap(); + manager.notify_completion(execution_id).await.unwrap(); } // Wait for all to complete @@ -574,7 +980,10 @@ mod tests { let action_id = 1; // Start first execution - manager.enqueue_and_wait(action_id, 100, 1).await.unwrap(); + manager + .enqueue_and_wait(action_id, 100, 1, None) + .await + .unwrap(); // Queue second execution let manager_clone = Arc::new(manager); @@ -582,7 +991,7 @@ mod tests { let handle = tokio::spawn(async move { manager_ref - .enqueue_and_wait(action_id, 101, 1) + .enqueue_and_wait(action_id, 101, 1, None) .await .unwrap(); }); @@ -596,7 +1005,7 @@ mod tests { assert_eq!(stats.active_count, 1); // Notify completion - let notified = manager_clone.notify_completion(action_id).await.unwrap(); + let notified = manager_clone.notify_completion(100).await.unwrap(); assert!(notified); // Wait for queued execution to proceed @@ -613,8 +1022,8 @@ mod tests { let manager = Arc::new(ExecutionQueueManager::with_defaults()); // Start executions on different actions - manager.enqueue_and_wait(1, 100, 1).await.unwrap(); - manager.enqueue_and_wait(2, 200, 1).await.unwrap(); + manager.enqueue_and_wait(1, 100, 1, None).await.unwrap(); + manager.enqueue_and_wait(2, 200, 1, None).await.unwrap(); // Both should be active let stats1 = manager.get_queue_stats(1).await.unwrap(); @@ -624,7 +1033,7 @@ mod tests { assert_eq!(stats2.active_count, 1); // Completion on action 1 shouldn't affect action 2 - manager.notify_completion(1).await.unwrap(); + manager.notify_completion(100).await.unwrap(); let stats1 = manager.get_queue_stats(1).await.unwrap(); let stats2 = manager.get_queue_stats(2).await.unwrap(); @@ -633,20 +1042,43 @@ mod tests { assert_eq!(stats2.active_count, 1); } + #[tokio::test] + async fn test_grouped_queues_are_independent() { + let manager = ExecutionQueueManager::with_defaults(); + let action_id = 1; + + manager + .enqueue_and_wait(action_id, 100, 1, Some("prod".to_string())) + .await + .unwrap(); + manager + .enqueue_and_wait(action_id, 200, 1, Some("staging".to_string())) + .await + .unwrap(); + + let stats = manager.get_queue_stats(action_id).await.unwrap(); + assert_eq!(stats.active_count, 2); + assert_eq!(stats.queue_length, 0); + assert_eq!(stats.max_concurrent, 2); + } + #[tokio::test] async fn test_cancel_execution() { let manager = ExecutionQueueManager::with_defaults(); let action_id = 1; // Fill capacity - manager.enqueue_and_wait(action_id, 100, 1).await.unwrap(); + manager + .enqueue_and_wait(action_id, 100, 1, None) + .await + .unwrap(); // Queue more executions let manager_arc = Arc::new(manager); let manager_ref = manager_arc.clone(); let handle = tokio::spawn(async move { - let result = manager_ref.enqueue_and_wait(action_id, 101, 1).await; + let result = manager_ref.enqueue_and_wait(action_id, 101, 1, None).await; result }); @@ -675,7 +1107,10 @@ mod tests { assert!(manager.get_queue_stats(action_id).await.is_none()); // After enqueue, stats should exist - manager.enqueue_and_wait(action_id, 100, 2).await.unwrap(); + manager + .enqueue_and_wait(action_id, 100, 2, None) + .await + .unwrap(); let stats = manager.get_queue_stats(action_id).await.unwrap(); assert_eq!(stats.action_id, action_id); @@ -696,13 +1131,16 @@ mod tests { let action_id = 1; // Fill capacity - manager.enqueue_and_wait(action_id, 100, 1).await.unwrap(); + manager + .enqueue_and_wait(action_id, 100, 1, None) + .await + .unwrap(); // Queue 2 more (should reach limit) let manager_ref = manager.clone(); tokio::spawn(async move { manager_ref - .enqueue_and_wait(action_id, 101, 1) + .enqueue_and_wait(action_id, 101, 1, None) .await .unwrap(); }); @@ -710,7 +1148,7 @@ mod tests { let manager_ref = manager.clone(); tokio::spawn(async move { manager_ref - .enqueue_and_wait(action_id, 102, 1) + .enqueue_and_wait(action_id, 102, 1, None) .await .unwrap(); }); @@ -718,7 +1156,7 @@ mod tests { sleep(Duration::from_millis(100)).await; // Next one should fail - let result = manager.enqueue_and_wait(action_id, 103, 1).await; + let result = manager.enqueue_and_wait(action_id, 103, 1, None).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("Queue full")); } @@ -732,7 +1170,7 @@ mod tests { // Start first execution manager - .enqueue_and_wait(action_id, 0, max_concurrent) + .enqueue_and_wait(action_id, 0, max_concurrent, None) .await .unwrap(); @@ -746,7 +1184,7 @@ mod tests { let handle = tokio::spawn(async move { manager - .enqueue_and_wait(action_id, i, max_concurrent) + .enqueue_and_wait(action_id, i, max_concurrent, None) .await .unwrap(); order.lock().await.push(i); @@ -759,9 +1197,9 @@ mod tests { sleep(Duration::from_millis(200)).await; // Release them all - for _ in 0..num_executions { + for execution_id in 0..num_executions { sleep(Duration::from_millis(10)).await; - manager.notify_completion(action_id).await.unwrap(); + manager.notify_completion(execution_id).await.unwrap(); } // Wait for completion diff --git a/crates/executor/src/scheduler.rs b/crates/executor/src/scheduler.rs index 8cb7ee5..62af490 100644 --- a/crates/executor/src/scheduler.rs +++ b/crates/executor/src/scheduler.rs @@ -16,7 +16,7 @@ use attune_common::{ models::{enums::ExecutionStatus, execution::WorkflowTaskMetadata, Action, Execution, Runtime}, mq::{ Consumer, ExecutionCompletedPayload, ExecutionRequestedPayload, MessageEnvelope, - MessageType, Publisher, + MessageType, MqError, Publisher, }, repositories::{ action::ActionRepository, @@ -40,6 +40,7 @@ use std::sync::Arc; use std::time::Duration; use tracing::{debug, error, info, warn}; +use crate::policy_enforcer::{PolicyEnforcer, SchedulingPolicyOutcome}; use crate::workflow::context::{TaskOutcome, WorkflowContext}; use crate::workflow::graph::TaskGraph; @@ -108,6 +109,7 @@ pub struct ExecutionScheduler { pool: PgPool, publisher: Arc, consumer: Arc, + policy_enforcer: Arc, /// Round-robin counter for distributing executions across workers round_robin_counter: AtomicUsize, } @@ -120,12 +122,31 @@ const DEFAULT_HEARTBEAT_INTERVAL: u64 = 30; const HEARTBEAT_STALENESS_MULTIPLIER: u64 = 3; impl ExecutionScheduler { + fn retryable_mq_error(error: &anyhow::Error) -> Option { + let mq_error = error.downcast_ref::()?; + Some(match mq_error { + MqError::Connection(msg) => MqError::Connection(msg.clone()), + MqError::Channel(msg) => MqError::Channel(msg.clone()), + MqError::Publish(msg) => MqError::Publish(msg.clone()), + MqError::Timeout(msg) => MqError::Timeout(msg.clone()), + MqError::Pool(msg) => MqError::Pool(msg.clone()), + MqError::Lapin(err) => MqError::Connection(err.to_string()), + _ => return None, + }) + } + /// Create a new execution scheduler - pub fn new(pool: PgPool, publisher: Arc, consumer: Arc) -> Self { + pub fn new( + pool: PgPool, + publisher: Arc, + consumer: Arc, + policy_enforcer: Arc, + ) -> Self { Self { pool, publisher, consumer, + policy_enforcer, round_robin_counter: AtomicUsize::new(0), } } @@ -136,6 +157,7 @@ impl ExecutionScheduler { let pool = self.pool.clone(); let publisher = self.publisher.clone(); + let policy_enforcer = self.policy_enforcer.clone(); // Share the counter with the handler closure via Arc. // We wrap &self's AtomicUsize in a new Arc by copying the // current value so the closure is 'static. @@ -149,16 +171,24 @@ impl ExecutionScheduler { move |envelope: MessageEnvelope| { let pool = pool.clone(); let publisher = publisher.clone(); + let policy_enforcer = policy_enforcer.clone(); let counter = counter.clone(); async move { if let Err(e) = Self::process_execution_requested( - &pool, &publisher, &counter, &envelope, + &pool, + &publisher, + &policy_enforcer, + &counter, + &envelope, ) .await { error!("Error scheduling execution: {}", e); // Return error to trigger nack with requeue + if let Some(mq_err) = Self::retryable_mq_error(&e) { + return Err(mq_err); + } return Err(format!("Failed to schedule execution: {}", e).into()); } Ok(()) @@ -174,6 +204,7 @@ impl ExecutionScheduler { async fn process_execution_requested( pool: &PgPool, publisher: &Publisher, + policy_enforcer: &PolicyEnforcer, round_robin_counter: &AtomicUsize, envelope: &MessageEnvelope, ) -> Result<()> { @@ -184,9 +215,30 @@ impl ExecutionScheduler { info!("Scheduling execution: {}", execution_id); // Fetch execution from database - let execution = ExecutionRepository::find_by_id(pool, execution_id) - .await? - .ok_or_else(|| anyhow::anyhow!("Execution not found: {}", execution_id))?; + let execution = match ExecutionRepository::find_by_id(pool, execution_id).await? { + Some(execution) => execution, + None => { + warn!("Execution {} not found during scheduling", execution_id); + Self::remove_queued_policy_execution( + policy_enforcer, + pool, + publisher, + execution_id, + ) + .await; + return Ok(()); + } + }; + + if execution.status != ExecutionStatus::Requested { + debug!( + "Skipping execution {} with status {:?}; only Requested executions are schedulable", + execution_id, execution.status + ); + Self::remove_queued_policy_execution(policy_enforcer, pool, publisher, execution_id) + .await; + return Ok(()); + } // Fetch action to determine runtime requirements let action = Self::get_action_for_execution(pool, &execution).await?; @@ -207,30 +259,6 @@ impl ExecutionScheduler { .await; } - // Regular action: select appropriate worker (round-robin among compatible workers) - let worker = match Self::select_worker(pool, &action, round_robin_counter).await { - Ok(worker) => worker, - Err(err) if Self::is_unschedulable_error(&err) => { - Self::fail_unschedulable_execution( - pool, - publisher, - envelope, - execution_id, - action.id, - &action.r#ref, - &err.to_string(), - ) - .await?; - return Ok(()); - } - Err(err) => return Err(err), - }; - - info!( - "Selected worker {} for execution {}", - worker.id, execution_id - ); - // Apply parameter defaults from the action's param_schema. // This mirrors what `process_workflow_execution` does for workflows // so that non-workflow executions also get missing parameters filled @@ -249,16 +277,97 @@ impl ExecutionScheduler { } }; + match policy_enforcer + .enforce_for_scheduling( + action.id, + Some(action.pack), + execution_id, + execution_config.as_ref(), + ) + .await + { + Ok(SchedulingPolicyOutcome::Queued) => { + info!( + "Execution {} queued by policy for action {}; deferring worker selection", + execution_id, action.id + ); + return Ok(()); + } + Ok(SchedulingPolicyOutcome::Ready) => {} + Err(err) => { + if Self::is_policy_cancellation_error(&err) { + Self::remove_queued_policy_execution( + policy_enforcer, + pool, + publisher, + execution_id, + ) + .await; + Self::cancel_execution_for_policy_violation( + pool, + publisher, + envelope, + execution_id, + action.id, + &action.r#ref, + &err.to_string(), + ) + .await?; + return Ok(()); + } + + return Err(err); + } + } + + // Regular action: select appropriate worker only after policy + // readiness is confirmed, so queued executions don't reserve stale + // workers while they wait. + let worker = match Self::select_worker(pool, &action, round_robin_counter).await { + Ok(worker) => worker, + Err(err) if Self::is_unschedulable_error(&err) => { + Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id) + .await?; + Self::fail_unschedulable_execution( + pool, + publisher, + envelope, + execution_id, + action.id, + &action.r#ref, + &err.to_string(), + ) + .await?; + return Ok(()); + } + Err(err) => { + Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id) + .await?; + return Err(err); + } + }; + + info!( + "Selected worker {} for execution {}", + worker.id, execution_id + ); + // Persist the selected worker so later cancellation requests can be // routed to the correct per-worker cancel queue. let mut execution_for_update = execution; execution_for_update.status = ExecutionStatus::Scheduled; execution_for_update.worker = Some(worker.id); - ExecutionRepository::update(pool, execution_for_update.id, execution_for_update.into()) - .await?; + if let Err(err) = + ExecutionRepository::update(pool, execution_for_update.id, execution_for_update.into()) + .await + { + Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id) + .await?; + return Err(err.into()); + } // Publish message to worker-specific queue - Self::queue_to_worker( + if let Err(err) = Self::queue_to_worker( publisher, &execution_id, &worker.id, @@ -266,7 +375,27 @@ impl ExecutionScheduler { &execution_config, &action, ) - .await?; + .await + { + if let Err(revert_err) = ExecutionRepository::update( + pool, + execution_id, + UpdateExecutionInput { + status: Some(ExecutionStatus::Requested), + ..Default::default() + }, + ) + .await + { + warn!( + "Failed to revert execution {} back to Requested after worker publish error: {}", + execution_id, revert_err + ); + } + Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id) + .await?; + return Err(err); + } info!( "Execution {} scheduled to worker {}", @@ -1698,6 +1827,131 @@ impl ExecutionScheduler { || message.starts_with("No workers with fresh heartbeats available") } + fn is_policy_cancellation_error(error: &anyhow::Error) -> bool { + let message = error.to_string(); + message.contains("Policy violation:") + || message.starts_with("Queue full for action ") + || message.starts_with("Queue timeout for execution ") + } + + async fn release_acquired_policy_slot( + policy_enforcer: &PolicyEnforcer, + pool: &PgPool, + publisher: &Publisher, + execution_id: i64, + ) -> Result<()> { + let release = match policy_enforcer.release_execution_slot(execution_id).await { + Ok(release) => release, + Err(release_err) => { + warn!( + "Failed to release acquired policy slot for execution {} after scheduling error: {}", + execution_id, release_err + ); + return Err(release_err); + } + }; + + let Some(release) = release else { + return Ok(()); + }; + + if let Some(next_execution_id) = release.next_execution_id { + if let Err(republish_err) = + Self::republish_execution_requested(pool, publisher, next_execution_id).await + { + warn!( + "Failed to republish deferred execution {} after releasing slot from execution {}: {}", + next_execution_id, execution_id, republish_err + ); + if let Err(restore_err) = policy_enforcer + .restore_execution_slot(execution_id, &release) + .await + { + warn!( + "Failed to restore policy slot for execution {} after republish error: {}", + execution_id, restore_err + ); + } + return Err(republish_err); + } + } + + Ok(()) + } + + async fn remove_queued_policy_execution( + policy_enforcer: &PolicyEnforcer, + pool: &PgPool, + publisher: &Publisher, + execution_id: i64, + ) { + let removal = match policy_enforcer.remove_queued_execution(execution_id).await { + Ok(removal) => removal, + Err(remove_err) => { + warn!( + "Failed to remove queued policy execution {} during scheduler cleanup: {}", + execution_id, remove_err + ); + return; + } + }; + + let Some(removal) = removal else { + return; + }; + + if let Some(next_execution_id) = removal.next_execution_id { + if let Err(republish_err) = + Self::republish_execution_requested(pool, publisher, next_execution_id).await + { + warn!( + "Failed to republish successor {} after removing queued execution {}: {}", + next_execution_id, execution_id, republish_err + ); + if let Err(restore_err) = policy_enforcer.restore_queued_execution(&removal).await { + warn!( + "Failed to restore queued execution {} after republish error: {}", + execution_id, restore_err + ); + } + } + } + } + + async fn republish_execution_requested( + pool: &PgPool, + publisher: &Publisher, + execution_id: i64, + ) -> Result<()> { + let execution = ExecutionRepository::find_by_id(pool, execution_id) + .await? + .ok_or_else(|| anyhow::anyhow!("Execution {} not found", execution_id))?; + + let action_id = execution + .action + .ok_or_else(|| anyhow::anyhow!("Execution {} has no action", execution_id))?; + let payload = ExecutionRequestedPayload { + execution_id, + action_id: Some(action_id), + action_ref: execution.action_ref.clone(), + parent_id: execution.parent, + enforcement_id: execution.enforcement, + config: execution.config.clone(), + }; + + let envelope = MessageEnvelope::new(MessageType::ExecutionRequested, payload) + .with_source("executor-scheduler"); + + publisher.publish_envelope(&envelope).await?; + + debug!( + "Republished deferred ExecutionRequested for execution {}", + execution_id + ); + + Ok(()) + } + async fn fail_unschedulable_execution( pool: &PgPool, publisher: &Publisher, @@ -1751,6 +2005,59 @@ impl ExecutionScheduler { Ok(()) } + async fn cancel_execution_for_policy_violation( + pool: &PgPool, + publisher: &Publisher, + envelope: &MessageEnvelope, + execution_id: i64, + action_id: i64, + action_ref: &str, + error_message: &str, + ) -> Result<()> { + let completed_at = Utc::now(); + let result = serde_json::json!({ + "error": "Execution cancelled by policy", + "message": error_message, + "action_ref": action_ref, + "cancelled_by": "execution_scheduler", + "cancelled_at": completed_at.to_rfc3339(), + }); + + ExecutionRepository::update( + pool, + execution_id, + UpdateExecutionInput { + status: Some(ExecutionStatus::Cancelled), + result: Some(result.clone()), + ..Default::default() + }, + ) + .await?; + + let completed = MessageEnvelope::new( + MessageType::ExecutionCompleted, + ExecutionCompletedPayload { + execution_id, + action_id, + action_ref: action_ref.to_string(), + status: "cancelled".to_string(), + result: Some(result), + completed_at, + }, + ) + .with_correlation_id(envelope.correlation_id) + .with_source("attune-executor"); + + publisher.publish_envelope(&completed).await?; + + warn!( + "Execution {} cancelled due to policy violation: {}", + execution_id, error_message + ); + + Ok(()) + } + /// Check if a worker's heartbeat is fresh enough to schedule work /// /// A worker is considered fresh if its last heartbeat is within @@ -1980,6 +2287,24 @@ mod tests { )); } + #[test] + fn test_policy_cancellation_error_classification() { + assert!(ExecutionScheduler::is_policy_cancellation_error( + &anyhow::anyhow!( + "Policy violation: Concurrency limit exceeded: 1 running executions (limit: 1)" + ) + )); + assert!(ExecutionScheduler::is_policy_cancellation_error( + &anyhow::anyhow!("Queue full for action 42: maximum 100 entries") + )); + assert!(ExecutionScheduler::is_policy_cancellation_error( + &anyhow::anyhow!("Queue timeout for execution 99: waited 60 seconds") + )); + assert!(!ExecutionScheduler::is_policy_cancellation_error( + &anyhow::anyhow!("rabbitmq publish failed") + )); + } + #[test] fn test_concurrency_limit_dispatch_count() { // Verify the dispatch_count calculation used by dispatch_with_items_task diff --git a/crates/executor/src/service.rs b/crates/executor/src/service.rs index 4143f07..2d20748 100644 --- a/crates/executor/src/service.rs +++ b/crates/executor/src/service.rs @@ -297,6 +297,7 @@ impl ExecutorService { self.inner.pool.clone(), self.inner.publisher.clone(), Arc::new(scheduler_consumer), + self.inner.policy_enforcer.clone(), ); handles.push(tokio::spawn(async move { scheduler.start().await })); diff --git a/crates/executor/tests/fifo_ordering_integration_test.rs b/crates/executor/tests/fifo_ordering_integration_test.rs index 33d3747..f3c387a 100644 --- a/crates/executor/tests/fifo_ordering_integration_test.rs +++ b/crates/executor/tests/fifo_ordering_integration_test.rs @@ -199,7 +199,7 @@ async fn test_fifo_ordering_with_database() { let first_exec_id = create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; manager - .enqueue_and_wait(action_id, first_exec_id, max_concurrent) + .enqueue_and_wait(action_id, first_exec_id, max_concurrent, None) .await .expect("First execution should enqueue"); @@ -222,7 +222,7 @@ async fn test_fifo_ordering_with_database() { // Enqueue and wait manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .expect("Enqueue should succeed"); @@ -316,7 +316,7 @@ async fn test_high_concurrency_stress() { .await; manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .expect("Enqueue should succeed"); @@ -343,7 +343,7 @@ async fn test_high_concurrency_stress() { .await; manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .expect("Enqueue should succeed"); @@ -479,7 +479,7 @@ async fn test_multiple_workers_simulation() { .await; manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .expect("Enqueue should succeed"); @@ -596,7 +596,7 @@ async fn test_cross_action_independence() { .await; manager_clone - .enqueue_and_wait(action_id, exec_id, 1) + .enqueue_and_wait(action_id, exec_id, 1, None) .await .expect("Enqueue should succeed"); @@ -699,7 +699,7 @@ async fn test_cancellation_during_queue() { let exec_id = create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; manager - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .unwrap(); @@ -722,7 +722,7 @@ async fn test_cancellation_during_queue() { ids.lock().await.push(exec_id); manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await }); @@ -808,7 +808,7 @@ async fn test_queue_stats_persistence() { let manager_clone = manager.clone(); tokio::spawn(async move { manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .ok(); }); @@ -888,7 +888,7 @@ async fn test_queue_full_rejection() { let exec_id = create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; manager - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .unwrap(); @@ -900,7 +900,7 @@ async fn test_queue_full_rejection() { tokio::spawn(async move { manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .ok(); }); @@ -917,7 +917,7 @@ async fn test_queue_full_rejection() { let exec_id = create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; let result = manager - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await; assert!(result.is_err(), "Should reject when queue is full"); @@ -976,7 +976,7 @@ async fn test_extreme_stress_10k_executions() { .await; manager_clone - .enqueue_and_wait(action_id, exec_id, max_concurrent) + .enqueue_and_wait(action_id, exec_id, max_concurrent, None) .await .expect("Enqueue should succeed"); diff --git a/crates/executor/tests/policy_enforcer_tests.rs b/crates/executor/tests/policy_enforcer_tests.rs index bf96960..08a86df 100644 --- a/crates/executor/tests/policy_enforcer_tests.rs +++ b/crates/executor/tests/policy_enforcer_tests.rs @@ -9,7 +9,7 @@ use attune_common::{ config::Config, db::Database, - models::enums::ExecutionStatus, + models::enums::{ExecutionStatus, PolicyMethod}, repositories::{ action::{ActionRepository, CreateActionInput}, execution::{CreateExecutionInput, ExecutionRepository}, @@ -190,6 +190,8 @@ async fn test_global_rate_limit() { window_seconds: 60, }), concurrency_limit: None, + concurrency_method: PolicyMethod::Enqueue, + concurrency_parameters: Vec::new(), quotas: None, }; @@ -242,6 +244,8 @@ async fn test_concurrency_limit() { let policy = ExecutionPolicy { rate_limit: None, concurrency_limit: Some(2), + concurrency_method: PolicyMethod::Enqueue, + concurrency_parameters: Vec::new(), quotas: None, }; @@ -300,6 +304,8 @@ async fn test_action_specific_policy() { window_seconds: 60, }), concurrency_limit: None, + concurrency_method: PolicyMethod::Enqueue, + concurrency_parameters: Vec::new(), quotas: None, }; enforcer.set_action_policy(action_id, action_policy); @@ -345,6 +351,8 @@ async fn test_pack_specific_policy() { let pack_policy = ExecutionPolicy { rate_limit: None, concurrency_limit: Some(1), + concurrency_method: PolicyMethod::Enqueue, + concurrency_parameters: Vec::new(), quotas: None, }; enforcer.set_pack_policy(pack_id, pack_policy); @@ -388,6 +396,8 @@ async fn test_policy_priority() { window_seconds: 60, }), concurrency_limit: None, + concurrency_method: PolicyMethod::Enqueue, + concurrency_parameters: Vec::new(), quotas: None, }; let mut enforcer = PolicyEnforcer::with_global_policy(pool.clone(), global_policy); @@ -399,6 +409,8 @@ async fn test_policy_priority() { window_seconds: 60, }), concurrency_limit: None, + concurrency_method: PolicyMethod::Enqueue, + concurrency_parameters: Vec::new(), quotas: None, }; enforcer.set_action_policy(action_id, action_policy);