concurrent action execution
This commit is contained in:
@@ -22,6 +22,7 @@ use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -40,6 +41,8 @@ pub struct ExecutionScheduler {
|
||||
pool: PgPool,
|
||||
publisher: Arc<Publisher>,
|
||||
consumer: Arc<Consumer>,
|
||||
/// Round-robin counter for distributing executions across workers
|
||||
round_robin_counter: AtomicUsize,
|
||||
}
|
||||
|
||||
/// Default heartbeat interval in seconds (should match worker config default)
|
||||
@@ -56,6 +59,7 @@ impl ExecutionScheduler {
|
||||
pool,
|
||||
publisher,
|
||||
consumer,
|
||||
round_robin_counter: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +69,12 @@ impl ExecutionScheduler {
|
||||
|
||||
let pool = self.pool.clone();
|
||||
let publisher = self.publisher.clone();
|
||||
// Share the counter with the handler closure via Arc.
|
||||
// We wrap &self's AtomicUsize in a new Arc<AtomicUsize> by copying the
|
||||
// current value so the closure is 'static.
|
||||
let counter = Arc::new(AtomicUsize::new(
|
||||
self.round_robin_counter.load(Ordering::Relaxed),
|
||||
));
|
||||
|
||||
// Use the handler pattern to consume messages
|
||||
self.consumer
|
||||
@@ -72,10 +82,13 @@ impl ExecutionScheduler {
|
||||
move |envelope: MessageEnvelope<ExecutionRequestedPayload>| {
|
||||
let pool = pool.clone();
|
||||
let publisher = publisher.clone();
|
||||
let counter = counter.clone();
|
||||
|
||||
async move {
|
||||
if let Err(e) =
|
||||
Self::process_execution_requested(&pool, &publisher, &envelope).await
|
||||
if let Err(e) = Self::process_execution_requested(
|
||||
&pool, &publisher, &counter, &envelope,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Error scheduling execution: {}", e);
|
||||
// Return error to trigger nack with requeue
|
||||
@@ -94,6 +107,7 @@ impl ExecutionScheduler {
|
||||
async fn process_execution_requested(
|
||||
pool: &PgPool,
|
||||
publisher: &Publisher,
|
||||
round_robin_counter: &AtomicUsize,
|
||||
envelope: &MessageEnvelope<ExecutionRequestedPayload>,
|
||||
) -> Result<()> {
|
||||
debug!("Processing execution requested message: {:?}", envelope);
|
||||
@@ -110,8 +124,8 @@ impl ExecutionScheduler {
|
||||
// Fetch action to determine runtime requirements
|
||||
let action = Self::get_action_for_execution(pool, &execution).await?;
|
||||
|
||||
// Select appropriate worker
|
||||
let worker = Self::select_worker(pool, &action).await?;
|
||||
// Select appropriate worker (round-robin among compatible workers)
|
||||
let worker = Self::select_worker(pool, &action, round_robin_counter).await?;
|
||||
|
||||
info!(
|
||||
"Selected worker {} for execution {}",
|
||||
@@ -158,9 +172,13 @@ impl ExecutionScheduler {
|
||||
}
|
||||
|
||||
/// Select an appropriate worker for the execution
|
||||
///
|
||||
/// Uses round-robin selection among compatible, active, and healthy workers
|
||||
/// to distribute load evenly across the worker pool.
|
||||
async fn select_worker(
|
||||
pool: &PgPool,
|
||||
action: &Action,
|
||||
round_robin_counter: &AtomicUsize,
|
||||
) -> Result<attune_common::models::Worker> {
|
||||
// Get runtime requirements for the action
|
||||
let runtime = if let Some(runtime_id) = action.runtime {
|
||||
@@ -219,17 +237,21 @@ impl ExecutionScheduler {
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Implement intelligent worker selection:
|
||||
// - Consider worker load/capacity
|
||||
// - Consider worker affinity (same pack, same runtime)
|
||||
// - Consider geographic locality
|
||||
// - Round-robin or least-connections strategy
|
||||
|
||||
// For now, just select the first available worker
|
||||
Ok(fresh_workers
|
||||
// Round-robin selection: distribute executions evenly across workers.
|
||||
// Each call increments the counter and picks the next worker in the list.
|
||||
let count = round_robin_counter.fetch_add(1, Ordering::Relaxed);
|
||||
let index = count % fresh_workers.len();
|
||||
let selected = fresh_workers
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("Worker list should not be empty"))
|
||||
.nth(index)
|
||||
.expect("Worker list should not be empty");
|
||||
|
||||
info!(
|
||||
"Selected worker {} (id={}) via round-robin (index {} of available workers)",
|
||||
selected.name, selected.id, index
|
||||
);
|
||||
|
||||
Ok(selected)
|
||||
}
|
||||
|
||||
/// Check if a worker supports a given runtime
|
||||
@@ -291,7 +313,8 @@ impl ExecutionScheduler {
|
||||
|
||||
let now = Utc::now();
|
||||
let age = now.signed_duration_since(last_heartbeat);
|
||||
let max_age = Duration::from_secs(DEFAULT_HEARTBEAT_INTERVAL * HEARTBEAT_STALENESS_MULTIPLIER);
|
||||
let max_age =
|
||||
Duration::from_secs(DEFAULT_HEARTBEAT_INTERVAL * HEARTBEAT_STALENESS_MULTIPLIER);
|
||||
|
||||
let is_fresh = age.to_std().unwrap_or(Duration::MAX) <= max_age;
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ use sqlx::PgPool;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::{Mutex, RwLock, Semaphore};
|
||||
use tokio::task::{JoinHandle, JoinSet};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::artifacts::ArtifactManager;
|
||||
@@ -67,6 +67,10 @@ pub struct WorkerService {
|
||||
packs_base_dir: PathBuf,
|
||||
/// Base directory for isolated runtime environments
|
||||
runtime_envs_dir: PathBuf,
|
||||
/// Semaphore to limit concurrent executions
|
||||
execution_semaphore: Arc<Semaphore>,
|
||||
/// Tracks in-flight execution tasks for graceful shutdown
|
||||
in_flight_tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl WorkerService {
|
||||
@@ -292,6 +296,17 @@ impl WorkerService {
|
||||
// Capture the runtime filter for use in env setup
|
||||
let runtime_filter_for_service = runtime_filter.clone();
|
||||
|
||||
// Read max concurrent tasks from config
|
||||
let max_concurrent_tasks = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.map(|w| w.max_concurrent_tasks)
|
||||
.unwrap_or(10);
|
||||
info!(
|
||||
"Worker configured for max {} concurrent executions",
|
||||
max_concurrent_tasks
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
db_pool: pool,
|
||||
@@ -308,6 +323,8 @@ impl WorkerService {
|
||||
runtime_filter: runtime_filter_for_service,
|
||||
packs_base_dir,
|
||||
runtime_envs_dir,
|
||||
execution_semaphore: Arc::new(Semaphore::new(max_concurrent_tasks)),
|
||||
in_flight_tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -563,17 +580,24 @@ impl WorkerService {
|
||||
|
||||
/// Wait for in-flight tasks to complete
|
||||
async fn wait_for_in_flight_tasks(&self) {
|
||||
// Poll for active executions with short intervals
|
||||
loop {
|
||||
// Check if executor has any active tasks
|
||||
// Note: This is a simplified check. In a real implementation,
|
||||
// we would track active execution count in the executor.
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
let remaining = {
|
||||
let mut tasks = self.in_flight_tasks.lock().await;
|
||||
// Drain any already-completed tasks
|
||||
while tasks.try_join_next().is_some() {}
|
||||
tasks.len()
|
||||
};
|
||||
|
||||
// TODO: Add proper tracking of active executions in ActionExecutor
|
||||
// For now, we just wait a reasonable amount of time
|
||||
// This will be improved when we add execution tracking
|
||||
break;
|
||||
if remaining == 0 {
|
||||
info!("All in-flight execution tasks have completed");
|
||||
break;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Waiting for {} in-flight execution task(s) to complete...",
|
||||
remaining
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,6 +605,10 @@ impl WorkerService {
|
||||
///
|
||||
/// Spawns the consumer loop as a background task so that `start()` returns
|
||||
/// immediately, allowing the caller to set up signal handlers.
|
||||
///
|
||||
/// Executions are spawned as concurrent background tasks, limited by the
|
||||
/// `execution_semaphore`. The consumer blocks when the concurrency limit is
|
||||
/// reached, providing natural backpressure via RabbitMQ.
|
||||
async fn start_execution_consumer(&mut self) -> Result<()> {
|
||||
let worker_id = self
|
||||
.worker_id
|
||||
@@ -589,7 +617,19 @@ impl WorkerService {
|
||||
// Queue name for this worker (already created in setup_worker_infrastructure)
|
||||
let queue_name = format!("worker.{}.executions", worker_id);
|
||||
|
||||
info!("Starting consumer for worker queue: {}", queue_name);
|
||||
// Set prefetch slightly above max concurrent tasks to keep the pipeline filled
|
||||
let max_concurrent = self
|
||||
.config
|
||||
.worker
|
||||
.as_ref()
|
||||
.map(|w| w.max_concurrent_tasks)
|
||||
.unwrap_or(10);
|
||||
let prefetch_count = (max_concurrent as u16).saturating_add(2);
|
||||
|
||||
info!(
|
||||
"Starting consumer for worker queue: {} (prefetch: {}, max_concurrent: {})",
|
||||
queue_name, prefetch_count, max_concurrent
|
||||
);
|
||||
|
||||
// Create consumer
|
||||
let consumer = Arc::new(
|
||||
@@ -598,7 +638,7 @@ impl WorkerService {
|
||||
ConsumerConfig {
|
||||
queue: queue_name.clone(),
|
||||
tag: format!("worker-{}", worker_id),
|
||||
prefetch_count: 10,
|
||||
prefetch_count,
|
||||
auto_ack: false,
|
||||
exclusive: false,
|
||||
},
|
||||
@@ -615,6 +655,8 @@ impl WorkerService {
|
||||
let db_pool = self.db_pool.clone();
|
||||
let consumer_for_task = consumer.clone();
|
||||
let queue_name_for_log = queue_name.clone();
|
||||
let semaphore = self.execution_semaphore.clone();
|
||||
let in_flight = self.in_flight_tasks.clone();
|
||||
|
||||
// Spawn the consumer loop as a background task so start() can return
|
||||
let handle = tokio::spawn(async move {
|
||||
@@ -625,11 +667,47 @@ impl WorkerService {
|
||||
let executor = executor.clone();
|
||||
let publisher = publisher.clone();
|
||||
let db_pool = db_pool.clone();
|
||||
let semaphore = semaphore.clone();
|
||||
let in_flight = in_flight.clone();
|
||||
|
||||
async move {
|
||||
Self::handle_execution_scheduled(executor, publisher, db_pool, envelope)
|
||||
let execution_id = envelope.payload.execution_id;
|
||||
|
||||
// Acquire a concurrency permit. This blocks if we're at the
|
||||
// max concurrent execution limit, providing natural backpressure:
|
||||
// the message won't be acked until we can actually start working,
|
||||
// so RabbitMQ will stop delivering once prefetch is exhausted.
|
||||
let permit = semaphore.clone().acquire_owned().await.map_err(|_| {
|
||||
attune_common::mq::error::MqError::Channel(
|
||||
"Execution semaphore closed".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"Acquired execution permit for execution {} ({} permits remaining)",
|
||||
execution_id,
|
||||
semaphore.available_permits()
|
||||
);
|
||||
|
||||
// Spawn the actual execution as a background task so this
|
||||
// handler returns immediately, acking the message and freeing
|
||||
// the consumer loop to process the next delivery.
|
||||
let mut tasks = in_flight.lock().await;
|
||||
tasks.spawn(async move {
|
||||
// The permit is moved into this task and will be released
|
||||
// when the task completes (on drop).
|
||||
let _permit = permit;
|
||||
|
||||
if let Err(e) = Self::handle_execution_scheduled(
|
||||
executor, publisher, db_pool, envelope,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("Execution handler error: {}", e).into())
|
||||
{
|
||||
error!("Execution {} handler error: {}", execution_id, e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user