merging semgrep-scan

This commit is contained in:
2026-04-01 20:38:18 -05:00
parent a0f59114a3
commit a4c303ec84
10 changed files with 1434 additions and 248 deletions

View File

@@ -11,7 +11,10 @@
use anyhow::Result;
use attune_common::{
mq::{Consumer, ExecutionCompletedPayload, MessageEnvelope, Publisher},
mq::{
Consumer, ExecutionCompletedPayload, ExecutionRequestedPayload, MessageEnvelope,
MessageType, MqError, Publisher,
},
repositories::{execution::ExecutionRepository, FindById},
};
use sqlx::PgPool;
@@ -36,6 +39,19 @@ pub struct CompletionListener {
}
impl CompletionListener {
fn retryable_mq_error(error: &anyhow::Error) -> Option<MqError> {
let mq_error = error.downcast_ref::<MqError>()?;
Some(match mq_error {
MqError::Connection(msg) => MqError::Connection(msg.clone()),
MqError::Channel(msg) => MqError::Channel(msg.clone()),
MqError::Publish(msg) => MqError::Publish(msg.clone()),
MqError::Timeout(msg) => MqError::Timeout(msg.clone()),
MqError::Pool(msg) => MqError::Pool(msg.clone()),
MqError::Lapin(err) => MqError::Connection(err.to_string()),
_ => return None,
})
}
/// Create a new completion listener
pub fn new(
pool: PgPool,
@@ -82,6 +98,9 @@ impl CompletionListener {
{
error!("Error processing execution completion: {}", e);
// Return error to trigger nack with requeue
if let Some(mq_err) = Self::retryable_mq_error(&e) {
return Err(mq_err);
}
return Err(
format!("Failed to process execution completion: {}", e).into()
);
@@ -187,17 +206,37 @@ impl CompletionListener {
action_id, execution_id
);
match queue_manager.notify_completion(action_id).await {
Ok(notified) => {
if notified {
info!(
"Queue slot released for action {}, next execution notified",
action_id
);
match queue_manager.release_active_slot(execution_id).await {
Ok(release) => {
if let Some(release) = release {
if let Some(next_execution_id) = release.next_execution_id {
info!(
"Queue slot released for action {}, next execution {} can proceed",
action_id, next_execution_id
);
if let Err(republish_err) = Self::publish_execution_requested(
pool,
publisher,
action_id,
next_execution_id,
)
.await
{
queue_manager
.restore_active_slot(execution_id, &release)
.await?;
return Err(republish_err);
}
} else {
debug!(
"Queue slot released for action {}, no executions waiting",
action_id
);
}
} else {
debug!(
"Queue slot released for action {}, no executions waiting",
action_id
"Execution {} had no active queue slot to release",
execution_id
);
}
}
@@ -225,6 +264,38 @@ impl CompletionListener {
Ok(())
}
async fn publish_execution_requested(
pool: &PgPool,
publisher: &Publisher,
action_id: i64,
execution_id: i64,
) -> Result<()> {
let execution = ExecutionRepository::find_by_id(pool, execution_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Execution {} not found", execution_id))?;
let payload = ExecutionRequestedPayload {
execution_id,
action_id: Some(action_id),
action_ref: execution.action_ref.clone(),
parent_id: execution.parent,
enforcement_id: execution.enforcement,
config: execution.config.clone(),
};
let envelope = MessageEnvelope::new(MessageType::ExecutionRequested, payload)
.with_source("executor-completion-listener");
publisher.publish_envelope(&envelope).await?;
debug!(
"Republished deferred ExecutionRequested for execution {}",
execution_id
);
Ok(())
}
}
#[cfg(test)]
@@ -239,7 +310,7 @@ mod tests {
// Simulate acquiring a slot
queue_manager
.enqueue_and_wait(action_id, 100, 1)
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
@@ -249,7 +320,7 @@ mod tests {
assert_eq!(stats.queue_length, 0);
// Simulate completion notification
let notified = queue_manager.notify_completion(action_id).await.unwrap();
let notified = queue_manager.notify_completion(100).await.unwrap();
assert!(!notified); // No one waiting
// Verify slot is released
@@ -264,7 +335,7 @@ mod tests {
// Fill capacity
queue_manager
.enqueue_and_wait(action_id, 100, 1)
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
@@ -272,7 +343,7 @@ mod tests {
let queue_manager_clone = queue_manager.clone();
let handle = tokio::spawn(async move {
queue_manager_clone
.enqueue_and_wait(action_id, 101, 1)
.enqueue_and_wait(action_id, 101, 1, None)
.await
.unwrap();
});
@@ -286,7 +357,7 @@ mod tests {
assert_eq!(stats.queue_length, 1);
// Notify completion
let notified = queue_manager.notify_completion(action_id).await.unwrap();
let notified = queue_manager.notify_completion(100).await.unwrap();
assert!(notified); // Should wake the waiting execution
// Wait for queued execution to proceed
@@ -306,7 +377,7 @@ mod tests {
// Fill capacity
queue_manager
.enqueue_and_wait(action_id, 100, 1)
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
@@ -320,7 +391,7 @@ mod tests {
let handle = tokio::spawn(async move {
queue_manager
.enqueue_and_wait(action_id, exec_id, 1)
.enqueue_and_wait(action_id, exec_id, 1, None)
.await
.unwrap();
order.lock().await.push(exec_id);
@@ -333,9 +404,9 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Release them one by one
for _ in 0..3 {
for execution_id in 100..103 {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
queue_manager.notify_completion(action_id).await.unwrap();
queue_manager.notify_completion(execution_id).await.unwrap();
}
// Wait for all to complete
@@ -351,10 +422,10 @@ mod tests {
#[tokio::test]
async fn test_completion_with_no_queue() {
let queue_manager = Arc::new(ExecutionQueueManager::with_defaults());
let action_id = 999; // Non-existent action
let execution_id = 999; // Non-existent execution
// Should succeed but not notify anyone
let result = queue_manager.notify_completion(action_id).await;
let result = queue_manager.notify_completion(execution_id).await;
assert!(result.is_ok());
assert!(!result.unwrap());
}

View File

@@ -230,7 +230,7 @@ impl EnforcementProcessor {
async fn create_execution(
pool: &PgPool,
publisher: &Publisher,
policy_enforcer: &PolicyEnforcer,
_policy_enforcer: &PolicyEnforcer,
_queue_manager: &ExecutionQueueManager,
enforcement: &Enforcement,
rule: &Rule,
@@ -257,33 +257,10 @@ impl EnforcementProcessor {
enforcement.id, rule.id, action_id
);
let pack_id = rule.pack;
let action_ref = &rule.action_ref;
// Enforce policies and wait for queue slot if needed
info!(
"Enforcing policies for action {} (enforcement: {})",
action_id, enforcement.id
);
// Use enforcement ID for queue tracking (execution doesn't exist yet)
if let Err(e) = policy_enforcer
.enforce_and_wait(action_id, Some(pack_id), enforcement.id)
.await
{
error!(
"Policy enforcement failed for enforcement {}: {}",
enforcement.id, e
);
return Err(e);
}
info!(
"Policy check passed and queue slot obtained for enforcement: {}",
enforcement.id
);
// Now create execution in database (we have a queue slot)
// Create the execution row first; scheduler-side policy enforcement
// now handles both rule-triggered and manual executions uniformly.
let execution_input = CreateExecutionInput {
action: Some(action_id),
action_ref: action_ref.clone(),

View File

@@ -10,14 +10,23 @@
use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use sqlx::PgPool;
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use tracing::{debug, info, warn};
use attune_common::models::{enums::ExecutionStatus, Id};
use attune_common::{
models::{
enums::{ExecutionStatus, PolicyMethod},
Id, Policy,
},
repositories::action::PolicyRepository,
};
use crate::queue_manager::ExecutionQueueManager;
use crate::queue_manager::{
ExecutionQueueManager, QueuedRemovalOutcome, SlotEnqueueOutcome, SlotReleaseOutcome,
};
/// Policy violation type
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -79,16 +88,38 @@ impl std::fmt::Display for PolicyViolation {
}
/// Execution policy configuration
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionPolicy {
/// Rate limit: maximum executions per time window
pub rate_limit: Option<RateLimit>,
/// Concurrency limit: maximum concurrent executions
pub concurrency_limit: Option<u32>,
/// How a concurrency violation should be handled.
pub concurrency_method: PolicyMethod,
/// Parameter paths used to scope concurrency grouping.
pub concurrency_parameters: Vec<String>,
/// Resource quotas
pub quotas: Option<HashMap<String, u64>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchedulingPolicyOutcome {
Ready,
Queued,
}
impl Default for ExecutionPolicy {
fn default() -> Self {
Self {
rate_limit: None,
concurrency_limit: None,
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None,
}
}
}
/// Rate limit configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimit {
@@ -98,6 +129,25 @@ pub struct RateLimit {
pub window_seconds: u32,
}
#[derive(Debug, Clone)]
struct ResolvedConcurrencyPolicy {
limit: u32,
method: PolicyMethod,
parameters: Vec<String>,
}
impl From<Policy> for ExecutionPolicy {
fn from(policy: Policy) -> Self {
Self {
rate_limit: None,
concurrency_limit: Some(policy.threshold as u32),
concurrency_method: policy.method,
concurrency_parameters: policy.parameters,
quotas: None,
}
}
}
/// Policy enforcement scope
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)] // Used in tests
@@ -185,6 +235,174 @@ impl PolicyEnforcer {
self.action_policies.insert(action_id, policy);
}
/// Best-effort release for a slot acquired during scheduling when the
/// execution never reaches the worker/completion path.
pub async fn release_execution_slot(
&self,
execution_id: Id,
) -> Result<Option<SlotReleaseOutcome>> {
match &self.queue_manager {
Some(queue_manager) => queue_manager.release_active_slot(execution_id).await,
None => Ok(None),
}
}
pub async fn restore_execution_slot(
&self,
execution_id: Id,
outcome: &SlotReleaseOutcome,
) -> Result<()> {
match &self.queue_manager {
Some(queue_manager) => {
queue_manager
.restore_active_slot(execution_id, outcome)
.await
}
None => Ok(()),
}
}
pub async fn remove_queued_execution(
&self,
execution_id: Id,
) -> Result<Option<QueuedRemovalOutcome>> {
match &self.queue_manager {
Some(queue_manager) => queue_manager.remove_queued_execution(execution_id).await,
None => Ok(None),
}
}
pub async fn restore_queued_execution(&self, outcome: &QueuedRemovalOutcome) -> Result<()> {
match &self.queue_manager {
Some(queue_manager) => queue_manager.restore_queued_execution(outcome).await,
None => Ok(()),
}
}
pub async fn enforce_for_scheduling(
&self,
action_id: Id,
pack_id: Option<Id>,
execution_id: Id,
config: Option<&JsonValue>,
) -> Result<SchedulingPolicyOutcome> {
if let Some(violation) = self
.check_policies_except_concurrency(action_id, pack_id)
.await?
{
warn!("Policy violation for action {}: {}", action_id, violation);
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
if let Some(concurrency) = self.resolve_concurrency_policy(action_id, pack_id).await? {
let group_key = self.build_parameter_group_key(&concurrency.parameters, config);
if let Some(queue_manager) = &self.queue_manager {
match concurrency.method {
PolicyMethod::Enqueue => {
return match queue_manager
.enqueue(action_id, execution_id, concurrency.limit, group_key)
.await?
{
SlotEnqueueOutcome::Acquired => Ok(SchedulingPolicyOutcome::Ready),
SlotEnqueueOutcome::Enqueued => Ok(SchedulingPolicyOutcome::Queued),
};
}
PolicyMethod::Cancel => {
let outcome = queue_manager
.try_acquire(
action_id,
execution_id,
concurrency.limit,
group_key.clone(),
)
.await?;
if !outcome.acquired {
let violation = PolicyViolation::ConcurrencyLimitExceeded {
limit: concurrency.limit,
current_count: outcome.current_count,
};
warn!("Policy violation for action {}: {}", action_id, violation);
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
}
}
} else {
let scope = PolicyScope::Action(action_id);
if let Some(violation) = self
.check_concurrency_limit(concurrency.limit, &scope)
.await?
{
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
}
}
Ok(SchedulingPolicyOutcome::Ready)
}
async fn resolve_policy(&self, action_id: Id, pack_id: Option<Id>) -> Result<ExecutionPolicy> {
if let Some(policy) = self.action_policies.get(&action_id) {
return Ok(policy.clone());
}
if let Some(policy) = PolicyRepository::find_latest_by_action(&self.pool, action_id).await?
{
return Ok(policy.into());
}
if let Some(pack_id) = pack_id {
if let Some(policy) = self.pack_policies.get(&pack_id) {
return Ok(policy.clone());
}
if let Some(policy) = PolicyRepository::find_latest_by_pack(&self.pool, pack_id).await?
{
return Ok(policy.into());
}
}
if let Some(policy) = PolicyRepository::find_latest_global(&self.pool).await? {
return Ok(policy.into());
}
Ok(self.global_policy.clone())
}
async fn resolve_concurrency_policy(
&self,
action_id: Id,
pack_id: Option<Id>,
) -> Result<Option<ResolvedConcurrencyPolicy>> {
let policy = self.resolve_policy(action_id, pack_id).await?;
Ok(policy
.concurrency_limit
.map(|limit| ResolvedConcurrencyPolicy {
limit,
method: policy.concurrency_method,
parameters: policy.concurrency_parameters,
}))
}
fn build_parameter_group_key(
&self,
parameter_paths: &[String],
config: Option<&JsonValue>,
) -> Option<String> {
if parameter_paths.is_empty() {
return None;
}
let values: BTreeMap<String, JsonValue> = parameter_paths
.iter()
.map(|path| (path.clone(), extract_parameter_value(config, path)))
.collect();
serde_json::to_string(&values).ok()
}
/// Get the concurrency limit for a specific action
///
/// Returns the most specific concurrency limit found:
@@ -192,6 +410,7 @@ impl PolicyEnforcer {
/// 2. Pack policy
/// 3. Global policy
/// 4. None (unlimited)
#[allow(dead_code)]
pub fn get_concurrency_limit(&self, action_id: Id, pack_id: Option<Id>) -> Option<u32> {
// Check action-specific policy first
if let Some(policy) = self.action_policies.get(&action_id) {
@@ -229,11 +448,13 @@ impl PolicyEnforcer {
/// * `Ok(())` - Policy allows execution and queue slot obtained
/// * `Err(PolicyViolation)` - Policy prevents execution
/// * `Err(QueueError)` - Queue timeout or other queue error
#[allow(dead_code)]
pub async fn enforce_and_wait(
&self,
action_id: Id,
pack_id: Option<Id>,
execution_id: Id,
config: Option<&JsonValue>,
) -> Result<()> {
// First, check for policy violations (rate limit, quotas, etc.)
// Note: We skip concurrency check here since queue manages that
@@ -246,36 +467,61 @@ impl PolicyEnforcer {
}
// If queue manager is available, use it for concurrency control
if let Some(queue_manager) = &self.queue_manager {
let concurrency_limit = self
.get_concurrency_limit(action_id, pack_id)
.unwrap_or(u32::MAX); // Default to unlimited if no policy
if let Some(concurrency) = self.resolve_concurrency_policy(action_id, pack_id).await? {
let group_key = self.build_parameter_group_key(&concurrency.parameters, config);
debug!(
"Enqueuing execution {} for action {} with concurrency limit {}",
execution_id, action_id, concurrency_limit
);
if let Some(queue_manager) = &self.queue_manager {
debug!(
"Applying concurrency policy to execution {} for action {} (limit: {}, method: {:?}, group: {:?})",
execution_id, action_id, concurrency.limit, concurrency.method, group_key
);
queue_manager
.enqueue_and_wait(action_id, execution_id, concurrency_limit)
.await?;
match concurrency.method {
PolicyMethod::Enqueue => {
queue_manager
.enqueue_and_wait(
action_id,
execution_id,
concurrency.limit,
group_key.clone(),
)
.await?;
}
PolicyMethod::Cancel => {
let outcome = queue_manager
.try_acquire(
action_id,
execution_id,
concurrency.limit,
group_key.clone(),
)
.await?;
info!(
"Execution {} obtained queue slot for action {}",
execution_id, action_id
);
} else {
// No queue manager - use legacy polling behavior
debug!(
"No queue manager configured, using legacy policy wait for action {}",
action_id
);
if !outcome.acquired {
let violation = PolicyViolation::ConcurrencyLimitExceeded {
limit: concurrency.limit,
current_count: outcome.current_count,
};
warn!("Policy violation for action {}: {}", action_id, violation);
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
}
}
info!(
"Execution {} obtained queue slot for action {} (group: {:?})",
execution_id, action_id, group_key
);
} else {
// No queue manager - use legacy polling behavior
debug!(
"No queue manager configured, using legacy policy wait for action {}",
action_id
);
if let Some(concurrency_limit) = self.get_concurrency_limit(action_id, pack_id) {
// Check concurrency with old method
let scope = PolicyScope::Action(action_id);
if let Some(violation) = self
.check_concurrency_limit(concurrency_limit, &scope)
.check_concurrency_limit(concurrency.limit, &scope)
.await?
{
return Err(anyhow::anyhow!("Policy violation: {}", violation));
@@ -631,6 +877,25 @@ impl PolicyEnforcer {
}
}
fn extract_parameter_value(config: Option<&JsonValue>, path: &str) -> JsonValue {
let mut current = match config {
Some(value) => value,
None => return JsonValue::Null,
};
for segment in path.split('.') {
match current {
JsonValue::Object(map) => match map.get(segment) {
Some(next) => current = next,
None => return JsonValue::Null,
},
_ => return JsonValue::Null,
}
}
current.clone()
}
#[cfg(test)]
mod tests {
use super::*;
@@ -665,6 +930,8 @@ mod tests {
let policy = ExecutionPolicy::default();
assert!(policy.rate_limit.is_none());
assert!(policy.concurrency_limit.is_none());
assert_eq!(policy.concurrency_method, PolicyMethod::Enqueue);
assert!(policy.concurrency_parameters.is_empty());
assert!(policy.quotas.is_none());
}
@@ -784,7 +1051,7 @@ mod tests {
);
// First execution should proceed immediately
let result = enforcer.enforce_and_wait(1, None, 100).await;
let result = enforcer.enforce_and_wait(1, None, 100, None).await;
assert!(result.is_ok());
// Check queue stats
@@ -809,7 +1076,7 @@ mod tests {
let enforcer = Arc::new(enforcer);
// First execution
let result = enforcer.enforce_and_wait(1, None, 100).await;
let result = enforcer.enforce_and_wait(1, None, 100, None).await;
assert!(result.is_ok());
// Queue multiple executions
@@ -822,11 +1089,14 @@ mod tests {
let order = execution_order.clone();
let handle = tokio::spawn(async move {
enforcer.enforce_and_wait(1, None, exec_id).await.unwrap();
enforcer
.enforce_and_wait(1, None, exec_id, None)
.await
.unwrap();
order.lock().await.push(exec_id);
// Simulate work
sleep(Duration::from_millis(10)).await;
queue_manager.notify_completion(1).await.unwrap();
queue_manager.notify_completion(exec_id).await.unwrap();
});
handles.push(handle);
@@ -836,7 +1106,7 @@ mod tests {
sleep(Duration::from_millis(100)).await;
// Release first execution
queue_manager.notify_completion(1).await.unwrap();
queue_manager.notify_completion(100).await.unwrap();
// Wait for all
for handle in handles {
@@ -863,7 +1133,7 @@ mod tests {
);
// Should work without queue manager (legacy behavior)
let result = enforcer.enforce_and_wait(1, None, 100).await;
let result = enforcer.enforce_and_wait(1, None, 100, None).await;
assert!(result.is_ok());
}
@@ -889,14 +1159,36 @@ mod tests {
);
// First execution proceeds
enforcer.enforce_and_wait(1, None, 100).await.unwrap();
enforcer.enforce_and_wait(1, None, 100, None).await.unwrap();
// Second execution should timeout
let result = enforcer.enforce_and_wait(1, None, 101).await;
let result = enforcer.enforce_and_wait(1, None, 101, None).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout"));
}
#[test]
fn test_build_parameter_group_key_uses_exact_values() {
let pool = sqlx::PgPool::connect_lazy("postgresql://localhost/test").unwrap();
let enforcer = PolicyEnforcer::new(pool);
let config = serde_json::json!({
"environment": "prod",
"target": {
"region": "us-east-1"
}
});
let group_key = enforcer.build_parameter_group_key(
&["target.region".to_string(), "environment".to_string()],
Some(&config),
);
assert_eq!(
group_key.as_deref(),
Some("{\"environment\":\"prod\",\"target.region\":\"us-east-1\"}")
);
}
// Integration tests would require database setup
// Those should be in a separate integration test file
}

View File

@@ -47,7 +47,7 @@ impl Default for QueueConfig {
}
/// Entry in the execution queue
#[derive(Debug)]
#[derive(Debug, Clone)]
struct QueueEntry {
/// Execution or enforcement ID being queued
execution_id: Id,
@@ -57,7 +57,13 @@ struct QueueEntry {
notifier: Arc<Notify>,
}
/// Queue state for a single action
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct QueueKey {
action_id: Id,
group_key: Option<String>,
}
/// Queue state for a single action/group pair
struct ActionQueue {
/// FIFO queue of waiting executions
queue: VecDeque<QueueEntry>,
@@ -93,6 +99,32 @@ impl ActionQueue {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SlotAcquireOutcome {
pub acquired: bool,
pub current_count: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SlotEnqueueOutcome {
Acquired,
Enqueued,
}
#[derive(Debug, Clone)]
pub struct SlotReleaseOutcome {
pub next_execution_id: Option<Id>,
queue_key: QueueKey,
}
#[derive(Debug, Clone)]
pub struct QueuedRemovalOutcome {
pub next_execution_id: Option<Id>,
queue_key: QueueKey,
removed_entry: QueueEntry,
removed_index: usize,
}
/// Statistics about a queue
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStats {
@@ -114,8 +146,10 @@ pub struct QueueStats {
/// Manages execution queues with FIFO ordering guarantees
pub struct ExecutionQueueManager {
/// Per-action queues (key: action_id)
queues: DashMap<Id, Arc<Mutex<ActionQueue>>>,
/// Per-action/per-group queues.
queues: DashMap<QueueKey, Arc<Mutex<ActionQueue>>>,
/// Tracks which queue key currently owns an active execution slot.
active_execution_keys: DashMap<Id, QueueKey>,
/// Configuration
config: QueueConfig,
/// Database connection pool (optional for stats persistence)
@@ -128,6 +162,7 @@ impl ExecutionQueueManager {
pub fn new(config: QueueConfig) -> Self {
Self {
queues: DashMap::new(),
active_execution_keys: DashMap::new(),
config,
db_pool: None,
}
@@ -137,6 +172,7 @@ impl ExecutionQueueManager {
pub fn with_db_pool(config: QueueConfig, db_pool: PgPool) -> Self {
Self {
queues: DashMap::new(),
active_execution_keys: DashMap::new(),
config,
db_pool: Some(db_pool),
}
@@ -148,6 +184,24 @@ impl ExecutionQueueManager {
Self::new(QueueConfig::default())
}
fn queue_key(&self, action_id: Id, group_key: Option<String>) -> QueueKey {
QueueKey {
action_id,
group_key,
}
}
async fn get_or_create_queue(
&self,
queue_key: QueueKey,
max_concurrent: u32,
) -> Arc<Mutex<ActionQueue>> {
self.queues
.entry(queue_key)
.or_insert_with(|| Arc::new(Mutex::new(ActionQueue::new(max_concurrent))))
.clone()
}
/// Enqueue an execution and wait until it can proceed
///
/// This method will:
@@ -164,23 +218,31 @@ impl ExecutionQueueManager {
/// # Returns
/// * `Ok(())` - Execution can proceed
/// * `Err(_)` - Queue full or timeout
#[allow(dead_code)]
pub async fn enqueue_and_wait(
&self,
action_id: Id,
execution_id: Id,
max_concurrent: u32,
group_key: Option<String>,
) -> Result<()> {
if self.active_execution_keys.contains_key(&execution_id) {
debug!(
"Execution {} already owns an active slot, skipping queue wait",
execution_id
);
return Ok(());
}
debug!(
"Enqueuing execution {} for action {} (max_concurrent: {})",
execution_id, action_id, max_concurrent
"Enqueuing execution {} for action {} (max_concurrent: {}, group: {:?})",
execution_id, action_id, max_concurrent, group_key
);
// Get or create queue for this action
let queue_key = self.queue_key(action_id, group_key);
let queue_arc = self
.queues
.entry(action_id)
.or_insert_with(|| Arc::new(Mutex::new(ActionQueue::new(max_concurrent))))
.clone();
.get_or_create_queue(queue_key.clone(), max_concurrent)
.await;
// Create notifier for this execution
let notifier = Arc::new(Notify::new());
@@ -192,14 +254,41 @@ impl ExecutionQueueManager {
// Update max_concurrent if it changed
queue.max_concurrent = max_concurrent;
let queued_index = queue
.queue
.iter()
.position(|entry| entry.execution_id == execution_id);
if let Some(queued_index) = queued_index {
if queued_index == 0 && queue.has_capacity() {
let entry = queue.queue.pop_front().expect("front entry just checked");
queue.active_count += 1;
self.active_execution_keys
.insert(entry.execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(());
}
debug!(
"Execution {} is already queued for action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
);
return Ok(());
}
// Check if we can run immediately
if queue.has_capacity() {
debug!(
"Execution {} can run immediately (active: {}/{})",
execution_id, queue.active_count, queue.max_concurrent
"Execution {} can run immediately for action {} (active: {}/{}, group: {:?})",
execution_id,
action_id,
queue.active_count,
queue.max_concurrent,
queue_key.group_key
);
queue.active_count += 1;
queue.total_enqueued += 1;
self.active_execution_keys
.insert(execution_id, queue_key.clone());
// Persist stats to database if available
drop(queue);
@@ -211,8 +300,9 @@ impl ExecutionQueueManager {
// Check if queue is full
if queue.is_full(self.config.max_queue_length) {
warn!(
"Queue full for action {}: {} entries (limit: {})",
"Queue full for action {} group {:?}: {} entries (limit: {})",
action_id,
queue_key.group_key,
queue.queue.len(),
self.config.max_queue_length
);
@@ -234,12 +324,13 @@ impl ExecutionQueueManager {
queue.total_enqueued += 1;
info!(
"Execution {} queued for action {} at position {} (active: {}/{})",
"Execution {} queued for action {} at position {} (active: {}/{}, group: {:?})",
execution_id,
action_id,
queue.queue.len() - 1,
queue.active_count,
queue.max_concurrent
queue.max_concurrent,
queue_key.group_key
);
}
@@ -273,6 +364,142 @@ impl ExecutionQueueManager {
}
}
/// Acquire a slot immediately or enqueue without blocking the caller.
pub async fn enqueue(
&self,
action_id: Id,
execution_id: Id,
max_concurrent: u32,
group_key: Option<String>,
) -> Result<SlotEnqueueOutcome> {
if self.active_execution_keys.contains_key(&execution_id) {
debug!(
"Execution {} already owns an active slot, treating as acquired",
execution_id
);
return Ok(SlotEnqueueOutcome::Acquired);
}
debug!(
"Enqueuing execution {} for action {} without waiting (max_concurrent: {}, group: {:?})",
execution_id, action_id, max_concurrent, group_key
);
let queue_key = self.queue_key(action_id, group_key);
let queue_arc = self
.get_or_create_queue(queue_key.clone(), max_concurrent)
.await;
{
let mut queue = queue_arc.lock().await;
queue.max_concurrent = max_concurrent;
let queued_index = queue
.queue
.iter()
.position(|entry| entry.execution_id == execution_id);
if let Some(queued_index) = queued_index {
if queued_index == 0 && queue.has_capacity() {
let entry = queue.queue.pop_front().expect("front entry just checked");
queue.active_count += 1;
self.active_execution_keys
.insert(entry.execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(SlotEnqueueOutcome::Acquired);
}
debug!(
"Execution {} is already queued for action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
);
return Ok(SlotEnqueueOutcome::Enqueued);
}
if queue.has_capacity() {
queue.active_count += 1;
queue.total_enqueued += 1;
self.active_execution_keys
.insert(execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(SlotEnqueueOutcome::Acquired);
}
if queue.is_full(self.config.max_queue_length) {
warn!(
"Queue full for action {} group {:?}: {} entries (limit: {})",
action_id,
queue_key.group_key,
queue.queue.len(),
self.config.max_queue_length
);
return Err(anyhow::anyhow!(
"Queue full for action {}: maximum {} entries",
action_id,
self.config.max_queue_length
));
}
queue.queue.push_back(QueueEntry {
execution_id,
enqueued_at: Utc::now(),
notifier: Arc::new(Notify::new()),
});
queue.total_enqueued += 1;
}
self.persist_queue_stats(action_id).await;
Ok(SlotEnqueueOutcome::Enqueued)
}
/// Try to acquire a slot immediately without queueing.
pub async fn try_acquire(
&self,
action_id: Id,
execution_id: Id,
max_concurrent: u32,
group_key: Option<String>,
) -> Result<SlotAcquireOutcome> {
let queue_key = self.queue_key(action_id, group_key);
let queue_arc = self
.get_or_create_queue(queue_key.clone(), max_concurrent)
.await;
let mut queue = queue_arc.lock().await;
queue.max_concurrent = max_concurrent;
let current_count = queue.active_count;
if self.active_execution_keys.contains_key(&execution_id) {
debug!(
"Execution {} already owns a slot for action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
);
return Ok(SlotAcquireOutcome {
acquired: true,
current_count,
});
}
if queue.has_capacity() {
queue.active_count += 1;
queue.total_enqueued += 1;
self.active_execution_keys
.insert(execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(SlotAcquireOutcome {
acquired: true,
current_count,
});
}
Ok(SlotAcquireOutcome {
acquired: false,
current_count,
})
}
/// Notify that an execution has completed, releasing a queue slot
///
/// This method will:
@@ -282,27 +509,64 @@ impl ExecutionQueueManager {
/// 4. Increment active count for the notified execution
///
/// # Arguments
/// * `action_id` - The action that completed
/// * `execution_id` - The execution that completed
///
/// # Returns
/// * `Ok(true)` - A queued execution was notified
/// * `Ok(false)` - No executions were waiting
/// * `Err(_)` - Error accessing queue
pub async fn notify_completion(&self, action_id: Id) -> Result<bool> {
pub async fn notify_completion(&self, execution_id: Id) -> Result<bool> {
Ok(self
.notify_completion_with_next(execution_id)
.await?
.is_some())
}
pub async fn notify_completion_with_next(&self, execution_id: Id) -> Result<Option<Id>> {
let release = match self.release_active_slot(execution_id).await? {
Some(release) => release,
None => return Ok(None),
};
let Some(next_execution_id) = release.next_execution_id else {
return Ok(None);
};
if self.activate_queued_execution(next_execution_id).await? {
Ok(Some(next_execution_id))
} else {
self.restore_active_slot(execution_id, &release).await?;
Ok(None)
}
}
pub async fn release_active_slot(
&self,
execution_id: Id,
) -> Result<Option<SlotReleaseOutcome>> {
let Some((_, queue_key)) = self.active_execution_keys.remove(&execution_id) else {
debug!(
"No active queue slot found for execution {} (queue may have been cleared)",
execution_id
);
return Ok(None);
};
let action_id = queue_key.action_id;
debug!(
"Processing completion notification for action {}",
action_id
"Processing completion notification for execution {} on action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
);
// Get queue for this action
let queue_arc = match self.queues.get(&action_id) {
// Get queue for this action/group
let queue_arc = match self.queues.get(&queue_key) {
Some(q) => q.clone(),
None => {
debug!(
"No queue found for action {} (no executions queued)",
action_id
"No queue found for action {} group {:?}",
action_id, queue_key.group_key
);
return Ok(false);
return Ok(None);
}
};
@@ -313,49 +577,162 @@ impl ExecutionQueueManager {
queue.active_count -= 1;
queue.total_completed += 1;
debug!(
"Decremented active count for action {} to {}",
action_id, queue.active_count
"Decremented active count for action {} group {:?} to {}",
action_id, queue_key.group_key, queue.active_count
);
} else {
warn!(
"Completion notification for action {} but active_count is 0",
action_id
"Completion notification for action {} group {:?} but active_count is 0",
action_id, queue_key.group_key
);
}
// Check if there are queued executions
if queue.queue.is_empty() {
debug!(
"No executions queued for action {} after completion",
action_id
"No executions queued for action {} group {:?} after completion",
action_id, queue_key.group_key
);
return Ok(false);
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(Some(SlotReleaseOutcome {
next_execution_id: None,
queue_key,
}));
}
// Pop the first (oldest) entry from queue
if let Some(entry) = queue.queue.pop_front() {
let next_execution_id = queue.queue.front().map(|entry| entry.execution_id);
if let Some(next_execution_id) = next_execution_id {
info!(
"Notifying execution {} for action {} (was queued for {:?})",
"Execution {} is next for action {} group {:?}",
next_execution_id, action_id, queue_key.group_key
);
}
drop(queue);
self.persist_queue_stats(action_id).await;
Ok(Some(SlotReleaseOutcome {
next_execution_id,
queue_key,
}))
}
pub async fn restore_active_slot(
&self,
execution_id: Id,
outcome: &SlotReleaseOutcome,
) -> Result<()> {
let action_id = outcome.queue_key.action_id;
let queue_arc = self.get_or_create_queue(outcome.queue_key.clone(), 1).await;
let mut queue = queue_arc.lock().await;
queue.active_count += 1;
if queue.total_completed > 0 {
queue.total_completed -= 1;
}
self.active_execution_keys
.insert(execution_id, outcome.queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
Ok(())
}
pub async fn activate_queued_execution(&self, execution_id: Id) -> Result<bool> {
for entry in self.queues.iter() {
let queue_key = entry.key().clone();
let queue_arc = entry.value().clone();
let mut queue = queue_arc.lock().await;
let Some(front) = queue.queue.front() else {
continue;
};
if front.execution_id != execution_id {
continue;
}
if !queue.has_capacity() {
return Ok(false);
}
let entry = queue.queue.pop_front().expect("front entry just checked");
info!(
"Activating queued execution {} for action {} group {:?} (queued for {:?})",
entry.execution_id,
action_id,
queue_key.action_id,
queue_key.group_key,
Utc::now() - entry.enqueued_at
);
// Increment active count for the execution we're about to notify
queue.active_count += 1;
self.active_execution_keys
.insert(entry.execution_id, queue_key.clone());
// Notify the waiter (after releasing lock)
drop(queue);
entry.notifier.notify_one();
self.persist_queue_stats(queue_key.action_id).await;
return Ok(true);
}
// Persist stats to database if available
Ok(false)
}
pub async fn remove_queued_execution(
&self,
execution_id: Id,
) -> Result<Option<QueuedRemovalOutcome>> {
for entry in self.queues.iter() {
let queue_key = entry.key().clone();
let queue_arc = entry.value().clone();
let mut queue = queue_arc.lock().await;
let Some(index) = queue
.queue
.iter()
.position(|queued| queued.execution_id == execution_id)
else {
continue;
};
let removed_entry = queue.queue.remove(index).expect("queue index just checked");
let next_execution_id = if index == 0 {
queue.queue.front().map(|queued| queued.execution_id)
} else {
None
};
let action_id = queue_key.action_id;
drop(queue);
self.persist_queue_stats(action_id).await;
Ok(true)
} else {
// Race condition check - queue was empty after all
Ok(false)
return Ok(Some(QueuedRemovalOutcome {
next_execution_id,
queue_key,
removed_entry,
removed_index: index,
}));
}
Ok(None)
}
pub async fn restore_queued_execution(&self, outcome: &QueuedRemovalOutcome) -> Result<()> {
let action_id = outcome.queue_key.action_id;
let queue_arc = self.get_or_create_queue(outcome.queue_key.clone(), 1).await;
let mut queue = queue_arc.lock().await;
if outcome.removed_index <= queue.queue.len() {
queue
.queue
.insert(outcome.removed_index, outcome.removed_entry.clone());
} else {
queue.queue.push_back(outcome.removed_entry.clone());
}
drop(queue);
self.persist_queue_stats(action_id).await;
Ok(())
}
/// Persist queue statistics to database (if database pool is available)
@@ -384,19 +761,48 @@ impl ExecutionQueueManager {
/// Get statistics for a specific action's queue
pub async fn get_queue_stats(&self, action_id: Id) -> Option<QueueStats> {
let queue_arc = self.queues.get(&action_id)?.clone();
let queue = queue_arc.lock().await;
let queue_arcs: Vec<Arc<Mutex<ActionQueue>>> = self
.queues
.iter()
.filter(|entry| entry.key().action_id == action_id)
.map(|entry| entry.value().clone())
.collect();
let oldest_enqueued_at = queue.queue.front().map(|e| e.enqueued_at);
if queue_arcs.is_empty() {
return None;
}
let mut queue_length = 0usize;
let mut active_count = 0u32;
let mut max_concurrent = 0u32;
let mut oldest_enqueued_at: Option<DateTime<Utc>> = None;
let mut total_enqueued = 0u64;
let mut total_completed = 0u64;
for queue_arc in queue_arcs {
let queue = queue_arc.lock().await;
queue_length += queue.queue.len();
active_count += queue.active_count;
max_concurrent += queue.max_concurrent;
total_enqueued += queue.total_enqueued;
total_completed += queue.total_completed;
if let Some(candidate) = queue.queue.front().map(|e| e.enqueued_at) {
oldest_enqueued_at = Some(match oldest_enqueued_at {
Some(current) => current.min(candidate),
None => candidate,
});
}
}
Some(QueueStats {
action_id,
queue_length: queue.queue.len(),
active_count: queue.active_count,
max_concurrent: queue.max_concurrent,
queue_length,
active_count,
max_concurrent,
oldest_enqueued_at,
total_enqueued: queue.total_enqueued,
total_completed: queue.total_completed,
total_enqueued,
total_completed,
})
}
@@ -405,22 +811,15 @@ impl ExecutionQueueManager {
pub async fn get_all_queue_stats(&self) -> Vec<QueueStats> {
let mut stats = Vec::new();
let mut action_ids = std::collections::BTreeSet::new();
for entry in self.queues.iter() {
let action_id = *entry.key();
let queue_arc = entry.value().clone();
let queue = queue_arc.lock().await;
action_ids.insert(entry.key().action_id);
}
let oldest_enqueued_at = queue.queue.front().map(|e| e.enqueued_at);
stats.push(QueueStats {
action_id,
queue_length: queue.queue.len(),
active_count: queue.active_count,
max_concurrent: queue.max_concurrent,
oldest_enqueued_at,
total_enqueued: queue.total_enqueued,
total_completed: queue.total_completed,
});
for action_id in action_ids {
if let Some(action_stats) = self.get_queue_stats(action_id).await {
stats.push(action_stats);
}
}
stats
@@ -445,27 +844,29 @@ impl ExecutionQueueManager {
execution_id, action_id
);
let queue_arc = match self.queues.get(&action_id) {
Some(q) => q.clone(),
None => return Ok(false),
};
let queue_arcs: Vec<Arc<Mutex<ActionQueue>>> = self
.queues
.iter()
.filter(|entry| entry.key().action_id == action_id)
.map(|entry| entry.value().clone())
.collect();
let mut queue = queue_arc.lock().await;
let initial_len = queue.queue.len();
queue.queue.retain(|e| e.execution_id != execution_id);
let removed = initial_len != queue.queue.len();
if removed {
info!("Cancelled execution {} from queue", execution_id);
} else {
debug!(
"Execution {} not found in queue (may be running)",
execution_id
);
for queue_arc in queue_arcs {
let mut queue = queue_arc.lock().await;
let initial_len = queue.queue.len();
queue.queue.retain(|e| e.execution_id != execution_id);
if initial_len != queue.queue.len() {
info!("Cancelled execution {} from queue", execution_id);
return Ok(true);
}
}
Ok(removed)
debug!(
"Execution {} not found in queue (may be running)",
execution_id
);
Ok(false)
}
/// Clear all queues (for testing or emergency situations)
@@ -479,12 +880,17 @@ impl ExecutionQueueManager {
queue.queue.clear();
queue.active_count = 0;
}
self.active_execution_keys.clear();
}
/// Get the number of actions with active queues
#[allow(dead_code)]
pub fn active_queue_count(&self) -> usize {
self.queues.len()
self.queues
.iter()
.map(|entry| entry.key().action_id)
.collect::<std::collections::BTreeSet<_>>()
.len()
}
}
@@ -504,7 +910,7 @@ mod tests {
let manager = ExecutionQueueManager::with_defaults();
// Should execute immediately when there's capacity
let result = manager.enqueue_and_wait(1, 100, 2).await;
let result = manager.enqueue_and_wait(1, 100, 2, None).await;
assert!(result.is_ok());
// Check stats
@@ -521,7 +927,7 @@ mod tests {
// First execution should run immediately
let result = manager
.enqueue_and_wait(action_id, 100, max_concurrent)
.enqueue_and_wait(action_id, 100, max_concurrent, None)
.await;
assert!(result.is_ok());
@@ -535,7 +941,7 @@ mod tests {
let handle = tokio::spawn(async move {
manager
.enqueue_and_wait(action_id, exec_id, max_concurrent)
.enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await
.unwrap();
order.lock().await.push(exec_id);
@@ -553,9 +959,9 @@ mod tests {
assert_eq!(stats.active_count, 1);
// Release them one by one
for _ in 0..3 {
for execution_id in 100..103 {
sleep(Duration::from_millis(50)).await;
manager.notify_completion(action_id).await.unwrap();
manager.notify_completion(execution_id).await.unwrap();
}
// Wait for all to complete
@@ -574,7 +980,10 @@ mod tests {
let action_id = 1;
// Start first execution
manager.enqueue_and_wait(action_id, 100, 1).await.unwrap();
manager
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
// Queue second execution
let manager_clone = Arc::new(manager);
@@ -582,7 +991,7 @@ mod tests {
let handle = tokio::spawn(async move {
manager_ref
.enqueue_and_wait(action_id, 101, 1)
.enqueue_and_wait(action_id, 101, 1, None)
.await
.unwrap();
});
@@ -596,7 +1005,7 @@ mod tests {
assert_eq!(stats.active_count, 1);
// Notify completion
let notified = manager_clone.notify_completion(action_id).await.unwrap();
let notified = manager_clone.notify_completion(100).await.unwrap();
assert!(notified);
// Wait for queued execution to proceed
@@ -613,8 +1022,8 @@ mod tests {
let manager = Arc::new(ExecutionQueueManager::with_defaults());
// Start executions on different actions
manager.enqueue_and_wait(1, 100, 1).await.unwrap();
manager.enqueue_and_wait(2, 200, 1).await.unwrap();
manager.enqueue_and_wait(1, 100, 1, None).await.unwrap();
manager.enqueue_and_wait(2, 200, 1, None).await.unwrap();
// Both should be active
let stats1 = manager.get_queue_stats(1).await.unwrap();
@@ -624,7 +1033,7 @@ mod tests {
assert_eq!(stats2.active_count, 1);
// Completion on action 1 shouldn't affect action 2
manager.notify_completion(1).await.unwrap();
manager.notify_completion(100).await.unwrap();
let stats1 = manager.get_queue_stats(1).await.unwrap();
let stats2 = manager.get_queue_stats(2).await.unwrap();
@@ -633,20 +1042,43 @@ mod tests {
assert_eq!(stats2.active_count, 1);
}
#[tokio::test]
async fn test_grouped_queues_are_independent() {
let manager = ExecutionQueueManager::with_defaults();
let action_id = 1;
manager
.enqueue_and_wait(action_id, 100, 1, Some("prod".to_string()))
.await
.unwrap();
manager
.enqueue_and_wait(action_id, 200, 1, Some("staging".to_string()))
.await
.unwrap();
let stats = manager.get_queue_stats(action_id).await.unwrap();
assert_eq!(stats.active_count, 2);
assert_eq!(stats.queue_length, 0);
assert_eq!(stats.max_concurrent, 2);
}
#[tokio::test]
async fn test_cancel_execution() {
let manager = ExecutionQueueManager::with_defaults();
let action_id = 1;
// Fill capacity
manager.enqueue_and_wait(action_id, 100, 1).await.unwrap();
manager
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
// Queue more executions
let manager_arc = Arc::new(manager);
let manager_ref = manager_arc.clone();
let handle = tokio::spawn(async move {
let result = manager_ref.enqueue_and_wait(action_id, 101, 1).await;
let result = manager_ref.enqueue_and_wait(action_id, 101, 1, None).await;
result
});
@@ -675,7 +1107,10 @@ mod tests {
assert!(manager.get_queue_stats(action_id).await.is_none());
// After enqueue, stats should exist
manager.enqueue_and_wait(action_id, 100, 2).await.unwrap();
manager
.enqueue_and_wait(action_id, 100, 2, None)
.await
.unwrap();
let stats = manager.get_queue_stats(action_id).await.unwrap();
assert_eq!(stats.action_id, action_id);
@@ -696,13 +1131,16 @@ mod tests {
let action_id = 1;
// Fill capacity
manager.enqueue_and_wait(action_id, 100, 1).await.unwrap();
manager
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
// Queue 2 more (should reach limit)
let manager_ref = manager.clone();
tokio::spawn(async move {
manager_ref
.enqueue_and_wait(action_id, 101, 1)
.enqueue_and_wait(action_id, 101, 1, None)
.await
.unwrap();
});
@@ -710,7 +1148,7 @@ mod tests {
let manager_ref = manager.clone();
tokio::spawn(async move {
manager_ref
.enqueue_and_wait(action_id, 102, 1)
.enqueue_and_wait(action_id, 102, 1, None)
.await
.unwrap();
});
@@ -718,7 +1156,7 @@ mod tests {
sleep(Duration::from_millis(100)).await;
// Next one should fail
let result = manager.enqueue_and_wait(action_id, 103, 1).await;
let result = manager.enqueue_and_wait(action_id, 103, 1, None).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Queue full"));
}
@@ -732,7 +1170,7 @@ mod tests {
// Start first execution
manager
.enqueue_and_wait(action_id, 0, max_concurrent)
.enqueue_and_wait(action_id, 0, max_concurrent, None)
.await
.unwrap();
@@ -746,7 +1184,7 @@ mod tests {
let handle = tokio::spawn(async move {
manager
.enqueue_and_wait(action_id, i, max_concurrent)
.enqueue_and_wait(action_id, i, max_concurrent, None)
.await
.unwrap();
order.lock().await.push(i);
@@ -759,9 +1197,9 @@ mod tests {
sleep(Duration::from_millis(200)).await;
// Release them all
for _ in 0..num_executions {
for execution_id in 0..num_executions {
sleep(Duration::from_millis(10)).await;
manager.notify_completion(action_id).await.unwrap();
manager.notify_completion(execution_id).await.unwrap();
}
// Wait for completion

View File

@@ -16,7 +16,7 @@ use attune_common::{
models::{enums::ExecutionStatus, execution::WorkflowTaskMetadata, Action, Execution, Runtime},
mq::{
Consumer, ExecutionCompletedPayload, ExecutionRequestedPayload, MessageEnvelope,
MessageType, Publisher,
MessageType, MqError, Publisher,
},
repositories::{
action::ActionRepository,
@@ -40,6 +40,7 @@ use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, error, info, warn};
use crate::policy_enforcer::{PolicyEnforcer, SchedulingPolicyOutcome};
use crate::workflow::context::{TaskOutcome, WorkflowContext};
use crate::workflow::graph::TaskGraph;
@@ -108,6 +109,7 @@ pub struct ExecutionScheduler {
pool: PgPool,
publisher: Arc<Publisher>,
consumer: Arc<Consumer>,
policy_enforcer: Arc<PolicyEnforcer>,
/// Round-robin counter for distributing executions across workers
round_robin_counter: AtomicUsize,
}
@@ -120,12 +122,31 @@ const DEFAULT_HEARTBEAT_INTERVAL: u64 = 30;
const HEARTBEAT_STALENESS_MULTIPLIER: u64 = 3;
impl ExecutionScheduler {
fn retryable_mq_error(error: &anyhow::Error) -> Option<MqError> {
let mq_error = error.downcast_ref::<MqError>()?;
Some(match mq_error {
MqError::Connection(msg) => MqError::Connection(msg.clone()),
MqError::Channel(msg) => MqError::Channel(msg.clone()),
MqError::Publish(msg) => MqError::Publish(msg.clone()),
MqError::Timeout(msg) => MqError::Timeout(msg.clone()),
MqError::Pool(msg) => MqError::Pool(msg.clone()),
MqError::Lapin(err) => MqError::Connection(err.to_string()),
_ => return None,
})
}
/// Create a new execution scheduler
pub fn new(pool: PgPool, publisher: Arc<Publisher>, consumer: Arc<Consumer>) -> Self {
pub fn new(
pool: PgPool,
publisher: Arc<Publisher>,
consumer: Arc<Consumer>,
policy_enforcer: Arc<PolicyEnforcer>,
) -> Self {
Self {
pool,
publisher,
consumer,
policy_enforcer,
round_robin_counter: AtomicUsize::new(0),
}
}
@@ -136,6 +157,7 @@ impl ExecutionScheduler {
let pool = self.pool.clone();
let publisher = self.publisher.clone();
let policy_enforcer = self.policy_enforcer.clone();
// Share the counter with the handler closure via Arc.
// We wrap &self's AtomicUsize in a new Arc<AtomicUsize> by copying the
// current value so the closure is 'static.
@@ -149,16 +171,24 @@ impl ExecutionScheduler {
move |envelope: MessageEnvelope<ExecutionRequestedPayload>| {
let pool = pool.clone();
let publisher = publisher.clone();
let policy_enforcer = policy_enforcer.clone();
let counter = counter.clone();
async move {
if let Err(e) = Self::process_execution_requested(
&pool, &publisher, &counter, &envelope,
&pool,
&publisher,
&policy_enforcer,
&counter,
&envelope,
)
.await
{
error!("Error scheduling execution: {}", e);
// Return error to trigger nack with requeue
if let Some(mq_err) = Self::retryable_mq_error(&e) {
return Err(mq_err);
}
return Err(format!("Failed to schedule execution: {}", e).into());
}
Ok(())
@@ -174,6 +204,7 @@ impl ExecutionScheduler {
async fn process_execution_requested(
pool: &PgPool,
publisher: &Publisher,
policy_enforcer: &PolicyEnforcer,
round_robin_counter: &AtomicUsize,
envelope: &MessageEnvelope<ExecutionRequestedPayload>,
) -> Result<()> {
@@ -184,9 +215,30 @@ impl ExecutionScheduler {
info!("Scheduling execution: {}", execution_id);
// Fetch execution from database
let execution = ExecutionRepository::find_by_id(pool, execution_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Execution not found: {}", execution_id))?;
let execution = match ExecutionRepository::find_by_id(pool, execution_id).await? {
Some(execution) => execution,
None => {
warn!("Execution {} not found during scheduling", execution_id);
Self::remove_queued_policy_execution(
policy_enforcer,
pool,
publisher,
execution_id,
)
.await;
return Ok(());
}
};
if execution.status != ExecutionStatus::Requested {
debug!(
"Skipping execution {} with status {:?}; only Requested executions are schedulable",
execution_id, execution.status
);
Self::remove_queued_policy_execution(policy_enforcer, pool, publisher, execution_id)
.await;
return Ok(());
}
// Fetch action to determine runtime requirements
let action = Self::get_action_for_execution(pool, &execution).await?;
@@ -207,30 +259,6 @@ impl ExecutionScheduler {
.await;
}
// Regular action: select appropriate worker (round-robin among compatible workers)
let worker = match Self::select_worker(pool, &action, round_robin_counter).await {
Ok(worker) => worker,
Err(err) if Self::is_unschedulable_error(&err) => {
Self::fail_unschedulable_execution(
pool,
publisher,
envelope,
execution_id,
action.id,
&action.r#ref,
&err.to_string(),
)
.await?;
return Ok(());
}
Err(err) => return Err(err),
};
info!(
"Selected worker {} for execution {}",
worker.id, execution_id
);
// Apply parameter defaults from the action's param_schema.
// This mirrors what `process_workflow_execution` does for workflows
// so that non-workflow executions also get missing parameters filled
@@ -249,16 +277,97 @@ impl ExecutionScheduler {
}
};
match policy_enforcer
.enforce_for_scheduling(
action.id,
Some(action.pack),
execution_id,
execution_config.as_ref(),
)
.await
{
Ok(SchedulingPolicyOutcome::Queued) => {
info!(
"Execution {} queued by policy for action {}; deferring worker selection",
execution_id, action.id
);
return Ok(());
}
Ok(SchedulingPolicyOutcome::Ready) => {}
Err(err) => {
if Self::is_policy_cancellation_error(&err) {
Self::remove_queued_policy_execution(
policy_enforcer,
pool,
publisher,
execution_id,
)
.await;
Self::cancel_execution_for_policy_violation(
pool,
publisher,
envelope,
execution_id,
action.id,
&action.r#ref,
&err.to_string(),
)
.await?;
return Ok(());
}
return Err(err);
}
}
// Regular action: select appropriate worker only after policy
// readiness is confirmed, so queued executions don't reserve stale
// workers while they wait.
let worker = match Self::select_worker(pool, &action, round_robin_counter).await {
Ok(worker) => worker,
Err(err) if Self::is_unschedulable_error(&err) => {
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?;
Self::fail_unschedulable_execution(
pool,
publisher,
envelope,
execution_id,
action.id,
&action.r#ref,
&err.to_string(),
)
.await?;
return Ok(());
}
Err(err) => {
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?;
return Err(err);
}
};
info!(
"Selected worker {} for execution {}",
worker.id, execution_id
);
// Persist the selected worker so later cancellation requests can be
// routed to the correct per-worker cancel queue.
let mut execution_for_update = execution;
execution_for_update.status = ExecutionStatus::Scheduled;
execution_for_update.worker = Some(worker.id);
ExecutionRepository::update(pool, execution_for_update.id, execution_for_update.into())
.await?;
if let Err(err) =
ExecutionRepository::update(pool, execution_for_update.id, execution_for_update.into())
.await
{
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?;
return Err(err.into());
}
// Publish message to worker-specific queue
Self::queue_to_worker(
if let Err(err) = Self::queue_to_worker(
publisher,
&execution_id,
&worker.id,
@@ -266,7 +375,27 @@ impl ExecutionScheduler {
&execution_config,
&action,
)
.await?;
.await
{
if let Err(revert_err) = ExecutionRepository::update(
pool,
execution_id,
UpdateExecutionInput {
status: Some(ExecutionStatus::Requested),
..Default::default()
},
)
.await
{
warn!(
"Failed to revert execution {} back to Requested after worker publish error: {}",
execution_id, revert_err
);
}
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?;
return Err(err);
}
info!(
"Execution {} scheduled to worker {}",
@@ -1698,6 +1827,131 @@ impl ExecutionScheduler {
|| message.starts_with("No workers with fresh heartbeats available")
}
fn is_policy_cancellation_error(error: &anyhow::Error) -> bool {
let message = error.to_string();
message.contains("Policy violation:")
|| message.starts_with("Queue full for action ")
|| message.starts_with("Queue timeout for execution ")
}
async fn release_acquired_policy_slot(
policy_enforcer: &PolicyEnforcer,
pool: &PgPool,
publisher: &Publisher,
execution_id: i64,
) -> Result<()> {
let release = match policy_enforcer.release_execution_slot(execution_id).await {
Ok(release) => release,
Err(release_err) => {
warn!(
"Failed to release acquired policy slot for execution {} after scheduling error: {}",
execution_id, release_err
);
return Err(release_err);
}
};
let Some(release) = release else {
return Ok(());
};
if let Some(next_execution_id) = release.next_execution_id {
if let Err(republish_err) =
Self::republish_execution_requested(pool, publisher, next_execution_id).await
{
warn!(
"Failed to republish deferred execution {} after releasing slot from execution {}: {}",
next_execution_id, execution_id, republish_err
);
if let Err(restore_err) = policy_enforcer
.restore_execution_slot(execution_id, &release)
.await
{
warn!(
"Failed to restore policy slot for execution {} after republish error: {}",
execution_id, restore_err
);
}
return Err(republish_err);
}
}
Ok(())
}
async fn remove_queued_policy_execution(
policy_enforcer: &PolicyEnforcer,
pool: &PgPool,
publisher: &Publisher,
execution_id: i64,
) {
let removal = match policy_enforcer.remove_queued_execution(execution_id).await {
Ok(removal) => removal,
Err(remove_err) => {
warn!(
"Failed to remove queued policy execution {} during scheduler cleanup: {}",
execution_id, remove_err
);
return;
}
};
let Some(removal) = removal else {
return;
};
if let Some(next_execution_id) = removal.next_execution_id {
if let Err(republish_err) =
Self::republish_execution_requested(pool, publisher, next_execution_id).await
{
warn!(
"Failed to republish successor {} after removing queued execution {}: {}",
next_execution_id, execution_id, republish_err
);
if let Err(restore_err) = policy_enforcer.restore_queued_execution(&removal).await {
warn!(
"Failed to restore queued execution {} after republish error: {}",
execution_id, restore_err
);
}
}
}
}
async fn republish_execution_requested(
pool: &PgPool,
publisher: &Publisher,
execution_id: i64,
) -> Result<()> {
let execution = ExecutionRepository::find_by_id(pool, execution_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Execution {} not found", execution_id))?;
let action_id = execution
.action
.ok_or_else(|| anyhow::anyhow!("Execution {} has no action", execution_id))?;
let payload = ExecutionRequestedPayload {
execution_id,
action_id: Some(action_id),
action_ref: execution.action_ref.clone(),
parent_id: execution.parent,
enforcement_id: execution.enforcement,
config: execution.config.clone(),
};
let envelope = MessageEnvelope::new(MessageType::ExecutionRequested, payload)
.with_source("executor-scheduler");
publisher.publish_envelope(&envelope).await?;
debug!(
"Republished deferred ExecutionRequested for execution {}",
execution_id
);
Ok(())
}
async fn fail_unschedulable_execution(
pool: &PgPool,
publisher: &Publisher,
@@ -1751,6 +2005,59 @@ impl ExecutionScheduler {
Ok(())
}
async fn cancel_execution_for_policy_violation(
pool: &PgPool,
publisher: &Publisher,
envelope: &MessageEnvelope<ExecutionRequestedPayload>,
execution_id: i64,
action_id: i64,
action_ref: &str,
error_message: &str,
) -> Result<()> {
let completed_at = Utc::now();
let result = serde_json::json!({
"error": "Execution cancelled by policy",
"message": error_message,
"action_ref": action_ref,
"cancelled_by": "execution_scheduler",
"cancelled_at": completed_at.to_rfc3339(),
});
ExecutionRepository::update(
pool,
execution_id,
UpdateExecutionInput {
status: Some(ExecutionStatus::Cancelled),
result: Some(result.clone()),
..Default::default()
},
)
.await?;
let completed = MessageEnvelope::new(
MessageType::ExecutionCompleted,
ExecutionCompletedPayload {
execution_id,
action_id,
action_ref: action_ref.to_string(),
status: "cancelled".to_string(),
result: Some(result),
completed_at,
},
)
.with_correlation_id(envelope.correlation_id)
.with_source("attune-executor");
publisher.publish_envelope(&completed).await?;
warn!(
"Execution {} cancelled due to policy violation: {}",
execution_id, error_message
);
Ok(())
}
/// Check if a worker's heartbeat is fresh enough to schedule work
///
/// A worker is considered fresh if its last heartbeat is within
@@ -1980,6 +2287,24 @@ mod tests {
));
}
#[test]
fn test_policy_cancellation_error_classification() {
assert!(ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!(
"Policy violation: Concurrency limit exceeded: 1 running executions (limit: 1)"
)
));
assert!(ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!("Queue full for action 42: maximum 100 entries")
));
assert!(ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!("Queue timeout for execution 99: waited 60 seconds")
));
assert!(!ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!("rabbitmq publish failed")
));
}
#[test]
fn test_concurrency_limit_dispatch_count() {
// Verify the dispatch_count calculation used by dispatch_with_items_task

View File

@@ -297,6 +297,7 @@ impl ExecutorService {
self.inner.pool.clone(),
self.inner.publisher.clone(),
Arc::new(scheduler_consumer),
self.inner.policy_enforcer.clone(),
);
handles.push(tokio::spawn(async move { scheduler.start().await }));