concurrent action execution

This commit is contained in:
2026-02-25 14:16:56 -06:00
parent adb9f30464
commit e89b5991ec
6 changed files with 257 additions and 74 deletions

View File

@@ -19,8 +19,8 @@ use sqlx::PgPool;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio::sync::{Mutex, RwLock, Semaphore};
use tokio::task::{JoinHandle, JoinSet};
use tracing::{debug, error, info, warn};
use crate::artifacts::ArtifactManager;
@@ -67,6 +67,10 @@ pub struct WorkerService {
packs_base_dir: PathBuf,
/// Base directory for isolated runtime environments
runtime_envs_dir: PathBuf,
/// Semaphore to limit concurrent executions
execution_semaphore: Arc<Semaphore>,
/// Tracks in-flight execution tasks for graceful shutdown
in_flight_tasks: Arc<Mutex<JoinSet<()>>>,
}
impl WorkerService {
@@ -292,6 +296,17 @@ impl WorkerService {
// Capture the runtime filter for use in env setup
let runtime_filter_for_service = runtime_filter.clone();
// Read max concurrent tasks from config
let max_concurrent_tasks = config
.worker
.as_ref()
.map(|w| w.max_concurrent_tasks)
.unwrap_or(10);
info!(
"Worker configured for max {} concurrent executions",
max_concurrent_tasks
);
Ok(Self {
config,
db_pool: pool,
@@ -308,6 +323,8 @@ impl WorkerService {
runtime_filter: runtime_filter_for_service,
packs_base_dir,
runtime_envs_dir,
execution_semaphore: Arc::new(Semaphore::new(max_concurrent_tasks)),
in_flight_tasks: Arc::new(Mutex::new(JoinSet::new())),
})
}
@@ -563,17 +580,24 @@ impl WorkerService {
/// Wait for in-flight tasks to complete
async fn wait_for_in_flight_tasks(&self) {
// Poll for active executions with short intervals
loop {
// Check if executor has any active tasks
// Note: This is a simplified check. In a real implementation,
// we would track active execution count in the executor.
tokio::time::sleep(Duration::from_millis(500)).await;
let remaining = {
let mut tasks = self.in_flight_tasks.lock().await;
// Drain any already-completed tasks
while tasks.try_join_next().is_some() {}
tasks.len()
};
// TODO: Add proper tracking of active executions in ActionExecutor
// For now, we just wait a reasonable amount of time
// This will be improved when we add execution tracking
break;
if remaining == 0 {
info!("All in-flight execution tasks have completed");
break;
}
info!(
"Waiting for {} in-flight execution task(s) to complete...",
remaining
);
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
@@ -581,6 +605,10 @@ impl WorkerService {
///
/// Spawns the consumer loop as a background task so that `start()` returns
/// immediately, allowing the caller to set up signal handlers.
///
/// Executions are spawned as concurrent background tasks, limited by the
/// `execution_semaphore`. The consumer blocks when the concurrency limit is
/// reached, providing natural backpressure via RabbitMQ.
async fn start_execution_consumer(&mut self) -> Result<()> {
let worker_id = self
.worker_id
@@ -589,7 +617,19 @@ impl WorkerService {
// Queue name for this worker (already created in setup_worker_infrastructure)
let queue_name = format!("worker.{}.executions", worker_id);
info!("Starting consumer for worker queue: {}", queue_name);
// Set prefetch slightly above max concurrent tasks to keep the pipeline filled
let max_concurrent = self
.config
.worker
.as_ref()
.map(|w| w.max_concurrent_tasks)
.unwrap_or(10);
let prefetch_count = (max_concurrent as u16).saturating_add(2);
info!(
"Starting consumer for worker queue: {} (prefetch: {}, max_concurrent: {})",
queue_name, prefetch_count, max_concurrent
);
// Create consumer
let consumer = Arc::new(
@@ -598,7 +638,7 @@ impl WorkerService {
ConsumerConfig {
queue: queue_name.clone(),
tag: format!("worker-{}", worker_id),
prefetch_count: 10,
prefetch_count,
auto_ack: false,
exclusive: false,
},
@@ -615,6 +655,8 @@ impl WorkerService {
let db_pool = self.db_pool.clone();
let consumer_for_task = consumer.clone();
let queue_name_for_log = queue_name.clone();
let semaphore = self.execution_semaphore.clone();
let in_flight = self.in_flight_tasks.clone();
// Spawn the consumer loop as a background task so start() can return
let handle = tokio::spawn(async move {
@@ -625,11 +667,47 @@ impl WorkerService {
let executor = executor.clone();
let publisher = publisher.clone();
let db_pool = db_pool.clone();
let semaphore = semaphore.clone();
let in_flight = in_flight.clone();
async move {
Self::handle_execution_scheduled(executor, publisher, db_pool, envelope)
let execution_id = envelope.payload.execution_id;
// Acquire a concurrency permit. This blocks if we're at the
// max concurrent execution limit, providing natural backpressure:
// the message won't be acked until we can actually start working,
// so RabbitMQ will stop delivering once prefetch is exhausted.
let permit = semaphore.clone().acquire_owned().await.map_err(|_| {
attune_common::mq::error::MqError::Channel(
"Execution semaphore closed".to_string(),
)
})?;
info!(
"Acquired execution permit for execution {} ({} permits remaining)",
execution_id,
semaphore.available_permits()
);
// Spawn the actual execution as a background task so this
// handler returns immediately, acking the message and freeing
// the consumer loop to process the next delivery.
let mut tasks = in_flight.lock().await;
tasks.spawn(async move {
// The permit is moved into this task and will be released
// when the task completes (on drop).
let _permit = permit;
if let Err(e) = Self::handle_execution_scheduled(
executor, publisher, db_pool, envelope,
)
.await
.map_err(|e| format!("Execution handler error: {}", e).into())
{
error!("Execution {} handler error: {}", execution_id, e);
}
});
Ok(())
}
},
)