re-uploading work
This commit is contained in:
35
crates/worker/Cargo.toml
Normal file
35
crates/worker/Cargo.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
[package]
|
||||
name = "attune-worker"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "attune-worker"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
attune-common = { path = "../common" }
|
||||
tokio = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
config = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
lapin = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
hostname = "0.4"
|
||||
async-trait = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
aes-gcm = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
365
crates/worker/src/artifacts.rs
Normal file
365
crates/worker/src/artifacts.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
//! Artifacts Module
|
||||
//!
|
||||
//! Handles storage and retrieval of execution artifacts (logs, outputs, results).
|
||||
|
||||
use attune_common::error::{Error, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Artifact type
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ArtifactType {
|
||||
/// Execution logs (stdout/stderr)
|
||||
Log,
|
||||
/// Execution result data
|
||||
Result,
|
||||
/// Custom file output
|
||||
File,
|
||||
/// Trace/debug information
|
||||
Trace,
|
||||
}
|
||||
|
||||
/// Artifact metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Artifact {
|
||||
/// Artifact ID
|
||||
pub id: String,
|
||||
/// Execution ID
|
||||
pub execution_id: i64,
|
||||
/// Artifact type
|
||||
pub artifact_type: ArtifactType,
|
||||
/// File path
|
||||
pub path: PathBuf,
|
||||
/// Content type (MIME type)
|
||||
pub content_type: String,
|
||||
/// Size in bytes
|
||||
pub size: u64,
|
||||
/// Creation timestamp
|
||||
pub created: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
/// Artifact manager for storing execution artifacts
|
||||
pub struct ArtifactManager {
|
||||
/// Base directory for artifact storage
|
||||
base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl ArtifactManager {
|
||||
/// Create a new artifact manager
|
||||
pub fn new(base_dir: PathBuf) -> Self {
|
||||
Self { base_dir }
|
||||
}
|
||||
|
||||
/// Initialize the artifact storage directory
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
fs::create_dir_all(&self.base_dir)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create artifact directory: {}", e)))?;
|
||||
|
||||
info!("Artifact storage initialized at: {:?}", self.base_dir);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the directory path for an execution
|
||||
pub fn get_execution_dir(&self, execution_id: i64) -> PathBuf {
|
||||
self.base_dir.join(format!("execution_{}", execution_id))
|
||||
}
|
||||
|
||||
/// Store execution logs
|
||||
pub async fn store_logs(
|
||||
&self,
|
||||
execution_id: i64,
|
||||
stdout: &str,
|
||||
stderr: &str,
|
||||
) -> Result<Vec<Artifact>> {
|
||||
let exec_dir = self.get_execution_dir(execution_id);
|
||||
fs::create_dir_all(&exec_dir)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create execution directory: {}", e)))?;
|
||||
|
||||
let mut artifacts = Vec::new();
|
||||
|
||||
// Store stdout
|
||||
if !stdout.is_empty() {
|
||||
let stdout_path = exec_dir.join("stdout.log");
|
||||
let mut file = fs::File::create(&stdout_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create stdout file: {}", e)))?;
|
||||
file.write_all(stdout.as_bytes())
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to write stdout: {}", e)))?;
|
||||
file.sync_all()
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to sync stdout file: {}", e)))?;
|
||||
|
||||
let metadata = fs::metadata(&stdout_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to get stdout metadata: {}", e)))?;
|
||||
artifacts.push(Artifact {
|
||||
id: format!("{}_stdout", execution_id),
|
||||
execution_id,
|
||||
artifact_type: ArtifactType::Log,
|
||||
path: stdout_path,
|
||||
content_type: "text/plain".to_string(),
|
||||
size: metadata.len(),
|
||||
created: chrono::Utc::now(),
|
||||
});
|
||||
|
||||
debug!(
|
||||
"Stored stdout log for execution {} ({} bytes)",
|
||||
execution_id,
|
||||
metadata.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Store stderr
|
||||
if !stderr.is_empty() {
|
||||
let stderr_path = exec_dir.join("stderr.log");
|
||||
let mut file = fs::File::create(&stderr_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create stderr file: {}", e)))?;
|
||||
file.write_all(stderr.as_bytes())
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to write stderr: {}", e)))?;
|
||||
file.sync_all()
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to sync stderr file: {}", e)))?;
|
||||
|
||||
let metadata = fs::metadata(&stderr_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to get stderr metadata: {}", e)))?;
|
||||
artifacts.push(Artifact {
|
||||
id: format!("{}_stderr", execution_id),
|
||||
execution_id,
|
||||
artifact_type: ArtifactType::Log,
|
||||
path: stderr_path,
|
||||
content_type: "text/plain".to_string(),
|
||||
size: metadata.len(),
|
||||
created: chrono::Utc::now(),
|
||||
});
|
||||
|
||||
debug!(
|
||||
"Stored stderr log for execution {} ({} bytes)",
|
||||
execution_id,
|
||||
metadata.len()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(artifacts)
|
||||
}
|
||||
|
||||
/// Store execution result
|
||||
pub async fn store_result(
|
||||
&self,
|
||||
execution_id: i64,
|
||||
result: &serde_json::Value,
|
||||
) -> Result<Artifact> {
|
||||
let exec_dir = self.get_execution_dir(execution_id);
|
||||
fs::create_dir_all(&exec_dir)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create execution directory: {}", e)))?;
|
||||
|
||||
let result_path = exec_dir.join("result.json");
|
||||
let result_json = serde_json::to_string_pretty(result)?;
|
||||
|
||||
let mut file = fs::File::create(&result_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create result file: {}", e)))?;
|
||||
file.write_all(result_json.as_bytes())
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to write result: {}", e)))?;
|
||||
file.sync_all()
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to sync result file: {}", e)))?;
|
||||
|
||||
let metadata = fs::metadata(&result_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to get result metadata: {}", e)))?;
|
||||
|
||||
debug!(
|
||||
"Stored result for execution {} ({} bytes)",
|
||||
execution_id,
|
||||
metadata.len()
|
||||
);
|
||||
|
||||
Ok(Artifact {
|
||||
id: format!("{}_result", execution_id),
|
||||
execution_id,
|
||||
artifact_type: ArtifactType::Result,
|
||||
path: result_path,
|
||||
content_type: "application/json".to_string(),
|
||||
size: metadata.len(),
|
||||
created: chrono::Utc::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Store a custom file artifact
|
||||
pub async fn store_file(
|
||||
&self,
|
||||
execution_id: i64,
|
||||
filename: &str,
|
||||
content: &[u8],
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Artifact> {
|
||||
let exec_dir = self.get_execution_dir(execution_id);
|
||||
fs::create_dir_all(&exec_dir)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create execution directory: {}", e)))?;
|
||||
|
||||
let file_path = exec_dir.join(filename);
|
||||
let mut file = fs::File::create(&file_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create file: {}", e)))?;
|
||||
file.write_all(content)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to write file: {}", e)))?;
|
||||
file.sync_all()
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to sync file: {}", e)))?;
|
||||
|
||||
let metadata = fs::metadata(&file_path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to get file metadata: {}", e)))?;
|
||||
|
||||
debug!(
|
||||
"Stored file artifact {} for execution {} ({} bytes)",
|
||||
filename,
|
||||
execution_id,
|
||||
metadata.len()
|
||||
);
|
||||
|
||||
Ok(Artifact {
|
||||
id: format!("{}_{}", execution_id, filename),
|
||||
execution_id,
|
||||
artifact_type: ArtifactType::File,
|
||||
path: file_path,
|
||||
content_type: content_type
|
||||
.unwrap_or("application/octet-stream")
|
||||
.to_string(),
|
||||
size: metadata.len(),
|
||||
created: chrono::Utc::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Read an artifact
|
||||
pub async fn read_artifact(&self, artifact: &Artifact) -> Result<Vec<u8>> {
|
||||
fs::read(&artifact.path)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to read artifact: {}", e)))
|
||||
}
|
||||
|
||||
/// Delete artifacts for an execution
|
||||
pub async fn delete_execution_artifacts(&self, execution_id: i64) -> Result<()> {
|
||||
let exec_dir = self.get_execution_dir(execution_id);
|
||||
|
||||
if exec_dir.exists() {
|
||||
fs::remove_dir_all(&exec_dir).await.map_err(|e| {
|
||||
Error::Internal(format!("Failed to delete execution artifacts: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Deleted artifacts for execution {}", execution_id);
|
||||
} else {
|
||||
warn!(
|
||||
"No artifacts found for execution {} (directory does not exist)",
|
||||
execution_id
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clean up old artifacts (retention policy)
|
||||
pub async fn cleanup_old_artifacts(&self, retention_days: u64) -> Result<usize> {
|
||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
||||
let mut deleted_count = 0;
|
||||
|
||||
let mut entries = fs::read_dir(&self.base_dir)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to read artifact directory: {}", e)))?;
|
||||
|
||||
while let Some(entry) = entries
|
||||
.next_entry()
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to read directory entry: {}", e)))?
|
||||
{
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
if let Ok(metadata) = fs::metadata(&path).await {
|
||||
if let Ok(modified) = metadata.modified() {
|
||||
let modified_time: chrono::DateTime<chrono::Utc> = modified.into();
|
||||
if modified_time < cutoff {
|
||||
if let Err(e) = fs::remove_dir_all(&path).await {
|
||||
warn!("Failed to delete old artifact directory {:?}: {}", path, e);
|
||||
} else {
|
||||
deleted_count += 1;
|
||||
debug!("Deleted old artifact directory: {:?}", path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Cleaned up {} old artifact directories (retention: {} days)",
|
||||
deleted_count, retention_days
|
||||
);
|
||||
|
||||
Ok(deleted_count)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ArtifactManager {
|
||||
fn default() -> Self {
|
||||
Self::new(PathBuf::from("/tmp/attune/artifacts"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_artifact_manager_store_logs() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = ArtifactManager::new(temp_dir.path().to_path_buf());
|
||||
manager.initialize().await.unwrap();
|
||||
|
||||
let artifacts = manager
|
||||
.store_logs(1, "stdout output", "stderr output")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(artifacts.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_artifact_manager_store_result() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = ArtifactManager::new(temp_dir.path().to_path_buf());
|
||||
manager.initialize().await.unwrap();
|
||||
|
||||
let result = serde_json::json!({"status": "success", "value": 42});
|
||||
let artifact = manager.store_result(1, &result).await.unwrap();
|
||||
|
||||
assert_eq!(artifact.execution_id, 1);
|
||||
assert_eq!(artifact.content_type, "application/json");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_artifact_manager_delete() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = ArtifactManager::new(temp_dir.path().to_path_buf());
|
||||
manager.initialize().await.unwrap();
|
||||
|
||||
manager.store_logs(1, "test", "test").await.unwrap();
|
||||
assert!(manager.get_execution_dir(1).exists());
|
||||
|
||||
manager.delete_execution_artifacts(1).await.unwrap();
|
||||
assert!(!manager.get_execution_dir(1).exists());
|
||||
}
|
||||
}
|
||||
596
crates/worker/src/executor.rs
Normal file
596
crates/worker/src/executor.rs
Normal file
@@ -0,0 +1,596 @@
|
||||
//! Action Executor Module
|
||||
//!
|
||||
//! Coordinates the execution of actions by managing the runtime,
|
||||
//! loading action data, preparing execution context, and collecting results.
|
||||
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::{runtime::Runtime as RuntimeModel, Action, Execution, ExecutionStatus};
|
||||
use attune_common::repositories::execution::{ExecutionRepository, UpdateExecutionInput};
|
||||
use attune_common::repositories::{FindById, Update};
|
||||
|
||||
use serde_json::Value as JsonValue;
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::artifacts::ArtifactManager;
|
||||
use crate::runtime::{ExecutionContext, ExecutionResult, RuntimeRegistry};
|
||||
use crate::secrets::SecretManager;
|
||||
|
||||
/// Action executor that orchestrates execution flow
|
||||
pub struct ActionExecutor {
|
||||
pool: PgPool,
|
||||
runtime_registry: RuntimeRegistry,
|
||||
artifact_manager: ArtifactManager,
|
||||
secret_manager: SecretManager,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
packs_base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl ActionExecutor {
|
||||
/// Create a new action executor
|
||||
pub fn new(
|
||||
pool: PgPool,
|
||||
runtime_registry: RuntimeRegistry,
|
||||
artifact_manager: ArtifactManager,
|
||||
secret_manager: SecretManager,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
packs_base_dir: PathBuf,
|
||||
) -> Self {
|
||||
Self {
|
||||
pool,
|
||||
runtime_registry,
|
||||
artifact_manager,
|
||||
secret_manager,
|
||||
max_stdout_bytes,
|
||||
max_stderr_bytes,
|
||||
packs_base_dir,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute an action for the given execution
|
||||
pub async fn execute(&self, execution_id: i64) -> Result<ExecutionResult> {
|
||||
info!("Starting execution: {}", execution_id);
|
||||
|
||||
// Update execution status to running
|
||||
if let Err(e) = self
|
||||
.update_execution_status(execution_id, ExecutionStatus::Running)
|
||||
.await
|
||||
{
|
||||
error!("Failed to update execution status to running: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Load execution from database
|
||||
let execution = self.load_execution(execution_id).await?;
|
||||
|
||||
// Load action from database
|
||||
let action = self.load_action(&execution).await?;
|
||||
|
||||
// Prepare execution context
|
||||
let context = match self.prepare_execution_context(&execution, &action).await {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => {
|
||||
error!("Failed to prepare execution context: {}", e);
|
||||
self.handle_execution_failure(execution_id, None).await?;
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the action
|
||||
// Note: execute_action should rarely return Err - most failures should be
|
||||
// captured in ExecutionResult with non-zero exit codes
|
||||
let result = match self.execute_action(context).await {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
error!("Action execution failed catastrophically: {}", e);
|
||||
// This should only happen for unrecoverable errors like runtime not found
|
||||
self.handle_execution_failure(execution_id, None).await?;
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
// Store artifacts
|
||||
if let Err(e) = self.store_execution_artifacts(execution_id, &result).await {
|
||||
warn!("Failed to store artifacts: {}", e);
|
||||
// Don't fail the execution just because artifact storage failed
|
||||
}
|
||||
|
||||
// Update execution with result
|
||||
if result.is_success() {
|
||||
self.handle_execution_success(execution_id, &result).await?;
|
||||
} else {
|
||||
self.handle_execution_failure(execution_id, Some(&result))
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Execution {} completed: {}",
|
||||
execution_id,
|
||||
if result.is_success() {
|
||||
"success"
|
||||
} else {
|
||||
"failed"
|
||||
}
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Load execution from database
|
||||
async fn load_execution(&self, execution_id: i64) -> Result<Execution> {
|
||||
debug!("Loading execution: {}", execution_id);
|
||||
|
||||
ExecutionRepository::find_by_id(&self.pool, execution_id)
|
||||
.await?
|
||||
.ok_or_else(|| Error::not_found("Execution", "id", execution_id.to_string()))
|
||||
}
|
||||
|
||||
/// Load action from database using execution data
|
||||
async fn load_action(&self, execution: &Execution) -> Result<Action> {
|
||||
debug!("Loading action: {}", execution.action_ref);
|
||||
|
||||
// Try to load by action ID if available
|
||||
if let Some(action_id) = execution.action {
|
||||
let action = sqlx::query_as::<_, Action>("SELECT * FROM action WHERE id = $1")
|
||||
.bind(action_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(action) = action {
|
||||
return Ok(action);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, parse action_ref and query by pack.ref + action.ref
|
||||
let parts: Vec<&str> = execution.action_ref.split('.').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(Error::validation(format!(
|
||||
"Invalid action reference format: {}. Expected format: pack.action",
|
||||
execution.action_ref
|
||||
)));
|
||||
}
|
||||
|
||||
let pack_ref = parts[0];
|
||||
let action_ref = parts[1];
|
||||
|
||||
// Query action by pack ref and action ref
|
||||
let action = sqlx::query_as::<_, Action>(
|
||||
r#"
|
||||
SELECT a.*
|
||||
FROM action a
|
||||
JOIN pack p ON a.pack = p.id
|
||||
WHERE p.ref = $1 AND a.ref = $2
|
||||
"#,
|
||||
)
|
||||
.bind(pack_ref)
|
||||
.bind(action_ref)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| Error::not_found("Action", "ref", execution.action_ref.clone()))?;
|
||||
|
||||
Ok(action)
|
||||
}
|
||||
|
||||
/// Prepare execution context from execution and action data
|
||||
async fn prepare_execution_context(
|
||||
&self,
|
||||
execution: &Execution,
|
||||
action: &Action,
|
||||
) -> Result<ExecutionContext> {
|
||||
debug!(
|
||||
"Preparing execution context for execution: {}",
|
||||
execution.id
|
||||
);
|
||||
|
||||
// Extract parameters from execution config
|
||||
let mut parameters = HashMap::new();
|
||||
|
||||
if let Some(config) = &execution.config {
|
||||
// Try to get parameters from config.parameters first
|
||||
if let Some(params) = config.get("parameters") {
|
||||
if let JsonValue::Object(map) = params {
|
||||
for (key, value) in map {
|
||||
parameters.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
} else if let JsonValue::Object(map) = config {
|
||||
// If no parameters key, treat entire config as parameters
|
||||
// (this handles rule action_params being placed at root level)
|
||||
for (key, value) in map {
|
||||
// Skip special keys that aren't action parameters
|
||||
if key != "context" && key != "env" {
|
||||
parameters.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare environment variables
|
||||
let mut env = HashMap::new();
|
||||
env.insert("ATTUNE_EXECUTION_ID".to_string(), execution.id.to_string());
|
||||
env.insert(
|
||||
"ATTUNE_ACTION_REF".to_string(),
|
||||
execution.action_ref.clone(),
|
||||
);
|
||||
|
||||
if let Some(action_id) = execution.action {
|
||||
env.insert("ATTUNE_ACTION_ID".to_string(), action_id.to_string());
|
||||
}
|
||||
|
||||
// Add context data as environment variables from config
|
||||
if let Some(config) = &execution.config {
|
||||
if let Some(context) = config.get("context") {
|
||||
if let JsonValue::Object(map) = context {
|
||||
for (key, value) in map {
|
||||
let env_key = format!("ATTUNE_CONTEXT_{}", key.to_uppercase());
|
||||
let env_value = match value {
|
||||
JsonValue::String(s) => s.clone(),
|
||||
JsonValue::Number(n) => n.to_string(),
|
||||
JsonValue::Bool(b) => b.to_string(),
|
||||
_ => serde_json::to_string(value)?,
|
||||
};
|
||||
env.insert(env_key, env_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch secrets (passed securely via stdin, not environment variables)
|
||||
let secrets = match self.secret_manager.fetch_secrets_for_action(action).await {
|
||||
Ok(secrets) => {
|
||||
debug!(
|
||||
"Fetched {} secrets for action {} (will be passed via stdin)",
|
||||
secrets.len(),
|
||||
action.r#ref
|
||||
);
|
||||
secrets
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch secrets for action {}: {}", action.r#ref, e);
|
||||
// Don't fail the execution if secrets can't be fetched
|
||||
// Some actions may not require secrets
|
||||
HashMap::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Determine entry point from action
|
||||
let entry_point = action.entrypoint.clone();
|
||||
|
||||
// Default timeout: 5 minutes (300 seconds)
|
||||
// In the future, this could come from action metadata or execution config
|
||||
let timeout = Some(300_u64);
|
||||
|
||||
// Load runtime information if specified
|
||||
let runtime_name = if let Some(runtime_id) = action.runtime {
|
||||
match sqlx::query_as::<_, RuntimeModel>("SELECT * FROM runtime WHERE id = $1")
|
||||
.bind(runtime_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
{
|
||||
Ok(Some(runtime)) => {
|
||||
debug!(
|
||||
"Loaded runtime '{}' for action '{}'",
|
||||
runtime.name, action.r#ref
|
||||
);
|
||||
Some(runtime.name.to_lowercase())
|
||||
}
|
||||
Ok(None) => {
|
||||
warn!(
|
||||
"Runtime ID {} not found for action '{}'",
|
||||
runtime_id, action.r#ref
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to load runtime {} for action '{}': {}",
|
||||
runtime_id, action.r#ref, e
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Construct code_path for pack actions
|
||||
// Pack actions have their script files in packs/{pack_ref}/actions/{entrypoint}
|
||||
let code_path = if action.pack_ref.starts_with("core") || !action.is_adhoc {
|
||||
// This is a pack action, construct the file path
|
||||
let action_file_path = self
|
||||
.packs_base_dir
|
||||
.join(&action.pack_ref)
|
||||
.join("actions")
|
||||
.join(&entry_point);
|
||||
|
||||
if action_file_path.exists() {
|
||||
Some(action_file_path)
|
||||
} else {
|
||||
warn!(
|
||||
"Action file not found at {:?} for action {}",
|
||||
action_file_path, action.r#ref
|
||||
);
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None // Ad-hoc actions don't have files
|
||||
};
|
||||
|
||||
// For shell actions without a file, use the entrypoint as inline code
|
||||
let code = if runtime_name.as_deref() == Some("shell") && code_path.is_none() {
|
||||
Some(entry_point.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: execution.id,
|
||||
action_ref: execution.action_ref.clone(),
|
||||
parameters,
|
||||
env,
|
||||
secrets, // Passed securely via stdin
|
||||
timeout,
|
||||
working_dir: None, // Could be configured per action
|
||||
entry_point,
|
||||
code,
|
||||
code_path,
|
||||
runtime_name,
|
||||
max_stdout_bytes: self.max_stdout_bytes,
|
||||
max_stderr_bytes: self.max_stderr_bytes,
|
||||
};
|
||||
|
||||
Ok(context)
|
||||
}
|
||||
|
||||
/// Execute the action using the runtime registry
|
||||
async fn execute_action(&self, context: ExecutionContext) -> Result<ExecutionResult> {
|
||||
debug!("Executing action: {}", context.action_ref);
|
||||
|
||||
let runtime = self
|
||||
.runtime_registry
|
||||
.get_runtime(&context)
|
||||
.map_err(|e| Error::Internal(e.to_string()))?;
|
||||
|
||||
let result = runtime
|
||||
.execute(context)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(e.to_string()))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Store execution artifacts (logs, results)
|
||||
async fn store_execution_artifacts(
|
||||
&self,
|
||||
execution_id: i64,
|
||||
result: &ExecutionResult,
|
||||
) -> Result<()> {
|
||||
debug!("Storing artifacts for execution: {}", execution_id);
|
||||
|
||||
// Store logs
|
||||
self.artifact_manager
|
||||
.store_logs(execution_id, &result.stdout, &result.stderr)
|
||||
.await?;
|
||||
|
||||
// Store result if available
|
||||
if let Some(result_data) = &result.result {
|
||||
self.artifact_manager
|
||||
.store_result(execution_id, result_data)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle successful execution
|
||||
async fn handle_execution_success(
|
||||
&self,
|
||||
execution_id: i64,
|
||||
result: &ExecutionResult,
|
||||
) -> Result<()> {
|
||||
info!("Execution {} succeeded", execution_id);
|
||||
|
||||
// Build comprehensive result with execution metadata
|
||||
let exec_dir = self.artifact_manager.get_execution_dir(execution_id);
|
||||
let mut result_data = serde_json::json!({
|
||||
"exit_code": result.exit_code,
|
||||
"duration_ms": result.duration_ms,
|
||||
"succeeded": true,
|
||||
});
|
||||
|
||||
// Add log file paths if logs exist
|
||||
if !result.stdout.is_empty() {
|
||||
let stdout_path = exec_dir.join("stdout.log");
|
||||
result_data["stdout_log"] = serde_json::json!(stdout_path.to_string_lossy());
|
||||
// Include stdout preview (first 1000 chars)
|
||||
let stdout_preview = if result.stdout.len() > 1000 {
|
||||
format!("{}...", &result.stdout[..1000])
|
||||
} else {
|
||||
result.stdout.clone()
|
||||
};
|
||||
result_data["stdout"] = serde_json::json!(stdout_preview);
|
||||
}
|
||||
|
||||
if !result.stderr.is_empty() {
|
||||
let stderr_path = exec_dir.join("stderr.log");
|
||||
result_data["stderr_log"] = serde_json::json!(stderr_path.to_string_lossy());
|
||||
// Include stderr preview (first 1000 chars)
|
||||
let stderr_preview = if result.stderr.len() > 1000 {
|
||||
format!("{}...", &result.stderr[..1000])
|
||||
} else {
|
||||
result.stderr.clone()
|
||||
};
|
||||
result_data["stderr"] = serde_json::json!(stderr_preview);
|
||||
}
|
||||
|
||||
// Include parsed result if available
|
||||
if let Some(parsed_result) = &result.result {
|
||||
result_data["data"] = parsed_result.clone();
|
||||
}
|
||||
|
||||
let input = UpdateExecutionInput {
|
||||
status: Some(ExecutionStatus::Completed),
|
||||
result: Some(result_data),
|
||||
executor: None,
|
||||
workflow_task: None, // Not updating workflow metadata
|
||||
};
|
||||
|
||||
ExecutionRepository::update(&self.pool, execution_id, input).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle failed execution
|
||||
async fn handle_execution_failure(
|
||||
&self,
|
||||
execution_id: i64,
|
||||
result: Option<&ExecutionResult>,
|
||||
) -> Result<()> {
|
||||
error!("Execution {} failed", execution_id);
|
||||
|
||||
let exec_dir = self.artifact_manager.get_execution_dir(execution_id);
|
||||
let mut result_data = serde_json::json!({
|
||||
"succeeded": false,
|
||||
});
|
||||
|
||||
// If we have execution result, include detailed information
|
||||
if let Some(exec_result) = result {
|
||||
result_data["exit_code"] = serde_json::json!(exec_result.exit_code);
|
||||
result_data["duration_ms"] = serde_json::json!(exec_result.duration_ms);
|
||||
|
||||
if let Some(ref error) = exec_result.error {
|
||||
result_data["error"] = serde_json::json!(error);
|
||||
}
|
||||
|
||||
// Add log file paths and previews if logs exist
|
||||
if !exec_result.stdout.is_empty() {
|
||||
let stdout_path = exec_dir.join("stdout.log");
|
||||
result_data["stdout_log"] = serde_json::json!(stdout_path.to_string_lossy());
|
||||
// Include stdout preview (first 1000 chars)
|
||||
let stdout_preview = if exec_result.stdout.len() > 1000 {
|
||||
format!("{}...", &exec_result.stdout[..1000])
|
||||
} else {
|
||||
exec_result.stdout.clone()
|
||||
};
|
||||
result_data["stdout"] = serde_json::json!(stdout_preview);
|
||||
}
|
||||
|
||||
if !exec_result.stderr.is_empty() {
|
||||
let stderr_path = exec_dir.join("stderr.log");
|
||||
result_data["stderr_log"] = serde_json::json!(stderr_path.to_string_lossy());
|
||||
// Include stderr preview (first 1000 chars)
|
||||
let stderr_preview = if exec_result.stderr.len() > 1000 {
|
||||
format!("{}...", &exec_result.stderr[..1000])
|
||||
} else {
|
||||
exec_result.stderr.clone()
|
||||
};
|
||||
result_data["stderr"] = serde_json::json!(stderr_preview);
|
||||
}
|
||||
|
||||
// Add truncation warnings if applicable
|
||||
if exec_result.stdout_truncated {
|
||||
result_data["stdout_truncated"] = serde_json::json!(true);
|
||||
result_data["stdout_bytes_truncated"] =
|
||||
serde_json::json!(exec_result.stdout_bytes_truncated);
|
||||
}
|
||||
if exec_result.stderr_truncated {
|
||||
result_data["stderr_truncated"] = serde_json::json!(true);
|
||||
result_data["stderr_bytes_truncated"] =
|
||||
serde_json::json!(exec_result.stderr_bytes_truncated);
|
||||
}
|
||||
} else {
|
||||
// No execution result available (early failure during setup/preparation)
|
||||
// This should be rare - most errors should be captured in ExecutionResult
|
||||
result_data["error"] = serde_json::json!("Execution failed during preparation");
|
||||
|
||||
warn!("Execution {} failed without ExecutionResult - this indicates an early/catastrophic failure", execution_id);
|
||||
|
||||
// Check if stderr log exists from artifact storage
|
||||
let stderr_path = exec_dir.join("stderr.log");
|
||||
if stderr_path.exists() {
|
||||
result_data["stderr_log"] = serde_json::json!(stderr_path.to_string_lossy());
|
||||
// Try to read a preview if file exists
|
||||
if let Ok(contents) = tokio::fs::read_to_string(&stderr_path).await {
|
||||
let preview = if contents.len() > 1000 {
|
||||
format!("{}...", &contents[..1000])
|
||||
} else {
|
||||
contents
|
||||
};
|
||||
result_data["stderr"] = serde_json::json!(preview);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout log exists from artifact storage
|
||||
let stdout_path = exec_dir.join("stdout.log");
|
||||
if stdout_path.exists() {
|
||||
result_data["stdout_log"] = serde_json::json!(stdout_path.to_string_lossy());
|
||||
// Try to read a preview if file exists
|
||||
if let Ok(contents) = tokio::fs::read_to_string(&stdout_path).await {
|
||||
let preview = if contents.len() > 1000 {
|
||||
format!("{}...", &contents[..1000])
|
||||
} else {
|
||||
contents
|
||||
};
|
||||
result_data["stdout"] = serde_json::json!(preview);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let input = UpdateExecutionInput {
|
||||
status: Some(ExecutionStatus::Failed),
|
||||
result: Some(result_data),
|
||||
executor: None,
|
||||
workflow_task: None, // Not updating workflow metadata
|
||||
};
|
||||
|
||||
ExecutionRepository::update(&self.pool, execution_id, input).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update execution status
|
||||
async fn update_execution_status(
|
||||
&self,
|
||||
execution_id: i64,
|
||||
status: ExecutionStatus,
|
||||
) -> Result<()> {
|
||||
debug!(
|
||||
"Updating execution {} status to: {:?}",
|
||||
execution_id, status
|
||||
);
|
||||
|
||||
let input = UpdateExecutionInput {
|
||||
status: Some(status),
|
||||
result: None,
|
||||
executor: None,
|
||||
workflow_task: None, // Not updating workflow metadata
|
||||
};
|
||||
|
||||
ExecutionRepository::update(&self.pool, execution_id, input).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn test_parse_action_reference() {
|
||||
let action_ref = "mypack.myaction";
|
||||
let parts: Vec<&str> = action_ref.split('.').collect();
|
||||
assert_eq!(parts.len(), 2);
|
||||
assert_eq!(parts[0], "mypack");
|
||||
assert_eq!(parts[1], "myaction");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_action_reference() {
|
||||
let action_ref = "invalid";
|
||||
let parts: Vec<&str> = action_ref.split('.').collect();
|
||||
assert_eq!(parts.len(), 1);
|
||||
}
|
||||
}
|
||||
140
crates/worker/src/heartbeat.rs
Normal file
140
crates/worker/src/heartbeat.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
//! Heartbeat Module
|
||||
//!
|
||||
//! Manages periodic heartbeat updates to keep the worker's status fresh in the database.
|
||||
|
||||
use attune_common::error::Result;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::registration::WorkerRegistration;
|
||||
|
||||
/// Heartbeat manager for worker status updates
|
||||
pub struct HeartbeatManager {
|
||||
registration: Arc<RwLock<WorkerRegistration>>,
|
||||
interval: Duration,
|
||||
running: Arc<RwLock<bool>>,
|
||||
}
|
||||
|
||||
impl HeartbeatManager {
|
||||
/// Create a new heartbeat manager
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `registration` - Worker registration instance
|
||||
/// * `interval_secs` - Heartbeat interval in seconds
|
||||
pub fn new(registration: Arc<RwLock<WorkerRegistration>>, interval_secs: u64) -> Self {
|
||||
Self {
|
||||
registration,
|
||||
interval: Duration::from_secs(interval_secs),
|
||||
running: Arc::new(RwLock::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the heartbeat loop
|
||||
///
|
||||
/// This spawns a background task that periodically updates the worker's heartbeat
|
||||
/// in the database. The task will continue running until `stop()` is called.
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
let mut running = self.running.write().await;
|
||||
if *running {
|
||||
warn!("Heartbeat manager is already running");
|
||||
return Ok(());
|
||||
}
|
||||
*running = true;
|
||||
drop(running);
|
||||
|
||||
info!(
|
||||
"Starting heartbeat manager with interval: {:?}",
|
||||
self.interval
|
||||
);
|
||||
|
||||
let registration = self.registration.clone();
|
||||
let interval = self.interval;
|
||||
let running = self.running.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut ticker = time::interval(interval);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
// Check if we should stop
|
||||
{
|
||||
let is_running = running.read().await;
|
||||
if !*is_running {
|
||||
info!("Heartbeat manager stopping");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send heartbeat
|
||||
let reg = registration.read().await;
|
||||
match reg.update_heartbeat().await {
|
||||
Ok(_) => {
|
||||
debug!("Heartbeat sent successfully");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to send heartbeat: {}", e);
|
||||
// Continue trying - don't break the loop on transient errors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Heartbeat manager stopped");
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the heartbeat loop
|
||||
pub async fn stop(&self) {
|
||||
info!("Stopping heartbeat manager");
|
||||
let mut running = self.running.write().await;
|
||||
*running = false;
|
||||
}
|
||||
|
||||
/// Check if the heartbeat manager is running
|
||||
pub async fn is_running(&self) -> bool {
|
||||
let running = self.running.read().await;
|
||||
*running
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::registration::WorkerRegistration;
|
||||
use attune_common::config::Config;
|
||||
use attune_common::db::Database;
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires database
|
||||
async fn test_heartbeat_manager() {
|
||||
let config = Config::load().unwrap();
|
||||
let db = Database::new(&config.database).await.unwrap();
|
||||
let pool = db.pool().clone();
|
||||
let mut registration = WorkerRegistration::new(pool, &config);
|
||||
registration.register().await.unwrap();
|
||||
|
||||
let registration = Arc::new(RwLock::new(registration));
|
||||
let manager = HeartbeatManager::new(registration.clone(), 1);
|
||||
|
||||
// Start heartbeat
|
||||
manager.start().await.unwrap();
|
||||
assert!(manager.is_running().await);
|
||||
|
||||
// Wait for a few heartbeats
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
// Stop heartbeat
|
||||
manager.stop().await;
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
assert!(!manager.is_running().await);
|
||||
|
||||
// Deregister worker
|
||||
let reg = registration.read().await;
|
||||
reg.deregister().await.unwrap();
|
||||
}
|
||||
}
|
||||
25
crates/worker/src/lib.rs
Normal file
25
crates/worker/src/lib.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
//! Attune Worker Service Library
|
||||
//!
|
||||
//! This library provides the core functionality for the Attune Worker Service,
|
||||
//! which executes actions in various runtime environments.
|
||||
|
||||
pub mod artifacts;
|
||||
pub mod executor;
|
||||
pub mod heartbeat;
|
||||
pub mod registration;
|
||||
pub mod runtime;
|
||||
pub mod secrets;
|
||||
pub mod service;
|
||||
pub mod test_executor;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use executor::ActionExecutor;
|
||||
pub use heartbeat::HeartbeatManager;
|
||||
pub use registration::WorkerRegistration;
|
||||
pub use runtime::{
|
||||
ExecutionContext, ExecutionResult, LocalRuntime, NativeRuntime, PythonRuntime, Runtime,
|
||||
RuntimeError, RuntimeResult, ShellRuntime,
|
||||
};
|
||||
pub use secrets::SecretManager;
|
||||
pub use service::WorkerService;
|
||||
pub use test_executor::{TestConfig, TestExecutor};
|
||||
79
crates/worker/src/main.rs
Normal file
79
crates/worker/src/main.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
//! Attune Worker Service
|
||||
|
||||
use anyhow::Result;
|
||||
use attune_common::config::Config;
|
||||
use clap::Parser;
|
||||
use tracing::info;
|
||||
|
||||
use attune_worker::service::WorkerService;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "attune-worker")]
|
||||
#[command(about = "Attune Worker Service - Executes automation actions", long_about = None)]
|
||||
struct Args {
|
||||
/// Path to configuration file
|
||||
#[arg(short, long)]
|
||||
config: Option<String>,
|
||||
|
||||
/// Worker name (overrides config)
|
||||
#[arg(short, long)]
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::fmt()
|
||||
.with_target(false)
|
||||
.with_thread_ids(true)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
info!("Starting Attune Worker Service");
|
||||
|
||||
// Load configuration
|
||||
if let Some(config_path) = args.config {
|
||||
std::env::set_var("ATTUNE_CONFIG", config_path);
|
||||
}
|
||||
|
||||
let mut config = Config::load()?;
|
||||
config.validate()?;
|
||||
|
||||
// Override worker name if provided via CLI
|
||||
if let Some(name) = args.name {
|
||||
if let Some(ref mut worker_config) = config.worker {
|
||||
worker_config.name = Some(name);
|
||||
} else {
|
||||
config.worker = Some(attune_common::config::WorkerConfig {
|
||||
name: Some(name),
|
||||
worker_type: None,
|
||||
runtime_id: None,
|
||||
host: None,
|
||||
port: None,
|
||||
capabilities: None,
|
||||
max_concurrent_tasks: 10,
|
||||
heartbeat_interval: 30,
|
||||
task_timeout: 300,
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
stream_logs: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!("Configuration loaded successfully");
|
||||
info!("Environment: {}", config.environment);
|
||||
|
||||
// Initialize and run worker service
|
||||
let mut service = WorkerService::new(config).await?;
|
||||
|
||||
info!("Attune Worker Service is ready");
|
||||
|
||||
// Run until interrupted
|
||||
service.run().await?;
|
||||
|
||||
info!("Attune Worker Service shutdown complete");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
349
crates/worker/src/registration.rs
Normal file
349
crates/worker/src/registration.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
//! Worker Registration Module
|
||||
//!
|
||||
//! Handles worker registration, discovery, and status management in the database.
|
||||
//! Uses unified runtime detection from the common crate.
|
||||
|
||||
use attune_common::config::Config;
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::{Worker, WorkerRole, WorkerStatus, WorkerType};
|
||||
use attune_common::runtime_detection::RuntimeDetector;
|
||||
use chrono::Utc;
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Worker registration manager
|
||||
pub struct WorkerRegistration {
|
||||
pool: PgPool,
|
||||
worker_id: Option<i64>,
|
||||
worker_name: String,
|
||||
worker_type: WorkerType,
|
||||
worker_role: WorkerRole,
|
||||
runtime_id: Option<i64>,
|
||||
host: Option<String>,
|
||||
port: Option<i32>,
|
||||
capabilities: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl WorkerRegistration {
|
||||
/// Create a new worker registration manager
|
||||
pub fn new(pool: PgPool, config: &Config) -> Self {
|
||||
let worker_name = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.and_then(|w| w.name.clone())
|
||||
.unwrap_or_else(|| {
|
||||
format!(
|
||||
"worker-{}",
|
||||
hostname::get()
|
||||
.unwrap_or_else(|_| "unknown".into())
|
||||
.to_string_lossy()
|
||||
)
|
||||
});
|
||||
|
||||
let worker_type = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.and_then(|w| w.worker_type.clone())
|
||||
.unwrap_or(WorkerType::Local);
|
||||
|
||||
let worker_role = WorkerRole::Action;
|
||||
|
||||
let runtime_id = config.worker.as_ref().and_then(|w| w.runtime_id);
|
||||
|
||||
let host = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.and_then(|w| w.host.clone())
|
||||
.or_else(|| {
|
||||
hostname::get()
|
||||
.ok()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
});
|
||||
|
||||
let port = config.worker.as_ref().and_then(|w| w.port);
|
||||
|
||||
// Initial capabilities (will be populated asynchronously)
|
||||
let mut capabilities = HashMap::new();
|
||||
|
||||
// Set max_concurrent_executions from config
|
||||
let max_concurrent = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.map(|w| w.max_concurrent_tasks)
|
||||
.unwrap_or(10);
|
||||
capabilities.insert(
|
||||
"max_concurrent_executions".to_string(),
|
||||
json!(max_concurrent),
|
||||
);
|
||||
|
||||
// Add worker version metadata
|
||||
capabilities.insert(
|
||||
"worker_version".to_string(),
|
||||
json!(env!("CARGO_PKG_VERSION")),
|
||||
);
|
||||
|
||||
// Placeholder for runtimes (will be detected asynchronously)
|
||||
capabilities.insert("runtimes".to_string(), json!(Vec::<String>::new()));
|
||||
|
||||
Self {
|
||||
pool,
|
||||
worker_id: None,
|
||||
worker_name,
|
||||
worker_type,
|
||||
worker_role,
|
||||
runtime_id,
|
||||
host,
|
||||
port,
|
||||
capabilities,
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect available runtimes using the unified runtime detector
|
||||
pub async fn detect_capabilities(&mut self, config: &Config) -> Result<()> {
|
||||
info!("Detecting worker capabilities...");
|
||||
|
||||
let detector = RuntimeDetector::new(self.pool.clone());
|
||||
|
||||
// Get config capabilities if available
|
||||
let config_capabilities = config.worker.as_ref().and_then(|w| w.capabilities.as_ref());
|
||||
|
||||
// Detect capabilities with three-tier priority:
|
||||
// 1. ATTUNE_WORKER_RUNTIMES env var
|
||||
// 2. Config file
|
||||
// 3. Database-driven detection
|
||||
let detected_capabilities = detector
|
||||
.detect_capabilities(config, "ATTUNE_WORKER_RUNTIMES", config_capabilities)
|
||||
.await?;
|
||||
|
||||
// Merge detected capabilities with existing ones
|
||||
for (key, value) in detected_capabilities {
|
||||
self.capabilities.insert(key, value);
|
||||
}
|
||||
|
||||
info!("Worker capabilities detected: {:?}", self.capabilities);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register the worker in the database
|
||||
pub async fn register(&mut self) -> Result<i64> {
|
||||
info!("Registering worker: {}", self.worker_name);
|
||||
|
||||
// Check if worker with this name already exists
|
||||
let existing = sqlx::query_as::<_, Worker>(
|
||||
"SELECT * FROM worker WHERE name = $1 ORDER BY created DESC LIMIT 1",
|
||||
)
|
||||
.bind(&self.worker_name)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
let worker_id = if let Some(existing_worker) = existing {
|
||||
info!(
|
||||
"Worker '{}' already exists (ID: {}), updating status",
|
||||
self.worker_name, existing_worker.id
|
||||
);
|
||||
|
||||
// Update existing worker to active status with new heartbeat
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE worker
|
||||
SET status = $1,
|
||||
last_heartbeat = $2,
|
||||
host = $3,
|
||||
port = $4,
|
||||
capabilities = $5,
|
||||
updated = $2
|
||||
WHERE id = $6
|
||||
"#,
|
||||
)
|
||||
.bind(WorkerStatus::Active)
|
||||
.bind(Utc::now())
|
||||
.bind(&self.host)
|
||||
.bind(self.port)
|
||||
.bind(serde_json::to_value(&self.capabilities)?)
|
||||
.bind(existing_worker.id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
existing_worker.id
|
||||
} else {
|
||||
info!("Creating new worker registration: {}", self.worker_name);
|
||||
|
||||
// Insert new worker
|
||||
let worker = sqlx::query_as::<_, Worker>(
|
||||
r#"
|
||||
INSERT INTO worker (name, worker_type, worker_role, runtime, host, port, status, capabilities, last_heartbeat)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(&self.worker_name)
|
||||
.bind(&self.worker_type)
|
||||
.bind(&self.worker_role)
|
||||
.bind(self.runtime_id)
|
||||
.bind(&self.host)
|
||||
.bind(self.port)
|
||||
.bind(WorkerStatus::Active)
|
||||
.bind(serde_json::to_value(&self.capabilities)?)
|
||||
.bind(Utc::now())
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
worker.id
|
||||
};
|
||||
|
||||
self.worker_id = Some(worker_id);
|
||||
info!("Worker registered successfully with ID: {}", worker_id);
|
||||
|
||||
Ok(worker_id)
|
||||
}
|
||||
|
||||
/// Deregister the worker (mark as inactive)
|
||||
pub async fn deregister(&self) -> Result<()> {
|
||||
if let Some(worker_id) = self.worker_id {
|
||||
info!("Deregistering worker ID: {}", worker_id);
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE worker
|
||||
SET status = $1,
|
||||
updated = $2
|
||||
WHERE id = $3
|
||||
"#,
|
||||
)
|
||||
.bind(WorkerStatus::Inactive)
|
||||
.bind(Utc::now())
|
||||
.bind(worker_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
info!("Worker deregistered successfully");
|
||||
} else {
|
||||
warn!("Cannot deregister: worker not registered");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update worker heartbeat
|
||||
pub async fn update_heartbeat(&self) -> Result<()> {
|
||||
if let Some(worker_id) = self.worker_id {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE worker
|
||||
SET last_heartbeat = $1,
|
||||
updated = $1
|
||||
WHERE id = $2
|
||||
"#,
|
||||
)
|
||||
.bind(Utc::now())
|
||||
.bind(worker_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
} else {
|
||||
return Err(Error::invalid_state("Worker not registered"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the registered worker ID
|
||||
pub fn worker_id(&self) -> Option<i64> {
|
||||
self.worker_id
|
||||
}
|
||||
|
||||
/// Get the worker name
|
||||
pub fn worker_name(&self) -> &str {
|
||||
&self.worker_name
|
||||
}
|
||||
|
||||
/// Add a capability to the worker
|
||||
pub fn add_capability(&mut self, key: String, value: serde_json::Value) {
|
||||
self.capabilities.insert(key, value);
|
||||
}
|
||||
|
||||
/// Update worker capabilities in the database
|
||||
pub async fn update_capabilities(&self) -> Result<()> {
|
||||
if let Some(worker_id) = self.worker_id {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE worker
|
||||
SET capabilities = $1,
|
||||
updated = $2
|
||||
WHERE id = $3
|
||||
"#,
|
||||
)
|
||||
.bind(serde_json::to_value(&self.capabilities)?)
|
||||
.bind(Utc::now())
|
||||
.bind(worker_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
info!("Worker capabilities updated");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WorkerRegistration {
|
||||
fn drop(&mut self) {
|
||||
// Note: We can't make this async, so we just log
|
||||
// The main service should call deregister() explicitly during shutdown
|
||||
if self.worker_id.is_some() {
|
||||
info!("WorkerRegistration dropped - worker should be deregistered");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires database
|
||||
async fn test_worker_registration() {
|
||||
let config = Config::load().unwrap();
|
||||
let db = attune_common::db::Database::new(&config.database)
|
||||
.await
|
||||
.unwrap();
|
||||
let pool = db.pool().clone();
|
||||
let mut registration = WorkerRegistration::new(pool, &config);
|
||||
|
||||
// Detect capabilities
|
||||
registration.detect_capabilities(&config).await.unwrap();
|
||||
|
||||
// Register worker
|
||||
let worker_id = registration.register().await.unwrap();
|
||||
assert!(worker_id > 0);
|
||||
assert_eq!(registration.worker_id(), Some(worker_id));
|
||||
|
||||
// Update heartbeat
|
||||
registration.update_heartbeat().await.unwrap();
|
||||
|
||||
// Deregister worker
|
||||
registration.deregister().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires database
|
||||
async fn test_worker_capabilities() {
|
||||
let config = Config::load().unwrap();
|
||||
let db = attune_common::db::Database::new(&config.database)
|
||||
.await
|
||||
.unwrap();
|
||||
let pool = db.pool().clone();
|
||||
let mut registration = WorkerRegistration::new(pool, &config);
|
||||
|
||||
registration.detect_capabilities(&config).await.unwrap();
|
||||
registration.register().await.unwrap();
|
||||
|
||||
// Add capability
|
||||
registration.add_capability("test_capability".to_string(), json!(true));
|
||||
registration.update_capabilities().await.unwrap();
|
||||
|
||||
registration.deregister().await.unwrap();
|
||||
}
|
||||
}
|
||||
320
crates/worker/src/runtime/dependency.rs
Normal file
320
crates/worker/src/runtime/dependency.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! Runtime Dependency Management
|
||||
//!
|
||||
//! Provides generic abstractions for managing runtime dependencies across
|
||||
//! different languages (Python, Node.js, Java, etc.).
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Dependency manager result type
|
||||
pub type DependencyResult<T> = std::result::Result<T, DependencyError>;
|
||||
|
||||
/// Dependency manager errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DependencyError {
|
||||
#[error("Failed to create environment: {0}")]
|
||||
CreateEnvironmentFailed(String),
|
||||
|
||||
#[error("Failed to install dependencies: {0}")]
|
||||
InstallFailed(String),
|
||||
|
||||
#[error("Environment not found: {0}")]
|
||||
EnvironmentNotFound(String),
|
||||
|
||||
#[error("Invalid dependency specification: {0}")]
|
||||
InvalidDependencySpec(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[error("Process execution error: {0}")]
|
||||
ProcessError(String),
|
||||
|
||||
#[error("Lock file error: {0}")]
|
||||
LockFileError(String),
|
||||
|
||||
#[error("Environment validation failed: {0}")]
|
||||
ValidationFailed(String),
|
||||
}
|
||||
|
||||
/// Dependency specification for a pack
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DependencySpec {
|
||||
/// Runtime type (python, nodejs, java, etc.)
|
||||
pub runtime: String,
|
||||
|
||||
/// List of dependencies (e.g., ["requests==2.28.0", "flask>=2.0.0"])
|
||||
pub dependencies: Vec<String>,
|
||||
|
||||
/// Requirements file content (alternative to dependencies list)
|
||||
pub requirements_file_content: Option<String>,
|
||||
|
||||
/// Minimum runtime version required
|
||||
pub min_version: Option<String>,
|
||||
|
||||
/// Maximum runtime version required
|
||||
pub max_version: Option<String>,
|
||||
|
||||
/// Additional metadata
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl DependencySpec {
|
||||
/// Create a new dependency specification
|
||||
pub fn new(runtime: impl Into<String>) -> Self {
|
||||
Self {
|
||||
runtime: runtime.into(),
|
||||
dependencies: Vec::new(),
|
||||
requirements_file_content: None,
|
||||
min_version: None,
|
||||
max_version: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a dependency
|
||||
pub fn with_dependency(mut self, dep: impl Into<String>) -> Self {
|
||||
self.dependencies.push(dep.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple dependencies
|
||||
pub fn with_dependencies(mut self, deps: Vec<String>) -> Self {
|
||||
self.dependencies.extend(deps);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set requirements file content
|
||||
pub fn with_requirements_file(mut self, content: String) -> Self {
|
||||
self.requirements_file_content = Some(content);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set version constraints
|
||||
pub fn with_version_range(
|
||||
mut self,
|
||||
min_version: Option<String>,
|
||||
max_version: Option<String>,
|
||||
) -> Self {
|
||||
self.min_version = min_version;
|
||||
self.max_version = max_version;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if this spec has any dependencies
|
||||
pub fn has_dependencies(&self) -> bool {
|
||||
!self.dependencies.is_empty() || self.requirements_file_content.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about an isolated environment
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EnvironmentInfo {
|
||||
/// Unique environment identifier (typically pack_ref)
|
||||
pub id: String,
|
||||
|
||||
/// Path to the environment directory
|
||||
pub path: PathBuf,
|
||||
|
||||
/// Runtime type
|
||||
pub runtime: String,
|
||||
|
||||
/// Runtime version in the environment
|
||||
pub runtime_version: String,
|
||||
|
||||
/// List of installed dependencies
|
||||
pub installed_dependencies: Vec<String>,
|
||||
|
||||
/// Timestamp when environment was created
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
|
||||
/// Timestamp when environment was last updated
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
|
||||
/// Whether the environment is valid and ready to use
|
||||
pub is_valid: bool,
|
||||
|
||||
/// Environment-specific executable path (e.g., venv/bin/python)
|
||||
pub executable_path: PathBuf,
|
||||
}
|
||||
|
||||
/// Trait for managing isolated runtime environments
|
||||
#[async_trait]
|
||||
pub trait DependencyManager: Send + Sync {
|
||||
/// Get the runtime type this manager handles (e.g., "python", "nodejs")
|
||||
fn runtime_type(&self) -> &str;
|
||||
|
||||
/// Create or update an isolated environment for a pack
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `pack_ref` - Unique identifier for the pack (e.g., "core.http")
|
||||
/// * `spec` - Dependency specification
|
||||
///
|
||||
/// # Returns
|
||||
/// Information about the created/updated environment
|
||||
async fn ensure_environment(
|
||||
&self,
|
||||
pack_ref: &str,
|
||||
spec: &DependencySpec,
|
||||
) -> DependencyResult<EnvironmentInfo>;
|
||||
|
||||
/// Get information about an existing environment
|
||||
async fn get_environment(&self, pack_ref: &str) -> DependencyResult<Option<EnvironmentInfo>>;
|
||||
|
||||
/// Remove an environment
|
||||
async fn remove_environment(&self, pack_ref: &str) -> DependencyResult<()>;
|
||||
|
||||
/// Validate an environment is still functional
|
||||
async fn validate_environment(&self, pack_ref: &str) -> DependencyResult<bool>;
|
||||
|
||||
/// Get the executable path for running actions in this environment
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `pack_ref` - Pack identifier
|
||||
///
|
||||
/// # Returns
|
||||
/// Path to the runtime executable within the isolated environment
|
||||
async fn get_executable_path(&self, pack_ref: &str) -> DependencyResult<PathBuf>;
|
||||
|
||||
/// List all managed environments
|
||||
async fn list_environments(&self) -> DependencyResult<Vec<EnvironmentInfo>>;
|
||||
|
||||
/// Clean up invalid or unused environments
|
||||
async fn cleanup(&self, keep_recent: usize) -> DependencyResult<Vec<String>>;
|
||||
|
||||
/// Check if dependencies have changed and environment needs updating
|
||||
async fn needs_update(&self, pack_ref: &str, _spec: &DependencySpec) -> DependencyResult<bool> {
|
||||
// Default implementation: check if environment exists and validate it
|
||||
match self.get_environment(pack_ref).await? {
|
||||
None => Ok(true), // Doesn't exist, needs creation
|
||||
Some(env_info) => {
|
||||
// Check if environment is valid
|
||||
if !env_info.is_valid {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Could add more sophisticated checks here (dependency hash comparison, etc.)
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Registry for managing multiple dependency managers
|
||||
pub struct DependencyManagerRegistry {
|
||||
managers: HashMap<String, Box<dyn DependencyManager>>,
|
||||
}
|
||||
|
||||
impl DependencyManagerRegistry {
|
||||
/// Create a new registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
managers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a dependency manager
|
||||
pub fn register(&mut self, manager: Box<dyn DependencyManager>) {
|
||||
let runtime_type = manager.runtime_type().to_string();
|
||||
self.managers.insert(runtime_type, manager);
|
||||
}
|
||||
|
||||
/// Get a dependency manager by runtime type
|
||||
pub fn get(&self, runtime_type: &str) -> Option<&dyn DependencyManager> {
|
||||
self.managers.get(runtime_type).map(|m| m.as_ref())
|
||||
}
|
||||
|
||||
/// Check if a runtime type is supported
|
||||
pub fn supports(&self, runtime_type: &str) -> bool {
|
||||
self.managers.contains_key(runtime_type)
|
||||
}
|
||||
|
||||
/// List all supported runtime types
|
||||
pub fn supported_runtimes(&self) -> Vec<String> {
|
||||
self.managers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Ensure environment for a pack with given spec
|
||||
pub async fn ensure_environment(
|
||||
&self,
|
||||
pack_ref: &str,
|
||||
spec: &DependencySpec,
|
||||
) -> DependencyResult<EnvironmentInfo> {
|
||||
let manager = self.get(&spec.runtime).ok_or_else(|| {
|
||||
DependencyError::InvalidDependencySpec(format!(
|
||||
"No dependency manager found for runtime: {}",
|
||||
spec.runtime
|
||||
))
|
||||
})?;
|
||||
|
||||
manager.ensure_environment(pack_ref, spec).await
|
||||
}
|
||||
|
||||
/// Get executable path for a pack
|
||||
pub async fn get_executable_path(
|
||||
&self,
|
||||
pack_ref: &str,
|
||||
runtime_type: &str,
|
||||
) -> DependencyResult<PathBuf> {
|
||||
let manager = self.get(runtime_type).ok_or_else(|| {
|
||||
DependencyError::InvalidDependencySpec(format!(
|
||||
"No dependency manager found for runtime: {}",
|
||||
runtime_type
|
||||
))
|
||||
})?;
|
||||
|
||||
manager.get_executable_path(pack_ref).await
|
||||
}
|
||||
|
||||
/// Cleanup all managers
|
||||
pub async fn cleanup_all(&self, keep_recent: usize) -> DependencyResult<Vec<String>> {
|
||||
let mut removed = Vec::new();
|
||||
|
||||
for manager in self.managers.values() {
|
||||
let mut cleaned = manager.cleanup(keep_recent).await?;
|
||||
removed.append(&mut cleaned);
|
||||
}
|
||||
|
||||
Ok(removed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DependencyManagerRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dependency_spec_builder() {
|
||||
let spec = DependencySpec::new("python")
|
||||
.with_dependency("requests==2.28.0")
|
||||
.with_dependency("flask>=2.0.0")
|
||||
.with_version_range(Some("3.8".to_string()), Some("3.11".to_string()));
|
||||
|
||||
assert_eq!(spec.runtime, "python");
|
||||
assert_eq!(spec.dependencies.len(), 2);
|
||||
assert!(spec.has_dependencies());
|
||||
assert_eq!(spec.min_version, Some("3.8".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dependency_spec_empty() {
|
||||
let spec = DependencySpec::new("nodejs");
|
||||
assert!(!spec.has_dependencies());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dependency_manager_registry() {
|
||||
let registry = DependencyManagerRegistry::new();
|
||||
assert_eq!(registry.supported_runtimes().len(), 0);
|
||||
assert!(!registry.supports("python"));
|
||||
}
|
||||
}
|
||||
207
crates/worker/src/runtime/local.rs
Normal file
207
crates/worker/src/runtime/local.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
//! Local Runtime Module
|
||||
//!
|
||||
//! Provides local execution capabilities by combining Python and Shell runtimes.
|
||||
//! This module serves as a facade for all local process-based execution.
|
||||
|
||||
use super::native::NativeRuntime;
|
||||
use super::python::PythonRuntime;
|
||||
use super::shell::ShellRuntime;
|
||||
use super::{ExecutionContext, ExecutionResult, Runtime, RuntimeError, RuntimeResult};
|
||||
use async_trait::async_trait;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Local runtime that delegates to Python, Shell, or Native based on action type
|
||||
pub struct LocalRuntime {
|
||||
native: NativeRuntime,
|
||||
python: PythonRuntime,
|
||||
shell: ShellRuntime,
|
||||
}
|
||||
|
||||
impl LocalRuntime {
|
||||
/// Create a new local runtime with default settings
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
native: NativeRuntime::new(),
|
||||
python: PythonRuntime::new(),
|
||||
shell: ShellRuntime::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a local runtime with custom runtimes
|
||||
pub fn with_runtimes(
|
||||
native: NativeRuntime,
|
||||
python: PythonRuntime,
|
||||
shell: ShellRuntime,
|
||||
) -> Self {
|
||||
Self {
|
||||
native,
|
||||
python,
|
||||
shell,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the appropriate runtime for the given context
|
||||
fn select_runtime(&self, context: &ExecutionContext) -> RuntimeResult<&dyn Runtime> {
|
||||
if self.native.can_execute(context) {
|
||||
debug!("Selected Native runtime for action: {}", context.action_ref);
|
||||
Ok(&self.native)
|
||||
} else if self.python.can_execute(context) {
|
||||
debug!("Selected Python runtime for action: {}", context.action_ref);
|
||||
Ok(&self.python)
|
||||
} else if self.shell.can_execute(context) {
|
||||
debug!("Selected Shell runtime for action: {}", context.action_ref);
|
||||
Ok(&self.shell)
|
||||
} else {
|
||||
Err(RuntimeError::RuntimeNotFound(format!(
|
||||
"No suitable local runtime found for action: {}",
|
||||
context.action_ref
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LocalRuntime {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Runtime for LocalRuntime {
|
||||
fn name(&self) -> &str {
|
||||
"local"
|
||||
}
|
||||
|
||||
fn can_execute(&self, context: &ExecutionContext) -> bool {
|
||||
self.native.can_execute(context)
|
||||
|| self.python.can_execute(context)
|
||||
|| self.shell.can_execute(context)
|
||||
}
|
||||
|
||||
async fn execute(&self, context: ExecutionContext) -> RuntimeResult<ExecutionResult> {
|
||||
info!(
|
||||
"Executing local action: {} (execution_id: {})",
|
||||
context.action_ref, context.execution_id
|
||||
);
|
||||
|
||||
let runtime = self.select_runtime(&context)?;
|
||||
runtime.execute(context).await
|
||||
}
|
||||
|
||||
async fn setup(&self) -> RuntimeResult<()> {
|
||||
info!("Setting up Local runtime");
|
||||
|
||||
self.native.setup().await?;
|
||||
self.python.setup().await?;
|
||||
self.shell.setup().await?;
|
||||
|
||||
info!("Local runtime setup complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cleanup(&self) -> RuntimeResult<()> {
|
||||
info!("Cleaning up Local runtime");
|
||||
|
||||
self.native.cleanup().await?;
|
||||
self.python.cleanup().await?;
|
||||
self.shell.cleanup().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate(&self) -> RuntimeResult<()> {
|
||||
debug!("Validating Local runtime");
|
||||
|
||||
self.native.validate().await?;
|
||||
self.python.validate().await?;
|
||||
self.shell.validate().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_local_runtime_python() {
|
||||
let runtime = LocalRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 1,
|
||||
action_ref: "test.python_action".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run():
|
||||
return "hello from python"
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
assert!(runtime.can_execute(&context));
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_local_runtime_shell() {
|
||||
let runtime = LocalRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 2,
|
||||
action_ref: "test.shell_action".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some("echo 'hello from shell'".to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
assert!(runtime.can_execute(&context));
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success());
|
||||
assert!(result.stdout.contains("hello from shell"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_local_runtime_unknown() {
|
||||
let runtime = LocalRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 3,
|
||||
action_ref: "test.unknown_action".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "unknown".to_string(),
|
||||
code: Some("some code".to_string()),
|
||||
code_path: None,
|
||||
runtime_name: None,
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
assert!(!runtime.can_execute(&context));
|
||||
}
|
||||
}
|
||||
300
crates/worker/src/runtime/log_writer.rs
Normal file
300
crates/worker/src/runtime/log_writer.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
//! Log Writer Module
|
||||
//!
|
||||
//! Provides bounded log writers that limit output size to prevent OOM issues.
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
const TRUNCATION_NOTICE_STDOUT: &str = "\n\n[OUTPUT TRUNCATED: stdout exceeded size limit]\n";
|
||||
const TRUNCATION_NOTICE_STDERR: &str = "\n\n[OUTPUT TRUNCATED: stderr exceeded size limit]\n";
|
||||
|
||||
// Reserve space for truncation notice so it can always fit
|
||||
const NOTICE_RESERVE_BYTES: usize = 128;
|
||||
|
||||
/// Result of bounded log writing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BoundedLogResult {
|
||||
/// The captured log content
|
||||
pub content: String,
|
||||
|
||||
/// Whether the log was truncated
|
||||
pub truncated: bool,
|
||||
|
||||
/// Number of bytes truncated (0 if not truncated)
|
||||
pub bytes_truncated: usize,
|
||||
|
||||
/// Total bytes attempted to write
|
||||
pub total_bytes_attempted: usize,
|
||||
}
|
||||
|
||||
impl BoundedLogResult {
|
||||
/// Create a new result with no truncation
|
||||
pub fn new(content: String) -> Self {
|
||||
let len = content.len();
|
||||
Self {
|
||||
content,
|
||||
truncated: false,
|
||||
bytes_truncated: 0,
|
||||
total_bytes_attempted: len,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a truncated result
|
||||
pub fn truncated(
|
||||
content: String,
|
||||
bytes_truncated: usize,
|
||||
total_bytes_attempted: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
content,
|
||||
truncated: true,
|
||||
bytes_truncated,
|
||||
total_bytes_attempted,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A writer that limits the amount of data captured and adds a truncation notice
|
||||
pub struct BoundedLogWriter {
|
||||
/// Internal buffer for captured data
|
||||
buffer: Vec<u8>,
|
||||
|
||||
/// Maximum bytes to capture
|
||||
max_bytes: usize,
|
||||
|
||||
/// Whether we've already truncated and added the notice
|
||||
truncated: bool,
|
||||
|
||||
/// Total bytes attempted to write (including truncated)
|
||||
total_bytes_attempted: usize,
|
||||
|
||||
/// Actual data bytes written to buffer (excluding truncation notice)
|
||||
data_bytes_written: usize,
|
||||
|
||||
/// Truncation notice to append when limit is reached
|
||||
truncation_notice: &'static str,
|
||||
}
|
||||
|
||||
impl BoundedLogWriter {
|
||||
/// Create a new bounded log writer for stdout
|
||||
pub fn new_stdout(max_bytes: usize) -> Self {
|
||||
Self {
|
||||
buffer: Vec::with_capacity(std::cmp::min(max_bytes, 1024 * 1024)),
|
||||
max_bytes,
|
||||
truncated: false,
|
||||
total_bytes_attempted: 0,
|
||||
data_bytes_written: 0,
|
||||
truncation_notice: TRUNCATION_NOTICE_STDOUT,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new bounded log writer for stderr
|
||||
pub fn new_stderr(max_bytes: usize) -> Self {
|
||||
Self {
|
||||
buffer: Vec::with_capacity(std::cmp::min(max_bytes, 1024 * 1024)),
|
||||
max_bytes,
|
||||
truncated: false,
|
||||
total_bytes_attempted: 0,
|
||||
data_bytes_written: 0,
|
||||
truncation_notice: TRUNCATION_NOTICE_STDERR,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the result with truncation information
|
||||
pub fn into_result(self) -> BoundedLogResult {
|
||||
let content = String::from_utf8_lossy(&self.buffer).to_string();
|
||||
|
||||
if self.truncated {
|
||||
BoundedLogResult::truncated(
|
||||
content,
|
||||
self.total_bytes_attempted
|
||||
.saturating_sub(self.data_bytes_written),
|
||||
self.total_bytes_attempted,
|
||||
)
|
||||
} else {
|
||||
BoundedLogResult::new(content)
|
||||
}
|
||||
}
|
||||
|
||||
/// Write data to the buffer, respecting size limits
|
||||
fn write_bounded(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.total_bytes_attempted = self.total_bytes_attempted.saturating_add(buf.len());
|
||||
|
||||
// If already truncated, discard all further writes
|
||||
if self.truncated {
|
||||
return Ok(buf.len()); // Pretend we wrote it all
|
||||
}
|
||||
|
||||
let current_size = self.buffer.len();
|
||||
// Reserve space for truncation notice
|
||||
let effective_limit = self.max_bytes.saturating_sub(NOTICE_RESERVE_BYTES);
|
||||
let remaining_space = effective_limit.saturating_sub(current_size);
|
||||
|
||||
if remaining_space == 0 {
|
||||
// Already at limit, add truncation notice if not already added
|
||||
if !self.truncated {
|
||||
self.add_truncation_notice();
|
||||
}
|
||||
return Ok(buf.len()); // Pretend we wrote it all
|
||||
}
|
||||
|
||||
// Calculate how much we can actually write
|
||||
let bytes_to_write = std::cmp::min(buf.len(), remaining_space);
|
||||
|
||||
if bytes_to_write < buf.len() {
|
||||
// We're about to hit the limit
|
||||
self.buffer.extend_from_slice(&buf[..bytes_to_write]);
|
||||
self.data_bytes_written += bytes_to_write;
|
||||
self.add_truncation_notice();
|
||||
} else {
|
||||
// We can write everything
|
||||
self.buffer.extend_from_slice(&buf[..bytes_to_write]);
|
||||
self.data_bytes_written += bytes_to_write;
|
||||
}
|
||||
|
||||
Ok(buf.len()) // Always report full write to avoid backpressure issues
|
||||
}
|
||||
|
||||
/// Add truncation notice to the buffer
|
||||
fn add_truncation_notice(&mut self) {
|
||||
self.truncated = true;
|
||||
|
||||
let notice_bytes = self.truncation_notice.as_bytes();
|
||||
// We reserved space, so the notice should always fit
|
||||
self.buffer.extend_from_slice(notice_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BoundedLogWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Poll::Ready(self.write_bounded(buf))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_under_limit() {
|
||||
let mut writer = BoundedLogWriter::new_stdout(1024);
|
||||
let data = b"Hello, world!";
|
||||
|
||||
writer.write_all(data).await.unwrap();
|
||||
|
||||
let result = writer.into_result();
|
||||
assert_eq!(result.content, "Hello, world!");
|
||||
assert!(!result.truncated);
|
||||
assert_eq!(result.bytes_truncated, 0);
|
||||
assert_eq!(result.total_bytes_attempted, 13);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_at_limit() {
|
||||
// With 178 bytes, we can fit 50 bytes (178 - 128 reserve = 50)
|
||||
let mut writer = BoundedLogWriter::new_stdout(178);
|
||||
let data = b"12345678901234567890123456789012345678901234567890"; // 50 bytes
|
||||
|
||||
writer.write_all(data).await.unwrap();
|
||||
|
||||
let result = writer.into_result();
|
||||
assert_eq!(result.content.len(), 50);
|
||||
assert!(!result.truncated);
|
||||
assert_eq!(result.bytes_truncated, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_exceeds_limit() {
|
||||
// 148 bytes means effective limit is 20 (148 - 128 = 20)
|
||||
let mut writer = BoundedLogWriter::new_stdout(148);
|
||||
let data = b"This is a long message that exceeds the limit";
|
||||
|
||||
writer.write_all(data).await.unwrap();
|
||||
|
||||
let result = writer.into_result();
|
||||
assert!(result.truncated);
|
||||
assert!(result.content.contains("[OUTPUT TRUNCATED"));
|
||||
assert!(result.bytes_truncated > 0);
|
||||
assert_eq!(result.total_bytes_attempted, 45);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_multiple_writes() {
|
||||
// 148 bytes means effective limit is 20 (148 - 128 = 20)
|
||||
let mut writer = BoundedLogWriter::new_stdout(148);
|
||||
|
||||
writer.write_all(b"First ").await.unwrap(); // 6 bytes
|
||||
writer.write_all(b"Second ").await.unwrap(); // 7 bytes = 13 total
|
||||
writer.write_all(b"Third ").await.unwrap(); // 6 bytes = 19 total
|
||||
writer.write_all(b"Fourth ").await.unwrap(); // 7 bytes = 26 total, exceeds 20 limit
|
||||
|
||||
let result = writer.into_result();
|
||||
assert!(result.truncated);
|
||||
assert!(result.content.contains("[OUTPUT TRUNCATED"));
|
||||
assert_eq!(result.total_bytes_attempted, 26);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_stderr_notice() {
|
||||
// 143 bytes means effective limit is 15 (143 - 128 = 15)
|
||||
let mut writer = BoundedLogWriter::new_stderr(143);
|
||||
let data = b"Error message that is too long";
|
||||
|
||||
writer.write_all(data).await.unwrap();
|
||||
|
||||
let result = writer.into_result();
|
||||
assert!(result.truncated);
|
||||
assert!(result.content.contains("stderr exceeded size limit"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_empty() {
|
||||
let writer = BoundedLogWriter::new_stdout(1024);
|
||||
|
||||
let result = writer.into_result();
|
||||
assert_eq!(result.content, "");
|
||||
assert!(!result.truncated);
|
||||
assert_eq!(result.bytes_truncated, 0);
|
||||
assert_eq!(result.total_bytes_attempted, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_exact_limit_no_truncation_notice() {
|
||||
// 138 bytes means effective limit is 10 (138 - 128 = 10)
|
||||
let mut writer = BoundedLogWriter::new_stdout(138);
|
||||
let data = b"1234567890"; // Exactly 10 bytes
|
||||
|
||||
writer.write_all(data).await.unwrap();
|
||||
|
||||
let result = writer.into_result();
|
||||
assert_eq!(result.content, "1234567890");
|
||||
assert!(!result.truncated);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bounded_writer_one_byte_over() {
|
||||
// 138 bytes means effective limit is 10 (138 - 128 = 10)
|
||||
let mut writer = BoundedLogWriter::new_stdout(138);
|
||||
let data = b"12345678901"; // 11 bytes
|
||||
|
||||
writer.write_all(data).await.unwrap();
|
||||
|
||||
let result = writer.into_result();
|
||||
assert!(result.truncated);
|
||||
assert_eq!(result.bytes_truncated, 1);
|
||||
}
|
||||
}
|
||||
330
crates/worker/src/runtime/mod.rs
Normal file
330
crates/worker/src/runtime/mod.rs
Normal file
@@ -0,0 +1,330 @@
|
||||
//! Runtime Module
|
||||
//!
|
||||
//! Provides runtime abstraction and implementations for executing actions
|
||||
//! in different environments (Python, Shell, Node.js, Containers).
|
||||
|
||||
pub mod dependency;
|
||||
pub mod local;
|
||||
pub mod log_writer;
|
||||
pub mod native;
|
||||
pub mod python;
|
||||
pub mod python_venv;
|
||||
pub mod shell;
|
||||
|
||||
// Re-export runtime implementations
|
||||
pub use local::LocalRuntime;
|
||||
pub use native::NativeRuntime;
|
||||
pub use python::PythonRuntime;
|
||||
pub use shell::ShellRuntime;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error;
|
||||
|
||||
// Re-export dependency management types
|
||||
pub use dependency::{
|
||||
DependencyError, DependencyManager, DependencyManagerRegistry, DependencyResult,
|
||||
DependencySpec, EnvironmentInfo,
|
||||
};
|
||||
pub use log_writer::{BoundedLogResult, BoundedLogWriter};
|
||||
pub use python_venv::PythonVenvManager;
|
||||
|
||||
/// Runtime execution result
|
||||
pub type RuntimeResult<T> = std::result::Result<T, RuntimeError>;
|
||||
|
||||
/// Runtime execution errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RuntimeError {
|
||||
#[error("Execution failed: {0}")]
|
||||
ExecutionFailed(String),
|
||||
|
||||
#[error("Timeout after {0} seconds")]
|
||||
Timeout(u64),
|
||||
|
||||
#[error("Runtime not found: {0}")]
|
||||
RuntimeNotFound(String),
|
||||
|
||||
#[error("Invalid action: {0}")]
|
||||
InvalidAction(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Process error: {0}")]
|
||||
ProcessError(String),
|
||||
|
||||
#[error("Setup error: {0}")]
|
||||
SetupError(String),
|
||||
|
||||
#[error("Cleanup error: {0}")]
|
||||
CleanupError(String),
|
||||
}
|
||||
|
||||
/// Action execution context
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionContext {
|
||||
/// Execution ID
|
||||
pub execution_id: i64,
|
||||
|
||||
/// Action reference (pack.action)
|
||||
pub action_ref: String,
|
||||
|
||||
/// Action parameters
|
||||
pub parameters: HashMap<String, serde_json::Value>,
|
||||
|
||||
/// Environment variables
|
||||
pub env: HashMap<String, String>,
|
||||
|
||||
/// Secrets (passed securely via stdin, not environment variables)
|
||||
pub secrets: HashMap<String, String>,
|
||||
|
||||
/// Execution timeout in seconds
|
||||
pub timeout: Option<u64>,
|
||||
|
||||
/// Working directory
|
||||
pub working_dir: Option<PathBuf>,
|
||||
|
||||
/// Action entry point (script, function, etc.)
|
||||
pub entry_point: String,
|
||||
|
||||
/// Action code/script content
|
||||
pub code: Option<String>,
|
||||
|
||||
/// Action code file path (alternative to code)
|
||||
pub code_path: Option<PathBuf>,
|
||||
|
||||
/// Runtime name (python, shell, etc.) - used to select the correct runtime
|
||||
pub runtime_name: Option<String>,
|
||||
|
||||
/// Maximum stdout size in bytes (for log truncation)
|
||||
#[serde(default = "default_max_log_bytes")]
|
||||
pub max_stdout_bytes: usize,
|
||||
|
||||
/// Maximum stderr size in bytes (for log truncation)
|
||||
#[serde(default = "default_max_log_bytes")]
|
||||
pub max_stderr_bytes: usize,
|
||||
}
|
||||
|
||||
fn default_max_log_bytes() -> usize {
|
||||
10 * 1024 * 1024 // 10MB
|
||||
}
|
||||
|
||||
impl ExecutionContext {
|
||||
/// Create a test context with default values (for tests)
|
||||
#[cfg(test)]
|
||||
pub fn test_context(action_ref: String, code: Option<String>) -> Self {
|
||||
use std::collections::HashMap;
|
||||
Self {
|
||||
execution_id: 1,
|
||||
action_ref,
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code,
|
||||
code_path: None,
|
||||
runtime_name: None,
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Action execution result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionResult {
|
||||
/// Exit code (0 = success)
|
||||
pub exit_code: i32,
|
||||
|
||||
/// Standard output
|
||||
pub stdout: String,
|
||||
|
||||
/// Standard error
|
||||
pub stderr: String,
|
||||
|
||||
/// Execution result data (parsed from stdout or returned by action)
|
||||
pub result: Option<serde_json::Value>,
|
||||
|
||||
/// Execution duration in milliseconds
|
||||
pub duration_ms: u64,
|
||||
|
||||
/// Error message if execution failed
|
||||
pub error: Option<String>,
|
||||
|
||||
/// Whether stdout was truncated due to size limits
|
||||
#[serde(default)]
|
||||
pub stdout_truncated: bool,
|
||||
|
||||
/// Whether stderr was truncated due to size limits
|
||||
#[serde(default)]
|
||||
pub stderr_truncated: bool,
|
||||
|
||||
/// Number of bytes truncated from stdout (0 if not truncated)
|
||||
#[serde(default)]
|
||||
pub stdout_bytes_truncated: usize,
|
||||
|
||||
/// Number of bytes truncated from stderr (0 if not truncated)
|
||||
#[serde(default)]
|
||||
pub stderr_bytes_truncated: usize,
|
||||
}
|
||||
|
||||
impl ExecutionResult {
|
||||
/// Check if execution was successful
|
||||
pub fn is_success(&self) -> bool {
|
||||
self.exit_code == 0 && self.error.is_none()
|
||||
}
|
||||
|
||||
/// Create a success result
|
||||
pub fn success(stdout: String, result: Option<serde_json::Value>, duration_ms: u64) -> Self {
|
||||
Self {
|
||||
exit_code: 0,
|
||||
stdout,
|
||||
stderr: String::new(),
|
||||
result,
|
||||
duration_ms,
|
||||
error: None,
|
||||
stdout_truncated: false,
|
||||
stderr_truncated: false,
|
||||
stdout_bytes_truncated: 0,
|
||||
stderr_bytes_truncated: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a failure result
|
||||
pub fn failure(exit_code: i32, stderr: String, error: String, duration_ms: u64) -> Self {
|
||||
Self {
|
||||
exit_code,
|
||||
stdout: String::new(),
|
||||
stderr,
|
||||
result: None,
|
||||
duration_ms,
|
||||
error: Some(error),
|
||||
stdout_truncated: false,
|
||||
stderr_truncated: false,
|
||||
stdout_bytes_truncated: 0,
|
||||
stderr_bytes_truncated: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runtime trait for executing actions
|
||||
#[async_trait]
|
||||
pub trait Runtime: Send + Sync {
|
||||
/// Get the runtime name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Check if this runtime can execute the given action
|
||||
fn can_execute(&self, context: &ExecutionContext) -> bool;
|
||||
|
||||
/// Execute an action
|
||||
async fn execute(&self, context: ExecutionContext) -> RuntimeResult<ExecutionResult>;
|
||||
|
||||
/// Setup the runtime environment (called once on worker startup)
|
||||
async fn setup(&self) -> RuntimeResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cleanup the runtime environment (called on worker shutdown)
|
||||
async fn cleanup(&self) -> RuntimeResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate the runtime is properly configured
|
||||
async fn validate(&self) -> RuntimeResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Runtime registry for managing multiple runtime implementations
|
||||
pub struct RuntimeRegistry {
|
||||
runtimes: Vec<Box<dyn Runtime>>,
|
||||
}
|
||||
|
||||
impl RuntimeRegistry {
|
||||
/// Create a new runtime registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
runtimes: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a runtime
|
||||
pub fn register(&mut self, runtime: Box<dyn Runtime>) {
|
||||
self.runtimes.push(runtime);
|
||||
}
|
||||
|
||||
/// Get a runtime that can execute the given context
|
||||
pub fn get_runtime(&self, context: &ExecutionContext) -> RuntimeResult<&dyn Runtime> {
|
||||
// If runtime_name is specified, use it to select the runtime directly
|
||||
if let Some(ref runtime_name) = context.runtime_name {
|
||||
return self
|
||||
.runtimes
|
||||
.iter()
|
||||
.find(|r| r.name() == runtime_name)
|
||||
.map(|r| r.as_ref())
|
||||
.ok_or_else(|| {
|
||||
RuntimeError::RuntimeNotFound(format!(
|
||||
"Runtime '{}' not found for action: {} (available: {})",
|
||||
runtime_name,
|
||||
context.action_ref,
|
||||
self.list_runtimes().join(", ")
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
// Otherwise, fall back to can_execute check
|
||||
self.runtimes
|
||||
.iter()
|
||||
.find(|r| r.can_execute(context))
|
||||
.map(|r| r.as_ref())
|
||||
.ok_or_else(|| {
|
||||
RuntimeError::RuntimeNotFound(format!(
|
||||
"No runtime found for action: {} (available: {})",
|
||||
context.action_ref,
|
||||
self.list_runtimes().join(", ")
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Setup all registered runtimes
|
||||
pub async fn setup_all(&self) -> RuntimeResult<()> {
|
||||
for runtime in &self.runtimes {
|
||||
runtime.setup().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cleanup all registered runtimes
|
||||
pub async fn cleanup_all(&self) -> RuntimeResult<()> {
|
||||
for runtime in &self.runtimes {
|
||||
runtime.cleanup().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate all registered runtimes
|
||||
pub async fn validate_all(&self) -> RuntimeResult<()> {
|
||||
for runtime in &self.runtimes {
|
||||
runtime.validate().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all registered runtimes
|
||||
pub fn list_runtimes(&self) -> Vec<&str> {
|
||||
self.runtimes.iter().map(|r| r.name()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RuntimeRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
493
crates/worker/src/runtime/native.rs
Normal file
493
crates/worker/src/runtime/native.rs
Normal file
@@ -0,0 +1,493 @@
|
||||
//! Native Runtime
|
||||
//!
|
||||
//! Executes compiled native binaries directly without any shell or interpreter wrapper.
|
||||
//! This runtime is used for Rust binaries and other compiled executables.
|
||||
|
||||
use super::{
|
||||
BoundedLogWriter, ExecutionContext, ExecutionResult, Runtime, RuntimeError, RuntimeResult,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::process::Stdio;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Native runtime for executing compiled binaries
|
||||
pub struct NativeRuntime {
|
||||
work_dir: Option<std::path::PathBuf>,
|
||||
}
|
||||
|
||||
impl NativeRuntime {
|
||||
/// Create a new native runtime
|
||||
pub fn new() -> Self {
|
||||
Self { work_dir: None }
|
||||
}
|
||||
|
||||
/// Create a native runtime with custom working directory
|
||||
pub fn with_work_dir(work_dir: std::path::PathBuf) -> Self {
|
||||
Self {
|
||||
work_dir: Some(work_dir),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a native binary with parameters and environment variables
|
||||
async fn execute_binary(
|
||||
&self,
|
||||
binary_path: std::path::PathBuf,
|
||||
parameters: &std::collections::HashMap<String, serde_json::Value>,
|
||||
secrets: &std::collections::HashMap<String, String>,
|
||||
env: &std::collections::HashMap<String, String>,
|
||||
exec_timeout: Option<u64>,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
) -> RuntimeResult<ExecutionResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Check if binary exists and is executable
|
||||
if !binary_path.exists() {
|
||||
return Err(RuntimeError::ExecutionFailed(format!(
|
||||
"Binary not found: {}",
|
||||
binary_path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let metadata = std::fs::metadata(&binary_path)?;
|
||||
let permissions = metadata.permissions();
|
||||
if permissions.mode() & 0o111 == 0 {
|
||||
return Err(RuntimeError::ExecutionFailed(format!(
|
||||
"Binary is not executable: {}",
|
||||
binary_path.display()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Executing native binary: {}", binary_path.display());
|
||||
|
||||
// Build command
|
||||
let mut cmd = Command::new(&binary_path);
|
||||
|
||||
// Set working directory
|
||||
if let Some(ref work_dir) = self.work_dir {
|
||||
cmd.current_dir(work_dir);
|
||||
}
|
||||
|
||||
// Add environment variables
|
||||
for (key, value) in env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
// Add parameters as environment variables with ATTUNE_ACTION_ prefix
|
||||
for (key, value) in parameters {
|
||||
let value_str = match value {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Number(n) => n.to_string(),
|
||||
serde_json::Value::Bool(b) => b.to_string(),
|
||||
_ => serde_json::to_string(value)?,
|
||||
};
|
||||
cmd.env(format!("ATTUNE_ACTION_{}", key.to_uppercase()), value_str);
|
||||
}
|
||||
|
||||
// Configure stdio
|
||||
cmd.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
// Spawn process
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.map_err(|e| RuntimeError::ExecutionFailed(format!("Failed to spawn binary: {}", e)))?;
|
||||
|
||||
// Write secrets to stdin - if this fails, the process has already started
|
||||
// so we should continue and capture whatever output we can
|
||||
let stdin_write_error = if !secrets.is_empty() {
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
match serde_json::to_string(secrets) {
|
||||
Ok(secrets_json) => {
|
||||
if let Err(e) = stdin.write_all(secrets_json.as_bytes()).await {
|
||||
Some(format!("Failed to write secrets to stdin: {}", e))
|
||||
} else if let Err(e) = stdin.shutdown().await {
|
||||
Some(format!("Failed to close stdin: {}", e))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => Some(format!("Failed to serialize secrets: {}", e)),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
if let Some(stdin) = child.stdin.take() {
|
||||
drop(stdin); // Close stdin if no secrets
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
// Capture stdout and stderr with size limits
|
||||
let stdout_handle = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| RuntimeError::ProcessError("Failed to capture stdout".to_string()))?;
|
||||
let stderr_handle = child
|
||||
.stderr
|
||||
.take()
|
||||
.ok_or_else(|| RuntimeError::ProcessError("Failed to capture stderr".to_string()))?;
|
||||
|
||||
let mut stdout_writer = BoundedLogWriter::new_stdout(max_stdout_bytes);
|
||||
let mut stderr_writer = BoundedLogWriter::new_stderr(max_stderr_bytes);
|
||||
|
||||
// Create buffered readers
|
||||
let mut stdout_reader = BufReader::new(stdout_handle);
|
||||
let mut stderr_reader = BufReader::new(stderr_handle);
|
||||
|
||||
// Stream both outputs concurrently
|
||||
let stdout_task = async {
|
||||
let mut line = Vec::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match stdout_reader.read_until(b'\n', &mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
if stdout_writer.write_all(&line).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
stdout_writer
|
||||
};
|
||||
|
||||
let stderr_task = async {
|
||||
let mut line = Vec::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match stderr_reader.read_until(b'\n', &mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
if stderr_writer.write_all(&line).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
stderr_writer
|
||||
};
|
||||
|
||||
// Wait for both streams to complete
|
||||
let (stdout_writer, stderr_writer) = tokio::join!(stdout_task, stderr_task);
|
||||
|
||||
// Wait for process with timeout
|
||||
let wait_result = if let Some(timeout_secs) = exec_timeout {
|
||||
match timeout(Duration::from_secs(timeout_secs), child.wait()).await {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
warn!(
|
||||
"Native binary execution timed out after {} seconds",
|
||||
timeout_secs
|
||||
);
|
||||
let _ = child.kill().await;
|
||||
return Err(RuntimeError::Timeout(timeout_secs));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
child.wait().await
|
||||
};
|
||||
|
||||
let status = wait_result.map_err(|e| {
|
||||
RuntimeError::ExecutionFailed(format!("Failed to wait for process: {}", e))
|
||||
})?;
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
let exit_code = status.code().unwrap_or(-1);
|
||||
|
||||
// Extract logs with truncation info
|
||||
let stdout_log = stdout_writer.into_result();
|
||||
let stderr_log = stderr_writer.into_result();
|
||||
|
||||
debug!(
|
||||
"Native binary completed with exit code {} in {}ms",
|
||||
exit_code, duration_ms
|
||||
);
|
||||
|
||||
if stdout_log.truncated {
|
||||
warn!(
|
||||
"stdout truncated: {} bytes over limit",
|
||||
stdout_log.bytes_truncated
|
||||
);
|
||||
}
|
||||
if stderr_log.truncated {
|
||||
warn!(
|
||||
"stderr truncated: {} bytes over limit",
|
||||
stderr_log.bytes_truncated
|
||||
);
|
||||
}
|
||||
|
||||
// Parse result from stdout if successful
|
||||
let result = if exit_code == 0 {
|
||||
serde_json::from_str(&stdout_log.content).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Determine error message
|
||||
let error = if exit_code != 0 {
|
||||
Some(format!(
|
||||
"Native binary exited with code {}: {}",
|
||||
exit_code,
|
||||
stderr_log.content.trim()
|
||||
))
|
||||
} else if let Some(stdin_err) = stdin_write_error {
|
||||
// Ignore broken pipe errors for fast-exiting successful actions
|
||||
// These occur when the process exits before we finish writing secrets to stdin
|
||||
let is_broken_pipe =
|
||||
stdin_err.contains("Broken pipe") || stdin_err.contains("os error 32");
|
||||
let is_fast_exit = duration_ms < 500;
|
||||
let is_success = exit_code == 0;
|
||||
|
||||
if is_broken_pipe && is_fast_exit && is_success {
|
||||
debug!(
|
||||
"Ignoring broken pipe error for fast-exiting successful action ({}ms)",
|
||||
duration_ms
|
||||
);
|
||||
None
|
||||
} else {
|
||||
Some(stdin_err)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ExecutionResult {
|
||||
exit_code,
|
||||
stdout: stdout_log.content,
|
||||
stderr: stderr_log.content,
|
||||
result,
|
||||
duration_ms,
|
||||
error,
|
||||
stdout_truncated: stdout_log.truncated,
|
||||
stderr_truncated: stderr_log.truncated,
|
||||
stdout_bytes_truncated: stdout_log.bytes_truncated,
|
||||
stderr_bytes_truncated: stderr_log.bytes_truncated,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NativeRuntime {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Runtime for NativeRuntime {
|
||||
fn name(&self) -> &str {
|
||||
"native"
|
||||
}
|
||||
|
||||
fn can_execute(&self, context: &ExecutionContext) -> bool {
|
||||
// Check if runtime_name is explicitly set to "native"
|
||||
if let Some(ref runtime_name) = context.runtime_name {
|
||||
return runtime_name.to_lowercase() == "native";
|
||||
}
|
||||
|
||||
// Otherwise, check if code_path points to an executable binary
|
||||
// This is a heuristic - native binaries typically don't have common script extensions
|
||||
if let Some(ref code_path) = context.code_path {
|
||||
let extension = code_path.extension().and_then(|e| e.to_str()).unwrap_or("");
|
||||
|
||||
// Exclude common script extensions
|
||||
let is_script = matches!(
|
||||
extension,
|
||||
"py" | "js" | "sh" | "bash" | "rb" | "pl" | "php" | "lua"
|
||||
);
|
||||
|
||||
// If it's not a script and the file exists, it might be a native binary
|
||||
!is_script && code_path.exists()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, context: ExecutionContext) -> RuntimeResult<ExecutionResult> {
|
||||
info!(
|
||||
"Executing native action: {} (execution_id: {})",
|
||||
context.action_ref, context.execution_id
|
||||
);
|
||||
|
||||
// Get the binary path
|
||||
let binary_path = context.code_path.ok_or_else(|| {
|
||||
RuntimeError::InvalidAction("Native runtime requires code_path to be set".to_string())
|
||||
})?;
|
||||
|
||||
self.execute_binary(
|
||||
binary_path,
|
||||
&context.parameters,
|
||||
&context.secrets,
|
||||
&context.env,
|
||||
context.timeout,
|
||||
context.max_stdout_bytes,
|
||||
context.max_stderr_bytes,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn setup(&self) -> RuntimeResult<()> {
|
||||
info!("Setting up Native runtime");
|
||||
|
||||
// Verify we can execute native binaries (basic check)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::process::Command;
|
||||
let output = Command::new("uname").arg("-s").output().map_err(|e| {
|
||||
RuntimeError::SetupError(format!("Failed to verify native runtime: {}", e))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(RuntimeError::SetupError(
|
||||
"Failed to execute native commands".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Native runtime setup complete");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cleanup(&self) -> RuntimeResult<()> {
|
||||
info!("Cleaning up Native runtime");
|
||||
// No cleanup needed for native runtime
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate(&self) -> RuntimeResult<()> {
|
||||
debug!("Validating Native runtime");
|
||||
|
||||
// Basic validation - ensure we can execute commands
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::process::Command;
|
||||
Command::new("echo").arg("test").output().map_err(|e| {
|
||||
RuntimeError::SetupError(format!("Native runtime validation failed: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_native_runtime_name() {
|
||||
let runtime = NativeRuntime::new();
|
||||
assert_eq!(runtime.name(), "native");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_native_runtime_can_execute() {
|
||||
let runtime = NativeRuntime::new();
|
||||
|
||||
// Test with explicit runtime_name
|
||||
let mut context = ExecutionContext::test_context("test.action".to_string(), None);
|
||||
context.runtime_name = Some("native".to_string());
|
||||
assert!(runtime.can_execute(&context));
|
||||
|
||||
// Test with uppercase runtime_name
|
||||
context.runtime_name = Some("NATIVE".to_string());
|
||||
assert!(runtime.can_execute(&context));
|
||||
|
||||
// Test with wrong runtime_name
|
||||
context.runtime_name = Some("python".to_string());
|
||||
assert!(!runtime.can_execute(&context));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_native_runtime_setup() {
|
||||
let runtime = NativeRuntime::new();
|
||||
let result = runtime.setup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_native_runtime_validate() {
|
||||
let runtime = NativeRuntime::new();
|
||||
let result = runtime.validate().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn test_native_runtime_execute_simple() {
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use tempfile::TempDir;
|
||||
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let binary_path = temp_dir.path().join("test_binary.sh");
|
||||
|
||||
// Create a simple shell script as our "binary"
|
||||
fs::write(
|
||||
&binary_path,
|
||||
"#!/bin/bash\necho 'Hello from native runtime'",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Make it executable
|
||||
let metadata = fs::metadata(&binary_path).unwrap();
|
||||
let mut permissions = metadata.permissions();
|
||||
permissions.set_mode(0o755);
|
||||
fs::set_permissions(&binary_path, permissions).unwrap();
|
||||
|
||||
let runtime = NativeRuntime::new();
|
||||
let mut context = ExecutionContext::test_context("test.native".to_string(), None);
|
||||
context.code_path = Some(binary_path);
|
||||
context.runtime_name = Some("native".to_string());
|
||||
|
||||
let result = runtime.execute(context).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let exec_result = result.unwrap();
|
||||
assert_eq!(exec_result.exit_code, 0);
|
||||
assert!(exec_result.stdout.contains("Hello from native runtime"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_native_runtime_missing_binary() {
|
||||
let runtime = NativeRuntime::new();
|
||||
let mut context = ExecutionContext::test_context("test.native".to_string(), None);
|
||||
context.code_path = Some(std::path::PathBuf::from("/nonexistent/binary"));
|
||||
context.runtime_name = Some("native".to_string());
|
||||
|
||||
let result = runtime.execute(context).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
RuntimeError::ExecutionFailed(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_native_runtime_no_code_path() {
|
||||
let runtime = NativeRuntime::new();
|
||||
let mut context = ExecutionContext::test_context("test.native".to_string(), None);
|
||||
context.runtime_name = Some("native".to_string());
|
||||
// code_path is None
|
||||
|
||||
let result = runtime.execute(context).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
RuntimeError::InvalidAction(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
752
crates/worker/src/runtime/python.rs
Normal file
752
crates/worker/src/runtime/python.rs
Normal file
@@ -0,0 +1,752 @@
|
||||
//! Python Runtime Implementation
|
||||
//!
|
||||
//! Executes Python actions using subprocess execution.
|
||||
|
||||
use super::{
|
||||
BoundedLogWriter, DependencyManagerRegistry, DependencySpec, ExecutionContext, ExecutionResult,
|
||||
Runtime, RuntimeError, RuntimeResult,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Python runtime for executing Python scripts and functions
|
||||
pub struct PythonRuntime {
|
||||
/// Python interpreter path (fallback when no venv exists)
|
||||
python_path: PathBuf,
|
||||
|
||||
/// Base directory for storing action code
|
||||
work_dir: PathBuf,
|
||||
|
||||
/// Optional dependency manager registry for isolated environments
|
||||
dependency_manager: Option<Arc<DependencyManagerRegistry>>,
|
||||
}
|
||||
|
||||
impl PythonRuntime {
|
||||
/// Create a new Python runtime
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
python_path: PathBuf::from("python3"),
|
||||
work_dir: PathBuf::from("/tmp/attune/actions"),
|
||||
dependency_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Python runtime with custom settings
|
||||
pub fn with_config(python_path: PathBuf, work_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
python_path,
|
||||
work_dir,
|
||||
dependency_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Python runtime with dependency manager support
|
||||
pub fn with_dependency_manager(
|
||||
python_path: PathBuf,
|
||||
work_dir: PathBuf,
|
||||
dependency_manager: Arc<DependencyManagerRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
python_path,
|
||||
work_dir,
|
||||
dependency_manager: Some(dependency_manager),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the Python executable path to use for a given context
|
||||
///
|
||||
/// If the action has a pack_ref with dependencies, use the venv Python.
|
||||
/// Otherwise, use the default Python interpreter.
|
||||
async fn get_python_executable(&self, context: &ExecutionContext) -> RuntimeResult<PathBuf> {
|
||||
// Check if we have a dependency manager and can extract pack_ref
|
||||
if let Some(ref dep_mgr) = self.dependency_manager {
|
||||
// Extract pack_ref from action_ref (format: "pack_ref.action_name")
|
||||
if let Some(pack_ref) = context.action_ref.split('.').next() {
|
||||
// Try to get the executable path for this pack
|
||||
match dep_mgr.get_executable_path(pack_ref, "python").await {
|
||||
Ok(python_path) => {
|
||||
debug!(
|
||||
"Using pack-specific Python from venv: {}",
|
||||
python_path.display()
|
||||
);
|
||||
return Ok(python_path);
|
||||
}
|
||||
Err(e) => {
|
||||
// Venv doesn't exist or failed - this is OK if pack has no dependencies
|
||||
debug!(
|
||||
"No venv found for pack {} ({}), using default Python",
|
||||
pack_ref, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default Python interpreter
|
||||
debug!("Using default Python interpreter: {:?}", self.python_path);
|
||||
Ok(self.python_path.clone())
|
||||
}
|
||||
|
||||
/// Generate Python wrapper script that loads parameters and executes the action
|
||||
fn generate_wrapper_script(&self, context: &ExecutionContext) -> RuntimeResult<String> {
|
||||
let params_json = serde_json::to_string(&context.parameters)?;
|
||||
|
||||
// Use base64 encoding for code to avoid any quote/escape issues
|
||||
let code_bytes = context.code.as_deref().unwrap_or("").as_bytes();
|
||||
let code_base64 =
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, code_bytes);
|
||||
|
||||
let wrapper = format!(
|
||||
r#"#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
# Global secrets storage (read from stdin, NOT from environment)
|
||||
_attune_secrets = {{}}
|
||||
|
||||
def get_secret(name):
|
||||
"""
|
||||
Get a secret value by name.
|
||||
|
||||
Secrets are passed securely via stdin and are never exposed in
|
||||
environment variables or process listings.
|
||||
|
||||
Args:
|
||||
name (str): The name of the secret to retrieve
|
||||
|
||||
Returns:
|
||||
str: The secret value, or None if not found
|
||||
"""
|
||||
return _attune_secrets.get(name)
|
||||
|
||||
def main():
|
||||
global _attune_secrets
|
||||
|
||||
try:
|
||||
# Read secrets from stdin FIRST (before executing action code)
|
||||
# This prevents secrets from being visible in process environment
|
||||
secrets_line = sys.stdin.readline().strip()
|
||||
if secrets_line:
|
||||
_attune_secrets = json.loads(secrets_line)
|
||||
|
||||
# Parse parameters
|
||||
parameters = json.loads('''{}''')
|
||||
|
||||
# Decode action code from base64 (avoids quote/escape issues)
|
||||
action_code = base64.b64decode('{}').decode('utf-8')
|
||||
|
||||
# Execute the code in a controlled namespace
|
||||
# Include get_secret helper function
|
||||
namespace = {{
|
||||
'__name__': '__main__',
|
||||
'parameters': parameters,
|
||||
'get_secret': get_secret
|
||||
}}
|
||||
exec(action_code, namespace)
|
||||
|
||||
# Look for main function or run function
|
||||
if '{}' in namespace:
|
||||
result = namespace['{}'](**parameters)
|
||||
elif 'run' in namespace:
|
||||
result = namespace['run'](**parameters)
|
||||
elif 'main' in namespace:
|
||||
result = namespace['main'](**parameters)
|
||||
else:
|
||||
# No entry point found, return the namespace (only JSON-serializable values)
|
||||
def is_json_serializable(obj):
|
||||
"""Check if an object is JSON serializable"""
|
||||
if obj is None:
|
||||
return True
|
||||
if isinstance(obj, (bool, int, float, str)):
|
||||
return True
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return all(is_json_serializable(item) for item in obj)
|
||||
if isinstance(obj, dict):
|
||||
return all(is_json_serializable(k) and is_json_serializable(v)
|
||||
for k, v in obj.items())
|
||||
return False
|
||||
|
||||
result = {{k: v for k, v in namespace.items()
|
||||
if not k.startswith('__') and is_json_serializable(v)}}
|
||||
|
||||
# Output result as JSON
|
||||
if result is not None:
|
||||
print(json.dumps({{'result': result, 'status': 'success'}}))
|
||||
else:
|
||||
print(json.dumps({{'status': 'success'}}))
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
except Exception as e:
|
||||
error_info = {{
|
||||
'status': 'error',
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__,
|
||||
'traceback': traceback.format_exc()
|
||||
}}
|
||||
print(json.dumps(error_info), file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
"#,
|
||||
params_json, code_base64, context.entry_point, context.entry_point
|
||||
);
|
||||
|
||||
Ok(wrapper)
|
||||
}
|
||||
|
||||
/// Execute with streaming and bounded log collection
|
||||
async fn execute_with_streaming(
|
||||
&self,
|
||||
mut cmd: Command,
|
||||
secrets: &std::collections::HashMap<String, String>,
|
||||
timeout_secs: Option<u64>,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
) -> RuntimeResult<ExecutionResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Spawn process with piped I/O
|
||||
let mut child = cmd
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
// Write secrets to stdin
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
let secrets_json = serde_json::to_string(secrets)?;
|
||||
stdin.write_all(secrets_json.as_bytes()).await?;
|
||||
stdin.write_all(b"\n").await?;
|
||||
drop(stdin);
|
||||
}
|
||||
|
||||
// Create bounded writers
|
||||
let mut stdout_writer = BoundedLogWriter::new_stdout(max_stdout_bytes);
|
||||
let mut stderr_writer = BoundedLogWriter::new_stderr(max_stderr_bytes);
|
||||
|
||||
// Take stdout and stderr streams
|
||||
let stdout = child.stdout.take().expect("stdout not captured");
|
||||
let stderr = child.stderr.take().expect("stderr not captured");
|
||||
|
||||
// Create buffered readers
|
||||
let mut stdout_reader = BufReader::new(stdout);
|
||||
let mut stderr_reader = BufReader::new(stderr);
|
||||
|
||||
// Stream both outputs concurrently
|
||||
let stdout_task = async {
|
||||
let mut line = Vec::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match stdout_reader.read_until(b'\n', &mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
if stdout_writer.write_all(&line).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
stdout_writer
|
||||
};
|
||||
|
||||
let stderr_task = async {
|
||||
let mut line = Vec::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match stderr_reader.read_until(b'\n', &mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
if stderr_writer.write_all(&line).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
stderr_writer
|
||||
};
|
||||
|
||||
// Wait for both streams and the process
|
||||
let (stdout_writer, stderr_writer, wait_result) =
|
||||
tokio::join!(stdout_task, stderr_task, async {
|
||||
if let Some(timeout_secs) = timeout_secs {
|
||||
timeout(std::time::Duration::from_secs(timeout_secs), child.wait()).await
|
||||
} else {
|
||||
Ok(child.wait().await)
|
||||
}
|
||||
});
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
// Handle timeout
|
||||
let status = match wait_result {
|
||||
Ok(Ok(status)) => status,
|
||||
Ok(Err(e)) => {
|
||||
return Err(RuntimeError::ProcessError(format!(
|
||||
"Process wait failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
Err(_) => {
|
||||
return Ok(ExecutionResult {
|
||||
exit_code: -1,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
result: None,
|
||||
duration_ms,
|
||||
error: Some(format!(
|
||||
"Execution timed out after {} seconds",
|
||||
timeout_secs.unwrap()
|
||||
)),
|
||||
stdout_truncated: false,
|
||||
stderr_truncated: false,
|
||||
stdout_bytes_truncated: 0,
|
||||
stderr_bytes_truncated: 0,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Get results from bounded writers
|
||||
let stdout_result = stdout_writer.into_result();
|
||||
let stderr_result = stderr_writer.into_result();
|
||||
|
||||
let exit_code = status.code().unwrap_or(-1);
|
||||
|
||||
debug!(
|
||||
"Python execution completed: exit_code={}, duration={}ms, stdout_truncated={}, stderr_truncated={}",
|
||||
exit_code, duration_ms, stdout_result.truncated, stderr_result.truncated
|
||||
);
|
||||
|
||||
// Try to parse result from stdout
|
||||
let result = if exit_code == 0 {
|
||||
stdout_result
|
||||
.content
|
||||
.lines()
|
||||
.last()
|
||||
.and_then(|line| serde_json::from_str(line).ok())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ExecutionResult {
|
||||
exit_code,
|
||||
stdout: stdout_result.content.clone(),
|
||||
stderr: stderr_result.content.clone(),
|
||||
result,
|
||||
duration_ms,
|
||||
error: if exit_code != 0 {
|
||||
Some(stderr_result.content)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
stdout_truncated: stdout_result.truncated,
|
||||
stderr_truncated: stderr_result.truncated,
|
||||
stdout_bytes_truncated: stdout_result.bytes_truncated,
|
||||
stderr_bytes_truncated: stderr_result.bytes_truncated,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_python_code(
|
||||
&self,
|
||||
script: String,
|
||||
secrets: &std::collections::HashMap<String, String>,
|
||||
env: &std::collections::HashMap<String, String>,
|
||||
timeout_secs: Option<u64>,
|
||||
python_path: PathBuf,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
) -> RuntimeResult<ExecutionResult> {
|
||||
debug!(
|
||||
"Executing Python script with {} secrets (passed via stdin)",
|
||||
secrets.len()
|
||||
);
|
||||
|
||||
// Build command
|
||||
let mut cmd = Command::new(&python_path);
|
||||
cmd.arg("-c").arg(&script);
|
||||
|
||||
// Add environment variables
|
||||
for (key, value) in env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
self.execute_with_streaming(
|
||||
cmd,
|
||||
secrets,
|
||||
timeout_secs,
|
||||
max_stdout_bytes,
|
||||
max_stderr_bytes,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute Python script from file
|
||||
async fn execute_python_file(
|
||||
&self,
|
||||
code_path: PathBuf,
|
||||
secrets: &std::collections::HashMap<String, String>,
|
||||
env: &std::collections::HashMap<String, String>,
|
||||
timeout_secs: Option<u64>,
|
||||
python_path: PathBuf,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
) -> RuntimeResult<ExecutionResult> {
|
||||
debug!(
|
||||
"Executing Python file: {:?} with {} secrets",
|
||||
code_path,
|
||||
secrets.len()
|
||||
);
|
||||
|
||||
// Build command
|
||||
let mut cmd = Command::new(&python_path);
|
||||
cmd.arg(&code_path);
|
||||
|
||||
// Add environment variables
|
||||
for (key, value) in env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
self.execute_with_streaming(
|
||||
cmd,
|
||||
secrets,
|
||||
timeout_secs,
|
||||
max_stdout_bytes,
|
||||
max_stderr_bytes,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PythonRuntime {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PythonRuntime {
|
||||
/// Ensure pack dependencies are installed (called before execution if needed)
|
||||
///
|
||||
/// This is a helper method that can be called by the worker service to ensure
|
||||
/// a pack's Python dependencies are set up before executing actions.
|
||||
pub async fn ensure_pack_dependencies(
|
||||
&self,
|
||||
pack_ref: &str,
|
||||
spec: &DependencySpec,
|
||||
) -> RuntimeResult<()> {
|
||||
if let Some(ref dep_mgr) = self.dependency_manager {
|
||||
if spec.has_dependencies() {
|
||||
info!(
|
||||
"Ensuring Python dependencies for pack: {} ({} dependencies)",
|
||||
pack_ref,
|
||||
spec.dependencies.len()
|
||||
);
|
||||
|
||||
dep_mgr
|
||||
.ensure_environment(pack_ref, spec)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
RuntimeError::SetupError(format!(
|
||||
"Failed to setup Python environment for {}: {}",
|
||||
pack_ref, e
|
||||
))
|
||||
})?;
|
||||
|
||||
info!("Python dependencies ready for pack: {}", pack_ref);
|
||||
} else {
|
||||
debug!("Pack {} has no Python dependencies", pack_ref);
|
||||
}
|
||||
} else {
|
||||
warn!("Dependency manager not configured, skipping dependency isolation");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Runtime for PythonRuntime {
|
||||
fn name(&self) -> &str {
|
||||
"python"
|
||||
}
|
||||
|
||||
fn can_execute(&self, context: &ExecutionContext) -> bool {
|
||||
// Check if action reference suggests Python
|
||||
let is_python = context.action_ref.contains(".py")
|
||||
|| context.entry_point.ends_with(".py")
|
||||
|| context
|
||||
.code_path
|
||||
.as_ref()
|
||||
.map(|p| p.extension().and_then(|e| e.to_str()) == Some("py"))
|
||||
.unwrap_or(false);
|
||||
|
||||
is_python
|
||||
}
|
||||
|
||||
async fn execute(&self, context: ExecutionContext) -> RuntimeResult<ExecutionResult> {
|
||||
info!(
|
||||
"Executing Python action: {} (execution_id: {})",
|
||||
context.action_ref, context.execution_id
|
||||
);
|
||||
|
||||
// Get the appropriate Python executable (venv or default)
|
||||
let python_path = self.get_python_executable(&context).await?;
|
||||
|
||||
// If code_path is provided, execute the file directly
|
||||
if let Some(code_path) = &context.code_path {
|
||||
return self
|
||||
.execute_python_file(
|
||||
code_path.clone(),
|
||||
&context.secrets,
|
||||
&context.env,
|
||||
context.timeout,
|
||||
python_path,
|
||||
context.max_stdout_bytes,
|
||||
context.max_stderr_bytes,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Otherwise, generate wrapper script and execute
|
||||
let script = self.generate_wrapper_script(&context)?;
|
||||
self.execute_python_code(
|
||||
script,
|
||||
&context.secrets,
|
||||
&context.env,
|
||||
context.timeout,
|
||||
python_path,
|
||||
context.max_stdout_bytes,
|
||||
context.max_stderr_bytes,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn setup(&self) -> RuntimeResult<()> {
|
||||
info!("Setting up Python runtime");
|
||||
|
||||
// Ensure work directory exists
|
||||
tokio::fs::create_dir_all(&self.work_dir)
|
||||
.await
|
||||
.map_err(|e| RuntimeError::SetupError(format!("Failed to create work dir: {}", e)))?;
|
||||
|
||||
// Verify Python is available
|
||||
let output = Command::new(&self.python_path)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
RuntimeError::SetupError(format!(
|
||||
"Python not found at {:?}: {}",
|
||||
self.python_path, e
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(RuntimeError::SetupError(
|
||||
"Python interpreter is not working".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let version = String::from_utf8_lossy(&output.stdout);
|
||||
info!("Python runtime ready: {}", version.trim());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cleanup(&self) -> RuntimeResult<()> {
|
||||
info!("Cleaning up Python runtime");
|
||||
// Could clean up temporary files here
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate(&self) -> RuntimeResult<()> {
|
||||
debug!("Validating Python runtime");
|
||||
|
||||
// Check if Python is available
|
||||
let output = Command::new(&self.python_path)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| RuntimeError::SetupError(format!("Python validation failed: {}", e)))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(RuntimeError::SetupError(
|
||||
"Python interpreter validation failed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_runtime_simple() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 1,
|
||||
action_ref: "test.simple".to_string(),
|
||||
parameters: {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("x".to_string(), serde_json::json!(5));
|
||||
map.insert("y".to_string(), serde_json::json!(10));
|
||||
map
|
||||
},
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run(x, y):
|
||||
return x + y
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success());
|
||||
assert_eq!(result.exit_code, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_runtime_timeout() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 2,
|
||||
action_ref: "test.timeout".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(1),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
import time
|
||||
def run():
|
||||
time.sleep(10)
|
||||
return "done"
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(!result.is_success());
|
||||
assert!(result.error.is_some());
|
||||
let error_msg = result.error.unwrap();
|
||||
assert!(error_msg.contains("timeout") || error_msg.contains("timed out"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_runtime_error() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 3,
|
||||
action_ref: "test.error".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run():
|
||||
raise ValueError("Test error")
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(!result.is_success());
|
||||
assert!(result.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_runtime_with_secrets() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 4,
|
||||
action_ref: "test.secrets".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: {
|
||||
let mut s = HashMap::new();
|
||||
s.insert("api_key".to_string(), "secret_key_12345".to_string());
|
||||
s.insert("db_password".to_string(), "super_secret_pass".to_string());
|
||||
s
|
||||
},
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run():
|
||||
# Access secrets via get_secret() helper
|
||||
api_key = get_secret('api_key')
|
||||
db_pass = get_secret('db_password')
|
||||
missing = get_secret('nonexistent')
|
||||
|
||||
return {
|
||||
'api_key': api_key,
|
||||
'db_pass': db_pass,
|
||||
'missing': missing
|
||||
}
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success());
|
||||
assert_eq!(result.exit_code, 0);
|
||||
|
||||
// Verify secrets are accessible in action code
|
||||
let result_data = result.result.unwrap();
|
||||
let result_obj = result_data.get("result").unwrap();
|
||||
assert_eq!(result_obj.get("api_key").unwrap(), "secret_key_12345");
|
||||
assert_eq!(result_obj.get("db_pass").unwrap(), "super_secret_pass");
|
||||
assert_eq!(result_obj.get("missing"), Some(&serde_json::Value::Null));
|
||||
}
|
||||
}
|
||||
653
crates/worker/src/runtime/python_venv.rs
Normal file
653
crates/worker/src/runtime/python_venv.rs
Normal file
@@ -0,0 +1,653 @@
|
||||
//! Python Virtual Environment Manager
|
||||
//!
|
||||
//! Manages isolated Python virtual environments for packs with Python dependencies.
|
||||
//! Each pack gets its own venv to prevent dependency conflicts.
|
||||
|
||||
use super::dependency::{
|
||||
DependencyError, DependencyManager, DependencyResult, DependencySpec, EnvironmentInfo,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Stdio;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Python virtual environment manager
|
||||
pub struct PythonVenvManager {
|
||||
/// Base directory for all virtual environments
|
||||
base_dir: PathBuf,
|
||||
|
||||
/// Python interpreter to use for creating venvs
|
||||
python_path: PathBuf,
|
||||
|
||||
/// Cache of environment info
|
||||
env_cache: tokio::sync::RwLock<HashMap<String, EnvironmentInfo>>,
|
||||
}
|
||||
|
||||
/// Metadata stored with each environment
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct VenvMetadata {
|
||||
pack_ref: String,
|
||||
dependencies: Vec<String>,
|
||||
created_at: chrono::DateTime<chrono::Utc>,
|
||||
updated_at: chrono::DateTime<chrono::Utc>,
|
||||
python_version: String,
|
||||
dependency_hash: String,
|
||||
}
|
||||
|
||||
impl PythonVenvManager {
|
||||
/// Create a new Python venv manager
|
||||
pub fn new(base_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
base_dir,
|
||||
python_path: PathBuf::from("python3"),
|
||||
env_cache: tokio::sync::RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Python venv manager with custom Python path
|
||||
pub fn with_python_path(base_dir: PathBuf, python_path: PathBuf) -> Self {
|
||||
Self {
|
||||
base_dir,
|
||||
python_path,
|
||||
env_cache: tokio::sync::RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the directory path for a pack's venv
|
||||
fn get_venv_path(&self, pack_ref: &str) -> PathBuf {
|
||||
// Sanitize pack_ref to create a valid directory name
|
||||
let safe_name = pack_ref.replace(['/', '\\', '.'], "_");
|
||||
self.base_dir.join(safe_name)
|
||||
}
|
||||
|
||||
/// Get the Python executable path within a venv
|
||||
fn get_venv_python(&self, venv_path: &Path) -> PathBuf {
|
||||
if cfg!(windows) {
|
||||
venv_path.join("Scripts").join("python.exe")
|
||||
} else {
|
||||
venv_path.join("bin").join("python")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the pip executable path within a venv
|
||||
fn get_venv_pip(&self, venv_path: &Path) -> PathBuf {
|
||||
if cfg!(windows) {
|
||||
venv_path.join("Scripts").join("pip.exe")
|
||||
} else {
|
||||
venv_path.join("bin").join("pip")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the metadata file path for a venv
|
||||
fn get_metadata_path(&self, venv_path: &Path) -> PathBuf {
|
||||
venv_path.join("attune_metadata.json")
|
||||
}
|
||||
|
||||
/// Calculate a hash of dependencies for change detection
|
||||
fn calculate_dependency_hash(&self, spec: &DependencySpec) -> String {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
|
||||
// Sort dependencies for consistent hashing
|
||||
let mut deps = spec.dependencies.clone();
|
||||
deps.sort();
|
||||
|
||||
for dep in &deps {
|
||||
dep.hash(&mut hasher);
|
||||
}
|
||||
|
||||
if let Some(ref content) = spec.requirements_file_content {
|
||||
content.hash(&mut hasher);
|
||||
}
|
||||
|
||||
format!("{:x}", hasher.finish())
|
||||
}
|
||||
|
||||
/// Create a new virtual environment
|
||||
async fn create_venv(&self, venv_path: &Path) -> DependencyResult<()> {
|
||||
info!(
|
||||
"Creating Python virtual environment at: {}",
|
||||
venv_path.display()
|
||||
);
|
||||
|
||||
// Ensure base directory exists
|
||||
if let Some(parent) = venv_path.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
// Create venv using python -m venv
|
||||
let output = Command::new(&self.python_path)
|
||||
.arg("-m")
|
||||
.arg("venv")
|
||||
.arg(venv_path)
|
||||
.arg("--clear") // Clear if exists
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DependencyError::CreateEnvironmentFailed(format!(
|
||||
"Failed to spawn venv command: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(DependencyError::CreateEnvironmentFailed(format!(
|
||||
"venv creation failed: {}",
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
|
||||
// Upgrade pip to latest version
|
||||
let pip_path = self.get_venv_pip(venv_path);
|
||||
let output = Command::new(&pip_path)
|
||||
.arg("install")
|
||||
.arg("--upgrade")
|
||||
.arg("pip")
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| DependencyError::InstallFailed(format!("Failed to upgrade pip: {}", e)))?;
|
||||
|
||||
if !output.status.success() {
|
||||
warn!("Failed to upgrade pip, continuing anyway");
|
||||
}
|
||||
|
||||
info!("Virtual environment created successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Install dependencies in a venv
|
||||
async fn install_dependencies(
|
||||
&self,
|
||||
venv_path: &Path,
|
||||
spec: &DependencySpec,
|
||||
) -> DependencyResult<()> {
|
||||
if !spec.has_dependencies() {
|
||||
debug!("No dependencies to install");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Installing dependencies in venv: {}", venv_path.display());
|
||||
|
||||
let pip_path = self.get_venv_pip(venv_path);
|
||||
|
||||
// Install from requirements file content if provided
|
||||
if let Some(ref requirements_content) = spec.requirements_file_content {
|
||||
let req_file = venv_path.join("requirements.txt");
|
||||
fs::write(&req_file, requirements_content).await?;
|
||||
|
||||
let output = Command::new(&pip_path)
|
||||
.arg("install")
|
||||
.arg("-r")
|
||||
.arg(&req_file)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DependencyError::InstallFailed(format!(
|
||||
"Failed to install from requirements.txt: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(DependencyError::InstallFailed(format!(
|
||||
"pip install failed: {}",
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
|
||||
info!("Dependencies installed from requirements.txt");
|
||||
} else if !spec.dependencies.is_empty() {
|
||||
// Install individual dependencies
|
||||
let output = Command::new(&pip_path)
|
||||
.arg("install")
|
||||
.args(&spec.dependencies)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DependencyError::InstallFailed(format!("Failed to install dependencies: {}", e))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(DependencyError::InstallFailed(format!(
|
||||
"pip install failed: {}",
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
|
||||
info!("Installed {} dependencies", spec.dependencies.len());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get Python version from a venv
|
||||
async fn get_python_version(&self, venv_path: &Path) -> DependencyResult<String> {
|
||||
let python_path = self.get_venv_python(venv_path);
|
||||
|
||||
let output = Command::new(&python_path)
|
||||
.arg("--version")
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DependencyError::ProcessError(format!("Failed to get Python version: {}", e))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(DependencyError::ProcessError(
|
||||
"Failed to get Python version".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let version = String::from_utf8_lossy(&output.stdout);
|
||||
Ok(version.trim().to_string())
|
||||
}
|
||||
|
||||
/// List installed packages in a venv
|
||||
async fn list_installed_packages(&self, venv_path: &Path) -> DependencyResult<Vec<String>> {
|
||||
let pip_path = self.get_venv_pip(venv_path);
|
||||
|
||||
let output = Command::new(&pip_path)
|
||||
.arg("list")
|
||||
.arg("--format=freeze")
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DependencyError::ProcessError(format!("Failed to list packages: {}", e))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let packages = String::from_utf8_lossy(&output.stdout)
|
||||
.lines()
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
|
||||
Ok(packages)
|
||||
}
|
||||
|
||||
/// Save metadata for a venv
|
||||
async fn save_metadata(
|
||||
&self,
|
||||
venv_path: &Path,
|
||||
metadata: &VenvMetadata,
|
||||
) -> DependencyResult<()> {
|
||||
let metadata_path = self.get_metadata_path(venv_path);
|
||||
let json = serde_json::to_string_pretty(metadata).map_err(|e| {
|
||||
DependencyError::LockFileError(format!("Failed to serialize metadata: {}", e))
|
||||
})?;
|
||||
|
||||
let mut file = fs::File::create(&metadata_path).await.map_err(|e| {
|
||||
DependencyError::LockFileError(format!("Failed to create metadata file: {}", e))
|
||||
})?;
|
||||
|
||||
file.write_all(json.as_bytes()).await.map_err(|e| {
|
||||
DependencyError::LockFileError(format!("Failed to write metadata: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load metadata for a venv
|
||||
async fn load_metadata(&self, venv_path: &Path) -> DependencyResult<Option<VenvMetadata>> {
|
||||
let metadata_path = self.get_metadata_path(venv_path);
|
||||
|
||||
if !metadata_path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&metadata_path).await.map_err(|e| {
|
||||
DependencyError::LockFileError(format!("Failed to read metadata: {}", e))
|
||||
})?;
|
||||
|
||||
let metadata: VenvMetadata = serde_json::from_str(&content).map_err(|e| {
|
||||
DependencyError::LockFileError(format!("Failed to parse metadata: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Some(metadata))
|
||||
}
|
||||
|
||||
/// Check if a venv exists and is valid
|
||||
async fn is_valid_venv(&self, venv_path: &Path) -> bool {
|
||||
if !venv_path.exists() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let python_path = self.get_venv_python(venv_path);
|
||||
if !python_path.exists() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to run python --version to verify it works
|
||||
let result = Command::new(&python_path)
|
||||
.arg("--version")
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.await;
|
||||
|
||||
matches!(result, Ok(status) if status.success())
|
||||
}
|
||||
|
||||
/// Build environment info from a venv
|
||||
async fn build_env_info(
|
||||
&self,
|
||||
pack_ref: &str,
|
||||
venv_path: &Path,
|
||||
) -> DependencyResult<EnvironmentInfo> {
|
||||
let is_valid = self.is_valid_venv(venv_path).await;
|
||||
let python_path = self.get_venv_python(venv_path);
|
||||
|
||||
let (python_version, installed_deps, created_at, updated_at) = if is_valid {
|
||||
let version = self
|
||||
.get_python_version(venv_path)
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown".to_string());
|
||||
let deps = self
|
||||
.list_installed_packages(venv_path)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let metadata = self.load_metadata(venv_path).await.ok().flatten();
|
||||
let created = metadata
|
||||
.as_ref()
|
||||
.map(|m| m.created_at)
|
||||
.unwrap_or_else(chrono::Utc::now);
|
||||
let updated = metadata
|
||||
.as_ref()
|
||||
.map(|m| m.updated_at)
|
||||
.unwrap_or_else(chrono::Utc::now);
|
||||
|
||||
(version, deps, created, updated)
|
||||
} else {
|
||||
(
|
||||
"Unknown".to_string(),
|
||||
Vec::new(),
|
||||
chrono::Utc::now(),
|
||||
chrono::Utc::now(),
|
||||
)
|
||||
};
|
||||
|
||||
Ok(EnvironmentInfo {
|
||||
id: pack_ref.to_string(),
|
||||
path: venv_path.to_path_buf(),
|
||||
runtime: "python".to_string(),
|
||||
runtime_version: python_version,
|
||||
installed_dependencies: installed_deps,
|
||||
created_at,
|
||||
updated_at,
|
||||
is_valid,
|
||||
executable_path: python_path,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DependencyManager for PythonVenvManager {
|
||||
fn runtime_type(&self) -> &str {
|
||||
"python"
|
||||
}
|
||||
|
||||
async fn ensure_environment(
|
||||
&self,
|
||||
pack_ref: &str,
|
||||
spec: &DependencySpec,
|
||||
) -> DependencyResult<EnvironmentInfo> {
|
||||
info!("Ensuring Python environment for pack: {}", pack_ref);
|
||||
|
||||
let venv_path = self.get_venv_path(pack_ref);
|
||||
let dependency_hash = self.calculate_dependency_hash(spec);
|
||||
|
||||
// Check if environment exists and is up to date
|
||||
if venv_path.exists() {
|
||||
if let Some(metadata) = self.load_metadata(&venv_path).await? {
|
||||
if metadata.dependency_hash == dependency_hash
|
||||
&& self.is_valid_venv(&venv_path).await
|
||||
{
|
||||
debug!("Using existing venv (dependencies unchanged)");
|
||||
let env_info = self.build_env_info(pack_ref, &venv_path).await?;
|
||||
|
||||
// Update cache
|
||||
let mut cache = self.env_cache.write().await;
|
||||
cache.insert(pack_ref.to_string(), env_info.clone());
|
||||
|
||||
return Ok(env_info);
|
||||
}
|
||||
info!("Dependencies changed or venv invalid, recreating environment");
|
||||
}
|
||||
}
|
||||
|
||||
// Create or recreate the venv
|
||||
self.create_venv(&venv_path).await?;
|
||||
|
||||
// Install dependencies
|
||||
self.install_dependencies(&venv_path, spec).await?;
|
||||
|
||||
// Get Python version
|
||||
let python_version = self.get_python_version(&venv_path).await?;
|
||||
|
||||
// Save metadata
|
||||
let metadata = VenvMetadata {
|
||||
pack_ref: pack_ref.to_string(),
|
||||
dependencies: spec.dependencies.clone(),
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
python_version: python_version.clone(),
|
||||
dependency_hash,
|
||||
};
|
||||
self.save_metadata(&venv_path, &metadata).await?;
|
||||
|
||||
// Build environment info
|
||||
let env_info = self.build_env_info(pack_ref, &venv_path).await?;
|
||||
|
||||
// Update cache
|
||||
let mut cache = self.env_cache.write().await;
|
||||
cache.insert(pack_ref.to_string(), env_info.clone());
|
||||
|
||||
info!("Python environment ready for pack: {}", pack_ref);
|
||||
Ok(env_info)
|
||||
}
|
||||
|
||||
async fn get_environment(&self, pack_ref: &str) -> DependencyResult<Option<EnvironmentInfo>> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.env_cache.read().await;
|
||||
if let Some(env_info) = cache.get(pack_ref) {
|
||||
return Ok(Some(env_info.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
let venv_path = self.get_venv_path(pack_ref);
|
||||
if !venv_path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let env_info = self.build_env_info(pack_ref, &venv_path).await?;
|
||||
|
||||
// Update cache
|
||||
let mut cache = self.env_cache.write().await;
|
||||
cache.insert(pack_ref.to_string(), env_info.clone());
|
||||
|
||||
Ok(Some(env_info))
|
||||
}
|
||||
|
||||
async fn remove_environment(&self, pack_ref: &str) -> DependencyResult<()> {
|
||||
info!("Removing Python environment for pack: {}", pack_ref);
|
||||
|
||||
let venv_path = self.get_venv_path(pack_ref);
|
||||
if venv_path.exists() {
|
||||
fs::remove_dir_all(&venv_path).await?;
|
||||
}
|
||||
|
||||
// Remove from cache
|
||||
let mut cache = self.env_cache.write().await;
|
||||
cache.remove(pack_ref);
|
||||
|
||||
info!("Environment removed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate_environment(&self, pack_ref: &str) -> DependencyResult<bool> {
|
||||
let venv_path = self.get_venv_path(pack_ref);
|
||||
Ok(self.is_valid_venv(&venv_path).await)
|
||||
}
|
||||
|
||||
async fn get_executable_path(&self, pack_ref: &str) -> DependencyResult<PathBuf> {
|
||||
let venv_path = self.get_venv_path(pack_ref);
|
||||
let python_path = self.get_venv_python(&venv_path);
|
||||
|
||||
if !python_path.exists() {
|
||||
return Err(DependencyError::EnvironmentNotFound(format!(
|
||||
"Python executable not found for pack: {}",
|
||||
pack_ref
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(python_path)
|
||||
}
|
||||
|
||||
async fn list_environments(&self) -> DependencyResult<Vec<EnvironmentInfo>> {
|
||||
let mut environments = Vec::new();
|
||||
|
||||
let mut entries = fs::read_dir(&self.base_dir).await?;
|
||||
while let Some(entry) = entries.next_entry().await? {
|
||||
if entry.file_type().await?.is_dir() {
|
||||
let venv_path = entry.path();
|
||||
if self.is_valid_venv(&venv_path).await {
|
||||
// Extract pack_ref from directory name
|
||||
if let Some(dir_name) = venv_path.file_name().and_then(|n| n.to_str()) {
|
||||
if let Ok(env_info) = self.build_env_info(dir_name, &venv_path).await {
|
||||
environments.push(env_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(environments)
|
||||
}
|
||||
|
||||
async fn cleanup(&self, keep_recent: usize) -> DependencyResult<Vec<String>> {
|
||||
info!(
|
||||
"Cleaning up Python virtual environments (keeping {} most recent)",
|
||||
keep_recent
|
||||
);
|
||||
|
||||
let mut environments = self.list_environments().await?;
|
||||
|
||||
// Sort by updated_at, newest first
|
||||
environments.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
|
||||
|
||||
let mut removed = Vec::new();
|
||||
|
||||
// Remove environments beyond keep_recent threshold
|
||||
for env in environments.iter().skip(keep_recent) {
|
||||
// Also skip if environment is invalid
|
||||
if !env.is_valid {
|
||||
if let Err(e) = self.remove_environment(&env.id).await {
|
||||
warn!("Failed to remove environment {}: {}", env.id, e);
|
||||
} else {
|
||||
removed.push(env.id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Cleaned up {} environments", removed.len());
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
async fn needs_update(&self, pack_ref: &str, spec: &DependencySpec) -> DependencyResult<bool> {
|
||||
let venv_path = self.get_venv_path(pack_ref);
|
||||
|
||||
if !venv_path.exists() {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
if !self.is_valid_venv(&venv_path).await {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Check if dependency hash matches
|
||||
if let Some(metadata) = self.load_metadata(&venv_path).await? {
|
||||
let current_hash = self.calculate_dependency_hash(spec);
|
||||
Ok(metadata.dependency_hash != current_hash)
|
||||
} else {
|
||||
// No metadata, assume needs update
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_venv_path_sanitization() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let path = manager.get_venv_path("core.http");
|
||||
assert!(path.to_string_lossy().contains("core_http"));
|
||||
|
||||
let path = manager.get_venv_path("my/pack");
|
||||
assert!(path.to_string_lossy().contains("my_pack"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dependency_hash_consistency() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec1 = DependencySpec::new("python")
|
||||
.with_dependency("requests==2.28.0")
|
||||
.with_dependency("flask==2.0.0");
|
||||
|
||||
let spec2 = DependencySpec::new("python")
|
||||
.with_dependency("flask==2.0.0")
|
||||
.with_dependency("requests==2.28.0");
|
||||
|
||||
// Hashes should be the same regardless of order (we sort)
|
||||
let hash1 = manager.calculate_dependency_hash(&spec1);
|
||||
let hash2 = manager.calculate_dependency_hash(&spec2);
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dependency_hash_different() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec1 = DependencySpec::new("python").with_dependency("requests==2.28.0");
|
||||
|
||||
let spec2 = DependencySpec::new("python").with_dependency("requests==2.29.0");
|
||||
|
||||
let hash1 = manager.calculate_dependency_hash(&spec1);
|
||||
let hash2 = manager.calculate_dependency_hash(&spec2);
|
||||
assert_ne!(hash1, hash2);
|
||||
}
|
||||
}
|
||||
672
crates/worker/src/runtime/shell.rs
Normal file
672
crates/worker/src/runtime/shell.rs
Normal file
@@ -0,0 +1,672 @@
|
||||
//! Shell Runtime Implementation
|
||||
//!
|
||||
//! Executes shell scripts and commands using subprocess execution.
|
||||
|
||||
use super::{
|
||||
BoundedLogWriter, ExecutionContext, ExecutionResult, Runtime, RuntimeError, RuntimeResult,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Shell runtime for executing shell scripts and commands
|
||||
pub struct ShellRuntime {
|
||||
/// Shell interpreter path (bash, sh, zsh, etc.)
|
||||
shell_path: PathBuf,
|
||||
|
||||
/// Base directory for storing action code
|
||||
work_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl ShellRuntime {
|
||||
/// Create a new Shell runtime with bash
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
shell_path: PathBuf::from("/bin/bash"),
|
||||
work_dir: PathBuf::from("/tmp/attune/actions"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Shell runtime with custom shell
|
||||
pub fn with_shell(shell_path: PathBuf) -> Self {
|
||||
Self {
|
||||
shell_path,
|
||||
work_dir: PathBuf::from("/tmp/attune/actions"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Shell runtime with custom settings
|
||||
pub fn with_config(shell_path: PathBuf, work_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
shell_path,
|
||||
work_dir,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute with streaming and bounded log collection
|
||||
async fn execute_with_streaming(
|
||||
&self,
|
||||
mut cmd: Command,
|
||||
secrets: &std::collections::HashMap<String, String>,
|
||||
timeout_secs: Option<u64>,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
) -> RuntimeResult<ExecutionResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Spawn process with piped I/O
|
||||
let mut child = cmd
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
// Write secrets to stdin - if this fails, the process has already started
|
||||
// so we should continue and capture whatever output we can
|
||||
let stdin_write_error = if let Some(mut stdin) = child.stdin.take() {
|
||||
match serde_json::to_string(secrets) {
|
||||
Ok(secrets_json) => {
|
||||
if let Err(e) = stdin.write_all(secrets_json.as_bytes()).await {
|
||||
Some(format!("Failed to write secrets to stdin: {}", e))
|
||||
} else if let Err(e) = stdin.write_all(b"\n").await {
|
||||
Some(format!("Failed to write newline to stdin: {}", e))
|
||||
} else {
|
||||
drop(stdin);
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => Some(format!("Failed to serialize secrets: {}", e)),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create bounded writers
|
||||
let mut stdout_writer = BoundedLogWriter::new_stdout(max_stdout_bytes);
|
||||
let mut stderr_writer = BoundedLogWriter::new_stderr(max_stderr_bytes);
|
||||
|
||||
// Take stdout and stderr streams
|
||||
let stdout = child.stdout.take().expect("stdout not captured");
|
||||
let stderr = child.stderr.take().expect("stderr not captured");
|
||||
|
||||
// Create buffered readers
|
||||
let mut stdout_reader = BufReader::new(stdout);
|
||||
let mut stderr_reader = BufReader::new(stderr);
|
||||
|
||||
// Stream both outputs concurrently
|
||||
let stdout_task = async {
|
||||
let mut line = Vec::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match stdout_reader.read_until(b'\n', &mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
if stdout_writer.write_all(&line).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
stdout_writer
|
||||
};
|
||||
|
||||
let stderr_task = async {
|
||||
let mut line = Vec::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match stderr_reader.read_until(b'\n', &mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
if stderr_writer.write_all(&line).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
stderr_writer
|
||||
};
|
||||
|
||||
// Wait for both streams and the process
|
||||
let (stdout_writer, stderr_writer, wait_result) =
|
||||
tokio::join!(stdout_task, stderr_task, async {
|
||||
if let Some(timeout_secs) = timeout_secs {
|
||||
timeout(std::time::Duration::from_secs(timeout_secs), child.wait()).await
|
||||
} else {
|
||||
Ok(child.wait().await)
|
||||
}
|
||||
});
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
// Get results from bounded writers - we have these regardless of wait() success
|
||||
let stdout_result = stdout_writer.into_result();
|
||||
let stderr_result = stderr_writer.into_result();
|
||||
|
||||
// Handle process wait result
|
||||
let (exit_code, process_error) = match wait_result {
|
||||
Ok(Ok(status)) => (status.code().unwrap_or(-1), None),
|
||||
Ok(Err(e)) => {
|
||||
// Process wait failed, but we have the output - return it with an error
|
||||
warn!("Process wait failed but captured output: {}", e);
|
||||
(-1, Some(format!("Process wait failed: {}", e)))
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout occurred
|
||||
return Ok(ExecutionResult {
|
||||
exit_code: -1,
|
||||
stdout: stdout_result.content.clone(),
|
||||
stderr: stderr_result.content.clone(),
|
||||
result: None,
|
||||
duration_ms,
|
||||
error: Some(format!(
|
||||
"Execution timed out after {} seconds",
|
||||
timeout_secs.unwrap()
|
||||
)),
|
||||
stdout_truncated: stdout_result.truncated,
|
||||
stderr_truncated: stderr_result.truncated,
|
||||
stdout_bytes_truncated: stdout_result.bytes_truncated,
|
||||
stderr_bytes_truncated: stderr_result.bytes_truncated,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Shell execution completed: exit_code={}, duration={}ms, stdout_truncated={}, stderr_truncated={}",
|
||||
exit_code, duration_ms, stdout_result.truncated, stderr_result.truncated
|
||||
);
|
||||
|
||||
// Try to parse result from stdout as JSON
|
||||
let result = if exit_code == 0 && !stdout_result.content.trim().is_empty() {
|
||||
stdout_result
|
||||
.content
|
||||
.trim()
|
||||
.lines()
|
||||
.last()
|
||||
.and_then(|line| serde_json::from_str(line).ok())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Determine error message
|
||||
let error = if let Some(proc_err) = process_error {
|
||||
Some(proc_err)
|
||||
} else if let Some(stdin_err) = stdin_write_error {
|
||||
// Ignore broken pipe errors for fast-exiting successful actions
|
||||
// These occur when the process exits before we finish writing secrets to stdin
|
||||
let is_broken_pipe =
|
||||
stdin_err.contains("Broken pipe") || stdin_err.contains("os error 32");
|
||||
let is_fast_exit = duration_ms < 500;
|
||||
let is_success = exit_code == 0;
|
||||
|
||||
if is_broken_pipe && is_fast_exit && is_success {
|
||||
debug!(
|
||||
"Ignoring broken pipe error for fast-exiting successful action ({}ms)",
|
||||
duration_ms
|
||||
);
|
||||
None
|
||||
} else {
|
||||
Some(stdin_err)
|
||||
}
|
||||
} else if exit_code != 0 {
|
||||
Some(if stderr_result.content.is_empty() {
|
||||
format!("Command exited with code {}", exit_code)
|
||||
} else {
|
||||
// Use last line of stderr as error, or full stderr if short
|
||||
if stderr_result.content.lines().count() > 5 {
|
||||
stderr_result
|
||||
.content
|
||||
.lines()
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else {
|
||||
stderr_result.content.clone()
|
||||
}
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ExecutionResult {
|
||||
exit_code,
|
||||
stdout: stdout_result.content.clone(),
|
||||
stderr: stderr_result.content.clone(),
|
||||
result,
|
||||
duration_ms,
|
||||
error,
|
||||
stdout_truncated: stdout_result.truncated,
|
||||
stderr_truncated: stderr_result.truncated,
|
||||
stdout_bytes_truncated: stdout_result.bytes_truncated,
|
||||
stderr_bytes_truncated: stderr_result.bytes_truncated,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate shell wrapper script that injects parameters as environment variables
|
||||
fn generate_wrapper_script(&self, context: &ExecutionContext) -> RuntimeResult<String> {
|
||||
let mut script = String::new();
|
||||
|
||||
// Add shebang
|
||||
script.push_str("#!/bin/bash\n");
|
||||
script.push_str("set -e\n\n"); // Exit on error
|
||||
|
||||
// Read secrets from stdin and store in associative array
|
||||
script.push_str("# Read secrets from stdin (passed securely, not via environment)\n");
|
||||
script.push_str("declare -A ATTUNE_SECRETS\n");
|
||||
script.push_str("read -r ATTUNE_SECRETS_JSON\n");
|
||||
script.push_str("if [ -n \"$ATTUNE_SECRETS_JSON\" ]; then\n");
|
||||
script.push_str(" # Parse JSON secrets using Python (always available)\n");
|
||||
script.push_str(" eval \"$(echo \"$ATTUNE_SECRETS_JSON\" | python3 -c \"\n");
|
||||
script.push_str("import sys, json\n");
|
||||
script.push_str("try:\n");
|
||||
script.push_str(" secrets = json.load(sys.stdin)\n");
|
||||
script.push_str(" for key, value in secrets.items():\n");
|
||||
script.push_str(" # Escape single quotes in value\n");
|
||||
script.push_str(
|
||||
" safe_value = value.replace(\\\"'\\\", \\\"'\\\\\\\\\\\\\\\\'\\\") \n",
|
||||
);
|
||||
script.push_str(" print(f\\\"ATTUNE_SECRETS['{key}']='{safe_value}'\\\")\n");
|
||||
script.push_str("except: pass\n");
|
||||
script.push_str("\")\"\n");
|
||||
script.push_str("fi\n\n");
|
||||
|
||||
// Helper function to get secrets
|
||||
script.push_str("# Helper function to access secrets\n");
|
||||
script.push_str("get_secret() {\n");
|
||||
script.push_str(" local name=\"$1\"\n");
|
||||
script.push_str(" echo \"${ATTUNE_SECRETS[$name]}\"\n");
|
||||
script.push_str("}\n\n");
|
||||
|
||||
// Export parameters as environment variables
|
||||
script.push_str("# Action parameters\n");
|
||||
for (key, value) in &context.parameters {
|
||||
let value_str = match value {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Number(n) => n.to_string(),
|
||||
serde_json::Value::Bool(b) => b.to_string(),
|
||||
_ => serde_json::to_string(value)?,
|
||||
};
|
||||
// Export with PARAM_ prefix for consistency
|
||||
script.push_str(&format!(
|
||||
"export PARAM_{}='{}'\n",
|
||||
key.to_uppercase(),
|
||||
value_str
|
||||
));
|
||||
// Also export without prefix for easier shell script writing
|
||||
script.push_str(&format!("export {}='{}'\n", key, value_str));
|
||||
}
|
||||
script.push_str("\n");
|
||||
|
||||
// Add the action code
|
||||
script.push_str("# Action code\n");
|
||||
if let Some(code) = &context.code {
|
||||
script.push_str(code);
|
||||
}
|
||||
|
||||
Ok(script)
|
||||
}
|
||||
|
||||
/// Execute shell script directly
|
||||
async fn execute_shell_code(
|
||||
&self,
|
||||
script: String,
|
||||
secrets: &std::collections::HashMap<String, String>,
|
||||
env: &std::collections::HashMap<String, String>,
|
||||
timeout_secs: Option<u64>,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
) -> RuntimeResult<ExecutionResult> {
|
||||
debug!(
|
||||
"Executing shell script with {} secrets (passed via stdin)",
|
||||
secrets.len()
|
||||
);
|
||||
|
||||
// Build command
|
||||
let mut cmd = Command::new(&self.shell_path);
|
||||
cmd.arg("-c").arg(&script);
|
||||
|
||||
// Add environment variables
|
||||
for (key, value) in env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
self.execute_with_streaming(
|
||||
cmd,
|
||||
secrets,
|
||||
timeout_secs,
|
||||
max_stdout_bytes,
|
||||
max_stderr_bytes,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute shell script from file
|
||||
async fn execute_shell_file(
|
||||
&self,
|
||||
code_path: PathBuf,
|
||||
secrets: &std::collections::HashMap<String, String>,
|
||||
env: &std::collections::HashMap<String, String>,
|
||||
timeout_secs: Option<u64>,
|
||||
max_stdout_bytes: usize,
|
||||
max_stderr_bytes: usize,
|
||||
) -> RuntimeResult<ExecutionResult> {
|
||||
debug!(
|
||||
"Executing shell file: {:?} with {} secrets",
|
||||
code_path,
|
||||
secrets.len()
|
||||
);
|
||||
|
||||
// Build command
|
||||
let mut cmd = Command::new(&self.shell_path);
|
||||
cmd.arg(&code_path);
|
||||
|
||||
// Add environment variables
|
||||
for (key, value) in env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
self.execute_with_streaming(
|
||||
cmd,
|
||||
secrets,
|
||||
timeout_secs,
|
||||
max_stdout_bytes,
|
||||
max_stderr_bytes,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ShellRuntime {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Runtime for ShellRuntime {
|
||||
fn name(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
|
||||
fn can_execute(&self, context: &ExecutionContext) -> bool {
|
||||
// Check if action reference suggests shell script
|
||||
let is_shell = context.action_ref.contains(".sh")
|
||||
|| context.entry_point.ends_with(".sh")
|
||||
|| context
|
||||
.code_path
|
||||
.as_ref()
|
||||
.map(|p| p.extension().and_then(|e| e.to_str()) == Some("sh"))
|
||||
.unwrap_or(false)
|
||||
|| context.entry_point == "bash"
|
||||
|| context.entry_point == "sh"
|
||||
|| context.entry_point == "shell";
|
||||
|
||||
is_shell
|
||||
}
|
||||
|
||||
async fn execute(&self, context: ExecutionContext) -> RuntimeResult<ExecutionResult> {
|
||||
info!(
|
||||
"Executing shell action: {} (execution_id: {})",
|
||||
context.action_ref, context.execution_id
|
||||
);
|
||||
|
||||
// If code_path is provided, execute the file directly
|
||||
if let Some(code_path) = &context.code_path {
|
||||
// Merge parameters into environment variables with ATTUNE_ACTION_ prefix
|
||||
let mut env = context.env.clone();
|
||||
for (key, value) in &context.parameters {
|
||||
let value_str = match value {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Number(n) => n.to_string(),
|
||||
serde_json::Value::Bool(b) => b.to_string(),
|
||||
_ => serde_json::to_string(value)?,
|
||||
};
|
||||
env.insert(format!("ATTUNE_ACTION_{}", key.to_uppercase()), value_str);
|
||||
}
|
||||
|
||||
return self
|
||||
.execute_shell_file(
|
||||
code_path.clone(),
|
||||
&context.secrets,
|
||||
&env,
|
||||
context.timeout,
|
||||
context.max_stdout_bytes,
|
||||
context.max_stderr_bytes,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Otherwise, generate wrapper script and execute
|
||||
let script = self.generate_wrapper_script(&context)?;
|
||||
self.execute_shell_code(
|
||||
script,
|
||||
&context.secrets,
|
||||
&context.env,
|
||||
context.timeout,
|
||||
context.max_stdout_bytes,
|
||||
context.max_stderr_bytes,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn setup(&self) -> RuntimeResult<()> {
|
||||
info!("Setting up Shell runtime");
|
||||
|
||||
// Ensure work directory exists
|
||||
tokio::fs::create_dir_all(&self.work_dir)
|
||||
.await
|
||||
.map_err(|e| RuntimeError::SetupError(format!("Failed to create work dir: {}", e)))?;
|
||||
|
||||
// Verify shell is available
|
||||
let output = Command::new(&self.shell_path)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
RuntimeError::SetupError(format!("Shell not found at {:?}: {}", self.shell_path, e))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(RuntimeError::SetupError(
|
||||
"Shell interpreter is not working".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let version = String::from_utf8_lossy(&output.stdout);
|
||||
info!("Shell runtime ready: {}", version.trim());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cleanup(&self) -> RuntimeResult<()> {
|
||||
info!("Cleaning up Shell runtime");
|
||||
// Could clean up temporary files here
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate(&self) -> RuntimeResult<()> {
|
||||
debug!("Validating Shell runtime");
|
||||
|
||||
// Check if shell is available
|
||||
let output = Command::new(&self.shell_path)
|
||||
.arg("-c")
|
||||
.arg("echo 'test'")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| RuntimeError::SetupError(format!("Shell validation failed: {}", e)))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(RuntimeError::SetupError(
|
||||
"Shell interpreter validation failed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_runtime_simple() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 1,
|
||||
action_ref: "test.simple".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some("echo 'Hello, World!'".to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success());
|
||||
assert_eq!(result.exit_code, 0);
|
||||
assert!(result.stdout.contains("Hello, World!"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_runtime_with_params() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 2,
|
||||
action_ref: "test.params".to_string(),
|
||||
parameters: {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("name".to_string(), serde_json::json!("Alice"));
|
||||
map
|
||||
},
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some("echo \"Hello, $name!\"".to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success());
|
||||
assert!(result.stdout.contains("Hello, Alice!"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_runtime_timeout() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 3,
|
||||
action_ref: "test.timeout".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(1),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some("sleep 10".to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(!result.is_success());
|
||||
assert!(result.error.is_some());
|
||||
let error_msg = result.error.unwrap();
|
||||
assert!(error_msg.contains("timeout") || error_msg.contains("timed out"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_runtime_error() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 4,
|
||||
action_ref: "test.error".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some("exit 1".to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(!result.is_success());
|
||||
assert_eq!(result.exit_code, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_runtime_with_secrets() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 5,
|
||||
action_ref: "test.secrets".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: {
|
||||
let mut s = HashMap::new();
|
||||
s.insert("api_key".to_string(), "secret_key_12345".to_string());
|
||||
s.insert("db_password".to_string(), "super_secret_pass".to_string());
|
||||
s
|
||||
},
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
# Access secrets via get_secret function
|
||||
api_key=$(get_secret 'api_key')
|
||||
db_pass=$(get_secret 'db_password')
|
||||
missing=$(get_secret 'nonexistent')
|
||||
|
||||
echo "api_key=$api_key"
|
||||
echo "db_pass=$db_pass"
|
||||
echo "missing=$missing"
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success());
|
||||
assert_eq!(result.exit_code, 0);
|
||||
|
||||
// Verify secrets are accessible in action code
|
||||
assert!(result.stdout.contains("api_key=secret_key_12345"));
|
||||
assert!(result.stdout.contains("db_pass=super_secret_pass"));
|
||||
assert!(result.stdout.contains("missing="));
|
||||
}
|
||||
}
|
||||
386
crates/worker/src/secrets.rs
Normal file
386
crates/worker/src/secrets.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
//! Secret Management Module
|
||||
//!
|
||||
//! Handles fetching, decrypting, and injecting secrets into execution environments.
|
||||
//! Secrets are stored encrypted in the database and decrypted on-demand for execution.
|
||||
|
||||
use aes_gcm::{
|
||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
||||
Aes256Gcm, Key as AesKey, Nonce,
|
||||
};
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::{key::Key, Action, OwnerType};
|
||||
use attune_common::repositories::key::KeyRepository;
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Secret manager for handling secret operations
|
||||
pub struct SecretManager {
|
||||
pool: PgPool,
|
||||
encryption_key: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl SecretManager {
|
||||
/// Create a new secret manager
|
||||
pub fn new(pool: PgPool, encryption_key: Option<String>) -> Result<Self> {
|
||||
let encryption_key = encryption_key.map(|key| Self::derive_key(&key));
|
||||
|
||||
if encryption_key.is_none() {
|
||||
warn!("No encryption key configured - encrypted secrets will fail to decrypt");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
pool,
|
||||
encryption_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Derive encryption key from password/key string
|
||||
fn derive_key(key: &str) -> Vec<u8> {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
/// Fetch all secrets relevant to an action execution
|
||||
///
|
||||
/// Secrets are fetched in order of precedence:
|
||||
/// 1. System-level secrets (owner_type='system')
|
||||
/// 2. Pack-level secrets (owner_type='pack')
|
||||
/// 3. Action-level secrets (owner_type='action')
|
||||
///
|
||||
/// More specific secrets override less specific ones with the same name.
|
||||
pub async fn fetch_secrets_for_action(
|
||||
&self,
|
||||
action: &Action,
|
||||
) -> Result<HashMap<String, String>> {
|
||||
debug!("Fetching secrets for action: {}", action.r#ref);
|
||||
|
||||
let mut secrets = HashMap::new();
|
||||
|
||||
// 1. Fetch system-level secrets
|
||||
let system_secrets = self.fetch_secrets_by_owner_type(OwnerType::System).await?;
|
||||
for secret in system_secrets {
|
||||
let value = self.decrypt_if_needed(&secret)?;
|
||||
secrets.insert(secret.name.clone(), value);
|
||||
}
|
||||
debug!("Loaded {} system secrets", secrets.len());
|
||||
|
||||
// 2. Fetch pack-level secrets
|
||||
let pack_secrets = self.fetch_secrets_by_pack(action.pack).await?;
|
||||
for secret in pack_secrets {
|
||||
let value = self.decrypt_if_needed(&secret)?;
|
||||
secrets.insert(secret.name.clone(), value);
|
||||
}
|
||||
debug!("Loaded {} pack secrets", secrets.len());
|
||||
|
||||
// 3. Fetch action-level secrets
|
||||
let action_secrets = self.fetch_secrets_by_action(action.id).await?;
|
||||
for secret in action_secrets {
|
||||
let value = self.decrypt_if_needed(&secret)?;
|
||||
secrets.insert(secret.name.clone(), value);
|
||||
}
|
||||
debug!("Total secrets loaded: {}", secrets.len());
|
||||
|
||||
Ok(secrets)
|
||||
}
|
||||
|
||||
/// Fetch secrets by owner type
|
||||
async fn fetch_secrets_by_owner_type(&self, owner_type: OwnerType) -> Result<Vec<Key>> {
|
||||
KeyRepository::find_by_owner_type(&self.pool, owner_type).await
|
||||
}
|
||||
|
||||
/// Fetch secrets for a specific pack
|
||||
async fn fetch_secrets_by_pack(&self, pack_id: i64) -> Result<Vec<Key>> {
|
||||
sqlx::query_as::<_, Key>(
|
||||
"SELECT id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref,
|
||||
owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted,
|
||||
encryption_key_hash, value, created, updated
|
||||
FROM key
|
||||
WHERE owner_type = $1 AND owner_pack = $2
|
||||
ORDER BY name ASC",
|
||||
)
|
||||
.bind(OwnerType::Pack)
|
||||
.bind(pack_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Fetch secrets for a specific action
|
||||
async fn fetch_secrets_by_action(&self, action_id: i64) -> Result<Vec<Key>> {
|
||||
sqlx::query_as::<_, Key>(
|
||||
"SELECT id, ref, owner_type, owner, owner_identity, owner_pack, owner_pack_ref,
|
||||
owner_action, owner_action_ref, owner_sensor, owner_sensor_ref, name, encrypted,
|
||||
encryption_key_hash, value, created, updated
|
||||
FROM key
|
||||
WHERE owner_type = $1 AND owner_action = $2
|
||||
ORDER BY name ASC",
|
||||
)
|
||||
.bind(OwnerType::Action)
|
||||
.bind(action_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Decrypt a secret if it's encrypted, otherwise return the value as-is
|
||||
fn decrypt_if_needed(&self, key: &Key) -> Result<String> {
|
||||
if !key.encrypted {
|
||||
return Ok(key.value.clone());
|
||||
}
|
||||
|
||||
// Encrypted secret requires encryption key
|
||||
let encryption_key = self
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::Internal("No encryption key configured".to_string()))?;
|
||||
|
||||
// Verify encryption key hash if present
|
||||
if let Some(expected_hash) = &key.encryption_key_hash {
|
||||
let actual_hash = Self::compute_key_hash_from_bytes(encryption_key);
|
||||
if &actual_hash != expected_hash {
|
||||
return Err(Error::Internal(format!(
|
||||
"Encryption key hash mismatch for secret '{}'",
|
||||
key.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Self::decrypt_value(&key.value, encryption_key)
|
||||
}
|
||||
|
||||
/// Decrypt an encrypted value
|
||||
///
|
||||
/// Format: "nonce:ciphertext" (both base64-encoded)
|
||||
fn decrypt_value(encrypted_value: &str, key: &[u8]) -> Result<String> {
|
||||
// Parse format: "nonce:ciphertext"
|
||||
let parts: Vec<&str> = encrypted_value.split(':').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(Error::Internal(
|
||||
"Invalid encrypted value format. Expected 'nonce:ciphertext'".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let nonce_bytes = BASE64
|
||||
.decode(parts[0])
|
||||
.map_err(|e| Error::Internal(format!("Failed to decode nonce: {}", e)))?;
|
||||
|
||||
let ciphertext = BASE64
|
||||
.decode(parts[1])
|
||||
.map_err(|e| Error::Internal(format!("Failed to decode ciphertext: {}", e)))?;
|
||||
|
||||
// Create cipher
|
||||
let key_array: [u8; 32] = key
|
||||
.try_into()
|
||||
.map_err(|_| Error::Internal("Invalid key length".to_string()))?;
|
||||
let cipher_key = AesKey::<Aes256Gcm>::from_slice(&key_array);
|
||||
let cipher = Aes256Gcm::new(cipher_key);
|
||||
|
||||
// Create nonce
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
// Decrypt
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext.as_ref())
|
||||
.map_err(|e| Error::Internal(format!("Decryption failed: {}", e)))?;
|
||||
|
||||
String::from_utf8(plaintext)
|
||||
.map_err(|e| Error::Internal(format!("Invalid UTF-8 in decrypted value: {}", e)))
|
||||
}
|
||||
|
||||
/// Encrypt a value (for testing and future use)
|
||||
#[allow(dead_code)]
|
||||
pub fn encrypt_value(&self, plaintext: &str) -> Result<String> {
|
||||
let encryption_key = self
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::Internal("No encryption key configured".to_string()))?;
|
||||
|
||||
Self::encrypt_value_with_key(plaintext, encryption_key)
|
||||
}
|
||||
|
||||
/// Encrypt a value with a specific key (static method)
|
||||
fn encrypt_value_with_key(plaintext: &str, encryption_key: &[u8]) -> Result<String> {
|
||||
// Create cipher
|
||||
let key_array: [u8; 32] = encryption_key
|
||||
.try_into()
|
||||
.map_err(|_| Error::Internal("Invalid key length".to_string()))?;
|
||||
let cipher_key = AesKey::<Aes256Gcm>::from_slice(&key_array);
|
||||
let cipher = Aes256Gcm::new(cipher_key);
|
||||
|
||||
// Generate random nonce
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
|
||||
// Encrypt
|
||||
let ciphertext = cipher
|
||||
.encrypt(&nonce, plaintext.as_bytes())
|
||||
.map_err(|e| Error::Internal(format!("Encryption failed: {}", e)))?;
|
||||
|
||||
// Format: "nonce:ciphertext" (both base64-encoded)
|
||||
let nonce_b64 = BASE64.encode(&nonce);
|
||||
let ciphertext_b64 = BASE64.encode(&ciphertext);
|
||||
|
||||
Ok(format!("{}:{}", nonce_b64, ciphertext_b64))
|
||||
}
|
||||
|
||||
/// Compute hash of the encryption key
|
||||
pub fn compute_key_hash(&self) -> String {
|
||||
if let Some(key) = &self.encryption_key {
|
||||
Self::compute_key_hash_from_bytes(key)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hash from key bytes (static method)
|
||||
fn compute_key_hash_from_bytes(key: &[u8]) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key);
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Prepare secrets as environment variables
|
||||
///
|
||||
/// **DEPRECATED - SECURITY VULNERABILITY**: This method exposes secrets in the process
|
||||
/// environment, making them visible in process listings (`ps auxe`) and `/proc/[pid]/environ`.
|
||||
///
|
||||
/// Secrets should be passed via stdin instead. This method is kept only for backward
|
||||
/// compatibility and will be removed in a future version.
|
||||
///
|
||||
/// Secret names are converted to uppercase and prefixed with "SECRET_"
|
||||
/// Example: "api_key" becomes "SECRET_API_KEY"
|
||||
#[deprecated(
|
||||
since = "0.2.0",
|
||||
note = "Secrets in environment variables are insecure. Pass secrets via stdin instead."
|
||||
)]
|
||||
pub fn prepare_secret_env(&self, secrets: &HashMap<String, String>) -> HashMap<String, String> {
|
||||
secrets
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
let env_name = format!("SECRET_{}", name.to_uppercase().replace('-', "_"));
|
||||
(env_name, value.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper to derive a test encryption key
|
||||
fn derive_test_key(key: &str) -> Vec<u8> {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt_roundtrip() {
|
||||
let key = derive_test_key("test-encryption-key-12345");
|
||||
let plaintext = "my-secret-value";
|
||||
let encrypted = SecretManager::encrypt_value_with_key(plaintext, &key).unwrap();
|
||||
|
||||
// Verify format
|
||||
assert!(encrypted.contains(':'));
|
||||
let parts: Vec<&str> = encrypted.split(':').collect();
|
||||
assert_eq!(parts.len(), 2);
|
||||
|
||||
// Decrypt and verify
|
||||
let decrypted = SecretManager::decrypt_value(&encrypted, &key).unwrap();
|
||||
assert_eq!(decrypted, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt_different_values() {
|
||||
let key = derive_test_key("test-encryption-key-12345");
|
||||
|
||||
let plaintext1 = "secret1";
|
||||
let plaintext2 = "secret2";
|
||||
|
||||
let encrypted1 = SecretManager::encrypt_value_with_key(plaintext1, &key).unwrap();
|
||||
let encrypted2 = SecretManager::encrypt_value_with_key(plaintext2, &key).unwrap();
|
||||
|
||||
// Encrypted values should be different (due to random nonces)
|
||||
assert_ne!(encrypted1, encrypted2);
|
||||
|
||||
// Both should decrypt correctly
|
||||
let decrypted1 = SecretManager::decrypt_value(&encrypted1, &key).unwrap();
|
||||
let decrypted2 = SecretManager::decrypt_value(&encrypted2, &key).unwrap();
|
||||
|
||||
assert_eq!(decrypted1, plaintext1);
|
||||
assert_eq!(decrypted2, plaintext2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decrypt_with_wrong_key() {
|
||||
let key1 = derive_test_key("key1");
|
||||
let key2 = derive_test_key("key2");
|
||||
|
||||
let plaintext = "secret";
|
||||
let encrypted = SecretManager::encrypt_value_with_key(plaintext, &key1).unwrap();
|
||||
|
||||
// Decrypting with wrong key should fail
|
||||
let result = SecretManager::decrypt_value(&encrypted, &key2);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prepare_secret_env() {
|
||||
// Test the static method directly without creating a SecretManager instance
|
||||
let mut secrets = HashMap::new();
|
||||
secrets.insert("api_key".to_string(), "secret123".to_string());
|
||||
secrets.insert("db-password".to_string(), "pass456".to_string());
|
||||
secrets.insert("oauth_token".to_string(), "token789".to_string());
|
||||
|
||||
// Call prepare_secret_env as a static-like method
|
||||
let env: HashMap<String, String> = secrets
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
let env_name = format!("SECRET_{}", name.to_uppercase().replace('-', "_"));
|
||||
(env_name, value.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(env.get("SECRET_API_KEY"), Some(&"secret123".to_string()));
|
||||
assert_eq!(env.get("SECRET_DB_PASSWORD"), Some(&"pass456".to_string()));
|
||||
assert_eq!(env.get("SECRET_OAUTH_TOKEN"), Some(&"token789".to_string()));
|
||||
assert_eq!(env.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_key_hash() {
|
||||
let key1 = derive_test_key("test-key");
|
||||
let key2 = derive_test_key("test-key");
|
||||
let key3 = derive_test_key("different-key");
|
||||
|
||||
let hash1 = SecretManager::compute_key_hash_from_bytes(&key1);
|
||||
let hash2 = SecretManager::compute_key_hash_from_bytes(&key2);
|
||||
let hash3 = SecretManager::compute_key_hash_from_bytes(&key3);
|
||||
|
||||
// Same key should produce same hash
|
||||
assert_eq!(hash1, hash2);
|
||||
// Different key should produce different hash
|
||||
assert_ne!(hash1, hash3);
|
||||
// Hash should not be empty
|
||||
assert!(!hash1.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_encrypted_format() {
|
||||
let key = derive_test_key("test-key");
|
||||
|
||||
// Invalid formats should fail
|
||||
let result = SecretManager::decrypt_value("no-colon", &key);
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = SecretManager::decrypt_value("too:many:colons", &key);
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = SecretManager::decrypt_value("invalid-base64:also-invalid", &key);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
692
crates/worker/src/service.rs
Normal file
692
crates/worker/src/service.rs
Normal file
@@ -0,0 +1,692 @@
|
||||
//! Worker Service Module
|
||||
//!
|
||||
//! Main service orchestration for the Attune Worker Service.
|
||||
//! Manages worker registration, heartbeat, message consumption, and action execution.
|
||||
|
||||
use attune_common::config::Config;
|
||||
use attune_common::db::Database;
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::ExecutionStatus;
|
||||
use attune_common::mq::{
|
||||
config::MessageQueueConfig as MqConfig, Connection, Consumer, ConsumerConfig,
|
||||
ExecutionCompletedPayload, ExecutionStatusChangedPayload, MessageEnvelope, MessageType,
|
||||
Publisher, PublisherConfig, QueueConfig,
|
||||
};
|
||||
use attune_common::repositories::{execution::ExecutionRepository, FindById};
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::artifacts::ArtifactManager;
|
||||
use crate::executor::ActionExecutor;
|
||||
use crate::heartbeat::HeartbeatManager;
|
||||
use crate::registration::WorkerRegistration;
|
||||
use crate::runtime::local::LocalRuntime;
|
||||
use crate::runtime::native::NativeRuntime;
|
||||
use crate::runtime::python::PythonRuntime;
|
||||
use crate::runtime::shell::ShellRuntime;
|
||||
use crate::runtime::{DependencyManagerRegistry, PythonVenvManager, RuntimeRegistry};
|
||||
use crate::secrets::SecretManager;
|
||||
|
||||
/// Message payload for execution.scheduled events
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionScheduledPayload {
|
||||
pub execution_id: i64,
|
||||
pub action_ref: String,
|
||||
pub worker_id: i64,
|
||||
}
|
||||
|
||||
/// Worker service that manages execution lifecycle
|
||||
pub struct WorkerService {
|
||||
#[allow(dead_code)]
|
||||
config: Config,
|
||||
db_pool: PgPool,
|
||||
registration: Arc<RwLock<WorkerRegistration>>,
|
||||
heartbeat: Arc<HeartbeatManager>,
|
||||
executor: Arc<ActionExecutor>,
|
||||
mq_connection: Arc<Connection>,
|
||||
publisher: Arc<Publisher>,
|
||||
consumer: Option<Arc<Consumer>>,
|
||||
worker_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl WorkerService {
|
||||
/// Create a new worker service
|
||||
pub async fn new(config: Config) -> Result<Self> {
|
||||
info!("Initializing Worker Service");
|
||||
|
||||
// Initialize database
|
||||
let db = Database::new(&config.database).await?;
|
||||
let pool = db.pool().clone();
|
||||
info!("Database connection established");
|
||||
|
||||
// Initialize message queue connection
|
||||
let mq_url = config
|
||||
.message_queue
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::Internal("Message queue configuration is required".to_string()))?
|
||||
.url
|
||||
.as_str();
|
||||
|
||||
let mq_connection = Connection::connect(mq_url)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to connect to message queue: {}", e)))?;
|
||||
info!("Message queue connection established");
|
||||
|
||||
// Setup message queue infrastructure (exchanges, queues, bindings)
|
||||
let mq_config = MqConfig::default();
|
||||
match mq_connection.setup_infrastructure(&mq_config).await {
|
||||
Ok(_) => info!("Message queue infrastructure setup completed"),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to setup MQ infrastructure (may already exist): {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize message queue publisher
|
||||
let publisher = Publisher::new(
|
||||
&mq_connection,
|
||||
PublisherConfig {
|
||||
confirm_publish: true,
|
||||
timeout_secs: 30,
|
||||
exchange: "attune.executions".to_string(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create publisher: {}", e)))?;
|
||||
info!("Message queue publisher initialized");
|
||||
|
||||
// Initialize worker registration
|
||||
let registration = Arc::new(RwLock::new(WorkerRegistration::new(pool.clone(), &config)));
|
||||
|
||||
// Initialize artifact manager
|
||||
let artifact_base_dir = std::path::PathBuf::from(
|
||||
config
|
||||
.worker
|
||||
.as_ref()
|
||||
.and_then(|w| w.name.clone())
|
||||
.map(|name| format!("/tmp/attune/artifacts/{}", name))
|
||||
.unwrap_or_else(|| "/tmp/attune/artifacts".to_string()),
|
||||
);
|
||||
let artifact_manager = ArtifactManager::new(artifact_base_dir);
|
||||
artifact_manager.initialize().await?;
|
||||
|
||||
// Determine which runtimes to register based on configuration
|
||||
// This reads from ATTUNE_WORKER_RUNTIMES env var (highest priority)
|
||||
let configured_runtimes = if let Ok(runtimes_env) = std::env::var("ATTUNE_WORKER_RUNTIMES")
|
||||
{
|
||||
info!(
|
||||
"Registering runtimes from ATTUNE_WORKER_RUNTIMES: {}",
|
||||
runtimes_env
|
||||
);
|
||||
runtimes_env
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_lowercase())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect::<Vec<String>>()
|
||||
} else {
|
||||
// Fallback to auto-detection if not configured
|
||||
info!("No ATTUNE_WORKER_RUNTIMES found, registering all available runtimes");
|
||||
vec![
|
||||
"shell".to_string(),
|
||||
"python".to_string(),
|
||||
"native".to_string(),
|
||||
]
|
||||
};
|
||||
|
||||
info!("Configured runtimes: {:?}", configured_runtimes);
|
||||
|
||||
// Initialize dependency manager registry for isolated environments
|
||||
let mut dependency_manager_registry = DependencyManagerRegistry::new();
|
||||
|
||||
// Only setup Python virtual environment manager if Python runtime is needed
|
||||
if configured_runtimes.contains(&"python".to_string()) {
|
||||
let venv_base_dir = std::path::PathBuf::from(
|
||||
config
|
||||
.worker
|
||||
.as_ref()
|
||||
.and_then(|w| w.name.clone())
|
||||
.map(|name| format!("/tmp/attune/venvs/{}", name))
|
||||
.unwrap_or_else(|| "/tmp/attune/venvs".to_string()),
|
||||
);
|
||||
let python_venv_manager = PythonVenvManager::new(venv_base_dir);
|
||||
dependency_manager_registry.register(Box::new(python_venv_manager));
|
||||
info!("Dependency manager initialized with Python venv support");
|
||||
}
|
||||
|
||||
let dependency_manager_arc = Arc::new(dependency_manager_registry);
|
||||
|
||||
// Initialize runtime registry
|
||||
let mut runtime_registry = RuntimeRegistry::new();
|
||||
|
||||
// Register runtimes based on configuration
|
||||
for runtime_name in &configured_runtimes {
|
||||
match runtime_name.as_str() {
|
||||
"python" => {
|
||||
let python_runtime = PythonRuntime::with_dependency_manager(
|
||||
std::path::PathBuf::from("python3"),
|
||||
std::path::PathBuf::from("/tmp/attune/actions"),
|
||||
dependency_manager_arc.clone(),
|
||||
);
|
||||
runtime_registry.register(Box::new(python_runtime));
|
||||
info!("Registered Python runtime");
|
||||
}
|
||||
"shell" => {
|
||||
runtime_registry.register(Box::new(ShellRuntime::new()));
|
||||
info!("Registered Shell runtime");
|
||||
}
|
||||
"native" => {
|
||||
runtime_registry.register(Box::new(NativeRuntime::new()));
|
||||
info!("Registered Native runtime");
|
||||
}
|
||||
"node" => {
|
||||
warn!("Node.js runtime requested but not yet implemented, skipping");
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown runtime type '{}', skipping", runtime_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only register local runtime as fallback if no specific runtimes configured
|
||||
// (LocalRuntime contains Python/Shell/Native and tries to validate all)
|
||||
if configured_runtimes.is_empty() {
|
||||
let local_runtime = LocalRuntime::new();
|
||||
runtime_registry.register(Box::new(local_runtime));
|
||||
info!("Registered Local runtime (fallback)");
|
||||
}
|
||||
|
||||
// Validate all registered runtimes
|
||||
runtime_registry
|
||||
.validate_all()
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to validate runtimes: {}", e)))?;
|
||||
|
||||
info!(
|
||||
"Successfully validated runtimes: {:?}",
|
||||
runtime_registry.list_runtimes()
|
||||
);
|
||||
|
||||
// Initialize secret manager
|
||||
let encryption_key = config.security.encryption_key.clone();
|
||||
let secret_manager = SecretManager::new(pool.clone(), encryption_key)?;
|
||||
info!("Secret manager initialized");
|
||||
|
||||
// Initialize action executor
|
||||
let max_stdout_bytes = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.map(|w| w.max_stdout_bytes)
|
||||
.unwrap_or(10 * 1024 * 1024);
|
||||
let max_stderr_bytes = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.map(|w| w.max_stderr_bytes)
|
||||
.unwrap_or(10 * 1024 * 1024);
|
||||
let packs_base_dir = std::path::PathBuf::from(&config.packs_base_dir);
|
||||
let executor = Arc::new(ActionExecutor::new(
|
||||
pool.clone(),
|
||||
runtime_registry,
|
||||
artifact_manager,
|
||||
secret_manager,
|
||||
max_stdout_bytes,
|
||||
max_stderr_bytes,
|
||||
packs_base_dir,
|
||||
));
|
||||
|
||||
// Initialize heartbeat manager
|
||||
let heartbeat_interval = config
|
||||
.worker
|
||||
.as_ref()
|
||||
.map(|w| w.heartbeat_interval)
|
||||
.unwrap_or(30);
|
||||
let heartbeat = Arc::new(HeartbeatManager::new(
|
||||
registration.clone(),
|
||||
heartbeat_interval,
|
||||
));
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
db_pool: pool,
|
||||
registration,
|
||||
heartbeat,
|
||||
executor,
|
||||
mq_connection: Arc::new(mq_connection),
|
||||
publisher: Arc::new(publisher),
|
||||
consumer: None,
|
||||
worker_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Start the worker service
|
||||
pub async fn start(&mut self) -> Result<()> {
|
||||
info!("Starting Worker Service");
|
||||
|
||||
// Detect runtime capabilities and register worker
|
||||
let worker_id = {
|
||||
let mut reg = self.registration.write().await;
|
||||
reg.detect_capabilities(&self.config).await?;
|
||||
reg.register().await?
|
||||
};
|
||||
self.worker_id = Some(worker_id);
|
||||
|
||||
info!("Worker registered with ID: {}", worker_id);
|
||||
|
||||
// Start heartbeat
|
||||
self.heartbeat.start().await?;
|
||||
|
||||
// Start consuming execution messages
|
||||
self.start_execution_consumer().await?;
|
||||
|
||||
info!("Worker Service started successfully");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the worker service
|
||||
pub async fn stop(&mut self) -> Result<()> {
|
||||
info!("Stopping Worker Service");
|
||||
|
||||
// Stop heartbeat
|
||||
self.heartbeat.stop().await;
|
||||
|
||||
// Wait a bit for heartbeat to stop
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
// Deregister worker
|
||||
{
|
||||
let reg = self.registration.read().await;
|
||||
reg.deregister().await?;
|
||||
}
|
||||
|
||||
info!("Worker Service stopped");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start consuming execution.scheduled messages
|
||||
async fn start_execution_consumer(&mut self) -> Result<()> {
|
||||
let worker_id = self
|
||||
.worker_id
|
||||
.ok_or_else(|| Error::Internal("Worker not registered".to_string()))?;
|
||||
|
||||
// Create queue name for this worker
|
||||
let queue_name = format!("worker.{}.executions", worker_id);
|
||||
|
||||
info!("Creating worker-specific queue: {}", queue_name);
|
||||
|
||||
// Create the worker-specific queue
|
||||
let worker_queue = QueueConfig {
|
||||
name: queue_name.clone(),
|
||||
durable: false, // Worker queues are temporary
|
||||
exclusive: false,
|
||||
auto_delete: true, // Delete when worker disconnects
|
||||
};
|
||||
|
||||
self.mq_connection
|
||||
.declare_queue(&worker_queue)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to declare queue: {}", e)))?;
|
||||
|
||||
info!("Worker queue created: {}", queue_name);
|
||||
|
||||
// Bind the queue to the executions exchange with worker-specific routing key
|
||||
self.mq_connection
|
||||
.bind_queue(
|
||||
&queue_name,
|
||||
"attune.executions",
|
||||
&format!("worker.{}", worker_id),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to bind queue: {}", e)))?;
|
||||
|
||||
info!(
|
||||
"Queue bound to exchange with routing key 'worker.{}'",
|
||||
worker_id
|
||||
);
|
||||
|
||||
// Create consumer
|
||||
let consumer = Consumer::new(
|
||||
&self.mq_connection,
|
||||
ConsumerConfig {
|
||||
queue: queue_name.clone(),
|
||||
tag: format!("worker-{}", worker_id),
|
||||
prefetch_count: 10,
|
||||
auto_ack: false,
|
||||
exclusive: false,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to create consumer: {}", e)))?;
|
||||
|
||||
info!("Consumer started for queue: {}", queue_name);
|
||||
|
||||
info!("Message queue consumer initialized");
|
||||
|
||||
// Clone Arc references for the handler
|
||||
let executor = self.executor.clone();
|
||||
let publisher = self.publisher.clone();
|
||||
let db_pool = self.db_pool.clone();
|
||||
|
||||
// Consume messages with handler
|
||||
consumer
|
||||
.consume_with_handler(
|
||||
move |envelope: MessageEnvelope<ExecutionScheduledPayload>| {
|
||||
let executor = executor.clone();
|
||||
let publisher = publisher.clone();
|
||||
let db_pool = db_pool.clone();
|
||||
|
||||
async move {
|
||||
Self::handle_execution_scheduled(executor, publisher, db_pool, envelope)
|
||||
.await
|
||||
.map_err(|e| format!("Execution handler error: {}", e).into())
|
||||
}
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to start consumer: {}", e)))?;
|
||||
|
||||
// Store consumer reference
|
||||
self.consumer = Some(Arc::new(consumer));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle execution.scheduled message
|
||||
async fn handle_execution_scheduled(
|
||||
executor: Arc<ActionExecutor>,
|
||||
publisher: Arc<Publisher>,
|
||||
db_pool: PgPool,
|
||||
envelope: MessageEnvelope<ExecutionScheduledPayload>,
|
||||
) -> Result<()> {
|
||||
let execution_id = envelope.payload.execution_id;
|
||||
|
||||
info!(
|
||||
"Processing execution.scheduled for execution: {}",
|
||||
execution_id
|
||||
);
|
||||
|
||||
// Publish status: running
|
||||
if let Err(e) = Self::publish_status_update(
|
||||
&db_pool,
|
||||
&publisher,
|
||||
execution_id,
|
||||
ExecutionStatus::Running,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Failed to publish running status: {}", e);
|
||||
// Continue anyway - the executor will update the database
|
||||
}
|
||||
|
||||
// Execute the action
|
||||
match executor.execute(execution_id).await {
|
||||
Ok(result) => {
|
||||
info!(
|
||||
"Execution {} completed successfully in {}ms",
|
||||
execution_id, result.duration_ms
|
||||
);
|
||||
|
||||
// Publish status: completed
|
||||
if let Err(e) = Self::publish_status_update(
|
||||
&db_pool,
|
||||
&publisher,
|
||||
execution_id,
|
||||
ExecutionStatus::Completed,
|
||||
result.result.clone(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Failed to publish success status: {}", e);
|
||||
}
|
||||
|
||||
// Publish completion notification for queue management
|
||||
if let Err(e) =
|
||||
Self::publish_completion_notification(&db_pool, &publisher, execution_id).await
|
||||
{
|
||||
error!(
|
||||
"Failed to publish completion notification for execution {}: {}",
|
||||
execution_id, e
|
||||
);
|
||||
// Continue - this is important for queue management but not fatal
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Execution {} failed: {}", execution_id, e);
|
||||
|
||||
// Publish status: failed
|
||||
if let Err(e) = Self::publish_status_update(
|
||||
&db_pool,
|
||||
&publisher,
|
||||
execution_id,
|
||||
ExecutionStatus::Failed,
|
||||
None,
|
||||
Some(e.to_string()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Failed to publish failure status: {}", e);
|
||||
}
|
||||
|
||||
// Publish completion notification for queue management
|
||||
if let Err(e) =
|
||||
Self::publish_completion_notification(&db_pool, &publisher, execution_id).await
|
||||
{
|
||||
error!(
|
||||
"Failed to publish completion notification for execution {}: {}",
|
||||
execution_id, e
|
||||
);
|
||||
// Continue - this is important for queue management but not fatal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Publish execution status update
|
||||
async fn publish_status_update(
|
||||
db_pool: &PgPool,
|
||||
publisher: &Publisher,
|
||||
execution_id: i64,
|
||||
status: ExecutionStatus,
|
||||
_result: Option<serde_json::Value>,
|
||||
_error: Option<String>,
|
||||
) -> Result<()> {
|
||||
// Fetch execution to get action_ref and previous status
|
||||
let execution = ExecutionRepository::find_by_id(db_pool, execution_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
Error::Internal(format!(
|
||||
"Execution {} not found for status update",
|
||||
execution_id
|
||||
))
|
||||
})?;
|
||||
|
||||
let new_status_str = match status {
|
||||
ExecutionStatus::Running => "running",
|
||||
ExecutionStatus::Completed => "completed",
|
||||
ExecutionStatus::Failed => "failed",
|
||||
ExecutionStatus::Cancelled => "cancelled",
|
||||
ExecutionStatus::Timeout => "timeout",
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let previous_status_str = format!("{:?}", execution.status).to_lowercase();
|
||||
|
||||
let payload = ExecutionStatusChangedPayload {
|
||||
execution_id,
|
||||
action_ref: execution.action_ref,
|
||||
previous_status: previous_status_str,
|
||||
new_status: new_status_str.to_string(),
|
||||
changed_at: Utc::now(),
|
||||
};
|
||||
|
||||
let message_type = MessageType::ExecutionStatusChanged;
|
||||
|
||||
let envelope = MessageEnvelope::new(message_type, payload).with_source("worker");
|
||||
|
||||
publisher
|
||||
.publish_envelope(&envelope)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to publish status update: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Publish execution completion notification for queue management
|
||||
async fn publish_completion_notification(
|
||||
db_pool: &PgPool,
|
||||
publisher: &Publisher,
|
||||
execution_id: i64,
|
||||
) -> Result<()> {
|
||||
// Fetch execution to get action_id and other required fields
|
||||
let execution = ExecutionRepository::find_by_id(db_pool, execution_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
Error::Internal(format!(
|
||||
"Execution {} not found after completion",
|
||||
execution_id
|
||||
))
|
||||
})?;
|
||||
|
||||
// Extract action_id - it should always be present for valid executions
|
||||
let action_id = execution.action.ok_or_else(|| {
|
||||
Error::Internal(format!(
|
||||
"Execution {} has no associated action",
|
||||
execution_id
|
||||
))
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"Publishing completion notification for execution {} (action_id: {})",
|
||||
execution_id, action_id
|
||||
);
|
||||
|
||||
let payload = ExecutionCompletedPayload {
|
||||
execution_id: execution.id,
|
||||
action_id,
|
||||
action_ref: execution.action_ref.clone(),
|
||||
status: format!("{:?}", execution.status),
|
||||
result: execution.result.clone(),
|
||||
completed_at: Utc::now(),
|
||||
};
|
||||
|
||||
let envelope =
|
||||
MessageEnvelope::new(MessageType::ExecutionCompleted, payload).with_source("worker");
|
||||
|
||||
publisher.publish_envelope(&envelope).await.map_err(|e| {
|
||||
Error::Internal(format!("Failed to publish completion notification: {}", e))
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"Completion notification published for execution {}",
|
||||
execution_id
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run the worker service until interrupted
|
||||
pub async fn run(&mut self) -> Result<()> {
|
||||
self.start().await?;
|
||||
|
||||
// Wait for shutdown signal
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to wait for shutdown signal: {}", e)))?;
|
||||
|
||||
info!("Received shutdown signal");
|
||||
|
||||
self.stop().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_queue_name_format() {
|
||||
let worker_id = 42;
|
||||
let queue_name = format!("worker.{}.executions", worker_id);
|
||||
assert_eq!(queue_name, "worker.42.executions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_status_string_conversion() {
|
||||
let status = ExecutionStatus::Running;
|
||||
let status_str = match status {
|
||||
ExecutionStatus::Running => "running",
|
||||
_ => "unknown",
|
||||
};
|
||||
assert_eq!(status_str, "running");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execution_completed_payload_structure() {
|
||||
let payload = ExecutionCompletedPayload {
|
||||
execution_id: 123,
|
||||
action_id: 456,
|
||||
action_ref: "test.action".to_string(),
|
||||
status: "Completed".to_string(),
|
||||
result: Some(serde_json::json!({"output": "test"})),
|
||||
completed_at: Utc::now(),
|
||||
};
|
||||
|
||||
assert_eq!(payload.execution_id, 123);
|
||||
assert_eq!(payload.action_id, 456);
|
||||
assert_eq!(payload.action_ref, "test.action");
|
||||
assert_eq!(payload.status, "Completed");
|
||||
assert!(payload.result.is_some());
|
||||
}
|
||||
|
||||
// Test removed - ExecutionStatusPayload struct doesn't exist
|
||||
// #[test]
|
||||
// fn test_execution_status_payload_structure() {
|
||||
// ...
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_execution_scheduled_payload_structure() {
|
||||
let payload = ExecutionScheduledPayload {
|
||||
execution_id: 111,
|
||||
action_ref: "core.test".to_string(),
|
||||
worker_id: 222,
|
||||
};
|
||||
|
||||
assert_eq!(payload.execution_id, 111);
|
||||
assert_eq!(payload.action_ref, "core.test");
|
||||
assert_eq!(payload.worker_id, 222);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_status_format_for_completion() {
|
||||
let status = ExecutionStatus::Completed;
|
||||
let status_str = format!("{:?}", status);
|
||||
assert_eq!(status_str, "Completed");
|
||||
|
||||
let status = ExecutionStatus::Failed;
|
||||
let status_str = format!("{:?}", status);
|
||||
assert_eq!(status_str, "Failed");
|
||||
|
||||
let status = ExecutionStatus::Timeout;
|
||||
let status_str = format!("{:?}", status);
|
||||
assert_eq!(status_str, "Timeout");
|
||||
|
||||
let status = ExecutionStatus::Cancelled;
|
||||
let status_str = format!("{:?}", status);
|
||||
assert_eq!(status_str, "Cancelled");
|
||||
}
|
||||
}
|
||||
507
crates/worker/src/test_executor.rs
Normal file
507
crates/worker/src/test_executor.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
//! Pack Test Executor Module
|
||||
//!
|
||||
//! Executes pack tests by running test runners and collecting results.
|
||||
|
||||
use attune_common::error::{Error, Result};
|
||||
use attune_common::models::pack_test::{
|
||||
PackTestResult, TestCaseResult, TestStatus, TestSuiteResult,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Stdio;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Test configuration from pack.yaml
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TestConfig {
|
||||
pub enabled: bool,
|
||||
pub discovery: DiscoveryConfig,
|
||||
pub runners: HashMap<String, RunnerConfig>,
|
||||
pub result_format: Option<String>,
|
||||
pub result_path: Option<String>,
|
||||
pub min_pass_rate: Option<f64>,
|
||||
pub on_failure: Option<String>,
|
||||
}
|
||||
|
||||
/// Test discovery configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiscoveryConfig {
|
||||
pub method: String,
|
||||
pub path: Option<String>,
|
||||
}
|
||||
|
||||
/// Test runner configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RunnerConfig {
|
||||
pub r#type: String,
|
||||
pub entry_point: String,
|
||||
pub timeout: Option<u64>,
|
||||
pub result_format: Option<String>,
|
||||
}
|
||||
|
||||
/// Test executor for running pack tests
|
||||
pub struct TestExecutor {
|
||||
/// Base directory for pack files
|
||||
pack_base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl TestExecutor {
|
||||
/// Create a new test executor
|
||||
pub fn new(pack_base_dir: PathBuf) -> Self {
|
||||
Self { pack_base_dir }
|
||||
}
|
||||
|
||||
/// Execute all tests for a pack
|
||||
pub async fn execute_pack_tests(
|
||||
&self,
|
||||
pack_ref: &str,
|
||||
pack_version: &str,
|
||||
test_config: &TestConfig,
|
||||
) -> Result<PackTestResult> {
|
||||
info!("Executing tests for pack: {} v{}", pack_ref, pack_version);
|
||||
|
||||
if !test_config.enabled {
|
||||
return Err(Error::Validation(
|
||||
"Testing is not enabled for this pack".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let pack_dir = self.pack_base_dir.join(pack_ref);
|
||||
if !pack_dir.exists() {
|
||||
return Err(Error::not_found(
|
||||
"pack_directory",
|
||||
"path",
|
||||
pack_dir.display().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let start_time = Instant::now();
|
||||
let execution_time = Utc::now();
|
||||
let mut test_suites = Vec::new();
|
||||
|
||||
// Execute tests for each runner
|
||||
for (runner_name, runner_config) in &test_config.runners {
|
||||
info!(
|
||||
"Running test suite: {} ({})",
|
||||
runner_name, runner_config.r#type
|
||||
);
|
||||
|
||||
match self
|
||||
.execute_test_suite(&pack_dir, runner_name, runner_config)
|
||||
.await
|
||||
{
|
||||
Ok(suite_result) => {
|
||||
info!(
|
||||
"Test suite '{}' completed: {}/{} passed",
|
||||
runner_name, suite_result.passed, suite_result.total
|
||||
);
|
||||
test_suites.push(suite_result);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Test suite '{}' failed to execute: {}", runner_name, e);
|
||||
// Create a failed suite result
|
||||
test_suites.push(TestSuiteResult {
|
||||
name: runner_name.clone(),
|
||||
runner_type: runner_config.r#type.clone(),
|
||||
total: 0,
|
||||
passed: 0,
|
||||
failed: 1,
|
||||
skipped: 0,
|
||||
duration_ms: 0,
|
||||
test_cases: vec![TestCaseResult {
|
||||
name: format!("{}_execution", runner_name),
|
||||
status: TestStatus::Error,
|
||||
duration_ms: 0,
|
||||
error_message: Some(e.to_string()),
|
||||
stdout: None,
|
||||
stderr: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_duration_ms = start_time.elapsed().as_millis() as i64;
|
||||
|
||||
// Aggregate results
|
||||
let total_tests: i32 = test_suites.iter().map(|s| s.total).sum();
|
||||
let passed: i32 = test_suites.iter().map(|s| s.passed).sum();
|
||||
let failed: i32 = test_suites.iter().map(|s| s.failed).sum();
|
||||
let skipped: i32 = test_suites.iter().map(|s| s.skipped).sum();
|
||||
let pass_rate = if total_tests > 0 {
|
||||
passed as f64 / total_tests as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
info!(
|
||||
"Pack tests completed: {}/{} passed ({:.1}%)",
|
||||
passed,
|
||||
total_tests,
|
||||
pass_rate * 100.0
|
||||
);
|
||||
|
||||
// Determine overall test status
|
||||
let status = if failed > 0 {
|
||||
"failed".to_string()
|
||||
} else if passed == total_tests {
|
||||
"passed".to_string()
|
||||
} else if skipped == total_tests {
|
||||
"skipped".to_string()
|
||||
} else {
|
||||
"partial".to_string()
|
||||
};
|
||||
|
||||
Ok(PackTestResult {
|
||||
pack_ref: pack_ref.to_string(),
|
||||
pack_version: pack_version.to_string(),
|
||||
execution_time,
|
||||
status,
|
||||
total_tests,
|
||||
passed,
|
||||
failed,
|
||||
skipped,
|
||||
pass_rate,
|
||||
duration_ms: total_duration_ms,
|
||||
test_suites,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a single test suite
|
||||
async fn execute_test_suite(
|
||||
&self,
|
||||
pack_dir: &Path,
|
||||
runner_name: &str,
|
||||
runner_config: &RunnerConfig,
|
||||
) -> Result<TestSuiteResult> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Resolve entry point path
|
||||
let entry_point = pack_dir.join(&runner_config.entry_point);
|
||||
if !entry_point.exists() {
|
||||
return Err(Error::not_found(
|
||||
"test_entry_point",
|
||||
"path",
|
||||
entry_point.display().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Determine command based on runner type
|
||||
// Use relative path from pack directory for the entry point
|
||||
let relative_entry_point = entry_point
|
||||
.strip_prefix(pack_dir)
|
||||
.unwrap_or(&entry_point)
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
let (command, args) = match runner_config.r#type.as_str() {
|
||||
"script" => {
|
||||
// Execute as shell script
|
||||
let shell = if entry_point.extension().and_then(|s| s.to_str()) == Some("sh") {
|
||||
"/bin/sh"
|
||||
} else {
|
||||
"/bin/bash"
|
||||
};
|
||||
(shell.to_string(), vec![relative_entry_point])
|
||||
}
|
||||
"unittest" => {
|
||||
// Execute as Python unittest
|
||||
(
|
||||
"python3".to_string(),
|
||||
vec![
|
||||
"-m".to_string(),
|
||||
"unittest".to_string(),
|
||||
relative_entry_point,
|
||||
],
|
||||
)
|
||||
}
|
||||
"pytest" => {
|
||||
// Execute with pytest
|
||||
(
|
||||
"pytest".to_string(),
|
||||
vec![relative_entry_point, "-v".to_string()],
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::Validation(format!(
|
||||
"Unsupported runner type: {}",
|
||||
runner_config.r#type
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Execute test command with pack_dir as working directory
|
||||
let timeout_duration = Duration::from_secs(runner_config.timeout.unwrap_or(300));
|
||||
let output = self
|
||||
.run_command(&command, &args, pack_dir, timeout_duration)
|
||||
.await?;
|
||||
|
||||
let duration_ms = start_time.elapsed().as_millis() as i64;
|
||||
|
||||
// Parse output based on result format
|
||||
let result_format = runner_config.result_format.as_deref().unwrap_or("simple");
|
||||
|
||||
let mut suite_result = match result_format {
|
||||
"simple" => self.parse_simple_output(&output, runner_name, &runner_config.r#type)?,
|
||||
"json" => self.parse_json_output(&output.stdout, runner_name)?,
|
||||
_ => {
|
||||
warn!(
|
||||
"Unknown result format '{}', falling back to simple",
|
||||
result_format
|
||||
);
|
||||
self.parse_simple_output(&output, runner_name, &runner_config.r#type)?
|
||||
}
|
||||
};
|
||||
|
||||
suite_result.duration_ms = duration_ms;
|
||||
|
||||
Ok(suite_result)
|
||||
}
|
||||
|
||||
/// Run a command with timeout
|
||||
async fn run_command(
|
||||
&self,
|
||||
command: &str,
|
||||
args: &[String],
|
||||
working_dir: &Path,
|
||||
timeout: Duration,
|
||||
) -> Result<CommandOutput> {
|
||||
debug!(
|
||||
"Executing command: {} {} (timeout: {:?})",
|
||||
command,
|
||||
args.join(" "),
|
||||
timeout
|
||||
);
|
||||
|
||||
let mut cmd = Command::new(command);
|
||||
cmd.args(args)
|
||||
.current_dir(working_dir)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.stdin(Stdio::null());
|
||||
|
||||
let start = Instant::now();
|
||||
let mut child = cmd.spawn().map_err(|e| {
|
||||
Error::Internal(format!("Failed to spawn command '{}': {}", command, e))
|
||||
})?;
|
||||
|
||||
// Wait for process with timeout
|
||||
let status = tokio::time::timeout(timeout, child.wait())
|
||||
.await
|
||||
.map_err(|_| Error::Timeout(format!("Test execution timed out after {:?}", timeout)))?
|
||||
.map_err(|e| Error::Internal(format!("Process wait failed: {}", e)))?;
|
||||
|
||||
// Read output
|
||||
let stdout_handle = child.stdout.take();
|
||||
let stderr_handle = child.stderr.take();
|
||||
|
||||
let stdout = if let Some(stdout) = stdout_handle {
|
||||
self.read_stream(stdout).await?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let stderr = if let Some(stderr) = stderr_handle {
|
||||
self.read_stream(stderr).await?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
let exit_code = status.code().unwrap_or(-1);
|
||||
|
||||
Ok(CommandOutput {
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
duration_ms,
|
||||
})
|
||||
}
|
||||
|
||||
/// Read from an async stream
|
||||
async fn read_stream(&self, stream: impl tokio::io::AsyncRead + Unpin) -> Result<String> {
|
||||
let mut reader = BufReader::new(stream);
|
||||
let mut output = String::new();
|
||||
let mut line = String::new();
|
||||
|
||||
while reader
|
||||
.read_line(&mut line)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to read stream: {}", e)))?
|
||||
> 0
|
||||
{
|
||||
output.push_str(&line);
|
||||
line.clear();
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Parse simple test output format
|
||||
fn parse_simple_output(
|
||||
&self,
|
||||
output: &CommandOutput,
|
||||
runner_name: &str,
|
||||
runner_type: &str,
|
||||
) -> Result<TestSuiteResult> {
|
||||
let text = format!("{}\n{}", output.stdout, output.stderr);
|
||||
|
||||
// Parse test counts from output
|
||||
let total = self.extract_number(&text, "Total Tests:");
|
||||
let passed = self.extract_number(&text, "Passed:");
|
||||
let failed = self.extract_number(&text, "Failed:");
|
||||
let skipped = self.extract_number(&text, "Skipped:").or_else(|| Some(0));
|
||||
|
||||
// If we couldn't parse counts, use exit code
|
||||
let (total, passed, failed, skipped) = if total.is_none() || passed.is_none() {
|
||||
if output.exit_code == 0 {
|
||||
(1, 1, 0, 0)
|
||||
} else {
|
||||
(1, 0, 1, 0)
|
||||
}
|
||||
} else {
|
||||
(
|
||||
total.unwrap_or(0),
|
||||
passed.unwrap_or(0),
|
||||
failed.unwrap_or(0),
|
||||
skipped.unwrap_or(0),
|
||||
)
|
||||
};
|
||||
|
||||
// Create a single test case representing the entire suite
|
||||
let test_case = TestCaseResult {
|
||||
name: format!("{}_suite", runner_name),
|
||||
status: if output.exit_code == 0 {
|
||||
TestStatus::Passed
|
||||
} else {
|
||||
TestStatus::Failed
|
||||
},
|
||||
duration_ms: output.duration_ms as i64,
|
||||
error_message: if output.exit_code != 0 {
|
||||
Some(format!("Exit code: {}", output.exit_code))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
stdout: if !output.stdout.is_empty() {
|
||||
Some(output.stdout.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
stderr: if !output.stderr.is_empty() {
|
||||
Some(output.stderr.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
|
||||
Ok(TestSuiteResult {
|
||||
name: runner_name.to_string(),
|
||||
runner_type: runner_type.to_string(),
|
||||
total,
|
||||
passed,
|
||||
failed,
|
||||
skipped,
|
||||
duration_ms: output.duration_ms as i64,
|
||||
test_cases: vec![test_case],
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse JSON test output format
|
||||
fn parse_json_output(&self, _json_str: &str, _runner_name: &str) -> Result<TestSuiteResult> {
|
||||
// TODO: Implement JSON parsing for structured test results
|
||||
// For now, return a basic result
|
||||
Err(Error::Validation(
|
||||
"JSON result format not yet implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Extract a number from text after a label
|
||||
fn extract_number(&self, text: &str, label: &str) -> Option<i32> {
|
||||
text.lines()
|
||||
.find(|line| line.contains(label))
|
||||
.and_then(|line| {
|
||||
line.split(label)
|
||||
.nth(1)?
|
||||
.trim()
|
||||
.split_whitespace()
|
||||
.next()?
|
||||
.parse::<i32>()
|
||||
.ok()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Command execution output
|
||||
#[derive(Debug)]
|
||||
struct CommandOutput {
|
||||
exit_code: i32,
|
||||
stdout: String,
|
||||
stderr: String,
|
||||
duration_ms: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_number() {
|
||||
let executor = TestExecutor::new(PathBuf::from("/tmp"));
|
||||
|
||||
let text = "Total Tests: 36\nPassed: 35\nFailed: 1";
|
||||
|
||||
assert_eq!(executor.extract_number(text, "Total Tests:"), Some(36));
|
||||
assert_eq!(executor.extract_number(text, "Passed:"), Some(35));
|
||||
assert_eq!(executor.extract_number(text, "Failed:"), Some(1));
|
||||
assert_eq!(executor.extract_number(text, "Skipped:"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_simple_output() {
|
||||
let executor = TestExecutor::new(PathBuf::from("/tmp"));
|
||||
|
||||
let output = CommandOutput {
|
||||
exit_code: 0,
|
||||
stdout: "Total Tests: 36\nPassed: 36\nFailed: 0\n".to_string(),
|
||||
stderr: String::new(),
|
||||
duration_ms: 1234,
|
||||
};
|
||||
|
||||
let result = executor
|
||||
.parse_simple_output(&output, "shell", "script")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.total, 36);
|
||||
assert_eq!(result.passed, 36);
|
||||
assert_eq!(result.failed, 0);
|
||||
assert_eq!(result.skipped, 0);
|
||||
assert_eq!(result.duration_ms, 1234);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_simple_output_with_failures() {
|
||||
let executor = TestExecutor::new(PathBuf::from("/tmp"));
|
||||
|
||||
let output = CommandOutput {
|
||||
exit_code: 1,
|
||||
stdout: "Total Tests: 10\nPassed: 8\nFailed: 2\n".to_string(),
|
||||
stderr: "Some tests failed\n".to_string(),
|
||||
duration_ms: 5000,
|
||||
};
|
||||
|
||||
let result = executor
|
||||
.parse_simple_output(&output, "python", "unittest")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.total, 10);
|
||||
assert_eq!(result.passed, 8);
|
||||
assert_eq!(result.failed, 2);
|
||||
assert_eq!(result.test_cases.len(), 1);
|
||||
assert_eq!(result.test_cases[0].status, TestStatus::Failed);
|
||||
}
|
||||
}
|
||||
377
crates/worker/tests/dependency_isolation_test.rs
Normal file
377
crates/worker/tests/dependency_isolation_test.rs
Normal file
@@ -0,0 +1,377 @@
|
||||
//! Integration tests for Python virtual environment dependency isolation
|
||||
//!
|
||||
//! Tests the end-to-end flow of creating isolated Python environments
|
||||
//! for packs with dependencies.
|
||||
|
||||
use attune_worker::runtime::{
|
||||
DependencyManager, DependencyManagerRegistry, DependencySpec, PythonVenvManager,
|
||||
};
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_venv_creation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec = DependencySpec::new("python").with_dependency("requests==2.28.0");
|
||||
|
||||
let env_info = manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
assert_eq!(env_info.runtime, "python");
|
||||
assert!(env_info.is_valid);
|
||||
assert!(env_info.path.exists());
|
||||
assert!(env_info.executable_path.exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_venv_idempotency() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec = DependencySpec::new("python").with_dependency("requests==2.28.0");
|
||||
|
||||
// Create environment first time
|
||||
let env_info1 = manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
let created_at1 = env_info1.created_at;
|
||||
|
||||
// Call ensure_environment again with same dependencies
|
||||
let env_info2 = manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to ensure environment");
|
||||
|
||||
// Should return existing environment (same created_at)
|
||||
assert_eq!(env_info1.created_at, env_info2.created_at);
|
||||
assert_eq!(created_at1, env_info2.created_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_venv_update_on_dependency_change() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec1 = DependencySpec::new("python").with_dependency("requests==2.28.0");
|
||||
|
||||
// Create environment with first set of dependencies
|
||||
let env_info1 = manager
|
||||
.ensure_environment("test_pack", &spec1)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
let created_at1 = env_info1.created_at;
|
||||
|
||||
// Give it a moment to ensure timestamp difference
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Change dependencies
|
||||
let spec2 = DependencySpec::new("python").with_dependency("requests==2.29.0");
|
||||
|
||||
// Should recreate environment
|
||||
let env_info2 = manager
|
||||
.ensure_environment("test_pack", &spec2)
|
||||
.await
|
||||
.expect("Failed to update environment");
|
||||
|
||||
// Updated timestamp should be newer
|
||||
assert!(env_info2.updated_at >= created_at1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_pack_isolation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec1 = DependencySpec::new("python").with_dependency("requests==2.28.0");
|
||||
let spec2 = DependencySpec::new("python").with_dependency("flask==2.3.0");
|
||||
|
||||
// Create environments for two different packs
|
||||
let env1 = manager
|
||||
.ensure_environment("pack_a", &spec1)
|
||||
.await
|
||||
.expect("Failed to create environment for pack_a");
|
||||
|
||||
let env2 = manager
|
||||
.ensure_environment("pack_b", &spec2)
|
||||
.await
|
||||
.expect("Failed to create environment for pack_b");
|
||||
|
||||
// Should have different paths
|
||||
assert_ne!(env1.path, env2.path);
|
||||
assert_ne!(env1.executable_path, env2.executable_path);
|
||||
|
||||
// Both should be valid
|
||||
assert!(env1.is_valid);
|
||||
assert!(env2.is_valid);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_executable_path() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec = DependencySpec::new("python");
|
||||
|
||||
manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
let python_path = manager
|
||||
.get_executable_path("test_pack")
|
||||
.await
|
||||
.expect("Failed to get executable path");
|
||||
|
||||
assert!(python_path.exists());
|
||||
assert!(python_path.to_string_lossy().contains("test_pack"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_environment() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
// Non-existent environment should not be valid
|
||||
let is_valid = manager
|
||||
.validate_environment("nonexistent")
|
||||
.await
|
||||
.expect("Validation check failed");
|
||||
assert!(!is_valid);
|
||||
|
||||
// Create environment
|
||||
let spec = DependencySpec::new("python");
|
||||
manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
// Should now be valid
|
||||
let is_valid = manager
|
||||
.validate_environment("test_pack")
|
||||
.await
|
||||
.expect("Validation check failed");
|
||||
assert!(is_valid);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_environment() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec = DependencySpec::new("python");
|
||||
|
||||
// Create environment
|
||||
let env_info = manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
let path = env_info.path.clone();
|
||||
assert!(path.exists());
|
||||
|
||||
// Remove environment
|
||||
manager
|
||||
.remove_environment("test_pack")
|
||||
.await
|
||||
.expect("Failed to remove environment");
|
||||
|
||||
assert!(!path.exists());
|
||||
|
||||
// Get environment should return None
|
||||
let env = manager
|
||||
.get_environment("test_pack")
|
||||
.await
|
||||
.expect("Failed to get environment");
|
||||
assert!(env.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_environments() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec = DependencySpec::new("python");
|
||||
|
||||
// Create multiple environments
|
||||
manager
|
||||
.ensure_environment("pack_a", &spec)
|
||||
.await
|
||||
.expect("Failed to create pack_a");
|
||||
|
||||
manager
|
||||
.ensure_environment("pack_b", &spec)
|
||||
.await
|
||||
.expect("Failed to create pack_b");
|
||||
|
||||
manager
|
||||
.ensure_environment("pack_c", &spec)
|
||||
.await
|
||||
.expect("Failed to create pack_c");
|
||||
|
||||
// List should return all three
|
||||
let environments = manager
|
||||
.list_environments()
|
||||
.await
|
||||
.expect("Failed to list environments");
|
||||
|
||||
assert_eq!(environments.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dependency_manager_registry() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let mut registry = DependencyManagerRegistry::new();
|
||||
|
||||
let python_manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
registry.register(Box::new(python_manager));
|
||||
|
||||
// Should support python
|
||||
assert!(registry.supports("python"));
|
||||
assert!(!registry.supports("nodejs"));
|
||||
|
||||
// Should be able to get manager
|
||||
let manager = registry.get("python");
|
||||
assert!(manager.is_some());
|
||||
assert_eq!(manager.unwrap().runtime_type(), "python");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dependency_spec_builder() {
|
||||
let spec = DependencySpec::new("python")
|
||||
.with_dependency("requests==2.28.0")
|
||||
.with_dependency("flask>=2.0.0")
|
||||
.with_version_range(Some("3.8".to_string()), Some("3.11".to_string()));
|
||||
|
||||
assert_eq!(spec.runtime, "python");
|
||||
assert_eq!(spec.dependencies.len(), 2);
|
||||
assert!(spec.has_dependencies());
|
||||
assert_eq!(spec.min_version, Some("3.8".to_string()));
|
||||
assert_eq!(spec.max_version, Some("3.11".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_requirements_file_content() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let requirements = "requests==2.28.0\nflask==2.3.0\npydantic>=2.0.0";
|
||||
let spec = DependencySpec::new("python").with_requirements_file(requirements.to_string());
|
||||
|
||||
let env_info = manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment with requirements file");
|
||||
|
||||
assert!(env_info.is_valid);
|
||||
assert!(env_info.installed_dependencies.len() > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pack_ref_sanitization() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec = DependencySpec::new("python");
|
||||
|
||||
// Pack refs with special characters should be sanitized
|
||||
let env_info = manager
|
||||
.ensure_environment("core.http", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
// Path should not contain dots
|
||||
let path_str = env_info.path.to_string_lossy();
|
||||
assert!(path_str.contains("core_http"));
|
||||
assert!(!path_str.contains("core.http"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_needs_update_detection() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec1 = DependencySpec::new("python").with_dependency("requests==2.28.0");
|
||||
|
||||
// Non-existent environment needs update
|
||||
let needs_update = manager
|
||||
.needs_update("test_pack", &spec1)
|
||||
.await
|
||||
.expect("Failed to check update status");
|
||||
assert!(needs_update);
|
||||
|
||||
// Create environment
|
||||
manager
|
||||
.ensure_environment("test_pack", &spec1)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
// Same spec should not need update
|
||||
let needs_update = manager
|
||||
.needs_update("test_pack", &spec1)
|
||||
.await
|
||||
.expect("Failed to check update status");
|
||||
assert!(!needs_update);
|
||||
|
||||
// Different spec should need update
|
||||
let spec2 = DependencySpec::new("python").with_dependency("requests==2.29.0");
|
||||
let needs_update = manager
|
||||
.needs_update("test_pack", &spec2)
|
||||
.await
|
||||
.expect("Failed to check update status");
|
||||
assert!(needs_update);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_dependencies() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
// Pack with no dependencies should still create venv
|
||||
let spec = DependencySpec::new("python");
|
||||
assert!(!spec.has_dependencies());
|
||||
|
||||
let env_info = manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment without dependencies");
|
||||
|
||||
assert!(env_info.is_valid);
|
||||
assert!(env_info.path.exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_environment_caching() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PythonVenvManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let spec = DependencySpec::new("python");
|
||||
|
||||
// Create environment
|
||||
manager
|
||||
.ensure_environment("test_pack", &spec)
|
||||
.await
|
||||
.expect("Failed to create environment");
|
||||
|
||||
// First get_environment should read from disk
|
||||
let env1 = manager
|
||||
.get_environment("test_pack")
|
||||
.await
|
||||
.expect("Failed to get environment")
|
||||
.expect("Environment not found");
|
||||
|
||||
// Second get_environment should use cache
|
||||
let env2 = manager
|
||||
.get_environment("test_pack")
|
||||
.await
|
||||
.expect("Failed to get environment")
|
||||
.expect("Environment not found");
|
||||
|
||||
assert_eq!(env1.id, env2.id);
|
||||
assert_eq!(env1.path, env2.path);
|
||||
}
|
||||
277
crates/worker/tests/log_truncation_test.rs
Normal file
277
crates/worker/tests/log_truncation_test.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
//! Integration tests for log size truncation
|
||||
//!
|
||||
//! Tests that verify stdout/stderr are properly truncated when they exceed
|
||||
//! configured size limits, preventing OOM issues with large output.
|
||||
|
||||
use attune_worker::runtime::{ExecutionContext, PythonRuntime, Runtime, ShellRuntime};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_stdout_truncation() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
// Create a Python script that outputs more than the limit
|
||||
let code = r#"
|
||||
import sys
|
||||
# Output 1KB of data (will exceed 500 byte limit)
|
||||
for i in range(100):
|
||||
print("x" * 10)
|
||||
"#;
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 1,
|
||||
action_ref: "test.large_output".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "test_script".to_string(),
|
||||
code: Some(code.to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 500, // Small limit to trigger truncation
|
||||
max_stderr_bytes: 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Should succeed but with truncated output
|
||||
assert!(result.is_success());
|
||||
assert!(result.stdout_truncated);
|
||||
assert!(result.stdout.contains("[OUTPUT TRUNCATED"));
|
||||
assert!(result.stdout_bytes_truncated > 0);
|
||||
assert!(result.stdout.len() <= 500);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_stderr_truncation() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
// Create a Python script that outputs to stderr
|
||||
let code = r#"
|
||||
import sys
|
||||
# Output 1KB of data to stderr
|
||||
for i in range(100):
|
||||
sys.stderr.write("error message line\n")
|
||||
"#;
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 2,
|
||||
action_ref: "test.large_stderr".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "test_script".to_string(),
|
||||
code: Some(code.to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 300, // Small limit for stderr
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Should succeed but with truncated stderr
|
||||
assert!(result.is_success());
|
||||
assert!(!result.stdout_truncated);
|
||||
assert!(result.stderr_truncated);
|
||||
assert!(result.stderr.contains("[OUTPUT TRUNCATED"));
|
||||
assert!(result.stderr.contains("stderr exceeded size limit"));
|
||||
assert!(result.stderr_bytes_truncated > 0);
|
||||
assert!(result.stderr.len() <= 300);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_stdout_truncation() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
// Shell script that outputs more than the limit
|
||||
let code = r#"
|
||||
for i in {1..100}; do
|
||||
echo "This is a long line of text that will add up quickly"
|
||||
done
|
||||
"#;
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 3,
|
||||
action_ref: "test.shell_large_output".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some(code.to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 400, // Small limit
|
||||
max_stderr_bytes: 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Should succeed but with truncated output
|
||||
assert!(result.is_success());
|
||||
assert!(result.stdout_truncated);
|
||||
assert!(result.stdout.contains("[OUTPUT TRUNCATED"));
|
||||
assert!(result.stdout_bytes_truncated > 0);
|
||||
assert!(result.stdout.len() <= 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_truncation_under_limit() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
// Small output that won't trigger truncation
|
||||
let code = r#"
|
||||
print("Hello, World!")
|
||||
"#;
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 4,
|
||||
action_ref: "test.small_output".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "test_script".to_string(),
|
||||
code: Some(code.to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024, // Large limit
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Should succeed without truncation
|
||||
assert!(result.is_success());
|
||||
assert!(!result.stdout_truncated);
|
||||
assert!(!result.stderr_truncated);
|
||||
assert_eq!(result.stdout_bytes_truncated, 0);
|
||||
assert_eq!(result.stderr_bytes_truncated, 0);
|
||||
assert!(result.stdout.contains("Hello, World!"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_both_streams_truncated() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
// Script that outputs to both stdout and stderr
|
||||
let code = r#"
|
||||
import sys
|
||||
# Output to both streams
|
||||
for i in range(50):
|
||||
print("stdout line " + str(i))
|
||||
sys.stderr.write("stderr line " + str(i) + "\n")
|
||||
"#;
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 5,
|
||||
action_ref: "test.dual_truncation".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "test_script".to_string(),
|
||||
code: Some(code.to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 300, // Both limits are small
|
||||
max_stderr_bytes: 300,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Should succeed but with both streams truncated
|
||||
assert!(result.is_success());
|
||||
assert!(result.stdout_truncated);
|
||||
assert!(result.stderr_truncated);
|
||||
assert!(result.stdout.contains("[OUTPUT TRUNCATED"));
|
||||
assert!(result.stderr.contains("[OUTPUT TRUNCATED"));
|
||||
assert!(result.stdout_bytes_truncated > 0);
|
||||
assert!(result.stderr_bytes_truncated > 0);
|
||||
assert!(result.stdout.len() <= 300);
|
||||
assert!(result.stderr.len() <= 300);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncation_with_timeout() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
// Script that times out but should still capture truncated logs
|
||||
let code = r#"
|
||||
import time
|
||||
for i in range(1000):
|
||||
print(f"Line {i}")
|
||||
time.sleep(30) # Will timeout before this
|
||||
"#;
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 6,
|
||||
action_ref: "test.timeout_truncation".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(2), // Short timeout
|
||||
working_dir: None,
|
||||
entry_point: "test_script".to_string(),
|
||||
code: Some(code.to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 500,
|
||||
max_stderr_bytes: 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Should timeout with truncated logs
|
||||
assert!(!result.is_success());
|
||||
assert!(result.error.is_some());
|
||||
assert!(result.error.as_ref().unwrap().contains("timed out"));
|
||||
// Logs may or may not be truncated depending on how fast it runs
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exact_limit_no_truncation() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
// Output a small amount that won't trigger truncation
|
||||
// The Python wrapper adds JSON result output, so we need headroom
|
||||
let code = r#"
|
||||
import sys
|
||||
sys.stdout.write("Small output")
|
||||
"#;
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 7,
|
||||
action_ref: "test.exact_limit".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(),
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "test_script".to_string(),
|
||||
code: Some(code.to_string()),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024, // Large limit to avoid truncation
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Should succeed without truncation
|
||||
eprintln!(
|
||||
"test_exact_limit_no_truncation: exit_code={}, error={:?}, stdout={:?}, stderr={:?}",
|
||||
result.exit_code, result.error, result.stdout, result.stderr
|
||||
);
|
||||
assert!(result.is_success());
|
||||
assert!(!result.stdout_truncated);
|
||||
assert!(result.stdout.contains("Small output"));
|
||||
}
|
||||
415
crates/worker/tests/security_tests.rs
Normal file
415
crates/worker/tests/security_tests.rs
Normal file
@@ -0,0 +1,415 @@
|
||||
//! Security Tests for Secret Handling
|
||||
//!
|
||||
//! These tests verify that secrets are NOT exposed in process environment
|
||||
//! or command-line arguments, ensuring secure secret passing via stdin.
|
||||
|
||||
use attune_worker::runtime::python::PythonRuntime;
|
||||
use attune_worker::runtime::shell::ShellRuntime;
|
||||
use attune_worker::runtime::{ExecutionContext, Runtime};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_secrets_not_in_environ() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 1,
|
||||
action_ref: "security.test_environ".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: {
|
||||
let mut s = HashMap::new();
|
||||
s.insert(
|
||||
"api_key".to_string(),
|
||||
"super_secret_key_do_not_expose".to_string(),
|
||||
);
|
||||
s.insert("password".to_string(), "secret_pass_123".to_string());
|
||||
s
|
||||
},
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
import os
|
||||
|
||||
def run():
|
||||
# Check if secrets are in environment variables
|
||||
environ_str = str(os.environ)
|
||||
|
||||
# Secrets should NOT be in environment
|
||||
has_secret_in_env = 'super_secret_key_do_not_expose' in environ_str
|
||||
has_password_in_env = 'secret_pass_123' in environ_str
|
||||
has_secret_prefix = 'SECRET_API_KEY' in os.environ or 'SECRET_PASSWORD' in os.environ
|
||||
|
||||
# But they SHOULD be accessible via get_secret()
|
||||
api_key_accessible = get_secret('api_key') == 'super_secret_key_do_not_expose'
|
||||
password_accessible = get_secret('password') == 'secret_pass_123'
|
||||
|
||||
return {
|
||||
'secrets_in_environ': has_secret_in_env or has_password_in_env or has_secret_prefix,
|
||||
'api_key_accessible': api_key_accessible,
|
||||
'password_accessible': password_accessible,
|
||||
'environ_check': 'SECRET_' not in environ_str
|
||||
}
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(result.is_success(), "Execution should succeed");
|
||||
|
||||
let result_data = result.result.unwrap();
|
||||
let result_obj = result_data.get("result").unwrap();
|
||||
|
||||
// Critical security check: secrets should NOT be in environment
|
||||
assert_eq!(
|
||||
result_obj.get("secrets_in_environ").unwrap(),
|
||||
&serde_json::json!(false),
|
||||
"SECURITY FAILURE: Secrets found in process environment!"
|
||||
);
|
||||
|
||||
// Verify secrets ARE accessible via secure method
|
||||
assert_eq!(
|
||||
result_obj.get("api_key_accessible").unwrap(),
|
||||
&serde_json::json!(true),
|
||||
"Secrets should be accessible via get_secret()"
|
||||
);
|
||||
assert_eq!(
|
||||
result_obj.get("password_accessible").unwrap(),
|
||||
&serde_json::json!(true),
|
||||
"Secrets should be accessible via get_secret()"
|
||||
);
|
||||
|
||||
// Verify no SECRET_ prefix in environment
|
||||
assert_eq!(
|
||||
result_obj.get("environ_check").unwrap(),
|
||||
&serde_json::json!(true),
|
||||
"Environment should not contain SECRET_ prefix variables"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_secrets_not_in_environ() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 2,
|
||||
action_ref: "security.test_shell_environ".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: {
|
||||
let mut s = HashMap::new();
|
||||
s.insert(
|
||||
"api_key".to_string(),
|
||||
"super_secret_key_do_not_expose".to_string(),
|
||||
);
|
||||
s.insert("password".to_string(), "secret_pass_123".to_string());
|
||||
s
|
||||
},
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
# Check if secrets are in environment variables
|
||||
if printenv | grep -q "super_secret_key_do_not_expose"; then
|
||||
echo "SECURITY_FAIL: Secret found in environment"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if printenv | grep -q "secret_pass_123"; then
|
||||
echo "SECURITY_FAIL: Password found in environment"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if printenv | grep -q "SECRET_API_KEY"; then
|
||||
echo "SECURITY_FAIL: SECRET_ prefix found in environment"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# But secrets SHOULD be accessible via get_secret function
|
||||
api_key=$(get_secret 'api_key')
|
||||
password=$(get_secret 'password')
|
||||
|
||||
if [ "$api_key" != "super_secret_key_do_not_expose" ]; then
|
||||
echo "ERROR: Secret not accessible via get_secret"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$password" != "secret_pass_123" ]; then
|
||||
echo "ERROR: Password not accessible via get_secret"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "SECURITY_PASS: Secrets not in environment but accessible via get_secret"
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
|
||||
// Check execution succeeded
|
||||
assert!(result.is_success(), "Execution should succeed");
|
||||
assert_eq!(result.exit_code, 0, "Exit code should be 0");
|
||||
|
||||
// Verify security pass message
|
||||
assert!(
|
||||
result.stdout.contains("SECURITY_PASS"),
|
||||
"Security checks should pass"
|
||||
);
|
||||
assert!(
|
||||
!result.stdout.contains("SECURITY_FAIL"),
|
||||
"Should not have security failures"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_secret_isolation_between_actions() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
// First action with secret A
|
||||
let context1 = ExecutionContext {
|
||||
execution_id: 3,
|
||||
action_ref: "security.action1".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: {
|
||||
let mut s = HashMap::new();
|
||||
s.insert("secret_a".to_string(), "value_a".to_string());
|
||||
s
|
||||
},
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run():
|
||||
return {'secret_a': get_secret('secret_a')}
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result1 = runtime.execute(context1).await.unwrap();
|
||||
assert!(result1.is_success());
|
||||
|
||||
// Second action with secret B (should not see secret A)
|
||||
let context2 = ExecutionContext {
|
||||
execution_id: 4,
|
||||
action_ref: "security.action2".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: {
|
||||
let mut s = HashMap::new();
|
||||
s.insert("secret_b".to_string(), "value_b".to_string());
|
||||
s
|
||||
},
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run():
|
||||
# Should NOT see secret_a from previous action
|
||||
secret_a = get_secret('secret_a')
|
||||
secret_b = get_secret('secret_b')
|
||||
return {
|
||||
'secret_a_leaked': secret_a is not None,
|
||||
'secret_b_present': secret_b == 'value_b'
|
||||
}
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result2 = runtime.execute(context2).await.unwrap();
|
||||
assert!(result2.is_success());
|
||||
|
||||
let result_data = result2.result.unwrap();
|
||||
let result_obj = result_data.get("result").unwrap();
|
||||
|
||||
// Verify secrets don't leak between actions
|
||||
assert_eq!(
|
||||
result_obj.get("secret_a_leaked").unwrap(),
|
||||
&serde_json::json!(false),
|
||||
"Secret from previous action should not leak"
|
||||
);
|
||||
assert_eq!(
|
||||
result_obj.get("secret_b_present").unwrap(),
|
||||
&serde_json::json!(true),
|
||||
"Current action's secret should be present"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_empty_secrets() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 5,
|
||||
action_ref: "security.no_secrets".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(), // No secrets
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run():
|
||||
# get_secret should return None for non-existent secrets
|
||||
result = get_secret('nonexistent')
|
||||
return {'result': result}
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(
|
||||
result.is_success(),
|
||||
"Should handle empty secrets gracefully"
|
||||
);
|
||||
|
||||
let result_data = result.result.unwrap();
|
||||
let result_obj = result_data.get("result").unwrap();
|
||||
assert_eq!(result_obj.get("result").unwrap(), &serde_json::Value::Null);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shell_empty_secrets() {
|
||||
let runtime = ShellRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 6,
|
||||
action_ref: "security.no_secrets".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: HashMap::new(), // No secrets
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "shell".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
# get_secret should return empty string for non-existent secrets
|
||||
result=$(get_secret 'nonexistent')
|
||||
if [ -z "$result" ]; then
|
||||
echo "PASS: Empty secret returns empty string"
|
||||
else
|
||||
echo "FAIL: Expected empty string"
|
||||
exit 1
|
||||
fi
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("shell".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(
|
||||
result.is_success(),
|
||||
"Should handle empty secrets gracefully"
|
||||
);
|
||||
assert!(result.stdout.contains("PASS"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_special_characters_in_secrets() {
|
||||
let runtime = PythonRuntime::new();
|
||||
|
||||
let context = ExecutionContext {
|
||||
execution_id: 7,
|
||||
action_ref: "security.special_chars".to_string(),
|
||||
parameters: HashMap::new(),
|
||||
env: HashMap::new(),
|
||||
secrets: {
|
||||
let mut s = HashMap::new();
|
||||
s.insert("special_chars".to_string(), "test!@#$%^&*()".to_string());
|
||||
s.insert("with_newline".to_string(), "line1\nline2".to_string());
|
||||
s
|
||||
},
|
||||
timeout: Some(10),
|
||||
working_dir: None,
|
||||
entry_point: "run".to_string(),
|
||||
code: Some(
|
||||
r#"
|
||||
def run():
|
||||
special = get_secret('special_chars')
|
||||
newline = get_secret('with_newline')
|
||||
|
||||
newline_char = chr(10)
|
||||
newline_parts = newline.split(newline_char) if newline else []
|
||||
|
||||
return {
|
||||
'special_correct': special == 'test!@#$%^&*()',
|
||||
'newline_has_two_parts': len(newline_parts) == 2,
|
||||
'newline_first_part': newline_parts[0] if len(newline_parts) > 0 else '',
|
||||
'newline_second_part': newline_parts[1] if len(newline_parts) > 1 else '',
|
||||
'special_len': len(special) if special else 0
|
||||
}
|
||||
"#
|
||||
.to_string(),
|
||||
),
|
||||
code_path: None,
|
||||
runtime_name: Some("python".to_string()),
|
||||
max_stdout_bytes: 10 * 1024 * 1024,
|
||||
max_stderr_bytes: 10 * 1024 * 1024,
|
||||
};
|
||||
|
||||
let result = runtime.execute(context).await.unwrap();
|
||||
assert!(
|
||||
result.is_success(),
|
||||
"Should handle special characters: {:?}",
|
||||
result.error
|
||||
);
|
||||
|
||||
let result_data = result.result.unwrap();
|
||||
let result_obj = result_data.get("result").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result_obj.get("special_correct").unwrap(),
|
||||
&serde_json::json!(true),
|
||||
"Special characters should be preserved"
|
||||
);
|
||||
assert_eq!(
|
||||
result_obj.get("newline_has_two_parts").unwrap(),
|
||||
&serde_json::json!(true),
|
||||
"Newline should split into two parts"
|
||||
);
|
||||
assert_eq!(
|
||||
result_obj.get("newline_first_part").unwrap(),
|
||||
&serde_json::json!("line1"),
|
||||
"First part should be 'line1'"
|
||||
);
|
||||
assert_eq!(
|
||||
result_obj.get("newline_second_part").unwrap(),
|
||||
&serde_json::json!("line2"),
|
||||
"Second part should be 'line2'"
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user