merging semgrep-scan

This commit is contained in:
2026-04-01 20:38:18 -05:00
parent a0f59114a3
commit a4c303ec84
10 changed files with 1434 additions and 248 deletions

View File

@@ -102,7 +102,12 @@ impl MqError {
pub fn is_retriable(&self) -> bool { pub fn is_retriable(&self) -> bool {
matches!( matches!(
self, self,
MqError::Connection(_) | MqError::Channel(_) | MqError::Timeout(_) | MqError::Pool(_) MqError::Connection(_)
| MqError::Channel(_)
| MqError::Publish(_)
| MqError::Timeout(_)
| MqError::Pool(_)
| MqError::Lapin(_)
) )
} }

View File

@@ -571,7 +571,7 @@ impl Repository for PolicyRepository {
type Entity = Policy; type Entity = Policy;
fn table_name() -> &'static str { fn table_name() -> &'static str {
"policies" "policy"
} }
} }
@@ -612,7 +612,7 @@ impl FindById for PolicyRepository {
r#" r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated threshold, name, description, tags, created, updated
FROM policies FROM policy
WHERE id = $1 WHERE id = $1
"#, "#,
) )
@@ -634,7 +634,7 @@ impl FindByRef for PolicyRepository {
r#" r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated threshold, name, description, tags, created, updated
FROM policies FROM policy
WHERE ref = $1 WHERE ref = $1
"#, "#,
) )
@@ -656,7 +656,7 @@ impl List for PolicyRepository {
r#" r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated threshold, name, description, tags, created, updated
FROM policies FROM policy
ORDER BY ref ASC ORDER BY ref ASC
"#, "#,
) )
@@ -678,7 +678,7 @@ impl Create for PolicyRepository {
// Try to insert - database will enforce uniqueness constraint // Try to insert - database will enforce uniqueness constraint
let policy = sqlx::query_as::<_, Policy>( let policy = sqlx::query_as::<_, Policy>(
r#" r#"
INSERT INTO policies (ref, pack, pack_ref, action, action_ref, parameters, INSERT INTO policy (ref, pack, pack_ref, action, action_ref, parameters,
method, threshold, name, description, tags) method, threshold, name, description, tags)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, ref, pack, pack_ref, action, action_ref, parameters, method, RETURNING id, ref, pack, pack_ref, action, action_ref, parameters, method,
@@ -720,7 +720,7 @@ impl Update for PolicyRepository {
where where
E: Executor<'e, Database = Postgres> + 'e, E: Executor<'e, Database = Postgres> + 'e,
{ {
let mut query = QueryBuilder::new("UPDATE policies SET "); let mut query = QueryBuilder::new("UPDATE policy SET ");
let mut has_updates = false; let mut has_updates = false;
if let Some(parameters) = &input.parameters { if let Some(parameters) = &input.parameters {
@@ -798,7 +798,7 @@ impl Delete for PolicyRepository {
where where
E: Executor<'e, Database = Postgres> + 'e, E: Executor<'e, Database = Postgres> + 'e,
{ {
let result = sqlx::query("DELETE FROM policies WHERE id = $1") let result = sqlx::query("DELETE FROM policy WHERE id = $1")
.bind(id) .bind(id)
.execute(executor) .execute(executor)
.await?; .await?;
@@ -817,7 +817,7 @@ impl PolicyRepository {
r#" r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated threshold, name, description, tags, created, updated
FROM policies FROM policy
WHERE action = $1 WHERE action = $1
ORDER BY ref ASC ORDER BY ref ASC
"#, "#,
@@ -838,7 +838,7 @@ impl PolicyRepository {
r#" r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method, SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated threshold, name, description, tags, created, updated
FROM policies FROM policy
WHERE $1 = ANY(tags) WHERE $1 = ANY(tags)
ORDER BY ref ASC ORDER BY ref ASC
"#, "#,
@@ -849,4 +849,69 @@ impl PolicyRepository {
Ok(policies) Ok(policies)
} }
/// Find the most recent action-specific policy.
pub async fn find_latest_by_action<'e, E>(executor: E, action_id: Id) -> Result<Option<Policy>>
where
E: Executor<'e, Database = Postgres> + 'e,
{
let policy = sqlx::query_as::<_, Policy>(
r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated
FROM policy
WHERE action = $1
ORDER BY created DESC
LIMIT 1
"#,
)
.bind(action_id)
.fetch_optional(executor)
.await?;
Ok(policy)
}
/// Find the most recent pack-specific policy.
pub async fn find_latest_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Option<Policy>>
where
E: Executor<'e, Database = Postgres> + 'e,
{
let policy = sqlx::query_as::<_, Policy>(
r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated
FROM policy
WHERE pack = $1 AND action IS NULL
ORDER BY created DESC
LIMIT 1
"#,
)
.bind(pack_id)
.fetch_optional(executor)
.await?;
Ok(policy)
}
/// Find the most recent global policy.
pub async fn find_latest_global<'e, E>(executor: E) -> Result<Option<Policy>>
where
E: Executor<'e, Database = Postgres> + 'e,
{
let policy = sqlx::query_as::<_, Policy>(
r#"
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
threshold, name, description, tags, created, updated
FROM policy
WHERE pack IS NULL AND action IS NULL
ORDER BY created DESC
LIMIT 1
"#,
)
.fetch_optional(executor)
.await?;
Ok(policy)
}
} }

View File

@@ -11,7 +11,10 @@
use anyhow::Result; use anyhow::Result;
use attune_common::{ use attune_common::{
mq::{Consumer, ExecutionCompletedPayload, MessageEnvelope, Publisher}, mq::{
Consumer, ExecutionCompletedPayload, ExecutionRequestedPayload, MessageEnvelope,
MessageType, MqError, Publisher,
},
repositories::{execution::ExecutionRepository, FindById}, repositories::{execution::ExecutionRepository, FindById},
}; };
use sqlx::PgPool; use sqlx::PgPool;
@@ -36,6 +39,19 @@ pub struct CompletionListener {
} }
impl CompletionListener { impl CompletionListener {
fn retryable_mq_error(error: &anyhow::Error) -> Option<MqError> {
let mq_error = error.downcast_ref::<MqError>()?;
Some(match mq_error {
MqError::Connection(msg) => MqError::Connection(msg.clone()),
MqError::Channel(msg) => MqError::Channel(msg.clone()),
MqError::Publish(msg) => MqError::Publish(msg.clone()),
MqError::Timeout(msg) => MqError::Timeout(msg.clone()),
MqError::Pool(msg) => MqError::Pool(msg.clone()),
MqError::Lapin(err) => MqError::Connection(err.to_string()),
_ => return None,
})
}
/// Create a new completion listener /// Create a new completion listener
pub fn new( pub fn new(
pool: PgPool, pool: PgPool,
@@ -82,6 +98,9 @@ impl CompletionListener {
{ {
error!("Error processing execution completion: {}", e); error!("Error processing execution completion: {}", e);
// Return error to trigger nack with requeue // Return error to trigger nack with requeue
if let Some(mq_err) = Self::retryable_mq_error(&e) {
return Err(mq_err);
}
return Err( return Err(
format!("Failed to process execution completion: {}", e).into() format!("Failed to process execution completion: {}", e).into()
); );
@@ -187,19 +206,39 @@ impl CompletionListener {
action_id, execution_id action_id, execution_id
); );
match queue_manager.notify_completion(action_id).await { match queue_manager.release_active_slot(execution_id).await {
Ok(notified) => { Ok(release) => {
if notified { if let Some(release) = release {
if let Some(next_execution_id) = release.next_execution_id {
info!( info!(
"Queue slot released for action {}, next execution notified", "Queue slot released for action {}, next execution {} can proceed",
action_id action_id, next_execution_id
); );
if let Err(republish_err) = Self::publish_execution_requested(
pool,
publisher,
action_id,
next_execution_id,
)
.await
{
queue_manager
.restore_active_slot(execution_id, &release)
.await?;
return Err(republish_err);
}
} else { } else {
debug!( debug!(
"Queue slot released for action {}, no executions waiting", "Queue slot released for action {}, no executions waiting",
action_id action_id
); );
} }
} else {
debug!(
"Execution {} had no active queue slot to release",
execution_id
);
}
} }
Err(e) => { Err(e) => {
error!( error!(
@@ -225,6 +264,38 @@ impl CompletionListener {
Ok(()) Ok(())
} }
async fn publish_execution_requested(
pool: &PgPool,
publisher: &Publisher,
action_id: i64,
execution_id: i64,
) -> Result<()> {
let execution = ExecutionRepository::find_by_id(pool, execution_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Execution {} not found", execution_id))?;
let payload = ExecutionRequestedPayload {
execution_id,
action_id: Some(action_id),
action_ref: execution.action_ref.clone(),
parent_id: execution.parent,
enforcement_id: execution.enforcement,
config: execution.config.clone(),
};
let envelope = MessageEnvelope::new(MessageType::ExecutionRequested, payload)
.with_source("executor-completion-listener");
publisher.publish_envelope(&envelope).await?;
debug!(
"Republished deferred ExecutionRequested for execution {}",
execution_id
);
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]
@@ -239,7 +310,7 @@ mod tests {
// Simulate acquiring a slot // Simulate acquiring a slot
queue_manager queue_manager
.enqueue_and_wait(action_id, 100, 1) .enqueue_and_wait(action_id, 100, 1, None)
.await .await
.unwrap(); .unwrap();
@@ -249,7 +320,7 @@ mod tests {
assert_eq!(stats.queue_length, 0); assert_eq!(stats.queue_length, 0);
// Simulate completion notification // Simulate completion notification
let notified = queue_manager.notify_completion(action_id).await.unwrap(); let notified = queue_manager.notify_completion(100).await.unwrap();
assert!(!notified); // No one waiting assert!(!notified); // No one waiting
// Verify slot is released // Verify slot is released
@@ -264,7 +335,7 @@ mod tests {
// Fill capacity // Fill capacity
queue_manager queue_manager
.enqueue_and_wait(action_id, 100, 1) .enqueue_and_wait(action_id, 100, 1, None)
.await .await
.unwrap(); .unwrap();
@@ -272,7 +343,7 @@ mod tests {
let queue_manager_clone = queue_manager.clone(); let queue_manager_clone = queue_manager.clone();
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
queue_manager_clone queue_manager_clone
.enqueue_and_wait(action_id, 101, 1) .enqueue_and_wait(action_id, 101, 1, None)
.await .await
.unwrap(); .unwrap();
}); });
@@ -286,7 +357,7 @@ mod tests {
assert_eq!(stats.queue_length, 1); assert_eq!(stats.queue_length, 1);
// Notify completion // Notify completion
let notified = queue_manager.notify_completion(action_id).await.unwrap(); let notified = queue_manager.notify_completion(100).await.unwrap();
assert!(notified); // Should wake the waiting execution assert!(notified); // Should wake the waiting execution
// Wait for queued execution to proceed // Wait for queued execution to proceed
@@ -306,7 +377,7 @@ mod tests {
// Fill capacity // Fill capacity
queue_manager queue_manager
.enqueue_and_wait(action_id, 100, 1) .enqueue_and_wait(action_id, 100, 1, None)
.await .await
.unwrap(); .unwrap();
@@ -320,7 +391,7 @@ mod tests {
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
queue_manager queue_manager
.enqueue_and_wait(action_id, exec_id, 1) .enqueue_and_wait(action_id, exec_id, 1, None)
.await .await
.unwrap(); .unwrap();
order.lock().await.push(exec_id); order.lock().await.push(exec_id);
@@ -333,9 +404,9 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Release them one by one // Release them one by one
for _ in 0..3 { for execution_id in 100..103 {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
queue_manager.notify_completion(action_id).await.unwrap(); queue_manager.notify_completion(execution_id).await.unwrap();
} }
// Wait for all to complete // Wait for all to complete
@@ -351,10 +422,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_completion_with_no_queue() { async fn test_completion_with_no_queue() {
let queue_manager = Arc::new(ExecutionQueueManager::with_defaults()); let queue_manager = Arc::new(ExecutionQueueManager::with_defaults());
let action_id = 999; // Non-existent action let execution_id = 999; // Non-existent execution
// Should succeed but not notify anyone // Should succeed but not notify anyone
let result = queue_manager.notify_completion(action_id).await; let result = queue_manager.notify_completion(execution_id).await;
assert!(result.is_ok()); assert!(result.is_ok());
assert!(!result.unwrap()); assert!(!result.unwrap());
} }

View File

@@ -230,7 +230,7 @@ impl EnforcementProcessor {
async fn create_execution( async fn create_execution(
pool: &PgPool, pool: &PgPool,
publisher: &Publisher, publisher: &Publisher,
policy_enforcer: &PolicyEnforcer, _policy_enforcer: &PolicyEnforcer,
_queue_manager: &ExecutionQueueManager, _queue_manager: &ExecutionQueueManager,
enforcement: &Enforcement, enforcement: &Enforcement,
rule: &Rule, rule: &Rule,
@@ -257,33 +257,10 @@ impl EnforcementProcessor {
enforcement.id, rule.id, action_id enforcement.id, rule.id, action_id
); );
let pack_id = rule.pack;
let action_ref = &rule.action_ref; let action_ref = &rule.action_ref;
// Enforce policies and wait for queue slot if needed // Create the execution row first; scheduler-side policy enforcement
info!( // now handles both rule-triggered and manual executions uniformly.
"Enforcing policies for action {} (enforcement: {})",
action_id, enforcement.id
);
// Use enforcement ID for queue tracking (execution doesn't exist yet)
if let Err(e) = policy_enforcer
.enforce_and_wait(action_id, Some(pack_id), enforcement.id)
.await
{
error!(
"Policy enforcement failed for enforcement {}: {}",
enforcement.id, e
);
return Err(e);
}
info!(
"Policy check passed and queue slot obtained for enforcement: {}",
enforcement.id
);
// Now create execution in database (we have a queue slot)
let execution_input = CreateExecutionInput { let execution_input = CreateExecutionInput {
action: Some(action_id), action: Some(action_id),
action_ref: action_ref.clone(), action_ref: action_ref.clone(),

View File

@@ -10,14 +10,23 @@
use anyhow::Result; use anyhow::Result;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use sqlx::PgPool; use sqlx::PgPool;
use std::collections::HashMap; use std::collections::{BTreeMap, HashMap};
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use attune_common::models::{enums::ExecutionStatus, Id}; use attune_common::{
models::{
enums::{ExecutionStatus, PolicyMethod},
Id, Policy,
},
repositories::action::PolicyRepository,
};
use crate::queue_manager::ExecutionQueueManager; use crate::queue_manager::{
ExecutionQueueManager, QueuedRemovalOutcome, SlotEnqueueOutcome, SlotReleaseOutcome,
};
/// Policy violation type /// Policy violation type
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -79,16 +88,38 @@ impl std::fmt::Display for PolicyViolation {
} }
/// Execution policy configuration /// Execution policy configuration
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionPolicy { pub struct ExecutionPolicy {
/// Rate limit: maximum executions per time window /// Rate limit: maximum executions per time window
pub rate_limit: Option<RateLimit>, pub rate_limit: Option<RateLimit>,
/// Concurrency limit: maximum concurrent executions /// Concurrency limit: maximum concurrent executions
pub concurrency_limit: Option<u32>, pub concurrency_limit: Option<u32>,
/// How a concurrency violation should be handled.
pub concurrency_method: PolicyMethod,
/// Parameter paths used to scope concurrency grouping.
pub concurrency_parameters: Vec<String>,
/// Resource quotas /// Resource quotas
pub quotas: Option<HashMap<String, u64>>, pub quotas: Option<HashMap<String, u64>>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchedulingPolicyOutcome {
Ready,
Queued,
}
impl Default for ExecutionPolicy {
fn default() -> Self {
Self {
rate_limit: None,
concurrency_limit: None,
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None,
}
}
}
/// Rate limit configuration /// Rate limit configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimit { pub struct RateLimit {
@@ -98,6 +129,25 @@ pub struct RateLimit {
pub window_seconds: u32, pub window_seconds: u32,
} }
#[derive(Debug, Clone)]
struct ResolvedConcurrencyPolicy {
limit: u32,
method: PolicyMethod,
parameters: Vec<String>,
}
impl From<Policy> for ExecutionPolicy {
fn from(policy: Policy) -> Self {
Self {
rate_limit: None,
concurrency_limit: Some(policy.threshold as u32),
concurrency_method: policy.method,
concurrency_parameters: policy.parameters,
quotas: None,
}
}
}
/// Policy enforcement scope /// Policy enforcement scope
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)] // Used in tests #[allow(dead_code)] // Used in tests
@@ -185,6 +235,174 @@ impl PolicyEnforcer {
self.action_policies.insert(action_id, policy); self.action_policies.insert(action_id, policy);
} }
/// Best-effort release for a slot acquired during scheduling when the
/// execution never reaches the worker/completion path.
pub async fn release_execution_slot(
&self,
execution_id: Id,
) -> Result<Option<SlotReleaseOutcome>> {
match &self.queue_manager {
Some(queue_manager) => queue_manager.release_active_slot(execution_id).await,
None => Ok(None),
}
}
pub async fn restore_execution_slot(
&self,
execution_id: Id,
outcome: &SlotReleaseOutcome,
) -> Result<()> {
match &self.queue_manager {
Some(queue_manager) => {
queue_manager
.restore_active_slot(execution_id, outcome)
.await
}
None => Ok(()),
}
}
pub async fn remove_queued_execution(
&self,
execution_id: Id,
) -> Result<Option<QueuedRemovalOutcome>> {
match &self.queue_manager {
Some(queue_manager) => queue_manager.remove_queued_execution(execution_id).await,
None => Ok(None),
}
}
pub async fn restore_queued_execution(&self, outcome: &QueuedRemovalOutcome) -> Result<()> {
match &self.queue_manager {
Some(queue_manager) => queue_manager.restore_queued_execution(outcome).await,
None => Ok(()),
}
}
pub async fn enforce_for_scheduling(
&self,
action_id: Id,
pack_id: Option<Id>,
execution_id: Id,
config: Option<&JsonValue>,
) -> Result<SchedulingPolicyOutcome> {
if let Some(violation) = self
.check_policies_except_concurrency(action_id, pack_id)
.await?
{
warn!("Policy violation for action {}: {}", action_id, violation);
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
if let Some(concurrency) = self.resolve_concurrency_policy(action_id, pack_id).await? {
let group_key = self.build_parameter_group_key(&concurrency.parameters, config);
if let Some(queue_manager) = &self.queue_manager {
match concurrency.method {
PolicyMethod::Enqueue => {
return match queue_manager
.enqueue(action_id, execution_id, concurrency.limit, group_key)
.await?
{
SlotEnqueueOutcome::Acquired => Ok(SchedulingPolicyOutcome::Ready),
SlotEnqueueOutcome::Enqueued => Ok(SchedulingPolicyOutcome::Queued),
};
}
PolicyMethod::Cancel => {
let outcome = queue_manager
.try_acquire(
action_id,
execution_id,
concurrency.limit,
group_key.clone(),
)
.await?;
if !outcome.acquired {
let violation = PolicyViolation::ConcurrencyLimitExceeded {
limit: concurrency.limit,
current_count: outcome.current_count,
};
warn!("Policy violation for action {}: {}", action_id, violation);
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
}
}
} else {
let scope = PolicyScope::Action(action_id);
if let Some(violation) = self
.check_concurrency_limit(concurrency.limit, &scope)
.await?
{
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
}
}
Ok(SchedulingPolicyOutcome::Ready)
}
async fn resolve_policy(&self, action_id: Id, pack_id: Option<Id>) -> Result<ExecutionPolicy> {
if let Some(policy) = self.action_policies.get(&action_id) {
return Ok(policy.clone());
}
if let Some(policy) = PolicyRepository::find_latest_by_action(&self.pool, action_id).await?
{
return Ok(policy.into());
}
if let Some(pack_id) = pack_id {
if let Some(policy) = self.pack_policies.get(&pack_id) {
return Ok(policy.clone());
}
if let Some(policy) = PolicyRepository::find_latest_by_pack(&self.pool, pack_id).await?
{
return Ok(policy.into());
}
}
if let Some(policy) = PolicyRepository::find_latest_global(&self.pool).await? {
return Ok(policy.into());
}
Ok(self.global_policy.clone())
}
async fn resolve_concurrency_policy(
&self,
action_id: Id,
pack_id: Option<Id>,
) -> Result<Option<ResolvedConcurrencyPolicy>> {
let policy = self.resolve_policy(action_id, pack_id).await?;
Ok(policy
.concurrency_limit
.map(|limit| ResolvedConcurrencyPolicy {
limit,
method: policy.concurrency_method,
parameters: policy.concurrency_parameters,
}))
}
fn build_parameter_group_key(
&self,
parameter_paths: &[String],
config: Option<&JsonValue>,
) -> Option<String> {
if parameter_paths.is_empty() {
return None;
}
let values: BTreeMap<String, JsonValue> = parameter_paths
.iter()
.map(|path| (path.clone(), extract_parameter_value(config, path)))
.collect();
serde_json::to_string(&values).ok()
}
/// Get the concurrency limit for a specific action /// Get the concurrency limit for a specific action
/// ///
/// Returns the most specific concurrency limit found: /// Returns the most specific concurrency limit found:
@@ -192,6 +410,7 @@ impl PolicyEnforcer {
/// 2. Pack policy /// 2. Pack policy
/// 3. Global policy /// 3. Global policy
/// 4. None (unlimited) /// 4. None (unlimited)
#[allow(dead_code)]
pub fn get_concurrency_limit(&self, action_id: Id, pack_id: Option<Id>) -> Option<u32> { pub fn get_concurrency_limit(&self, action_id: Id, pack_id: Option<Id>) -> Option<u32> {
// Check action-specific policy first // Check action-specific policy first
if let Some(policy) = self.action_policies.get(&action_id) { if let Some(policy) = self.action_policies.get(&action_id) {
@@ -229,11 +448,13 @@ impl PolicyEnforcer {
/// * `Ok(())` - Policy allows execution and queue slot obtained /// * `Ok(())` - Policy allows execution and queue slot obtained
/// * `Err(PolicyViolation)` - Policy prevents execution /// * `Err(PolicyViolation)` - Policy prevents execution
/// * `Err(QueueError)` - Queue timeout or other queue error /// * `Err(QueueError)` - Queue timeout or other queue error
#[allow(dead_code)]
pub async fn enforce_and_wait( pub async fn enforce_and_wait(
&self, &self,
action_id: Id, action_id: Id,
pack_id: Option<Id>, pack_id: Option<Id>,
execution_id: Id, execution_id: Id,
config: Option<&JsonValue>,
) -> Result<()> { ) -> Result<()> {
// First, check for policy violations (rate limit, quotas, etc.) // First, check for policy violations (rate limit, quotas, etc.)
// Note: We skip concurrency check here since queue manages that // Note: We skip concurrency check here since queue manages that
@@ -246,23 +467,50 @@ impl PolicyEnforcer {
} }
// If queue manager is available, use it for concurrency control // If queue manager is available, use it for concurrency control
if let Some(queue_manager) = &self.queue_manager { if let Some(concurrency) = self.resolve_concurrency_policy(action_id, pack_id).await? {
let concurrency_limit = self let group_key = self.build_parameter_group_key(&concurrency.parameters, config);
.get_concurrency_limit(action_id, pack_id)
.unwrap_or(u32::MAX); // Default to unlimited if no policy
if let Some(queue_manager) = &self.queue_manager {
debug!( debug!(
"Enqueuing execution {} for action {} with concurrency limit {}", "Applying concurrency policy to execution {} for action {} (limit: {}, method: {:?}, group: {:?})",
execution_id, action_id, concurrency_limit execution_id, action_id, concurrency.limit, concurrency.method, group_key
); );
match concurrency.method {
PolicyMethod::Enqueue => {
queue_manager queue_manager
.enqueue_and_wait(action_id, execution_id, concurrency_limit) .enqueue_and_wait(
action_id,
execution_id,
concurrency.limit,
group_key.clone(),
)
.await?;
}
PolicyMethod::Cancel => {
let outcome = queue_manager
.try_acquire(
action_id,
execution_id,
concurrency.limit,
group_key.clone(),
)
.await?; .await?;
if !outcome.acquired {
let violation = PolicyViolation::ConcurrencyLimitExceeded {
limit: concurrency.limit,
current_count: outcome.current_count,
};
warn!("Policy violation for action {}: {}", action_id, violation);
return Err(anyhow::anyhow!("Policy violation: {}", violation));
}
}
}
info!( info!(
"Execution {} obtained queue slot for action {}", "Execution {} obtained queue slot for action {} (group: {:?})",
execution_id, action_id execution_id, action_id, group_key
); );
} else { } else {
// No queue manager - use legacy polling behavior // No queue manager - use legacy polling behavior
@@ -271,11 +519,9 @@ impl PolicyEnforcer {
action_id action_id
); );
if let Some(concurrency_limit) = self.get_concurrency_limit(action_id, pack_id) {
// Check concurrency with old method
let scope = PolicyScope::Action(action_id); let scope = PolicyScope::Action(action_id);
if let Some(violation) = self if let Some(violation) = self
.check_concurrency_limit(concurrency_limit, &scope) .check_concurrency_limit(concurrency.limit, &scope)
.await? .await?
{ {
return Err(anyhow::anyhow!("Policy violation: {}", violation)); return Err(anyhow::anyhow!("Policy violation: {}", violation));
@@ -631,6 +877,25 @@ impl PolicyEnforcer {
} }
} }
fn extract_parameter_value(config: Option<&JsonValue>, path: &str) -> JsonValue {
let mut current = match config {
Some(value) => value,
None => return JsonValue::Null,
};
for segment in path.split('.') {
match current {
JsonValue::Object(map) => match map.get(segment) {
Some(next) => current = next,
None => return JsonValue::Null,
},
_ => return JsonValue::Null,
}
}
current.clone()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -665,6 +930,8 @@ mod tests {
let policy = ExecutionPolicy::default(); let policy = ExecutionPolicy::default();
assert!(policy.rate_limit.is_none()); assert!(policy.rate_limit.is_none());
assert!(policy.concurrency_limit.is_none()); assert!(policy.concurrency_limit.is_none());
assert_eq!(policy.concurrency_method, PolicyMethod::Enqueue);
assert!(policy.concurrency_parameters.is_empty());
assert!(policy.quotas.is_none()); assert!(policy.quotas.is_none());
} }
@@ -784,7 +1051,7 @@ mod tests {
); );
// First execution should proceed immediately // First execution should proceed immediately
let result = enforcer.enforce_and_wait(1, None, 100).await; let result = enforcer.enforce_and_wait(1, None, 100, None).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Check queue stats // Check queue stats
@@ -809,7 +1076,7 @@ mod tests {
let enforcer = Arc::new(enforcer); let enforcer = Arc::new(enforcer);
// First execution // First execution
let result = enforcer.enforce_and_wait(1, None, 100).await; let result = enforcer.enforce_and_wait(1, None, 100, None).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Queue multiple executions // Queue multiple executions
@@ -822,11 +1089,14 @@ mod tests {
let order = execution_order.clone(); let order = execution_order.clone();
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
enforcer.enforce_and_wait(1, None, exec_id).await.unwrap(); enforcer
.enforce_and_wait(1, None, exec_id, None)
.await
.unwrap();
order.lock().await.push(exec_id); order.lock().await.push(exec_id);
// Simulate work // Simulate work
sleep(Duration::from_millis(10)).await; sleep(Duration::from_millis(10)).await;
queue_manager.notify_completion(1).await.unwrap(); queue_manager.notify_completion(exec_id).await.unwrap();
}); });
handles.push(handle); handles.push(handle);
@@ -836,7 +1106,7 @@ mod tests {
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Release first execution // Release first execution
queue_manager.notify_completion(1).await.unwrap(); queue_manager.notify_completion(100).await.unwrap();
// Wait for all // Wait for all
for handle in handles { for handle in handles {
@@ -863,7 +1133,7 @@ mod tests {
); );
// Should work without queue manager (legacy behavior) // Should work without queue manager (legacy behavior)
let result = enforcer.enforce_and_wait(1, None, 100).await; let result = enforcer.enforce_and_wait(1, None, 100, None).await;
assert!(result.is_ok()); assert!(result.is_ok());
} }
@@ -889,14 +1159,36 @@ mod tests {
); );
// First execution proceeds // First execution proceeds
enforcer.enforce_and_wait(1, None, 100).await.unwrap(); enforcer.enforce_and_wait(1, None, 100, None).await.unwrap();
// Second execution should timeout // Second execution should timeout
let result = enforcer.enforce_and_wait(1, None, 101).await; let result = enforcer.enforce_and_wait(1, None, 101, None).await;
assert!(result.is_err()); assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout")); assert!(result.unwrap_err().to_string().contains("timeout"));
} }
#[test]
fn test_build_parameter_group_key_uses_exact_values() {
let pool = sqlx::PgPool::connect_lazy("postgresql://localhost/test").unwrap();
let enforcer = PolicyEnforcer::new(pool);
let config = serde_json::json!({
"environment": "prod",
"target": {
"region": "us-east-1"
}
});
let group_key = enforcer.build_parameter_group_key(
&["target.region".to_string(), "environment".to_string()],
Some(&config),
);
assert_eq!(
group_key.as_deref(),
Some("{\"environment\":\"prod\",\"target.region\":\"us-east-1\"}")
);
}
// Integration tests would require database setup // Integration tests would require database setup
// Those should be in a separate integration test file // Those should be in a separate integration test file
} }

View File

@@ -47,7 +47,7 @@ impl Default for QueueConfig {
} }
/// Entry in the execution queue /// Entry in the execution queue
#[derive(Debug)] #[derive(Debug, Clone)]
struct QueueEntry { struct QueueEntry {
/// Execution or enforcement ID being queued /// Execution or enforcement ID being queued
execution_id: Id, execution_id: Id,
@@ -57,7 +57,13 @@ struct QueueEntry {
notifier: Arc<Notify>, notifier: Arc<Notify>,
} }
/// Queue state for a single action #[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct QueueKey {
action_id: Id,
group_key: Option<String>,
}
/// Queue state for a single action/group pair
struct ActionQueue { struct ActionQueue {
/// FIFO queue of waiting executions /// FIFO queue of waiting executions
queue: VecDeque<QueueEntry>, queue: VecDeque<QueueEntry>,
@@ -93,6 +99,32 @@ impl ActionQueue {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SlotAcquireOutcome {
pub acquired: bool,
pub current_count: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SlotEnqueueOutcome {
Acquired,
Enqueued,
}
#[derive(Debug, Clone)]
pub struct SlotReleaseOutcome {
pub next_execution_id: Option<Id>,
queue_key: QueueKey,
}
#[derive(Debug, Clone)]
pub struct QueuedRemovalOutcome {
pub next_execution_id: Option<Id>,
queue_key: QueueKey,
removed_entry: QueueEntry,
removed_index: usize,
}
/// Statistics about a queue /// Statistics about a queue
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStats { pub struct QueueStats {
@@ -114,8 +146,10 @@ pub struct QueueStats {
/// Manages execution queues with FIFO ordering guarantees /// Manages execution queues with FIFO ordering guarantees
pub struct ExecutionQueueManager { pub struct ExecutionQueueManager {
/// Per-action queues (key: action_id) /// Per-action/per-group queues.
queues: DashMap<Id, Arc<Mutex<ActionQueue>>>, queues: DashMap<QueueKey, Arc<Mutex<ActionQueue>>>,
/// Tracks which queue key currently owns an active execution slot.
active_execution_keys: DashMap<Id, QueueKey>,
/// Configuration /// Configuration
config: QueueConfig, config: QueueConfig,
/// Database connection pool (optional for stats persistence) /// Database connection pool (optional for stats persistence)
@@ -128,6 +162,7 @@ impl ExecutionQueueManager {
pub fn new(config: QueueConfig) -> Self { pub fn new(config: QueueConfig) -> Self {
Self { Self {
queues: DashMap::new(), queues: DashMap::new(),
active_execution_keys: DashMap::new(),
config, config,
db_pool: None, db_pool: None,
} }
@@ -137,6 +172,7 @@ impl ExecutionQueueManager {
pub fn with_db_pool(config: QueueConfig, db_pool: PgPool) -> Self { pub fn with_db_pool(config: QueueConfig, db_pool: PgPool) -> Self {
Self { Self {
queues: DashMap::new(), queues: DashMap::new(),
active_execution_keys: DashMap::new(),
config, config,
db_pool: Some(db_pool), db_pool: Some(db_pool),
} }
@@ -148,6 +184,24 @@ impl ExecutionQueueManager {
Self::new(QueueConfig::default()) Self::new(QueueConfig::default())
} }
fn queue_key(&self, action_id: Id, group_key: Option<String>) -> QueueKey {
QueueKey {
action_id,
group_key,
}
}
async fn get_or_create_queue(
&self,
queue_key: QueueKey,
max_concurrent: u32,
) -> Arc<Mutex<ActionQueue>> {
self.queues
.entry(queue_key)
.or_insert_with(|| Arc::new(Mutex::new(ActionQueue::new(max_concurrent))))
.clone()
}
/// Enqueue an execution and wait until it can proceed /// Enqueue an execution and wait until it can proceed
/// ///
/// This method will: /// This method will:
@@ -164,23 +218,31 @@ impl ExecutionQueueManager {
/// # Returns /// # Returns
/// * `Ok(())` - Execution can proceed /// * `Ok(())` - Execution can proceed
/// * `Err(_)` - Queue full or timeout /// * `Err(_)` - Queue full or timeout
#[allow(dead_code)]
pub async fn enqueue_and_wait( pub async fn enqueue_and_wait(
&self, &self,
action_id: Id, action_id: Id,
execution_id: Id, execution_id: Id,
max_concurrent: u32, max_concurrent: u32,
group_key: Option<String>,
) -> Result<()> { ) -> Result<()> {
if self.active_execution_keys.contains_key(&execution_id) {
debug!( debug!(
"Enqueuing execution {} for action {} (max_concurrent: {})", "Execution {} already owns an active slot, skipping queue wait",
execution_id, action_id, max_concurrent execution_id
);
return Ok(());
}
debug!(
"Enqueuing execution {} for action {} (max_concurrent: {}, group: {:?})",
execution_id, action_id, max_concurrent, group_key
); );
// Get or create queue for this action let queue_key = self.queue_key(action_id, group_key);
let queue_arc = self let queue_arc = self
.queues .get_or_create_queue(queue_key.clone(), max_concurrent)
.entry(action_id) .await;
.or_insert_with(|| Arc::new(Mutex::new(ActionQueue::new(max_concurrent))))
.clone();
// Create notifier for this execution // Create notifier for this execution
let notifier = Arc::new(Notify::new()); let notifier = Arc::new(Notify::new());
@@ -192,14 +254,41 @@ impl ExecutionQueueManager {
// Update max_concurrent if it changed // Update max_concurrent if it changed
queue.max_concurrent = max_concurrent; queue.max_concurrent = max_concurrent;
let queued_index = queue
.queue
.iter()
.position(|entry| entry.execution_id == execution_id);
if let Some(queued_index) = queued_index {
if queued_index == 0 && queue.has_capacity() {
let entry = queue.queue.pop_front().expect("front entry just checked");
queue.active_count += 1;
self.active_execution_keys
.insert(entry.execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(());
}
debug!(
"Execution {} is already queued for action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
);
return Ok(());
}
// Check if we can run immediately // Check if we can run immediately
if queue.has_capacity() { if queue.has_capacity() {
debug!( debug!(
"Execution {} can run immediately (active: {}/{})", "Execution {} can run immediately for action {} (active: {}/{}, group: {:?})",
execution_id, queue.active_count, queue.max_concurrent execution_id,
action_id,
queue.active_count,
queue.max_concurrent,
queue_key.group_key
); );
queue.active_count += 1; queue.active_count += 1;
queue.total_enqueued += 1; queue.total_enqueued += 1;
self.active_execution_keys
.insert(execution_id, queue_key.clone());
// Persist stats to database if available // Persist stats to database if available
drop(queue); drop(queue);
@@ -211,8 +300,9 @@ impl ExecutionQueueManager {
// Check if queue is full // Check if queue is full
if queue.is_full(self.config.max_queue_length) { if queue.is_full(self.config.max_queue_length) {
warn!( warn!(
"Queue full for action {}: {} entries (limit: {})", "Queue full for action {} group {:?}: {} entries (limit: {})",
action_id, action_id,
queue_key.group_key,
queue.queue.len(), queue.queue.len(),
self.config.max_queue_length self.config.max_queue_length
); );
@@ -234,12 +324,13 @@ impl ExecutionQueueManager {
queue.total_enqueued += 1; queue.total_enqueued += 1;
info!( info!(
"Execution {} queued for action {} at position {} (active: {}/{})", "Execution {} queued for action {} at position {} (active: {}/{}, group: {:?})",
execution_id, execution_id,
action_id, action_id,
queue.queue.len() - 1, queue.queue.len() - 1,
queue.active_count, queue.active_count,
queue.max_concurrent queue.max_concurrent,
queue_key.group_key
); );
} }
@@ -273,6 +364,142 @@ impl ExecutionQueueManager {
} }
} }
/// Acquire a slot immediately or enqueue without blocking the caller.
pub async fn enqueue(
&self,
action_id: Id,
execution_id: Id,
max_concurrent: u32,
group_key: Option<String>,
) -> Result<SlotEnqueueOutcome> {
if self.active_execution_keys.contains_key(&execution_id) {
debug!(
"Execution {} already owns an active slot, treating as acquired",
execution_id
);
return Ok(SlotEnqueueOutcome::Acquired);
}
debug!(
"Enqueuing execution {} for action {} without waiting (max_concurrent: {}, group: {:?})",
execution_id, action_id, max_concurrent, group_key
);
let queue_key = self.queue_key(action_id, group_key);
let queue_arc = self
.get_or_create_queue(queue_key.clone(), max_concurrent)
.await;
{
let mut queue = queue_arc.lock().await;
queue.max_concurrent = max_concurrent;
let queued_index = queue
.queue
.iter()
.position(|entry| entry.execution_id == execution_id);
if let Some(queued_index) = queued_index {
if queued_index == 0 && queue.has_capacity() {
let entry = queue.queue.pop_front().expect("front entry just checked");
queue.active_count += 1;
self.active_execution_keys
.insert(entry.execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(SlotEnqueueOutcome::Acquired);
}
debug!(
"Execution {} is already queued for action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
);
return Ok(SlotEnqueueOutcome::Enqueued);
}
if queue.has_capacity() {
queue.active_count += 1;
queue.total_enqueued += 1;
self.active_execution_keys
.insert(execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(SlotEnqueueOutcome::Acquired);
}
if queue.is_full(self.config.max_queue_length) {
warn!(
"Queue full for action {} group {:?}: {} entries (limit: {})",
action_id,
queue_key.group_key,
queue.queue.len(),
self.config.max_queue_length
);
return Err(anyhow::anyhow!(
"Queue full for action {}: maximum {} entries",
action_id,
self.config.max_queue_length
));
}
queue.queue.push_back(QueueEntry {
execution_id,
enqueued_at: Utc::now(),
notifier: Arc::new(Notify::new()),
});
queue.total_enqueued += 1;
}
self.persist_queue_stats(action_id).await;
Ok(SlotEnqueueOutcome::Enqueued)
}
/// Try to acquire a slot immediately without queueing.
pub async fn try_acquire(
&self,
action_id: Id,
execution_id: Id,
max_concurrent: u32,
group_key: Option<String>,
) -> Result<SlotAcquireOutcome> {
let queue_key = self.queue_key(action_id, group_key);
let queue_arc = self
.get_or_create_queue(queue_key.clone(), max_concurrent)
.await;
let mut queue = queue_arc.lock().await;
queue.max_concurrent = max_concurrent;
let current_count = queue.active_count;
if self.active_execution_keys.contains_key(&execution_id) {
debug!(
"Execution {} already owns a slot for action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
);
return Ok(SlotAcquireOutcome {
acquired: true,
current_count,
});
}
if queue.has_capacity() {
queue.active_count += 1;
queue.total_enqueued += 1;
self.active_execution_keys
.insert(execution_id, queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(SlotAcquireOutcome {
acquired: true,
current_count,
});
}
Ok(SlotAcquireOutcome {
acquired: false,
current_count,
})
}
/// Notify that an execution has completed, releasing a queue slot /// Notify that an execution has completed, releasing a queue slot
/// ///
/// This method will: /// This method will:
@@ -282,27 +509,64 @@ impl ExecutionQueueManager {
/// 4. Increment active count for the notified execution /// 4. Increment active count for the notified execution
/// ///
/// # Arguments /// # Arguments
/// * `action_id` - The action that completed /// * `execution_id` - The execution that completed
/// ///
/// # Returns /// # Returns
/// * `Ok(true)` - A queued execution was notified /// * `Ok(true)` - A queued execution was notified
/// * `Ok(false)` - No executions were waiting /// * `Ok(false)` - No executions were waiting
/// * `Err(_)` - Error accessing queue /// * `Err(_)` - Error accessing queue
pub async fn notify_completion(&self, action_id: Id) -> Result<bool> { pub async fn notify_completion(&self, execution_id: Id) -> Result<bool> {
Ok(self
.notify_completion_with_next(execution_id)
.await?
.is_some())
}
pub async fn notify_completion_with_next(&self, execution_id: Id) -> Result<Option<Id>> {
let release = match self.release_active_slot(execution_id).await? {
Some(release) => release,
None => return Ok(None),
};
let Some(next_execution_id) = release.next_execution_id else {
return Ok(None);
};
if self.activate_queued_execution(next_execution_id).await? {
Ok(Some(next_execution_id))
} else {
self.restore_active_slot(execution_id, &release).await?;
Ok(None)
}
}
pub async fn release_active_slot(
&self,
execution_id: Id,
) -> Result<Option<SlotReleaseOutcome>> {
let Some((_, queue_key)) = self.active_execution_keys.remove(&execution_id) else {
debug!( debug!(
"Processing completion notification for action {}", "No active queue slot found for execution {} (queue may have been cleared)",
action_id execution_id
);
return Ok(None);
};
let action_id = queue_key.action_id;
debug!(
"Processing completion notification for execution {} on action {} (group: {:?})",
execution_id, action_id, queue_key.group_key
); );
// Get queue for this action // Get queue for this action/group
let queue_arc = match self.queues.get(&action_id) { let queue_arc = match self.queues.get(&queue_key) {
Some(q) => q.clone(), Some(q) => q.clone(),
None => { None => {
debug!( debug!(
"No queue found for action {} (no executions queued)", "No queue found for action {} group {:?}",
action_id action_id, queue_key.group_key
); );
return Ok(false); return Ok(None);
} }
}; };
@@ -313,49 +577,162 @@ impl ExecutionQueueManager {
queue.active_count -= 1; queue.active_count -= 1;
queue.total_completed += 1; queue.total_completed += 1;
debug!( debug!(
"Decremented active count for action {} to {}", "Decremented active count for action {} group {:?} to {}",
action_id, queue.active_count action_id, queue_key.group_key, queue.active_count
); );
} else { } else {
warn!( warn!(
"Completion notification for action {} but active_count is 0", "Completion notification for action {} group {:?} but active_count is 0",
action_id action_id, queue_key.group_key
); );
} }
// Check if there are queued executions // Check if there are queued executions
if queue.queue.is_empty() { if queue.queue.is_empty() {
debug!( debug!(
"No executions queued for action {} after completion", "No executions queued for action {} group {:?} after completion",
action_id action_id, queue_key.group_key
); );
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(Some(SlotReleaseOutcome {
next_execution_id: None,
queue_key,
}));
}
let next_execution_id = queue.queue.front().map(|entry| entry.execution_id);
if let Some(next_execution_id) = next_execution_id {
info!(
"Execution {} is next for action {} group {:?}",
next_execution_id, action_id, queue_key.group_key
);
}
drop(queue);
self.persist_queue_stats(action_id).await;
Ok(Some(SlotReleaseOutcome {
next_execution_id,
queue_key,
}))
}
pub async fn restore_active_slot(
&self,
execution_id: Id,
outcome: &SlotReleaseOutcome,
) -> Result<()> {
let action_id = outcome.queue_key.action_id;
let queue_arc = self.get_or_create_queue(outcome.queue_key.clone(), 1).await;
let mut queue = queue_arc.lock().await;
queue.active_count += 1;
if queue.total_completed > 0 {
queue.total_completed -= 1;
}
self.active_execution_keys
.insert(execution_id, outcome.queue_key.clone());
drop(queue);
self.persist_queue_stats(action_id).await;
Ok(())
}
pub async fn activate_queued_execution(&self, execution_id: Id) -> Result<bool> {
for entry in self.queues.iter() {
let queue_key = entry.key().clone();
let queue_arc = entry.value().clone();
let mut queue = queue_arc.lock().await;
let Some(front) = queue.queue.front() else {
continue;
};
if front.execution_id != execution_id {
continue;
}
if !queue.has_capacity() {
return Ok(false); return Ok(false);
} }
// Pop the first (oldest) entry from queue let entry = queue.queue.pop_front().expect("front entry just checked");
if let Some(entry) = queue.queue.pop_front() {
info!( info!(
"Notifying execution {} for action {} (was queued for {:?})", "Activating queued execution {} for action {} group {:?} (queued for {:?})",
entry.execution_id, entry.execution_id,
action_id, queue_key.action_id,
queue_key.group_key,
Utc::now() - entry.enqueued_at Utc::now() - entry.enqueued_at
); );
// Increment active count for the execution we're about to notify
queue.active_count += 1; queue.active_count += 1;
self.active_execution_keys
.insert(entry.execution_id, queue_key.clone());
// Notify the waiter (after releasing lock)
drop(queue); drop(queue);
entry.notifier.notify_one(); entry.notifier.notify_one();
self.persist_queue_stats(queue_key.action_id).await;
return Ok(true);
}
// Persist stats to database if available
self.persist_queue_stats(action_id).await;
Ok(true)
} else {
// Race condition check - queue was empty after all
Ok(false) Ok(false)
} }
pub async fn remove_queued_execution(
&self,
execution_id: Id,
) -> Result<Option<QueuedRemovalOutcome>> {
for entry in self.queues.iter() {
let queue_key = entry.key().clone();
let queue_arc = entry.value().clone();
let mut queue = queue_arc.lock().await;
let Some(index) = queue
.queue
.iter()
.position(|queued| queued.execution_id == execution_id)
else {
continue;
};
let removed_entry = queue.queue.remove(index).expect("queue index just checked");
let next_execution_id = if index == 0 {
queue.queue.front().map(|queued| queued.execution_id)
} else {
None
};
let action_id = queue_key.action_id;
drop(queue);
self.persist_queue_stats(action_id).await;
return Ok(Some(QueuedRemovalOutcome {
next_execution_id,
queue_key,
removed_entry,
removed_index: index,
}));
}
Ok(None)
}
pub async fn restore_queued_execution(&self, outcome: &QueuedRemovalOutcome) -> Result<()> {
let action_id = outcome.queue_key.action_id;
let queue_arc = self.get_or_create_queue(outcome.queue_key.clone(), 1).await;
let mut queue = queue_arc.lock().await;
if outcome.removed_index <= queue.queue.len() {
queue
.queue
.insert(outcome.removed_index, outcome.removed_entry.clone());
} else {
queue.queue.push_back(outcome.removed_entry.clone());
}
drop(queue);
self.persist_queue_stats(action_id).await;
Ok(())
} }
/// Persist queue statistics to database (if database pool is available) /// Persist queue statistics to database (if database pool is available)
@@ -384,19 +761,48 @@ impl ExecutionQueueManager {
/// Get statistics for a specific action's queue /// Get statistics for a specific action's queue
pub async fn get_queue_stats(&self, action_id: Id) -> Option<QueueStats> { pub async fn get_queue_stats(&self, action_id: Id) -> Option<QueueStats> {
let queue_arc = self.queues.get(&action_id)?.clone(); let queue_arcs: Vec<Arc<Mutex<ActionQueue>>> = self
let queue = queue_arc.lock().await; .queues
.iter()
.filter(|entry| entry.key().action_id == action_id)
.map(|entry| entry.value().clone())
.collect();
let oldest_enqueued_at = queue.queue.front().map(|e| e.enqueued_at); if queue_arcs.is_empty() {
return None;
}
let mut queue_length = 0usize;
let mut active_count = 0u32;
let mut max_concurrent = 0u32;
let mut oldest_enqueued_at: Option<DateTime<Utc>> = None;
let mut total_enqueued = 0u64;
let mut total_completed = 0u64;
for queue_arc in queue_arcs {
let queue = queue_arc.lock().await;
queue_length += queue.queue.len();
active_count += queue.active_count;
max_concurrent += queue.max_concurrent;
total_enqueued += queue.total_enqueued;
total_completed += queue.total_completed;
if let Some(candidate) = queue.queue.front().map(|e| e.enqueued_at) {
oldest_enqueued_at = Some(match oldest_enqueued_at {
Some(current) => current.min(candidate),
None => candidate,
});
}
}
Some(QueueStats { Some(QueueStats {
action_id, action_id,
queue_length: queue.queue.len(), queue_length,
active_count: queue.active_count, active_count,
max_concurrent: queue.max_concurrent, max_concurrent,
oldest_enqueued_at, oldest_enqueued_at,
total_enqueued: queue.total_enqueued, total_enqueued,
total_completed: queue.total_completed, total_completed,
}) })
} }
@@ -405,22 +811,15 @@ impl ExecutionQueueManager {
pub async fn get_all_queue_stats(&self) -> Vec<QueueStats> { pub async fn get_all_queue_stats(&self) -> Vec<QueueStats> {
let mut stats = Vec::new(); let mut stats = Vec::new();
let mut action_ids = std::collections::BTreeSet::new();
for entry in self.queues.iter() { for entry in self.queues.iter() {
let action_id = *entry.key(); action_ids.insert(entry.key().action_id);
let queue_arc = entry.value().clone(); }
let queue = queue_arc.lock().await;
let oldest_enqueued_at = queue.queue.front().map(|e| e.enqueued_at); for action_id in action_ids {
if let Some(action_stats) = self.get_queue_stats(action_id).await {
stats.push(QueueStats { stats.push(action_stats);
action_id, }
queue_length: queue.queue.len(),
active_count: queue.active_count,
max_concurrent: queue.max_concurrent,
oldest_enqueued_at,
total_enqueued: queue.total_enqueued,
total_completed: queue.total_completed,
});
} }
stats stats
@@ -445,27 +844,29 @@ impl ExecutionQueueManager {
execution_id, action_id execution_id, action_id
); );
let queue_arc = match self.queues.get(&action_id) { let queue_arcs: Vec<Arc<Mutex<ActionQueue>>> = self
Some(q) => q.clone(), .queues
None => return Ok(false), .iter()
}; .filter(|entry| entry.key().action_id == action_id)
.map(|entry| entry.value().clone())
.collect();
for queue_arc in queue_arcs {
let mut queue = queue_arc.lock().await; let mut queue = queue_arc.lock().await;
let initial_len = queue.queue.len(); let initial_len = queue.queue.len();
queue.queue.retain(|e| e.execution_id != execution_id); queue.queue.retain(|e| e.execution_id != execution_id);
let removed = initial_len != queue.queue.len(); if initial_len != queue.queue.len() {
if removed {
info!("Cancelled execution {} from queue", execution_id); info!("Cancelled execution {} from queue", execution_id);
} else { return Ok(true);
}
}
debug!( debug!(
"Execution {} not found in queue (may be running)", "Execution {} not found in queue (may be running)",
execution_id execution_id
); );
}
Ok(removed) Ok(false)
} }
/// Clear all queues (for testing or emergency situations) /// Clear all queues (for testing or emergency situations)
@@ -479,12 +880,17 @@ impl ExecutionQueueManager {
queue.queue.clear(); queue.queue.clear();
queue.active_count = 0; queue.active_count = 0;
} }
self.active_execution_keys.clear();
} }
/// Get the number of actions with active queues /// Get the number of actions with active queues
#[allow(dead_code)] #[allow(dead_code)]
pub fn active_queue_count(&self) -> usize { pub fn active_queue_count(&self) -> usize {
self.queues.len() self.queues
.iter()
.map(|entry| entry.key().action_id)
.collect::<std::collections::BTreeSet<_>>()
.len()
} }
} }
@@ -504,7 +910,7 @@ mod tests {
let manager = ExecutionQueueManager::with_defaults(); let manager = ExecutionQueueManager::with_defaults();
// Should execute immediately when there's capacity // Should execute immediately when there's capacity
let result = manager.enqueue_and_wait(1, 100, 2).await; let result = manager.enqueue_and_wait(1, 100, 2, None).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Check stats // Check stats
@@ -521,7 +927,7 @@ mod tests {
// First execution should run immediately // First execution should run immediately
let result = manager let result = manager
.enqueue_and_wait(action_id, 100, max_concurrent) .enqueue_and_wait(action_id, 100, max_concurrent, None)
.await; .await;
assert!(result.is_ok()); assert!(result.is_ok());
@@ -535,7 +941,7 @@ mod tests {
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
manager manager
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.unwrap(); .unwrap();
order.lock().await.push(exec_id); order.lock().await.push(exec_id);
@@ -553,9 +959,9 @@ mod tests {
assert_eq!(stats.active_count, 1); assert_eq!(stats.active_count, 1);
// Release them one by one // Release them one by one
for _ in 0..3 { for execution_id in 100..103 {
sleep(Duration::from_millis(50)).await; sleep(Duration::from_millis(50)).await;
manager.notify_completion(action_id).await.unwrap(); manager.notify_completion(execution_id).await.unwrap();
} }
// Wait for all to complete // Wait for all to complete
@@ -574,7 +980,10 @@ mod tests {
let action_id = 1; let action_id = 1;
// Start first execution // Start first execution
manager.enqueue_and_wait(action_id, 100, 1).await.unwrap(); manager
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
// Queue second execution // Queue second execution
let manager_clone = Arc::new(manager); let manager_clone = Arc::new(manager);
@@ -582,7 +991,7 @@ mod tests {
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
manager_ref manager_ref
.enqueue_and_wait(action_id, 101, 1) .enqueue_and_wait(action_id, 101, 1, None)
.await .await
.unwrap(); .unwrap();
}); });
@@ -596,7 +1005,7 @@ mod tests {
assert_eq!(stats.active_count, 1); assert_eq!(stats.active_count, 1);
// Notify completion // Notify completion
let notified = manager_clone.notify_completion(action_id).await.unwrap(); let notified = manager_clone.notify_completion(100).await.unwrap();
assert!(notified); assert!(notified);
// Wait for queued execution to proceed // Wait for queued execution to proceed
@@ -613,8 +1022,8 @@ mod tests {
let manager = Arc::new(ExecutionQueueManager::with_defaults()); let manager = Arc::new(ExecutionQueueManager::with_defaults());
// Start executions on different actions // Start executions on different actions
manager.enqueue_and_wait(1, 100, 1).await.unwrap(); manager.enqueue_and_wait(1, 100, 1, None).await.unwrap();
manager.enqueue_and_wait(2, 200, 1).await.unwrap(); manager.enqueue_and_wait(2, 200, 1, None).await.unwrap();
// Both should be active // Both should be active
let stats1 = manager.get_queue_stats(1).await.unwrap(); let stats1 = manager.get_queue_stats(1).await.unwrap();
@@ -624,7 +1033,7 @@ mod tests {
assert_eq!(stats2.active_count, 1); assert_eq!(stats2.active_count, 1);
// Completion on action 1 shouldn't affect action 2 // Completion on action 1 shouldn't affect action 2
manager.notify_completion(1).await.unwrap(); manager.notify_completion(100).await.unwrap();
let stats1 = manager.get_queue_stats(1).await.unwrap(); let stats1 = manager.get_queue_stats(1).await.unwrap();
let stats2 = manager.get_queue_stats(2).await.unwrap(); let stats2 = manager.get_queue_stats(2).await.unwrap();
@@ -633,20 +1042,43 @@ mod tests {
assert_eq!(stats2.active_count, 1); assert_eq!(stats2.active_count, 1);
} }
#[tokio::test]
async fn test_grouped_queues_are_independent() {
let manager = ExecutionQueueManager::with_defaults();
let action_id = 1;
manager
.enqueue_and_wait(action_id, 100, 1, Some("prod".to_string()))
.await
.unwrap();
manager
.enqueue_and_wait(action_id, 200, 1, Some("staging".to_string()))
.await
.unwrap();
let stats = manager.get_queue_stats(action_id).await.unwrap();
assert_eq!(stats.active_count, 2);
assert_eq!(stats.queue_length, 0);
assert_eq!(stats.max_concurrent, 2);
}
#[tokio::test] #[tokio::test]
async fn test_cancel_execution() { async fn test_cancel_execution() {
let manager = ExecutionQueueManager::with_defaults(); let manager = ExecutionQueueManager::with_defaults();
let action_id = 1; let action_id = 1;
// Fill capacity // Fill capacity
manager.enqueue_and_wait(action_id, 100, 1).await.unwrap(); manager
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
// Queue more executions // Queue more executions
let manager_arc = Arc::new(manager); let manager_arc = Arc::new(manager);
let manager_ref = manager_arc.clone(); let manager_ref = manager_arc.clone();
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
let result = manager_ref.enqueue_and_wait(action_id, 101, 1).await; let result = manager_ref.enqueue_and_wait(action_id, 101, 1, None).await;
result result
}); });
@@ -675,7 +1107,10 @@ mod tests {
assert!(manager.get_queue_stats(action_id).await.is_none()); assert!(manager.get_queue_stats(action_id).await.is_none());
// After enqueue, stats should exist // After enqueue, stats should exist
manager.enqueue_and_wait(action_id, 100, 2).await.unwrap(); manager
.enqueue_and_wait(action_id, 100, 2, None)
.await
.unwrap();
let stats = manager.get_queue_stats(action_id).await.unwrap(); let stats = manager.get_queue_stats(action_id).await.unwrap();
assert_eq!(stats.action_id, action_id); assert_eq!(stats.action_id, action_id);
@@ -696,13 +1131,16 @@ mod tests {
let action_id = 1; let action_id = 1;
// Fill capacity // Fill capacity
manager.enqueue_and_wait(action_id, 100, 1).await.unwrap(); manager
.enqueue_and_wait(action_id, 100, 1, None)
.await
.unwrap();
// Queue 2 more (should reach limit) // Queue 2 more (should reach limit)
let manager_ref = manager.clone(); let manager_ref = manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
manager_ref manager_ref
.enqueue_and_wait(action_id, 101, 1) .enqueue_and_wait(action_id, 101, 1, None)
.await .await
.unwrap(); .unwrap();
}); });
@@ -710,7 +1148,7 @@ mod tests {
let manager_ref = manager.clone(); let manager_ref = manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
manager_ref manager_ref
.enqueue_and_wait(action_id, 102, 1) .enqueue_and_wait(action_id, 102, 1, None)
.await .await
.unwrap(); .unwrap();
}); });
@@ -718,7 +1156,7 @@ mod tests {
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Next one should fail // Next one should fail
let result = manager.enqueue_and_wait(action_id, 103, 1).await; let result = manager.enqueue_and_wait(action_id, 103, 1, None).await;
assert!(result.is_err()); assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Queue full")); assert!(result.unwrap_err().to_string().contains("Queue full"));
} }
@@ -732,7 +1170,7 @@ mod tests {
// Start first execution // Start first execution
manager manager
.enqueue_and_wait(action_id, 0, max_concurrent) .enqueue_and_wait(action_id, 0, max_concurrent, None)
.await .await
.unwrap(); .unwrap();
@@ -746,7 +1184,7 @@ mod tests {
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
manager manager
.enqueue_and_wait(action_id, i, max_concurrent) .enqueue_and_wait(action_id, i, max_concurrent, None)
.await .await
.unwrap(); .unwrap();
order.lock().await.push(i); order.lock().await.push(i);
@@ -759,9 +1197,9 @@ mod tests {
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
// Release them all // Release them all
for _ in 0..num_executions { for execution_id in 0..num_executions {
sleep(Duration::from_millis(10)).await; sleep(Duration::from_millis(10)).await;
manager.notify_completion(action_id).await.unwrap(); manager.notify_completion(execution_id).await.unwrap();
} }
// Wait for completion // Wait for completion

View File

@@ -16,7 +16,7 @@ use attune_common::{
models::{enums::ExecutionStatus, execution::WorkflowTaskMetadata, Action, Execution, Runtime}, models::{enums::ExecutionStatus, execution::WorkflowTaskMetadata, Action, Execution, Runtime},
mq::{ mq::{
Consumer, ExecutionCompletedPayload, ExecutionRequestedPayload, MessageEnvelope, Consumer, ExecutionCompletedPayload, ExecutionRequestedPayload, MessageEnvelope,
MessageType, Publisher, MessageType, MqError, Publisher,
}, },
repositories::{ repositories::{
action::ActionRepository, action::ActionRepository,
@@ -40,6 +40,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::policy_enforcer::{PolicyEnforcer, SchedulingPolicyOutcome};
use crate::workflow::context::{TaskOutcome, WorkflowContext}; use crate::workflow::context::{TaskOutcome, WorkflowContext};
use crate::workflow::graph::TaskGraph; use crate::workflow::graph::TaskGraph;
@@ -108,6 +109,7 @@ pub struct ExecutionScheduler {
pool: PgPool, pool: PgPool,
publisher: Arc<Publisher>, publisher: Arc<Publisher>,
consumer: Arc<Consumer>, consumer: Arc<Consumer>,
policy_enforcer: Arc<PolicyEnforcer>,
/// Round-robin counter for distributing executions across workers /// Round-robin counter for distributing executions across workers
round_robin_counter: AtomicUsize, round_robin_counter: AtomicUsize,
} }
@@ -120,12 +122,31 @@ const DEFAULT_HEARTBEAT_INTERVAL: u64 = 30;
const HEARTBEAT_STALENESS_MULTIPLIER: u64 = 3; const HEARTBEAT_STALENESS_MULTIPLIER: u64 = 3;
impl ExecutionScheduler { impl ExecutionScheduler {
fn retryable_mq_error(error: &anyhow::Error) -> Option<MqError> {
let mq_error = error.downcast_ref::<MqError>()?;
Some(match mq_error {
MqError::Connection(msg) => MqError::Connection(msg.clone()),
MqError::Channel(msg) => MqError::Channel(msg.clone()),
MqError::Publish(msg) => MqError::Publish(msg.clone()),
MqError::Timeout(msg) => MqError::Timeout(msg.clone()),
MqError::Pool(msg) => MqError::Pool(msg.clone()),
MqError::Lapin(err) => MqError::Connection(err.to_string()),
_ => return None,
})
}
/// Create a new execution scheduler /// Create a new execution scheduler
pub fn new(pool: PgPool, publisher: Arc<Publisher>, consumer: Arc<Consumer>) -> Self { pub fn new(
pool: PgPool,
publisher: Arc<Publisher>,
consumer: Arc<Consumer>,
policy_enforcer: Arc<PolicyEnforcer>,
) -> Self {
Self { Self {
pool, pool,
publisher, publisher,
consumer, consumer,
policy_enforcer,
round_robin_counter: AtomicUsize::new(0), round_robin_counter: AtomicUsize::new(0),
} }
} }
@@ -136,6 +157,7 @@ impl ExecutionScheduler {
let pool = self.pool.clone(); let pool = self.pool.clone();
let publisher = self.publisher.clone(); let publisher = self.publisher.clone();
let policy_enforcer = self.policy_enforcer.clone();
// Share the counter with the handler closure via Arc. // Share the counter with the handler closure via Arc.
// We wrap &self's AtomicUsize in a new Arc<AtomicUsize> by copying the // We wrap &self's AtomicUsize in a new Arc<AtomicUsize> by copying the
// current value so the closure is 'static. // current value so the closure is 'static.
@@ -149,16 +171,24 @@ impl ExecutionScheduler {
move |envelope: MessageEnvelope<ExecutionRequestedPayload>| { move |envelope: MessageEnvelope<ExecutionRequestedPayload>| {
let pool = pool.clone(); let pool = pool.clone();
let publisher = publisher.clone(); let publisher = publisher.clone();
let policy_enforcer = policy_enforcer.clone();
let counter = counter.clone(); let counter = counter.clone();
async move { async move {
if let Err(e) = Self::process_execution_requested( if let Err(e) = Self::process_execution_requested(
&pool, &publisher, &counter, &envelope, &pool,
&publisher,
&policy_enforcer,
&counter,
&envelope,
) )
.await .await
{ {
error!("Error scheduling execution: {}", e); error!("Error scheduling execution: {}", e);
// Return error to trigger nack with requeue // Return error to trigger nack with requeue
if let Some(mq_err) = Self::retryable_mq_error(&e) {
return Err(mq_err);
}
return Err(format!("Failed to schedule execution: {}", e).into()); return Err(format!("Failed to schedule execution: {}", e).into());
} }
Ok(()) Ok(())
@@ -174,6 +204,7 @@ impl ExecutionScheduler {
async fn process_execution_requested( async fn process_execution_requested(
pool: &PgPool, pool: &PgPool,
publisher: &Publisher, publisher: &Publisher,
policy_enforcer: &PolicyEnforcer,
round_robin_counter: &AtomicUsize, round_robin_counter: &AtomicUsize,
envelope: &MessageEnvelope<ExecutionRequestedPayload>, envelope: &MessageEnvelope<ExecutionRequestedPayload>,
) -> Result<()> { ) -> Result<()> {
@@ -184,9 +215,30 @@ impl ExecutionScheduler {
info!("Scheduling execution: {}", execution_id); info!("Scheduling execution: {}", execution_id);
// Fetch execution from database // Fetch execution from database
let execution = ExecutionRepository::find_by_id(pool, execution_id) let execution = match ExecutionRepository::find_by_id(pool, execution_id).await? {
.await? Some(execution) => execution,
.ok_or_else(|| anyhow::anyhow!("Execution not found: {}", execution_id))?; None => {
warn!("Execution {} not found during scheduling", execution_id);
Self::remove_queued_policy_execution(
policy_enforcer,
pool,
publisher,
execution_id,
)
.await;
return Ok(());
}
};
if execution.status != ExecutionStatus::Requested {
debug!(
"Skipping execution {} with status {:?}; only Requested executions are schedulable",
execution_id, execution.status
);
Self::remove_queued_policy_execution(policy_enforcer, pool, publisher, execution_id)
.await;
return Ok(());
}
// Fetch action to determine runtime requirements // Fetch action to determine runtime requirements
let action = Self::get_action_for_execution(pool, &execution).await?; let action = Self::get_action_for_execution(pool, &execution).await?;
@@ -207,30 +259,6 @@ impl ExecutionScheduler {
.await; .await;
} }
// Regular action: select appropriate worker (round-robin among compatible workers)
let worker = match Self::select_worker(pool, &action, round_robin_counter).await {
Ok(worker) => worker,
Err(err) if Self::is_unschedulable_error(&err) => {
Self::fail_unschedulable_execution(
pool,
publisher,
envelope,
execution_id,
action.id,
&action.r#ref,
&err.to_string(),
)
.await?;
return Ok(());
}
Err(err) => return Err(err),
};
info!(
"Selected worker {} for execution {}",
worker.id, execution_id
);
// Apply parameter defaults from the action's param_schema. // Apply parameter defaults from the action's param_schema.
// This mirrors what `process_workflow_execution` does for workflows // This mirrors what `process_workflow_execution` does for workflows
// so that non-workflow executions also get missing parameters filled // so that non-workflow executions also get missing parameters filled
@@ -249,16 +277,97 @@ impl ExecutionScheduler {
} }
}; };
match policy_enforcer
.enforce_for_scheduling(
action.id,
Some(action.pack),
execution_id,
execution_config.as_ref(),
)
.await
{
Ok(SchedulingPolicyOutcome::Queued) => {
info!(
"Execution {} queued by policy for action {}; deferring worker selection",
execution_id, action.id
);
return Ok(());
}
Ok(SchedulingPolicyOutcome::Ready) => {}
Err(err) => {
if Self::is_policy_cancellation_error(&err) {
Self::remove_queued_policy_execution(
policy_enforcer,
pool,
publisher,
execution_id,
)
.await;
Self::cancel_execution_for_policy_violation(
pool,
publisher,
envelope,
execution_id,
action.id,
&action.r#ref,
&err.to_string(),
)
.await?;
return Ok(());
}
return Err(err);
}
}
// Regular action: select appropriate worker only after policy
// readiness is confirmed, so queued executions don't reserve stale
// workers while they wait.
let worker = match Self::select_worker(pool, &action, round_robin_counter).await {
Ok(worker) => worker,
Err(err) if Self::is_unschedulable_error(&err) => {
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?;
Self::fail_unschedulable_execution(
pool,
publisher,
envelope,
execution_id,
action.id,
&action.r#ref,
&err.to_string(),
)
.await?;
return Ok(());
}
Err(err) => {
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?;
return Err(err);
}
};
info!(
"Selected worker {} for execution {}",
worker.id, execution_id
);
// Persist the selected worker so later cancellation requests can be // Persist the selected worker so later cancellation requests can be
// routed to the correct per-worker cancel queue. // routed to the correct per-worker cancel queue.
let mut execution_for_update = execution; let mut execution_for_update = execution;
execution_for_update.status = ExecutionStatus::Scheduled; execution_for_update.status = ExecutionStatus::Scheduled;
execution_for_update.worker = Some(worker.id); execution_for_update.worker = Some(worker.id);
if let Err(err) =
ExecutionRepository::update(pool, execution_for_update.id, execution_for_update.into()) ExecutionRepository::update(pool, execution_for_update.id, execution_for_update.into())
.await
{
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?; .await?;
return Err(err.into());
}
// Publish message to worker-specific queue // Publish message to worker-specific queue
Self::queue_to_worker( if let Err(err) = Self::queue_to_worker(
publisher, publisher,
&execution_id, &execution_id,
&worker.id, &worker.id,
@@ -266,7 +375,27 @@ impl ExecutionScheduler {
&execution_config, &execution_config,
&action, &action,
) )
.await
{
if let Err(revert_err) = ExecutionRepository::update(
pool,
execution_id,
UpdateExecutionInput {
status: Some(ExecutionStatus::Requested),
..Default::default()
},
)
.await
{
warn!(
"Failed to revert execution {} back to Requested after worker publish error: {}",
execution_id, revert_err
);
}
Self::release_acquired_policy_slot(policy_enforcer, pool, publisher, execution_id)
.await?; .await?;
return Err(err);
}
info!( info!(
"Execution {} scheduled to worker {}", "Execution {} scheduled to worker {}",
@@ -1698,6 +1827,131 @@ impl ExecutionScheduler {
|| message.starts_with("No workers with fresh heartbeats available") || message.starts_with("No workers with fresh heartbeats available")
} }
fn is_policy_cancellation_error(error: &anyhow::Error) -> bool {
let message = error.to_string();
message.contains("Policy violation:")
|| message.starts_with("Queue full for action ")
|| message.starts_with("Queue timeout for execution ")
}
async fn release_acquired_policy_slot(
policy_enforcer: &PolicyEnforcer,
pool: &PgPool,
publisher: &Publisher,
execution_id: i64,
) -> Result<()> {
let release = match policy_enforcer.release_execution_slot(execution_id).await {
Ok(release) => release,
Err(release_err) => {
warn!(
"Failed to release acquired policy slot for execution {} after scheduling error: {}",
execution_id, release_err
);
return Err(release_err);
}
};
let Some(release) = release else {
return Ok(());
};
if let Some(next_execution_id) = release.next_execution_id {
if let Err(republish_err) =
Self::republish_execution_requested(pool, publisher, next_execution_id).await
{
warn!(
"Failed to republish deferred execution {} after releasing slot from execution {}: {}",
next_execution_id, execution_id, republish_err
);
if let Err(restore_err) = policy_enforcer
.restore_execution_slot(execution_id, &release)
.await
{
warn!(
"Failed to restore policy slot for execution {} after republish error: {}",
execution_id, restore_err
);
}
return Err(republish_err);
}
}
Ok(())
}
async fn remove_queued_policy_execution(
policy_enforcer: &PolicyEnforcer,
pool: &PgPool,
publisher: &Publisher,
execution_id: i64,
) {
let removal = match policy_enforcer.remove_queued_execution(execution_id).await {
Ok(removal) => removal,
Err(remove_err) => {
warn!(
"Failed to remove queued policy execution {} during scheduler cleanup: {}",
execution_id, remove_err
);
return;
}
};
let Some(removal) = removal else {
return;
};
if let Some(next_execution_id) = removal.next_execution_id {
if let Err(republish_err) =
Self::republish_execution_requested(pool, publisher, next_execution_id).await
{
warn!(
"Failed to republish successor {} after removing queued execution {}: {}",
next_execution_id, execution_id, republish_err
);
if let Err(restore_err) = policy_enforcer.restore_queued_execution(&removal).await {
warn!(
"Failed to restore queued execution {} after republish error: {}",
execution_id, restore_err
);
}
}
}
}
async fn republish_execution_requested(
pool: &PgPool,
publisher: &Publisher,
execution_id: i64,
) -> Result<()> {
let execution = ExecutionRepository::find_by_id(pool, execution_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Execution {} not found", execution_id))?;
let action_id = execution
.action
.ok_or_else(|| anyhow::anyhow!("Execution {} has no action", execution_id))?;
let payload = ExecutionRequestedPayload {
execution_id,
action_id: Some(action_id),
action_ref: execution.action_ref.clone(),
parent_id: execution.parent,
enforcement_id: execution.enforcement,
config: execution.config.clone(),
};
let envelope = MessageEnvelope::new(MessageType::ExecutionRequested, payload)
.with_source("executor-scheduler");
publisher.publish_envelope(&envelope).await?;
debug!(
"Republished deferred ExecutionRequested for execution {}",
execution_id
);
Ok(())
}
async fn fail_unschedulable_execution( async fn fail_unschedulable_execution(
pool: &PgPool, pool: &PgPool,
publisher: &Publisher, publisher: &Publisher,
@@ -1751,6 +2005,59 @@ impl ExecutionScheduler {
Ok(()) Ok(())
} }
async fn cancel_execution_for_policy_violation(
pool: &PgPool,
publisher: &Publisher,
envelope: &MessageEnvelope<ExecutionRequestedPayload>,
execution_id: i64,
action_id: i64,
action_ref: &str,
error_message: &str,
) -> Result<()> {
let completed_at = Utc::now();
let result = serde_json::json!({
"error": "Execution cancelled by policy",
"message": error_message,
"action_ref": action_ref,
"cancelled_by": "execution_scheduler",
"cancelled_at": completed_at.to_rfc3339(),
});
ExecutionRepository::update(
pool,
execution_id,
UpdateExecutionInput {
status: Some(ExecutionStatus::Cancelled),
result: Some(result.clone()),
..Default::default()
},
)
.await?;
let completed = MessageEnvelope::new(
MessageType::ExecutionCompleted,
ExecutionCompletedPayload {
execution_id,
action_id,
action_ref: action_ref.to_string(),
status: "cancelled".to_string(),
result: Some(result),
completed_at,
},
)
.with_correlation_id(envelope.correlation_id)
.with_source("attune-executor");
publisher.publish_envelope(&completed).await?;
warn!(
"Execution {} cancelled due to policy violation: {}",
execution_id, error_message
);
Ok(())
}
/// Check if a worker's heartbeat is fresh enough to schedule work /// Check if a worker's heartbeat is fresh enough to schedule work
/// ///
/// A worker is considered fresh if its last heartbeat is within /// A worker is considered fresh if its last heartbeat is within
@@ -1980,6 +2287,24 @@ mod tests {
)); ));
} }
#[test]
fn test_policy_cancellation_error_classification() {
assert!(ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!(
"Policy violation: Concurrency limit exceeded: 1 running executions (limit: 1)"
)
));
assert!(ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!("Queue full for action 42: maximum 100 entries")
));
assert!(ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!("Queue timeout for execution 99: waited 60 seconds")
));
assert!(!ExecutionScheduler::is_policy_cancellation_error(
&anyhow::anyhow!("rabbitmq publish failed")
));
}
#[test] #[test]
fn test_concurrency_limit_dispatch_count() { fn test_concurrency_limit_dispatch_count() {
// Verify the dispatch_count calculation used by dispatch_with_items_task // Verify the dispatch_count calculation used by dispatch_with_items_task

View File

@@ -297,6 +297,7 @@ impl ExecutorService {
self.inner.pool.clone(), self.inner.pool.clone(),
self.inner.publisher.clone(), self.inner.publisher.clone(),
Arc::new(scheduler_consumer), Arc::new(scheduler_consumer),
self.inner.policy_enforcer.clone(),
); );
handles.push(tokio::spawn(async move { scheduler.start().await })); handles.push(tokio::spawn(async move { scheduler.start().await }));

View File

@@ -199,7 +199,7 @@ async fn test_fifo_ordering_with_database() {
let first_exec_id = let first_exec_id =
create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await;
manager manager
.enqueue_and_wait(action_id, first_exec_id, max_concurrent) .enqueue_and_wait(action_id, first_exec_id, max_concurrent, None)
.await .await
.expect("First execution should enqueue"); .expect("First execution should enqueue");
@@ -222,7 +222,7 @@ async fn test_fifo_ordering_with_database() {
// Enqueue and wait // Enqueue and wait
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.expect("Enqueue should succeed"); .expect("Enqueue should succeed");
@@ -316,7 +316,7 @@ async fn test_high_concurrency_stress() {
.await; .await;
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.expect("Enqueue should succeed"); .expect("Enqueue should succeed");
@@ -343,7 +343,7 @@ async fn test_high_concurrency_stress() {
.await; .await;
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.expect("Enqueue should succeed"); .expect("Enqueue should succeed");
@@ -479,7 +479,7 @@ async fn test_multiple_workers_simulation() {
.await; .await;
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.expect("Enqueue should succeed"); .expect("Enqueue should succeed");
@@ -596,7 +596,7 @@ async fn test_cross_action_independence() {
.await; .await;
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, 1) .enqueue_and_wait(action_id, exec_id, 1, None)
.await .await
.expect("Enqueue should succeed"); .expect("Enqueue should succeed");
@@ -699,7 +699,7 @@ async fn test_cancellation_during_queue() {
let exec_id = let exec_id =
create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await;
manager manager
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.unwrap(); .unwrap();
@@ -722,7 +722,7 @@ async fn test_cancellation_during_queue() {
ids.lock().await.push(exec_id); ids.lock().await.push(exec_id);
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
}); });
@@ -808,7 +808,7 @@ async fn test_queue_stats_persistence() {
let manager_clone = manager.clone(); let manager_clone = manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.ok(); .ok();
}); });
@@ -888,7 +888,7 @@ async fn test_queue_full_rejection() {
let exec_id = let exec_id =
create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await;
manager manager
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.unwrap(); .unwrap();
@@ -900,7 +900,7 @@ async fn test_queue_full_rejection() {
tokio::spawn(async move { tokio::spawn(async move {
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.ok(); .ok();
}); });
@@ -917,7 +917,7 @@ async fn test_queue_full_rejection() {
let exec_id = let exec_id =
create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await; create_test_execution(&pool, action_id, &action_ref, ExecutionStatus::Requested).await;
let result = manager let result = manager
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await; .await;
assert!(result.is_err(), "Should reject when queue is full"); assert!(result.is_err(), "Should reject when queue is full");
@@ -976,7 +976,7 @@ async fn test_extreme_stress_10k_executions() {
.await; .await;
manager_clone manager_clone
.enqueue_and_wait(action_id, exec_id, max_concurrent) .enqueue_and_wait(action_id, exec_id, max_concurrent, None)
.await .await
.expect("Enqueue should succeed"); .expect("Enqueue should succeed");

View File

@@ -9,7 +9,7 @@
use attune_common::{ use attune_common::{
config::Config, config::Config,
db::Database, db::Database,
models::enums::ExecutionStatus, models::enums::{ExecutionStatus, PolicyMethod},
repositories::{ repositories::{
action::{ActionRepository, CreateActionInput}, action::{ActionRepository, CreateActionInput},
execution::{CreateExecutionInput, ExecutionRepository}, execution::{CreateExecutionInput, ExecutionRepository},
@@ -190,6 +190,8 @@ async fn test_global_rate_limit() {
window_seconds: 60, window_seconds: 60,
}), }),
concurrency_limit: None, concurrency_limit: None,
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None, quotas: None,
}; };
@@ -242,6 +244,8 @@ async fn test_concurrency_limit() {
let policy = ExecutionPolicy { let policy = ExecutionPolicy {
rate_limit: None, rate_limit: None,
concurrency_limit: Some(2), concurrency_limit: Some(2),
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None, quotas: None,
}; };
@@ -300,6 +304,8 @@ async fn test_action_specific_policy() {
window_seconds: 60, window_seconds: 60,
}), }),
concurrency_limit: None, concurrency_limit: None,
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None, quotas: None,
}; };
enforcer.set_action_policy(action_id, action_policy); enforcer.set_action_policy(action_id, action_policy);
@@ -345,6 +351,8 @@ async fn test_pack_specific_policy() {
let pack_policy = ExecutionPolicy { let pack_policy = ExecutionPolicy {
rate_limit: None, rate_limit: None,
concurrency_limit: Some(1), concurrency_limit: Some(1),
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None, quotas: None,
}; };
enforcer.set_pack_policy(pack_id, pack_policy); enforcer.set_pack_policy(pack_id, pack_policy);
@@ -388,6 +396,8 @@ async fn test_policy_priority() {
window_seconds: 60, window_seconds: 60,
}), }),
concurrency_limit: None, concurrency_limit: None,
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None, quotas: None,
}; };
let mut enforcer = PolicyEnforcer::with_global_policy(pool.clone(), global_policy); let mut enforcer = PolicyEnforcer::with_global_policy(pool.clone(), global_policy);
@@ -399,6 +409,8 @@ async fn test_policy_priority() {
window_seconds: 60, window_seconds: 60,
}), }),
concurrency_limit: None, concurrency_limit: None,
concurrency_method: PolicyMethod::Enqueue,
concurrency_parameters: Vec::new(),
quotas: None, quotas: None,
}; };
enforcer.set_action_policy(action_id, action_policy); enforcer.set_action_policy(action_id, action_policy);