Compare commits

..

3 Commits

Author SHA1 Message Date
a7ed135af2 more edge case resolution on workflow builder
Some checks failed
CI / Rustfmt (push) Successful in 22s
CI / Cargo Audit & Deny (push) Successful in 32s
CI / Web Blocking Checks (push) Failing after 26s
CI / Security Blocking Checks (push) Successful in 8s
CI / Clippy (push) Failing after 2m0s
CI / Web Advisory Checks (push) Successful in 32s
CI / Security Advisory Checks (push) Successful in 37s
CI / Tests (push) Failing after 7m33s
2026-03-11 09:29:17 -05:00
71ea3f34ca cancelling actions works now 2026-03-10 19:53:20 -05:00
5b45b17fa6 [wip] single runtime handling 2026-03-10 09:30:57 -05:00
68 changed files with 3602 additions and 1284 deletions

View File

@@ -46,6 +46,7 @@ security:
jwt_refresh_expiration: 2592000 # 30 days
encryption_key: test-encryption-key-32-chars-okay
enable_auth: true
allow_self_registration: true
# Packs directory (where pack action files are located)
packs_base_dir: ./packs

View File

@@ -48,6 +48,7 @@ security:
jwt_refresh_expiration: 3600 # 1 hour
encryption_key: test-encryption-key-32-chars-okay
enable_auth: true
allow_self_registration: true
# Test packs directory (use /tmp for tests to avoid permission issues)
packs_base_dir: /tmp/attune-test-packs

149
crates/api/src/authz.rs Normal file
View File

@@ -0,0 +1,149 @@
//! RBAC authorization service for API handlers.
//!
//! This module evaluates grants assigned to user identities via
//! `permission_set` and `permission_assignment`.
use crate::{
auth::{jwt::TokenType, middleware::AuthenticatedUser},
middleware::ApiError,
};
use attune_common::{
rbac::{Action, AuthorizationContext, Grant, Resource},
repositories::{
identity::{IdentityRepository, PermissionSetRepository},
FindById,
},
};
use sqlx::PgPool;
#[derive(Debug, Clone)]
pub struct AuthorizationCheck {
pub resource: Resource,
pub action: Action,
pub context: AuthorizationContext,
}
#[derive(Clone)]
pub struct AuthorizationService {
db: PgPool,
}
impl AuthorizationService {
pub fn new(db: PgPool) -> Self {
Self { db }
}
pub async fn authorize(
&self,
user: &AuthenticatedUser,
mut check: AuthorizationCheck,
) -> Result<(), ApiError> {
// Non-access tokens are governed by dedicated scope checks in route logic.
// They are not evaluated through identity RBAC grants.
if user.claims.token_type != TokenType::Access {
return Ok(());
}
let identity_id = user.identity_id().map_err(|_| {
ApiError::Unauthorized("Invalid authentication subject in access token".to_string())
})?;
// Ensure identity exists and load identity attributes used by attribute constraints.
let identity = IdentityRepository::find_by_id(&self.db, identity_id)
.await?
.ok_or_else(|| ApiError::Unauthorized("Identity not found".to_string()))?;
check.context.identity_id = identity_id;
check.context.identity_attributes = match identity.attributes {
serde_json::Value::Object(map) => map.into_iter().collect(),
_ => Default::default(),
};
let grants = self.load_effective_grants(identity_id).await?;
let allowed = Self::is_allowed(&grants, check.resource, check.action, &check.context);
if !allowed {
return Err(ApiError::Forbidden(format!(
"Insufficient permissions: {}:{}",
resource_name(check.resource),
action_name(check.action)
)));
}
Ok(())
}
pub async fn effective_grants(&self, user: &AuthenticatedUser) -> Result<Vec<Grant>, ApiError> {
if user.claims.token_type != TokenType::Access {
return Ok(Vec::new());
}
let identity_id = user.identity_id().map_err(|_| {
ApiError::Unauthorized("Invalid authentication subject in access token".to_string())
})?;
self.load_effective_grants(identity_id).await
}
pub fn is_allowed(
grants: &[Grant],
resource: Resource,
action: Action,
context: &AuthorizationContext,
) -> bool {
grants.iter().any(|g| g.allows(resource, action, context))
}
async fn load_effective_grants(&self, identity_id: i64) -> Result<Vec<Grant>, ApiError> {
let permission_sets =
PermissionSetRepository::find_by_identity(&self.db, identity_id).await?;
let mut grants = Vec::new();
for permission_set in permission_sets {
let set_grants: Vec<Grant> =
serde_json::from_value(permission_set.grants).map_err(|e| {
ApiError::InternalServerError(format!(
"Invalid grant schema in permission set '{}': {}",
permission_set.r#ref, e
))
})?;
grants.extend(set_grants);
}
Ok(grants)
}
}
fn resource_name(resource: Resource) -> &'static str {
match resource {
Resource::Packs => "packs",
Resource::Actions => "actions",
Resource::Rules => "rules",
Resource::Triggers => "triggers",
Resource::Executions => "executions",
Resource::Events => "events",
Resource::Enforcements => "enforcements",
Resource::Inquiries => "inquiries",
Resource::Keys => "keys",
Resource::Artifacts => "artifacts",
Resource::Workflows => "workflows",
Resource::Webhooks => "webhooks",
Resource::Analytics => "analytics",
Resource::History => "history",
Resource::Identities => "identities",
Resource::Permissions => "permissions",
}
}
fn action_name(action: Action) -> &'static str {
match action {
Action::Read => "read",
Action::Create => "create",
Action::Update => "update",
Action::Delete => "delete",
Action::Execute => "execute",
Action::Cancel => "cancel",
Action::Respond => "respond",
Action::Manage => "manage",
}
}

View File

@@ -52,10 +52,14 @@ pub struct ExecutionResponse {
#[schema(example = 1)]
pub enforcement: Option<i64>,
/// Executor ID (worker/executor that ran this)
/// Identity ID that initiated this execution
#[schema(example = 1)]
pub executor: Option<i64>,
/// Worker ID currently assigned to this execution
#[schema(example = 1)]
pub worker: Option<i64>,
/// Execution status
#[schema(example = "succeeded")]
pub status: ExecutionStatus,
@@ -216,6 +220,7 @@ impl From<attune_common::models::execution::Execution> for ExecutionResponse {
parent: execution.parent,
enforcement: execution.enforcement,
executor: execution.executor,
worker: execution.worker,
status: execution.status,
result: execution
.result

View File

@@ -11,6 +11,7 @@ pub mod history;
pub mod inquiry;
pub mod key;
pub mod pack;
pub mod permission;
pub mod rule;
pub mod trigger;
pub mod webhook;
@@ -48,6 +49,11 @@ pub use inquiry::{
};
pub use key::{CreateKeyRequest, KeyQueryParams, KeyResponse, KeySummary, UpdateKeyRequest};
pub use pack::{CreatePackRequest, PackResponse, PackSummary, UpdatePackRequest};
pub use permission::{
CreateIdentityRequest, CreatePermissionAssignmentRequest, IdentityResponse, IdentitySummary,
PermissionAssignmentResponse, PermissionSetQueryParams, PermissionSetSummary,
UpdateIdentityRequest,
};
pub use rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest};
pub use trigger::{
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary, TriggerResponse,

View File

@@ -0,0 +1,65 @@
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use utoipa::{IntoParams, ToSchema};
use validator::Validate;
#[derive(Debug, Clone, Deserialize, IntoParams)]
pub struct PermissionSetQueryParams {
#[serde(default)]
pub pack_ref: Option<String>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct IdentitySummary {
pub id: i64,
pub login: String,
pub display_name: Option<String>,
pub attributes: JsonValue,
}
pub type IdentityResponse = IdentitySummary;
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct PermissionSetSummary {
pub id: i64,
pub r#ref: String,
pub pack_ref: Option<String>,
pub label: Option<String>,
pub description: Option<String>,
pub grants: JsonValue,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct PermissionAssignmentResponse {
pub id: i64,
pub identity_id: i64,
pub permission_set_id: i64,
pub permission_set_ref: String,
pub created: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct CreatePermissionAssignmentRequest {
pub identity_id: Option<i64>,
pub identity_login: Option<String>,
pub permission_set_ref: String,
}
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
pub struct CreateIdentityRequest {
#[validate(length(min = 3, max = 255))]
pub login: String,
#[validate(length(max = 255))]
pub display_name: Option<String>,
#[validate(length(min = 8, max = 128))]
pub password: Option<String>,
#[serde(default)]
pub attributes: JsonValue,
}
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct UpdateIdentityRequest {
pub display_name: Option<String>,
pub password: Option<String>,
pub attributes: Option<JsonValue>,
}

View File

@@ -5,6 +5,7 @@
//! It is primarily used by the binary target and integration tests.
pub mod auth;
pub mod authz;
pub mod dto;
pub mod middleware;
pub mod openapi;

View File

@@ -26,6 +26,10 @@ use crate::dto::{
PackWorkflowSyncResponse, PackWorkflowValidationResponse, RegisterPackRequest,
UpdatePackRequest, WorkflowSyncResult,
},
permission::{
CreateIdentityRequest, CreatePermissionAssignmentRequest, IdentityResponse,
IdentitySummary, PermissionAssignmentResponse, PermissionSetSummary, UpdateIdentityRequest,
},
rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest},
trigger::{
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary, TriggerResponse,
@@ -160,6 +164,17 @@ use crate::dto::{
crate::routes::keys::update_key,
crate::routes::keys::delete_key,
// Permissions
crate::routes::permissions::list_identities,
crate::routes::permissions::get_identity,
crate::routes::permissions::create_identity,
crate::routes::permissions::update_identity,
crate::routes::permissions::delete_identity,
crate::routes::permissions::list_permission_sets,
crate::routes::permissions::list_identity_permissions,
crate::routes::permissions::create_permission_assignment,
crate::routes::permissions::delete_permission_assignment,
// Workflows
crate::routes::workflows::list_workflows,
crate::routes::workflows::list_workflows_by_pack,
@@ -190,6 +205,8 @@ use crate::dto::{
ApiResponse<EnforcementResponse>,
ApiResponse<InquiryResponse>,
ApiResponse<KeyResponse>,
ApiResponse<IdentityResponse>,
ApiResponse<PermissionAssignmentResponse>,
ApiResponse<WorkflowResponse>,
ApiResponse<QueueStatsResponse>,
PaginatedResponse<PackSummary>,
@@ -202,6 +219,7 @@ use crate::dto::{
PaginatedResponse<EnforcementSummary>,
PaginatedResponse<InquirySummary>,
PaginatedResponse<KeySummary>,
PaginatedResponse<IdentitySummary>,
PaginatedResponse<WorkflowSummary>,
PaginationMeta,
SuccessResponse,
@@ -233,6 +251,15 @@ use crate::dto::{
attune_common::models::pack_test::PackTestSummary,
PaginatedResponse<attune_common::models::pack_test::PackTestSummary>,
// Permission DTOs
CreateIdentityRequest,
UpdateIdentityRequest,
IdentityResponse,
PermissionSetSummary,
PermissionAssignmentResponse,
CreatePermissionAssignmentRequest,
IdentitySummary,
// Action DTOs
CreateActionRequest,
UpdateActionRequest,

View File

@@ -10,6 +10,7 @@ use axum::{
use std::sync::Arc;
use validator::Validate;
use attune_common::rbac::{Action, AuthorizationContext, Resource};
use attune_common::repositories::{
action::{ActionRepository, ActionSearchFilters, CreateActionInput, UpdateActionInput},
pack::PackRepository,
@@ -19,6 +20,7 @@ use attune_common::repositories::{
use crate::{
auth::middleware::RequireAuth,
authz::{AuthorizationCheck, AuthorizationService},
dto::{
action::{
ActionResponse, ActionSummary, CreateActionRequest, QueueStatsResponse,
@@ -153,7 +155,7 @@ pub async fn get_action(
)]
pub async fn create_action(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Json(request): Json<CreateActionRequest>,
) -> ApiResult<impl IntoResponse> {
// Validate request
@@ -175,6 +177,26 @@ pub async fn create_action(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", request.pack_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.pack_ref = Some(pack.r#ref.clone());
ctx.target_ref = Some(request.r#ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Actions,
action: Action::Create,
context: ctx,
},
)
.await?;
}
// If runtime is specified, we could verify it exists (future enhancement)
// For now, the database foreign key constraint will handle invalid runtime IDs
@@ -219,7 +241,7 @@ pub async fn create_action(
)]
pub async fn update_action(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Path(action_ref): Path<String>,
Json(request): Json<UpdateActionRequest>,
) -> ApiResult<impl IntoResponse> {
@@ -231,6 +253,27 @@ pub async fn update_action(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(existing_action.id);
ctx.target_ref = Some(existing_action.r#ref.clone());
ctx.pack_ref = Some(existing_action.pack_ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Actions,
action: Action::Update,
context: ctx,
},
)
.await?;
}
// Create update input
let update_input = UpdateActionInput {
label: request.label,
@@ -269,7 +312,7 @@ pub async fn update_action(
)]
pub async fn delete_action(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Path(action_ref): Path<String>,
) -> ApiResult<impl IntoResponse> {
// Check if action exists
@@ -277,6 +320,27 @@ pub async fn delete_action(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", action_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(action.id);
ctx.target_ref = Some(action.r#ref.clone());
ctx.pack_ref = Some(action.pack_ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Actions,
action: Action::Delete,
context: ctx,
},
)
.await?;
}
// Delete the action
let deleted = ActionRepository::delete(&state.db, action.id).await?;

View File

@@ -152,6 +152,12 @@ pub async fn register(
State(state): State<SharedState>,
Json(payload): Json<RegisterRequest>,
) -> Result<Json<ApiResponse<TokenResponse>>, ApiError> {
if !state.config.security.allow_self_registration {
return Err(ApiError::Forbidden(
"Self-service registration is disabled; identities must be provisioned by an administrator or identity provider".to_string(),
));
}
// Validate request
payload
.validate()
@@ -171,7 +177,7 @@ pub async fn register(
// Hash password
let password_hash = hash_password(&payload.password)?;
// Create identity with password hash
// Registration creates an identity only; permission assignments are managed separately.
let input = CreateIdentityInput {
login: payload.login.clone(),
display_name: payload.display_name,

View File

@@ -10,6 +10,7 @@ use axum::{
routing::get,
Json, Router,
};
use chrono::Utc;
use futures::stream::{Stream, StreamExt};
use std::sync::Arc;
use tokio_stream::wrappers::BroadcastStream;
@@ -32,6 +33,7 @@ use sqlx::Row;
use crate::{
auth::middleware::RequireAuth,
authz::{AuthorizationCheck, AuthorizationService},
dto::{
common::{PaginatedResponse, PaginationParams},
execution::{
@@ -42,6 +44,7 @@ use crate::{
middleware::{ApiError, ApiResult},
state::AppState,
};
use attune_common::rbac::{Action, AuthorizationContext, Resource};
/// Create a new execution (manual execution)
///
@@ -61,7 +64,7 @@ use crate::{
)]
pub async fn create_execution(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Json(request): Json<CreateExecutionRequest>,
) -> ApiResult<impl IntoResponse> {
// Validate that the action exists
@@ -69,6 +72,42 @@ pub async fn create_execution(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Action '{}' not found", request.action_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut action_ctx = AuthorizationContext::new(identity_id);
action_ctx.target_id = Some(action.id);
action_ctx.target_ref = Some(action.r#ref.clone());
action_ctx.pack_ref = Some(action.pack_ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Actions,
action: Action::Execute,
context: action_ctx,
},
)
.await?;
let mut execution_ctx = AuthorizationContext::new(identity_id);
execution_ctx.pack_ref = Some(action.pack_ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Executions,
action: Action::Create,
context: execution_ctx,
},
)
.await?;
}
// Create execution input
let execution_input = CreateExecutionInput {
action: Some(action.id),
@@ -84,6 +123,7 @@ pub async fn create_execution(
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None, // Non-workflow execution
@@ -440,9 +480,17 @@ pub async fn cancel_execution(
..Default::default()
};
let updated = ExecutionRepository::update(&state.db, id, update).await?;
let delegated_to_executor = publish_status_change_to_executor(
publisher.as_deref(),
&execution,
ExecutionStatus::Cancelled,
"api-service",
)
.await;
// Cascade to workflow children if this is a workflow execution
cancel_workflow_children(&state.db, publisher.as_deref(), id).await;
if !delegated_to_executor {
cancel_workflow_children(&state.db, publisher.as_deref(), id).await;
}
let response = ApiResponse::new(ExecutionResponse::from(updated));
return Ok((StatusCode::OK, Json(response)));
@@ -454,19 +502,27 @@ pub async fn cancel_execution(
..Default::default()
};
let updated = ExecutionRepository::update(&state.db, id, update).await?;
let delegated_to_executor = publish_status_change_to_executor(
publisher.as_deref(),
&execution,
ExecutionStatus::Canceling,
"api-service",
)
.await;
// Send cancel request to the worker via MQ
if let Some(worker_id) = execution.executor {
if let Some(worker_id) = execution.worker {
send_cancel_to_worker(publisher.as_deref(), id, worker_id).await;
} else {
tracing::warn!(
"Execution {} has no executor/worker assigned; marked as canceling but no MQ message sent",
"Execution {} has no worker assigned; marked as canceling but no MQ message sent",
id
);
}
// Cascade to workflow children if this is a workflow execution
cancel_workflow_children(&state.db, publisher.as_deref(), id).await;
if !delegated_to_executor {
cancel_workflow_children(&state.db, publisher.as_deref(), id).await;
}
let response = ApiResponse::new(ExecutionResponse::from(updated));
Ok((StatusCode::OK, Json(response)))
@@ -504,6 +560,53 @@ async fn send_cancel_to_worker(publisher: Option<&Publisher>, execution_id: i64,
}
}
async fn publish_status_change_to_executor(
publisher: Option<&Publisher>,
execution: &attune_common::models::Execution,
new_status: ExecutionStatus,
source: &str,
) -> bool {
let Some(publisher) = publisher else {
return false;
};
let new_status = match new_status {
ExecutionStatus::Requested => "requested",
ExecutionStatus::Scheduling => "scheduling",
ExecutionStatus::Scheduled => "scheduled",
ExecutionStatus::Running => "running",
ExecutionStatus::Completed => "completed",
ExecutionStatus::Failed => "failed",
ExecutionStatus::Canceling => "canceling",
ExecutionStatus::Cancelled => "cancelled",
ExecutionStatus::Timeout => "timeout",
ExecutionStatus::Abandoned => "abandoned",
};
let payload = attune_common::mq::ExecutionStatusChangedPayload {
execution_id: execution.id,
action_ref: execution.action_ref.clone(),
previous_status: format!("{:?}", execution.status).to_lowercase(),
new_status: new_status.to_string(),
changed_at: Utc::now(),
};
let envelope = MessageEnvelope::new(MessageType::ExecutionStatusChanged, payload)
.with_source(source)
.with_correlation_id(uuid::Uuid::new_v4());
if let Err(e) = publisher.publish_envelope(&envelope).await {
tracing::error!(
"Failed to publish status change for execution {} to executor: {}",
execution.id,
e
);
return false;
}
true
}
/// Resolve the [`CancellationPolicy`] for a workflow parent execution.
///
/// Looks up the `workflow_execution` → `workflow_definition` chain and
@@ -652,7 +755,7 @@ async fn cancel_workflow_children_with_policy(
}
}
if let Some(worker_id) = child.executor {
if let Some(worker_id) = child.worker {
send_cancel_to_worker(publisher, child_id, worker_id).await;
}
}

View File

@@ -10,7 +10,6 @@ use axum::{
use std::sync::Arc;
use validator::Validate;
use attune_common::models::OwnerType;
use attune_common::repositories::{
action::ActionRepository,
key::{CreateKeyInput, KeyRepository, KeySearchFilters, UpdateKeyInput},
@@ -18,9 +17,14 @@ use attune_common::repositories::{
trigger::SensorRepository,
Create, Delete, FindByRef, Update,
};
use attune_common::{
models::{key::Key, OwnerType},
rbac::{Action, AuthorizationContext, Resource},
};
use crate::auth::RequireAuth;
use crate::auth::{jwt::TokenType, RequireAuth};
use crate::{
authz::{AuthorizationCheck, AuthorizationService},
dto::{
common::{PaginatedResponse, PaginationParams},
key::{CreateKeyRequest, KeyQueryParams, KeyResponse, KeySummary, UpdateKeyRequest},
@@ -42,7 +46,7 @@ use crate::{
security(("bearer_auth" = []))
)]
pub async fn list_keys(
_user: RequireAuth,
user: RequireAuth,
State(state): State<Arc<AppState>>,
Query(query): Query<KeyQueryParams>,
) -> ApiResult<impl IntoResponse> {
@@ -55,8 +59,33 @@ pub async fn list_keys(
};
let result = KeyRepository::search(&state.db, &filters).await?;
let mut rows = result.rows;
let paginated_keys: Vec<KeySummary> = result.rows.into_iter().map(KeySummary::from).collect();
if user.0.claims.token_type == TokenType::Access {
let identity_id = user
.0
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let grants = authz.effective_grants(&user.0).await?;
// Ensure the principal can read at least some key records.
let can_read_any_key = grants
.iter()
.any(|g| g.resource == Resource::Keys && g.actions.contains(&Action::Read));
if !can_read_any_key {
return Err(ApiError::Forbidden(
"Insufficient permissions: keys:read".to_string(),
));
}
rows.retain(|key| {
let ctx = key_authorization_context(identity_id, key);
AuthorizationService::is_allowed(&grants, Resource::Keys, Action::Read, &ctx)
});
}
let paginated_keys: Vec<KeySummary> = rows.into_iter().map(KeySummary::from).collect();
let pagination_params = PaginationParams {
page: query.page,
@@ -83,7 +112,7 @@ pub async fn list_keys(
security(("bearer_auth" = []))
)]
pub async fn get_key(
_user: RequireAuth,
user: RequireAuth,
State(state): State<Arc<AppState>>,
Path(key_ref): Path<String>,
) -> ApiResult<impl IntoResponse> {
@@ -91,6 +120,26 @@ pub async fn get_key(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
if user.0.claims.token_type == TokenType::Access {
let identity_id = user
.0
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
&user.0,
AuthorizationCheck {
resource: Resource::Keys,
action: Action::Read,
context: key_authorization_context(identity_id, &key),
},
)
.await
// Hide unauthorized records behind 404 to reduce enumeration leakage.
.map_err(|_| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
}
// Decrypt value if encrypted
if key.encrypted {
let encryption_key = state
@@ -130,13 +179,37 @@ pub async fn get_key(
security(("bearer_auth" = []))
)]
pub async fn create_key(
_user: RequireAuth,
user: RequireAuth,
State(state): State<Arc<AppState>>,
Json(request): Json<CreateKeyRequest>,
) -> ApiResult<impl IntoResponse> {
// Validate request
request.validate()?;
if user.0.claims.token_type == TokenType::Access {
let identity_id = user
.0
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.owner_identity_id = request.owner_identity;
ctx.owner_type = Some(request.owner_type);
ctx.encrypted = Some(request.encrypted);
ctx.target_ref = Some(request.r#ref.clone());
authz
.authorize(
&user.0,
AuthorizationCheck {
resource: Resource::Keys,
action: Action::Create,
context: ctx,
},
)
.await?;
}
// Check if key with same ref already exists
if KeyRepository::find_by_ref(&state.db, &request.r#ref)
.await?
@@ -299,7 +372,7 @@ pub async fn create_key(
security(("bearer_auth" = []))
)]
pub async fn update_key(
_user: RequireAuth,
user: RequireAuth,
State(state): State<Arc<AppState>>,
Path(key_ref): Path<String>,
Json(request): Json<UpdateKeyRequest>,
@@ -312,6 +385,24 @@ pub async fn update_key(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
if user.0.claims.token_type == TokenType::Access {
let identity_id = user
.0
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
&user.0,
AuthorizationCheck {
resource: Resource::Keys,
action: Action::Update,
context: key_authorization_context(identity_id, &existing),
},
)
.await?;
}
// Handle value update with encryption
let (value, encrypted, encryption_key_hash) = if let Some(new_value) = request.value {
let should_encrypt = request.encrypted.unwrap_or(existing.encrypted);
@@ -395,7 +486,7 @@ pub async fn update_key(
security(("bearer_auth" = []))
)]
pub async fn delete_key(
_user: RequireAuth,
user: RequireAuth,
State(state): State<Arc<AppState>>,
Path(key_ref): Path<String>,
) -> ApiResult<impl IntoResponse> {
@@ -404,6 +495,24 @@ pub async fn delete_key(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
if user.0.claims.token_type == TokenType::Access {
let identity_id = user
.0
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
&user.0,
AuthorizationCheck {
resource: Resource::Keys,
action: Action::Delete,
context: key_authorization_context(identity_id, &key),
},
)
.await?;
}
// Delete the key
let deleted = KeyRepository::delete(&state.db, key.id).await?;
@@ -425,3 +534,13 @@ pub fn routes() -> Router<Arc<AppState>> {
get(get_key).put(update_key).delete(delete_key),
)
}
fn key_authorization_context(identity_id: i64, key: &Key) -> AuthorizationContext {
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(key.id);
ctx.target_ref = Some(key.r#ref.clone());
ctx.owner_identity_id = key.owner_identity;
ctx.owner_type = Some(key.owner_type);
ctx.encrypted = Some(key.encrypted);
ctx
}

View File

@@ -11,6 +11,7 @@ pub mod history;
pub mod inquiries;
pub mod keys;
pub mod packs;
pub mod permissions;
pub mod rules;
pub mod triggers;
pub mod webhooks;
@@ -27,6 +28,7 @@ pub use history::routes as history_routes;
pub use inquiries::routes as inquiry_routes;
pub use keys::routes as key_routes;
pub use packs::routes as pack_routes;
pub use permissions::routes as permission_routes;
pub use rules::routes as rule_routes;
pub use triggers::routes as trigger_routes;
pub use webhooks::routes as webhook_routes;

View File

@@ -13,6 +13,7 @@ use validator::Validate;
use attune_common::models::pack_test::PackTestResult;
use attune_common::mq::{MessageEnvelope, MessageType, PackRegisteredPayload};
use attune_common::rbac::{Action, AuthorizationContext, Resource};
use attune_common::repositories::{
pack::{CreatePackInput, UpdatePackInput},
Create, Delete, FindById, FindByRef, PackRepository, PackTestRepository, Pagination, Update,
@@ -21,6 +22,7 @@ use attune_common::workflow::{PackWorkflowService, PackWorkflowServiceConfig};
use crate::{
auth::middleware::RequireAuth,
authz::{AuthorizationCheck, AuthorizationService},
dto::{
common::{PaginatedResponse, PaginationParams},
pack::{
@@ -115,7 +117,7 @@ pub async fn get_pack(
)]
pub async fn create_pack(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Json(request): Json<CreatePackRequest>,
) -> ApiResult<impl IntoResponse> {
// Validate request
@@ -129,6 +131,25 @@ pub async fn create_pack(
)));
}
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_ref = Some(request.r#ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Packs,
action: Action::Create,
context: ctx,
},
)
.await?;
}
// Create pack input
let pack_input = CreatePackInput {
r#ref: request.r#ref,
@@ -202,7 +223,7 @@ pub async fn create_pack(
)]
pub async fn update_pack(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Path(pack_ref): Path<String>,
Json(request): Json<UpdatePackRequest>,
) -> ApiResult<impl IntoResponse> {
@@ -214,6 +235,26 @@ pub async fn update_pack(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(existing_pack.id);
ctx.target_ref = Some(existing_pack.r#ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Packs,
action: Action::Update,
context: ctx,
},
)
.await?;
}
// Create update input
let update_input = UpdatePackInput {
label: request.label,
@@ -284,7 +325,7 @@ pub async fn update_pack(
)]
pub async fn delete_pack(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Path(pack_ref): Path<String>,
) -> ApiResult<impl IntoResponse> {
// Check if pack exists
@@ -292,6 +333,26 @@ pub async fn delete_pack(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(pack.id);
ctx.target_ref = Some(pack.r#ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Packs,
action: Action::Delete,
context: ctx,
},
)
.await?;
}
// Delete the pack from the database (cascades to actions, triggers, sensors, rules, etc.
// Foreign keys on execution, event, enforcement, and rule tables use ON DELETE SET NULL
// so historical records are preserved with their text ref fields intact.)
@@ -475,6 +536,23 @@ pub async fn upload_pack(
const MAX_PACK_SIZE: usize = 100 * 1024 * 1024; // 100 MB
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Packs,
action: Action::Create,
context: AuthorizationContext::new(identity_id),
},
)
.await?;
}
let mut pack_bytes: Option<Vec<u8>> = None;
let mut force = false;
let mut skip_tests = false;
@@ -649,6 +727,23 @@ pub async fn register_pack(
// Validate request
request.validate()?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Packs,
action: Action::Create,
context: AuthorizationContext::new(identity_id),
},
)
.await?;
}
// Call internal registration logic
let pack_id = register_pack_internal(
state.clone(),
@@ -1207,6 +1302,23 @@ pub async fn install_pack(
tracing::info!("Installing pack from source: {}", request.source);
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Packs,
action: Action::Create,
context: AuthorizationContext::new(identity_id),
},
)
.await?;
}
// Get user ID early to avoid borrow issues
let user_id = user.identity_id().ok();
let user_sub = user.claims.sub.clone();
@@ -2247,6 +2359,23 @@ pub async fn register_packs_batch(
RequireAuth(user): RequireAuth,
Json(request): Json<RegisterPacksRequest>,
) -> ApiResult<Json<ApiResponse<RegisterPacksResponse>>> {
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Packs,
action: Action::Create,
context: AuthorizationContext::new(identity_id),
},
)
.await?;
}
let start = std::time::Instant::now();
let mut registered = Vec::new();
let mut failed = Vec::new();

View File

@@ -0,0 +1,507 @@
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::IntoResponse,
routing::{delete, get, post},
Json, Router,
};
use std::sync::Arc;
use validator::Validate;
use attune_common::{
models::identity::{Identity, PermissionSet},
rbac::{Action, AuthorizationContext, Resource},
repositories::{
identity::{
CreateIdentityInput, CreatePermissionAssignmentInput, IdentityRepository,
PermissionAssignmentRepository, PermissionSetRepository, UpdateIdentityInput,
},
Create, Delete, FindById, FindByRef, List, Update,
},
};
use crate::{
auth::hash_password,
auth::middleware::RequireAuth,
authz::{AuthorizationCheck, AuthorizationService},
dto::{
common::{PaginatedResponse, PaginationParams},
ApiResponse, CreateIdentityRequest, CreatePermissionAssignmentRequest, IdentityResponse,
IdentitySummary, PermissionAssignmentResponse, PermissionSetQueryParams,
PermissionSetSummary, SuccessResponse, UpdateIdentityRequest,
},
middleware::{ApiError, ApiResult},
state::AppState,
};
#[utoipa::path(
get,
path = "/api/v1/identities",
tag = "permissions",
params(PaginationParams),
responses(
(status = 200, description = "List identities", body = PaginatedResponse<IdentitySummary>)
),
security(("bearer_auth" = []))
)]
pub async fn list_identities(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Query(query): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Identities, Action::Read).await?;
let identities = IdentityRepository::list(&state.db).await?;
let total = identities.len() as u64;
let start = query.offset() as usize;
let end = (start + query.limit() as usize).min(identities.len());
let page_items = if start >= identities.len() {
Vec::new()
} else {
identities[start..end]
.iter()
.cloned()
.map(IdentitySummary::from)
.collect()
};
Ok((
StatusCode::OK,
Json(PaginatedResponse::new(page_items, &query, total)),
))
}
#[utoipa::path(
get,
path = "/api/v1/identities/{id}",
tag = "permissions",
params(
("id" = i64, Path, description = "Identity ID")
),
responses(
(status = 200, description = "Identity details", body = inline(ApiResponse<IdentityResponse>)),
(status = 404, description = "Identity not found")
),
security(("bearer_auth" = []))
)]
pub async fn get_identity(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Path(identity_id): Path<i64>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Identities, Action::Read).await?;
let identity = IdentityRepository::find_by_id(&state.db, identity_id)
.await?
.ok_or_else(|| ApiError::NotFound(format!("Identity '{}' not found", identity_id)))?;
Ok((
StatusCode::OK,
Json(ApiResponse::new(IdentityResponse::from(identity))),
))
}
#[utoipa::path(
post,
path = "/api/v1/identities",
tag = "permissions",
request_body = CreateIdentityRequest,
responses(
(status = 201, description = "Identity created", body = inline(ApiResponse<IdentityResponse>)),
(status = 409, description = "Identity already exists")
),
security(("bearer_auth" = []))
)]
pub async fn create_identity(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Json(request): Json<CreateIdentityRequest>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Identities, Action::Create).await?;
request.validate()?;
let password_hash = match request.password {
Some(password) => Some(hash_password(&password)?),
None => None,
};
let identity = IdentityRepository::create(
&state.db,
CreateIdentityInput {
login: request.login,
display_name: request.display_name,
password_hash,
attributes: request.attributes,
},
)
.await?;
Ok((
StatusCode::CREATED,
Json(ApiResponse::new(IdentityResponse::from(identity))),
))
}
#[utoipa::path(
put,
path = "/api/v1/identities/{id}",
tag = "permissions",
params(
("id" = i64, Path, description = "Identity ID")
),
request_body = UpdateIdentityRequest,
responses(
(status = 200, description = "Identity updated", body = inline(ApiResponse<IdentityResponse>)),
(status = 404, description = "Identity not found")
),
security(("bearer_auth" = []))
)]
pub async fn update_identity(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Path(identity_id): Path<i64>,
Json(request): Json<UpdateIdentityRequest>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Identities, Action::Update).await?;
IdentityRepository::find_by_id(&state.db, identity_id)
.await?
.ok_or_else(|| ApiError::NotFound(format!("Identity '{}' not found", identity_id)))?;
let password_hash = match request.password {
Some(password) => Some(hash_password(&password)?),
None => None,
};
let identity = IdentityRepository::update(
&state.db,
identity_id,
UpdateIdentityInput {
display_name: request.display_name,
password_hash,
attributes: request.attributes,
},
)
.await?;
Ok((
StatusCode::OK,
Json(ApiResponse::new(IdentityResponse::from(identity))),
))
}
#[utoipa::path(
delete,
path = "/api/v1/identities/{id}",
tag = "permissions",
params(
("id" = i64, Path, description = "Identity ID")
),
responses(
(status = 200, description = "Identity deleted", body = inline(ApiResponse<SuccessResponse>)),
(status = 404, description = "Identity not found")
),
security(("bearer_auth" = []))
)]
pub async fn delete_identity(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Path(identity_id): Path<i64>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Identities, Action::Delete).await?;
let caller_identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
if caller_identity_id == identity_id {
return Err(ApiError::BadRequest(
"Refusing to delete the currently authenticated identity".to_string(),
));
}
let deleted = IdentityRepository::delete(&state.db, identity_id).await?;
if !deleted {
return Err(ApiError::NotFound(format!(
"Identity '{}' not found",
identity_id
)));
}
Ok((
StatusCode::OK,
Json(ApiResponse::new(SuccessResponse::new(
"Identity deleted successfully",
))),
))
}
#[utoipa::path(
get,
path = "/api/v1/permissions/sets",
tag = "permissions",
params(PermissionSetQueryParams),
responses(
(status = 200, description = "List permission sets", body = Vec<PermissionSetSummary>)
),
security(("bearer_auth" = []))
)]
pub async fn list_permission_sets(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Query(query): Query<PermissionSetQueryParams>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Permissions, Action::Read).await?;
let mut permission_sets = PermissionSetRepository::list(&state.db).await?;
if let Some(pack_ref) = &query.pack_ref {
permission_sets.retain(|ps| ps.pack_ref.as_deref() == Some(pack_ref.as_str()));
}
let response: Vec<PermissionSetSummary> = permission_sets
.into_iter()
.map(PermissionSetSummary::from)
.collect();
Ok((StatusCode::OK, Json(response)))
}
#[utoipa::path(
get,
path = "/api/v1/identities/{id}/permissions",
tag = "permissions",
params(
("id" = i64, Path, description = "Identity ID")
),
responses(
(status = 200, description = "List permission assignments for an identity", body = Vec<PermissionAssignmentResponse>),
(status = 404, description = "Identity not found")
),
security(("bearer_auth" = []))
)]
pub async fn list_identity_permissions(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Path(identity_id): Path<i64>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Permissions, Action::Read).await?;
IdentityRepository::find_by_id(&state.db, identity_id)
.await?
.ok_or_else(|| ApiError::NotFound(format!("Identity '{}' not found", identity_id)))?;
let assignments =
PermissionAssignmentRepository::find_by_identity(&state.db, identity_id).await?;
let permission_sets = PermissionSetRepository::find_by_identity(&state.db, identity_id).await?;
let permission_set_refs = permission_sets
.into_iter()
.map(|ps| (ps.id, ps.r#ref))
.collect::<std::collections::HashMap<_, _>>();
let response: Vec<PermissionAssignmentResponse> = assignments
.into_iter()
.filter_map(|assignment| {
permission_set_refs
.get(&assignment.permset)
.cloned()
.map(|permission_set_ref| PermissionAssignmentResponse {
id: assignment.id,
identity_id: assignment.identity,
permission_set_id: assignment.permset,
permission_set_ref,
created: assignment.created,
})
})
.collect();
Ok((StatusCode::OK, Json(response)))
}
#[utoipa::path(
post,
path = "/api/v1/permissions/assignments",
tag = "permissions",
request_body = CreatePermissionAssignmentRequest,
responses(
(status = 201, description = "Permission assignment created", body = inline(ApiResponse<PermissionAssignmentResponse>)),
(status = 404, description = "Identity or permission set not found"),
(status = 409, description = "Assignment already exists")
),
security(("bearer_auth" = []))
)]
pub async fn create_permission_assignment(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Json(request): Json<CreatePermissionAssignmentRequest>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Permissions, Action::Manage).await?;
let identity = resolve_identity(&state, &request).await?;
let permission_set =
PermissionSetRepository::find_by_ref(&state.db, &request.permission_set_ref)
.await?
.ok_or_else(|| {
ApiError::NotFound(format!(
"Permission set '{}' not found",
request.permission_set_ref
))
})?;
let assignment = PermissionAssignmentRepository::create(
&state.db,
CreatePermissionAssignmentInput {
identity: identity.id,
permset: permission_set.id,
},
)
.await?;
let response = PermissionAssignmentResponse {
id: assignment.id,
identity_id: assignment.identity,
permission_set_id: assignment.permset,
permission_set_ref: permission_set.r#ref,
created: assignment.created,
};
Ok((StatusCode::CREATED, Json(ApiResponse::new(response))))
}
#[utoipa::path(
delete,
path = "/api/v1/permissions/assignments/{id}",
tag = "permissions",
params(
("id" = i64, Path, description = "Permission assignment ID")
),
responses(
(status = 200, description = "Permission assignment deleted", body = inline(ApiResponse<SuccessResponse>)),
(status = 404, description = "Assignment not found")
),
security(("bearer_auth" = []))
)]
pub async fn delete_permission_assignment(
State(state): State<Arc<AppState>>,
RequireAuth(user): RequireAuth,
Path(assignment_id): Path<i64>,
) -> ApiResult<impl IntoResponse> {
authorize_permissions(&state, &user, Resource::Permissions, Action::Manage).await?;
let existing = PermissionAssignmentRepository::find_by_id(&state.db, assignment_id)
.await?
.ok_or_else(|| {
ApiError::NotFound(format!(
"Permission assignment '{}' not found",
assignment_id
))
})?;
let deleted = PermissionAssignmentRepository::delete(&state.db, existing.id).await?;
if !deleted {
return Err(ApiError::NotFound(format!(
"Permission assignment '{}' not found",
assignment_id
)));
}
Ok((
StatusCode::OK,
Json(ApiResponse::new(SuccessResponse::new(
"Permission assignment deleted successfully",
))),
))
}
pub fn routes() -> Router<Arc<AppState>> {
Router::new()
.route("/identities", get(list_identities).post(create_identity))
.route(
"/identities/{id}",
get(get_identity)
.put(update_identity)
.delete(delete_identity),
)
.route(
"/identities/{id}/permissions",
get(list_identity_permissions),
)
.route("/permissions/sets", get(list_permission_sets))
.route(
"/permissions/assignments",
post(create_permission_assignment),
)
.route(
"/permissions/assignments/{id}",
delete(delete_permission_assignment),
)
}
async fn authorize_permissions(
state: &Arc<AppState>,
user: &crate::auth::middleware::AuthenticatedUser,
resource: Resource,
action: Action,
) -> ApiResult<()> {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
authz
.authorize(
user,
AuthorizationCheck {
resource,
action,
context: AuthorizationContext::new(identity_id),
},
)
.await
}
async fn resolve_identity(
state: &Arc<AppState>,
request: &CreatePermissionAssignmentRequest,
) -> ApiResult<Identity> {
match (request.identity_id, request.identity_login.as_deref()) {
(Some(identity_id), None) => IdentityRepository::find_by_id(&state.db, identity_id)
.await?
.ok_or_else(|| ApiError::NotFound(format!("Identity '{}' not found", identity_id))),
(None, Some(identity_login)) => {
IdentityRepository::find_by_login(&state.db, identity_login)
.await?
.ok_or_else(|| {
ApiError::NotFound(format!("Identity '{}' not found", identity_login))
})
}
(Some(_), Some(_)) => Err(ApiError::BadRequest(
"Provide either identity_id or identity_login, not both".to_string(),
)),
(None, None) => Err(ApiError::BadRequest(
"Either identity_id or identity_login is required".to_string(),
)),
}
}
impl From<Identity> for IdentitySummary {
fn from(value: Identity) -> Self {
Self {
id: value.id,
login: value.login,
display_name: value.display_name,
attributes: value.attributes,
}
}
}
impl From<PermissionSet> for PermissionSetSummary {
fn from(value: PermissionSet) -> Self {
Self {
id: value.id,
r#ref: value.r#ref,
pack_ref: value.pack_ref,
label: value.label,
description: value.description,
grants: value.grants,
}
}
}

View File

@@ -14,6 +14,7 @@ use validator::Validate;
use attune_common::mq::{
MessageEnvelope, MessageType, RuleCreatedPayload, RuleDisabledPayload, RuleEnabledPayload,
};
use attune_common::rbac::{Action, AuthorizationContext, Resource};
use attune_common::repositories::{
action::ActionRepository,
pack::PackRepository,
@@ -24,6 +25,7 @@ use attune_common::repositories::{
use crate::{
auth::middleware::RequireAuth,
authz::{AuthorizationCheck, AuthorizationService},
dto::{
common::{PaginatedResponse, PaginationParams},
rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest},
@@ -283,7 +285,7 @@ pub async fn get_rule(
)]
pub async fn create_rule(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Json(request): Json<CreateRuleRequest>,
) -> ApiResult<impl IntoResponse> {
// Validate request
@@ -317,6 +319,26 @@ pub async fn create_rule(
ApiError::NotFound(format!("Trigger '{}' not found", request.trigger_ref))
})?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.pack_ref = Some(pack.r#ref.clone());
ctx.target_ref = Some(request.r#ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Rules,
action: Action::Create,
context: ctx,
},
)
.await?;
}
// Validate trigger parameters against schema
validate_trigger_params(&trigger, &request.trigger_params)?;
@@ -392,7 +414,7 @@ pub async fn create_rule(
)]
pub async fn update_rule(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Path(rule_ref): Path<String>,
Json(request): Json<UpdateRuleRequest>,
) -> ApiResult<impl IntoResponse> {
@@ -404,6 +426,27 @@ pub async fn update_rule(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Rule '{}' not found", rule_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(existing_rule.id);
ctx.target_ref = Some(existing_rule.r#ref.clone());
ctx.pack_ref = Some(existing_rule.pack_ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Rules,
action: Action::Update,
context: ctx,
},
)
.await?;
}
// If action parameters are being updated, validate against the action's schema
if let Some(ref action_params) = request.action_params {
let action = ActionRepository::find_by_ref(&state.db, &existing_rule.action_ref)
@@ -489,7 +532,7 @@ pub async fn update_rule(
)]
pub async fn delete_rule(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
RequireAuth(user): RequireAuth,
Path(rule_ref): Path<String>,
) -> ApiResult<impl IntoResponse> {
// Check if rule exists
@@ -497,6 +540,27 @@ pub async fn delete_rule(
.await?
.ok_or_else(|| ApiError::NotFound(format!("Rule '{}' not found", rule_ref)))?;
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(rule.id);
ctx.target_ref = Some(rule.r#ref.clone());
ctx.pack_ref = Some(rule.pack_ref.clone());
authz
.authorize(
&user,
AuthorizationCheck {
resource: Resource::Rules,
action: Action::Delete,
context: ctx,
},
)
.await?;
}
// Delete the rule
let deleted = RuleRepository::delete(&state.db, rule.id).await?;

View File

@@ -53,6 +53,7 @@ impl Server {
.merge(routes::inquiry_routes())
.merge(routes::event_routes())
.merge(routes::key_routes())
.merge(routes::permission_routes())
.merge(routes::workflow_routes())
.merge(routes::webhook_routes())
.merge(routes::history_routes())

View File

@@ -9,6 +9,10 @@ use attune_common::{
models::*,
repositories::{
action::{ActionRepository, CreateActionInput},
identity::{
CreatePermissionAssignmentInput, CreatePermissionSetInput,
PermissionAssignmentRepository, PermissionSetRepository,
},
pack::{CreatePackInput, PackRepository},
trigger::{CreateTriggerInput, TriggerRepository},
workflow::{CreateWorkflowDefinitionInput, WorkflowDefinitionRepository},
@@ -246,6 +250,47 @@ impl TestContext {
Ok(self)
}
/// Create and authenticate a test user with identity + permission admin grants.
pub async fn with_admin_auth(mut self) -> Result<Self> {
let unique_id = uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string();
let login = format!("adminuser_{}", unique_id);
let token = self.create_test_user(&login).await?;
let identity = attune_common::repositories::identity::IdentityRepository::find_by_login(
&self.pool, &login,
)
.await?
.ok_or_else(|| format!("Failed to find newly created identity '{}'", login))?;
let permset = PermissionSetRepository::create(
&self.pool,
CreatePermissionSetInput {
r#ref: "core.admin".to_string(),
pack: None,
pack_ref: None,
label: Some("Admin".to_string()),
description: Some("Test admin permission set".to_string()),
grants: json!([
{"resource": "identities", "actions": ["read", "create", "update", "delete"]},
{"resource": "permissions", "actions": ["read", "create", "update", "delete", "manage"]}
]),
},
)
.await?;
PermissionAssignmentRepository::create(
&self.pool,
CreatePermissionAssignmentInput {
identity: identity.id,
permset: permset.id,
},
)
.await?;
self.token = Some(token);
Ok(self)
}
/// Create a test user and return access token
async fn create_test_user(&self, login: &str) -> Result<String> {
// Register via API to get real token

View File

@@ -0,0 +1,178 @@
use axum::http::StatusCode;
use helpers::*;
use serde_json::json;
mod helpers;
#[tokio::test]
#[ignore = "integration test — requires database"]
async fn test_identity_crud_and_permission_assignment_flow() {
let ctx = TestContext::new()
.await
.expect("Failed to create test context")
.with_admin_auth()
.await
.expect("Failed to create admin-authenticated test user");
let create_identity_response = ctx
.post(
"/api/v1/identities",
json!({
"login": "managed_user",
"display_name": "Managed User",
"password": "ManagedPass123!",
"attributes": {
"department": "platform"
}
}),
ctx.token(),
)
.await
.expect("Failed to create identity");
assert_eq!(create_identity_response.status(), StatusCode::CREATED);
let created_identity: serde_json::Value = create_identity_response
.json()
.await
.expect("Failed to parse identity create response");
let identity_id = created_identity["data"]["id"]
.as_i64()
.expect("Missing identity id");
let list_identities_response = ctx
.get("/api/v1/identities", ctx.token())
.await
.expect("Failed to list identities");
assert_eq!(list_identities_response.status(), StatusCode::OK);
let identities_body: serde_json::Value = list_identities_response
.json()
.await
.expect("Failed to parse identities response");
assert!(identities_body["data"]
.as_array()
.expect("Expected data array")
.iter()
.any(|item| item["login"] == "managed_user"));
let update_identity_response = ctx
.put(
&format!("/api/v1/identities/{}", identity_id),
json!({
"display_name": "Managed User Updated",
"attributes": {
"department": "security"
}
}),
ctx.token(),
)
.await
.expect("Failed to update identity");
assert_eq!(update_identity_response.status(), StatusCode::OK);
let get_identity_response = ctx
.get(&format!("/api/v1/identities/{}", identity_id), ctx.token())
.await
.expect("Failed to get identity");
assert_eq!(get_identity_response.status(), StatusCode::OK);
let identity_body: serde_json::Value = get_identity_response
.json()
.await
.expect("Failed to parse get identity response");
assert_eq!(
identity_body["data"]["display_name"],
"Managed User Updated"
);
assert_eq!(
identity_body["data"]["attributes"]["department"],
"security"
);
let permission_sets_response = ctx
.get("/api/v1/permissions/sets", ctx.token())
.await
.expect("Failed to list permission sets");
assert_eq!(permission_sets_response.status(), StatusCode::OK);
let assignment_response = ctx
.post(
"/api/v1/permissions/assignments",
json!({
"identity_id": identity_id,
"permission_set_ref": "core.admin"
}),
ctx.token(),
)
.await
.expect("Failed to create permission assignment");
assert_eq!(assignment_response.status(), StatusCode::CREATED);
let assignment_body: serde_json::Value = assignment_response
.json()
.await
.expect("Failed to parse permission assignment response");
let assignment_id = assignment_body["data"]["id"]
.as_i64()
.expect("Missing assignment id");
assert_eq!(assignment_body["data"]["permission_set_ref"], "core.admin");
let list_assignments_response = ctx
.get(
&format!("/api/v1/identities/{}/permissions", identity_id),
ctx.token(),
)
.await
.expect("Failed to list identity permissions");
assert_eq!(list_assignments_response.status(), StatusCode::OK);
let assignments_body: serde_json::Value = list_assignments_response
.json()
.await
.expect("Failed to parse identity permissions response");
assert!(assignments_body
.as_array()
.expect("Expected array response")
.iter()
.any(|item| item["permission_set_ref"] == "core.admin"));
let delete_assignment_response = ctx
.delete(
&format!("/api/v1/permissions/assignments/{}", assignment_id),
ctx.token(),
)
.await
.expect("Failed to delete assignment");
assert_eq!(delete_assignment_response.status(), StatusCode::OK);
let delete_identity_response = ctx
.delete(&format!("/api/v1/identities/{}", identity_id), ctx.token())
.await
.expect("Failed to delete identity");
assert_eq!(delete_identity_response.status(), StatusCode::OK);
let missing_identity_response = ctx
.get(&format!("/api/v1/identities/{}", identity_id), ctx.token())
.await
.expect("Failed to fetch deleted identity");
assert_eq!(missing_identity_response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
#[ignore = "integration test — requires database"]
async fn test_plain_authenticated_user_cannot_manage_identities() {
let ctx = TestContext::new()
.await
.expect("Failed to create test context")
.with_auth()
.await
.expect("Failed to authenticate plain test user");
let response = ctx
.get("/api/v1/identities", ctx.token())
.await
.expect("Failed to call identities endpoint");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}

View File

@@ -75,6 +75,7 @@ async fn create_test_execution(pool: &PgPool, action_id: i64) -> Result<Executio
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Scheduled,
result: None,
workflow_task: None,

View File

@@ -295,6 +295,10 @@ pub struct SecurityConfig {
/// Enable authentication
#[serde(default = "default_true")]
pub enable_auth: bool,
/// Allow unauthenticated self-service user registration
#[serde(default)]
pub allow_self_registration: bool,
}
fn default_jwt_access_expiration() -> u64 {
@@ -676,6 +680,7 @@ impl Default for SecurityConfig {
jwt_refresh_expiration: default_jwt_refresh_expiration(),
encryption_key: None,
enable_auth: true,
allow_self_registration: false,
}
}
}
@@ -924,6 +929,7 @@ mod tests {
jwt_refresh_expiration: 604800,
encryption_key: Some("a".repeat(32)),
enable_auth: true,
allow_self_registration: false,
},
worker: None,
sensor: None,

View File

@@ -15,6 +15,7 @@ pub mod models;
pub mod mq;
pub mod pack_environment;
pub mod pack_registry;
pub mod rbac;
pub mod repositories;
pub mod runtime_detection;
pub mod schema;

View File

@@ -430,6 +430,10 @@ pub mod runtime {
#[serde(default)]
pub interpreter: InterpreterConfig,
/// Strategy for inline code execution.
#[serde(default)]
pub inline_execution: InlineExecutionConfig,
/// Optional isolated environment configuration (venv, node_modules, etc.)
#[serde(default)]
pub environment: Option<EnvironmentConfig>,
@@ -449,6 +453,33 @@ pub mod runtime {
pub env_vars: HashMap<String, String>,
}
/// Controls how inline code is materialized before execution.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct InlineExecutionConfig {
/// Whether inline code is passed directly to the interpreter or first
/// written to a temporary file.
#[serde(default)]
pub strategy: InlineExecutionStrategy,
/// Optional extension for temporary inline files (e.g. ".sh").
#[serde(default)]
pub extension: Option<String>,
/// When true, inline wrapper files export the merged input map as shell
/// environment variables (`PARAM_*` and bare names) before executing the
/// script body.
#[serde(default)]
pub inject_shell_helpers: bool,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum InlineExecutionStrategy {
#[default]
Direct,
TempFile,
}
/// Describes the interpreter binary and how it invokes action scripts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InterpreterConfig {
@@ -1102,6 +1133,7 @@ pub mod execution {
pub enforcement: Option<Id>,
pub executor: Option<Id>,
pub worker: Option<Id>,
pub status: ExecutionStatus,
pub result: Option<JsonDict>,

View File

@@ -481,9 +481,8 @@ pub struct PackRegisteredPayload {
/// Payload for ExecutionCancelRequested message
///
/// Sent by the API to the worker that is running a specific execution,
/// instructing it to gracefully terminate the process (SIGINT, then SIGTERM
/// after a grace period).
/// Sent by the API or executor to the worker that is running a specific
/// execution, instructing it to terminate the process promptly.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionCancelRequestedPayload {
/// Execution ID to cancel

View File

@@ -1,14 +1,15 @@
//! Pack Component Loader
//!
//! Reads runtime, action, trigger, and sensor YAML definitions from a pack directory
//! Reads permission set, runtime, action, trigger, and sensor YAML definitions from a pack directory
//! and registers them in the database. This is the Rust-native equivalent of
//! the Python `load_core_pack.py` script used during init-packs.
//!
//! Components are loaded in dependency order:
//! 1. Runtimes (no dependencies)
//! 2. Triggers (no dependencies)
//! 3. Actions (depend on runtime; workflow actions also create workflow_definition records)
//! 4. Sensors (depend on triggers and runtime)
//! 1. Permission sets (no dependencies)
//! 2. Runtimes (no dependencies)
//! 3. Triggers (no dependencies)
//! 4. Actions (depend on runtime; workflow actions also create workflow_definition records)
//! 5. Sensors (depend on triggers and runtime)
//!
//! All loaders use **upsert** semantics: if an entity with the same ref already
//! exists it is updated in place (preserving its database ID); otherwise a new
@@ -38,6 +39,9 @@ use tracing::{debug, info, warn};
use crate::error::{Error, Result};
use crate::models::Id;
use crate::repositories::action::{ActionRepository, UpdateActionInput};
use crate::repositories::identity::{
CreatePermissionSetInput, PermissionSetRepository, UpdatePermissionSetInput,
};
use crate::repositories::runtime::{CreateRuntimeInput, RuntimeRepository, UpdateRuntimeInput};
use crate::repositories::runtime_version::{
CreateRuntimeVersionInput, RuntimeVersionRepository, UpdateRuntimeVersionInput,
@@ -56,6 +60,12 @@ use crate::workflow::parser::parse_workflow_yaml;
/// Result of loading pack components into the database.
#[derive(Debug, Default)]
pub struct PackLoadResult {
/// Number of permission sets created
pub permission_sets_loaded: usize,
/// Number of permission sets updated
pub permission_sets_updated: usize,
/// Number of permission sets skipped
pub permission_sets_skipped: usize,
/// Number of runtimes created
pub runtimes_loaded: usize,
/// Number of runtimes updated (already existed)
@@ -88,15 +98,27 @@ pub struct PackLoadResult {
impl PackLoadResult {
pub fn total_loaded(&self) -> usize {
self.runtimes_loaded + self.triggers_loaded + self.actions_loaded + self.sensors_loaded
self.permission_sets_loaded
+ self.runtimes_loaded
+ self.triggers_loaded
+ self.actions_loaded
+ self.sensors_loaded
}
pub fn total_skipped(&self) -> usize {
self.runtimes_skipped + self.triggers_skipped + self.actions_skipped + self.sensors_skipped
self.permission_sets_skipped
+ self.runtimes_skipped
+ self.triggers_skipped
+ self.actions_skipped
+ self.sensors_skipped
}
pub fn total_updated(&self) -> usize {
self.runtimes_updated + self.triggers_updated + self.actions_updated + self.sensors_updated
self.permission_sets_updated
+ self.runtimes_updated
+ self.triggers_updated
+ self.actions_updated
+ self.sensors_updated
}
}
@@ -132,22 +154,26 @@ impl<'a> PackComponentLoader<'a> {
pack_dir.display()
);
// 1. Load runtimes first (no dependencies)
// 1. Load permission sets first (no dependencies)
let permission_set_refs = self.load_permission_sets(pack_dir, &mut result).await?;
// 2. Load runtimes (no dependencies)
let runtime_refs = self.load_runtimes(pack_dir, &mut result).await?;
// 2. Load triggers (no dependencies)
// 3. Load triggers (no dependencies)
let (trigger_ids, trigger_refs) = self.load_triggers(pack_dir, &mut result).await?;
// 3. Load actions (depend on runtime)
// 4. Load actions (depend on runtime)
let action_refs = self.load_actions(pack_dir, &mut result).await?;
// 4. Load sensors (depend on triggers and runtime)
// 5. Load sensors (depend on triggers and runtime)
let sensor_refs = self
.load_sensors(pack_dir, &trigger_ids, &mut result)
.await?;
// 5. Clean up entities that are no longer in the pack's YAML files
// 6. Clean up entities that are no longer in the pack's YAML files
self.cleanup_removed_entities(
&permission_set_refs,
&runtime_refs,
&trigger_refs,
&action_refs,
@@ -169,6 +195,146 @@ impl<'a> PackComponentLoader<'a> {
Ok(result)
}
/// Load permission set definitions from `pack_dir/permission_sets/*.yaml`.
///
/// Permission sets are pack-scoped authorization metadata. Their `grants`
/// payload is stored verbatim and interpreted by the API authorization
/// layer at request time.
async fn load_permission_sets(
&self,
pack_dir: &Path,
result: &mut PackLoadResult,
) -> Result<Vec<String>> {
let permission_sets_dir = pack_dir.join("permission_sets");
let mut loaded_refs = Vec::new();
if !permission_sets_dir.exists() {
info!(
"No permission_sets directory found for pack '{}'",
self.pack_ref
);
return Ok(loaded_refs);
}
let yaml_files = read_yaml_files(&permission_sets_dir)?;
info!(
"Found {} permission set definition(s) for pack '{}'",
yaml_files.len(),
self.pack_ref
);
for (filename, content) in &yaml_files {
let data: serde_yaml_ng::Value = serde_yaml_ng::from_str(content).map_err(|e| {
Error::validation(format!(
"Failed to parse permission set YAML {}: {}",
filename, e
))
})?;
let permission_set_ref = match data.get("ref").and_then(|v| v.as_str()) {
Some(r) => r.to_string(),
None => {
let msg = format!(
"Permission set YAML {} missing 'ref' field, skipping",
filename
);
warn!("{}", msg);
result.warnings.push(msg);
result.permission_sets_skipped += 1;
continue;
}
};
let label = data
.get("label")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let description = data
.get("description")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let grants = data
.get("grants")
.and_then(|v| serde_json::to_value(v).ok())
.unwrap_or_else(|| serde_json::json!([]));
if !grants.is_array() {
let msg = format!(
"Permission set '{}' has non-array 'grants', skipping",
permission_set_ref
);
warn!("{}", msg);
result.warnings.push(msg);
result.permission_sets_skipped += 1;
continue;
}
if let Some(existing) =
PermissionSetRepository::find_by_ref(self.pool, &permission_set_ref).await?
{
let update_input = UpdatePermissionSetInput {
label,
description,
grants: Some(grants),
};
match PermissionSetRepository::update(self.pool, existing.id, update_input).await {
Ok(_) => {
info!(
"Updated permission set '{}' (ID: {})",
permission_set_ref, existing.id
);
result.permission_sets_updated += 1;
}
Err(e) => {
let msg = format!(
"Failed to update permission set '{}': {}",
permission_set_ref, e
);
warn!("{}", msg);
result.warnings.push(msg);
result.permission_sets_skipped += 1;
}
}
loaded_refs.push(permission_set_ref);
continue;
}
let input = CreatePermissionSetInput {
r#ref: permission_set_ref.clone(),
pack: Some(self.pack_id),
pack_ref: Some(self.pack_ref.clone()),
label,
description,
grants,
};
match PermissionSetRepository::create(self.pool, input).await {
Ok(permission_set) => {
info!(
"Created permission set '{}' (ID: {})",
permission_set_ref, permission_set.id
);
result.permission_sets_loaded += 1;
loaded_refs.push(permission_set_ref);
}
Err(e) => {
let msg = format!(
"Failed to create permission set '{}': {}",
permission_set_ref, e
);
warn!("{}", msg);
result.warnings.push(msg);
result.permission_sets_skipped += 1;
}
}
}
Ok(loaded_refs)
}
/// Load runtime definitions from `pack_dir/runtimes/*.yaml`.
///
/// Runtimes define how actions and sensors are executed (interpreter,
@@ -1308,12 +1474,37 @@ impl<'a> PackComponentLoader<'a> {
/// removed.
async fn cleanup_removed_entities(
&self,
permission_set_refs: &[String],
runtime_refs: &[String],
trigger_refs: &[String],
action_refs: &[String],
sensor_refs: &[String],
result: &mut PackLoadResult,
) {
match PermissionSetRepository::delete_by_pack_excluding(
self.pool,
self.pack_id,
permission_set_refs,
)
.await
{
Ok(count) => {
if count > 0 {
info!(
"Removed {} stale permission set(s) from pack '{}'",
count, self.pack_ref
);
result.removed += count as usize;
}
}
Err(e) => {
warn!(
"Failed to clean up stale permission sets for pack '{}': {}",
self.pack_ref, e
);
}
}
// Clean up sensors first (they depend on triggers/runtimes)
match SensorRepository::delete_by_pack_excluding(self.pool, self.pack_id, sensor_refs).await
{

292
crates/common/src/rbac.rs Normal file
View File

@@ -0,0 +1,292 @@
//! Role-based access control (RBAC) model and evaluator.
//!
//! Permission sets store `grants` as a JSON array of [`Grant`].
//! This module defines the canonical grant schema and matching logic.
use crate::models::{ArtifactVisibility, Id, OwnerType};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum Resource {
Packs,
Actions,
Rules,
Triggers,
Executions,
Events,
Enforcements,
Inquiries,
Keys,
Artifacts,
Workflows,
Webhooks,
Analytics,
History,
Identities,
Permissions,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum Action {
Read,
Create,
Update,
Delete,
Execute,
Cancel,
Respond,
Manage,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OwnerConstraint {
#[serde(rename = "self")]
SelfOnly,
Any,
None,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ExecutionScopeConstraint {
#[serde(rename = "self")]
SelfOnly,
Descendants,
Any,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct GrantConstraints {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pack_refs: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub owner: Option<OwnerConstraint>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub owner_types: Option<Vec<OwnerType>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub visibility: Option<Vec<ArtifactVisibility>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub execution_scope: Option<ExecutionScopeConstraint>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub refs: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ids: Option<Vec<Id>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub encrypted: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub attributes: Option<HashMap<String, JsonValue>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Grant {
pub resource: Resource,
pub actions: Vec<Action>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub constraints: Option<GrantConstraints>,
}
#[derive(Debug, Clone)]
pub struct AuthorizationContext {
pub identity_id: Id,
pub identity_attributes: HashMap<String, JsonValue>,
pub target_id: Option<Id>,
pub target_ref: Option<String>,
pub pack_ref: Option<String>,
pub owner_identity_id: Option<Id>,
pub owner_type: Option<OwnerType>,
pub visibility: Option<ArtifactVisibility>,
pub encrypted: Option<bool>,
pub execution_owner_identity_id: Option<Id>,
pub execution_ancestor_identity_ids: Vec<Id>,
}
impl AuthorizationContext {
pub fn new(identity_id: Id) -> Self {
Self {
identity_id,
identity_attributes: HashMap::new(),
target_id: None,
target_ref: None,
pack_ref: None,
owner_identity_id: None,
owner_type: None,
visibility: None,
encrypted: None,
execution_owner_identity_id: None,
execution_ancestor_identity_ids: Vec::new(),
}
}
}
impl Grant {
pub fn allows(&self, resource: Resource, action: Action, ctx: &AuthorizationContext) -> bool {
self.resource == resource && self.actions.contains(&action) && self.constraints_match(ctx)
}
fn constraints_match(&self, ctx: &AuthorizationContext) -> bool {
let Some(constraints) = &self.constraints else {
return true;
};
if let Some(pack_refs) = &constraints.pack_refs {
let Some(pack_ref) = &ctx.pack_ref else {
return false;
};
if !pack_refs.contains(pack_ref) {
return false;
}
}
if let Some(owner) = constraints.owner {
let owner_match = match owner {
OwnerConstraint::SelfOnly => ctx.owner_identity_id == Some(ctx.identity_id),
OwnerConstraint::Any => true,
OwnerConstraint::None => ctx.owner_identity_id.is_none(),
};
if !owner_match {
return false;
}
}
if let Some(owner_types) = &constraints.owner_types {
let Some(owner_type) = ctx.owner_type else {
return false;
};
if !owner_types.contains(&owner_type) {
return false;
}
}
if let Some(visibility) = &constraints.visibility {
let Some(target_visibility) = ctx.visibility else {
return false;
};
if !visibility.contains(&target_visibility) {
return false;
}
}
if let Some(execution_scope) = constraints.execution_scope {
let execution_match = match execution_scope {
ExecutionScopeConstraint::SelfOnly => {
ctx.execution_owner_identity_id == Some(ctx.identity_id)
}
ExecutionScopeConstraint::Descendants => {
ctx.execution_owner_identity_id == Some(ctx.identity_id)
|| ctx
.execution_ancestor_identity_ids
.contains(&ctx.identity_id)
}
ExecutionScopeConstraint::Any => true,
};
if !execution_match {
return false;
}
}
if let Some(refs) = &constraints.refs {
let Some(target_ref) = &ctx.target_ref else {
return false;
};
if !refs.contains(target_ref) {
return false;
}
}
if let Some(ids) = &constraints.ids {
let Some(target_id) = ctx.target_id else {
return false;
};
if !ids.contains(&target_id) {
return false;
}
}
if let Some(encrypted) = constraints.encrypted {
let Some(target_encrypted) = ctx.encrypted else {
return false;
};
if encrypted != target_encrypted {
return false;
}
}
if let Some(attributes) = &constraints.attributes {
for (key, expected_value) in attributes {
let Some(actual_value) = ctx.identity_attributes.get(key) else {
return false;
};
if actual_value != expected_value {
return false;
}
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn grant_without_constraints_allows() {
let grant = Grant {
resource: Resource::Actions,
actions: vec![Action::Read],
constraints: None,
};
let ctx = AuthorizationContext::new(42);
assert!(grant.allows(Resource::Actions, Action::Read, &ctx));
assert!(!grant.allows(Resource::Actions, Action::Create, &ctx));
}
#[test]
fn key_constraint_owner_type_and_encrypted() {
let grant = Grant {
resource: Resource::Keys,
actions: vec![Action::Read],
constraints: Some(GrantConstraints {
owner_types: Some(vec![OwnerType::System]),
encrypted: Some(false),
..Default::default()
}),
};
let mut ctx = AuthorizationContext::new(1);
ctx.owner_type = Some(OwnerType::System);
ctx.encrypted = Some(false);
assert!(grant.allows(Resource::Keys, Action::Read, &ctx));
ctx.encrypted = Some(true);
assert!(!grant.allows(Resource::Keys, Action::Read, &ctx));
}
#[test]
fn attributes_constraint_requires_exact_value_match() {
let grant = Grant {
resource: Resource::Packs,
actions: vec![Action::Read],
constraints: Some(GrantConstraints {
attributes: Some(HashMap::from([("team".to_string(), json!("platform"))])),
..Default::default()
}),
};
let mut ctx = AuthorizationContext::new(1);
ctx.identity_attributes
.insert("team".to_string(), json!("platform"));
assert!(grant.allows(Resource::Packs, Action::Read, &ctx));
ctx.identity_attributes
.insert("team".to_string(), json!("infra"));
assert!(!grant.allows(Resource::Packs, Action::Read, &ctx));
}
}

View File

@@ -54,6 +54,7 @@ pub struct ExecutionWithRefs {
pub parent: Option<Id>,
pub enforcement: Option<Id>,
pub executor: Option<Id>,
pub worker: Option<Id>,
pub status: ExecutionStatus,
pub result: Option<JsonDict>,
pub started_at: Option<DateTime<Utc>>,
@@ -73,7 +74,7 @@ pub struct ExecutionWithRefs {
/// are NOT in the Rust struct, so `SELECT *` must never be used.
pub const SELECT_COLUMNS: &str = "\
id, action, action_ref, config, env_vars, parent, enforcement, \
executor, status, result, started_at, workflow_task, created, updated";
executor, worker, status, result, started_at, workflow_task, created, updated";
pub struct ExecutionRepository;
@@ -93,6 +94,7 @@ pub struct CreateExecutionInput {
pub parent: Option<Id>,
pub enforcement: Option<Id>,
pub executor: Option<Id>,
pub worker: Option<Id>,
pub status: ExecutionStatus,
pub result: Option<JsonDict>,
pub workflow_task: Option<WorkflowTaskMetadata>,
@@ -103,6 +105,7 @@ pub struct UpdateExecutionInput {
pub status: Option<ExecutionStatus>,
pub result: Option<JsonDict>,
pub executor: Option<Id>,
pub worker: Option<Id>,
pub started_at: Option<DateTime<Utc>>,
pub workflow_task: Option<WorkflowTaskMetadata>,
}
@@ -113,6 +116,7 @@ impl From<Execution> for UpdateExecutionInput {
status: Some(execution.status),
result: execution.result,
executor: execution.executor,
worker: execution.worker,
started_at: execution.started_at,
workflow_task: execution.workflow_task,
}
@@ -158,8 +162,8 @@ impl Create for ExecutionRepository {
{
let sql = format!(
"INSERT INTO execution \
(action, action_ref, config, env_vars, parent, enforcement, executor, status, result, workflow_task) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) \
(action, action_ref, config, env_vars, parent, enforcement, executor, worker, status, result, workflow_task) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) \
RETURNING {SELECT_COLUMNS}"
);
sqlx::query_as::<_, Execution>(&sql)
@@ -170,6 +174,7 @@ impl Create for ExecutionRepository {
.bind(input.parent)
.bind(input.enforcement)
.bind(input.executor)
.bind(input.worker)
.bind(input.status)
.bind(&input.result)
.bind(sqlx::types::Json(&input.workflow_task))
@@ -208,6 +213,13 @@ impl Update for ExecutionRepository {
query.push("executor = ").push_bind(executor_id);
has_updates = true;
}
if let Some(worker_id) = input.worker {
if has_updates {
query.push(", ");
}
query.push("worker = ").push_bind(worker_id);
has_updates = true;
}
if let Some(started_at) = input.started_at {
if has_updates {
query.push(", ");

View File

@@ -4,7 +4,7 @@ use crate::models::{identity::*, Id, JsonDict};
use crate::Result;
use sqlx::{Executor, Postgres, QueryBuilder};
use super::{Create, Delete, FindById, List, Repository, Update};
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
pub struct IdentityRepository;
@@ -200,6 +200,22 @@ impl FindById for PermissionSetRepository {
}
}
#[async_trait::async_trait]
impl FindByRef for PermissionSetRepository {
async fn find_by_ref<'e, E>(executor: E, ref_str: &str) -> Result<Option<Self::Entity>>
where
E: Executor<'e, Database = Postgres> + 'e,
{
sqlx::query_as::<_, PermissionSet>(
"SELECT id, ref, pack, pack_ref, label, description, grants, created, updated FROM permission_set WHERE ref = $1"
)
.bind(ref_str)
.fetch_optional(executor)
.await
.map_err(Into::into)
}
}
#[async_trait::async_trait]
impl List for PermissionSetRepository {
async fn list<'e, E>(executor: E) -> Result<Vec<Self::Entity>>
@@ -287,6 +303,54 @@ impl Delete for PermissionSetRepository {
}
}
impl PermissionSetRepository {
pub async fn find_by_identity<'e, E>(executor: E, identity_id: Id) -> Result<Vec<PermissionSet>>
where
E: Executor<'e, Database = Postgres> + 'e,
{
sqlx::query_as::<_, PermissionSet>(
"SELECT ps.id, ps.ref, ps.pack, ps.pack_ref, ps.label, ps.description, ps.grants, ps.created, ps.updated
FROM permission_set ps
INNER JOIN permission_assignment pa ON pa.permset = ps.id
WHERE pa.identity = $1
ORDER BY ps.ref ASC",
)
.bind(identity_id)
.fetch_all(executor)
.await
.map_err(Into::into)
}
/// Delete permission sets belonging to a pack whose refs are NOT in the given set.
///
/// Used during pack reinstallation to clean up permission sets that were
/// removed from the pack's metadata. Associated permission assignments are
/// cascade-deleted by the FK constraint.
pub async fn delete_by_pack_excluding<'e, E>(
executor: E,
pack_id: Id,
keep_refs: &[String],
) -> Result<u64>
where
E: Executor<'e, Database = Postgres> + 'e,
{
let result = if keep_refs.is_empty() {
sqlx::query("DELETE FROM permission_set WHERE pack = $1")
.bind(pack_id)
.execute(executor)
.await?
} else {
sqlx::query("DELETE FROM permission_set WHERE pack = $1 AND ref != ALL($2)")
.bind(pack_id)
.bind(keep_refs)
.execute(executor)
.await?
};
Ok(result.rows_affected())
}
}
// Permission Assignment Repository
pub struct PermissionAssignmentRepository;

View File

@@ -42,6 +42,7 @@ async fn test_create_execution_basic() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -76,6 +77,7 @@ async fn test_create_execution_without_action() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -110,6 +112,7 @@ async fn test_create_execution_with_all_fields() {
parent: None,
enforcement: None,
executor: None, // Don't reference non-existent identity
worker: None,
status: ExecutionStatus::Scheduled,
result: Some(json!({"status": "ok"})),
workflow_task: None,
@@ -146,6 +149,7 @@ async fn test_create_execution_with_parent() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Running,
result: None,
workflow_task: None,
@@ -164,6 +168,7 @@ async fn test_create_execution_with_parent() {
parent: Some(parent.id),
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -203,6 +208,7 @@ async fn test_find_execution_by_id() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -257,6 +263,7 @@ async fn test_list_executions() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -303,6 +310,7 @@ async fn test_list_executions_ordered_by_created_desc() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -354,6 +362,7 @@ async fn test_update_execution_status() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -399,6 +408,7 @@ async fn test_update_execution_result() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Running,
result: None,
workflow_task: None,
@@ -445,6 +455,7 @@ async fn test_update_execution_executor() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -489,6 +500,7 @@ async fn test_update_execution_status_transitions() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -580,6 +592,7 @@ async fn test_update_execution_failed_status() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Running,
result: None,
workflow_task: None,
@@ -625,6 +638,7 @@ async fn test_update_execution_no_changes() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -669,6 +683,7 @@ async fn test_delete_execution() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Completed,
result: None,
workflow_task: None,
@@ -736,6 +751,7 @@ async fn test_find_executions_by_status() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: *status,
result: None,
workflow_task: None,
@@ -783,6 +799,7 @@ async fn test_find_executions_by_enforcement() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -801,6 +818,7 @@ async fn test_find_executions_by_enforcement() {
parent: None,
enforcement: None, // Can't reference non-existent enforcement
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -845,6 +863,7 @@ async fn test_parent_child_execution_hierarchy() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Running,
result: None,
workflow_task: None,
@@ -865,6 +884,7 @@ async fn test_parent_child_execution_hierarchy() {
parent: Some(parent.id),
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -909,6 +929,7 @@ async fn test_nested_execution_hierarchy() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Running,
result: None,
workflow_task: None,
@@ -927,6 +948,7 @@ async fn test_nested_execution_hierarchy() {
parent: Some(grandparent.id),
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Running,
result: None,
workflow_task: None,
@@ -945,6 +967,7 @@ async fn test_nested_execution_hierarchy() {
parent: Some(parent.id),
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -987,6 +1010,7 @@ async fn test_execution_timestamps() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -1058,6 +1082,7 @@ async fn test_execution_config_json() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -1091,6 +1116,7 @@ async fn test_execution_result_json() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Running,
result: None,
workflow_task: None,

View File

@@ -49,6 +49,7 @@ async fn test_create_inquiry_minimal() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -109,6 +110,7 @@ async fn test_create_inquiry_with_response_schema() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -167,6 +169,7 @@ async fn test_create_inquiry_with_timeout() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -221,6 +224,7 @@ async fn test_create_inquiry_with_assigned_user() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -310,6 +314,7 @@ async fn test_find_inquiry_by_id() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -372,6 +377,7 @@ async fn test_get_inquiry_by_id() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -443,6 +449,7 @@ async fn test_list_inquiries() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -504,6 +511,7 @@ async fn test_update_inquiry_status() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -560,6 +568,7 @@ async fn test_update_inquiry_status_transitions() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -645,6 +654,7 @@ async fn test_update_inquiry_response() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -703,6 +713,7 @@ async fn test_update_inquiry_with_response_and_status() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -761,6 +772,7 @@ async fn test_update_inquiry_assignment() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -828,6 +840,7 @@ async fn test_update_inquiry_no_changes() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -905,6 +918,7 @@ async fn test_delete_inquiry() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -965,6 +979,7 @@ async fn test_delete_execution_cascades_to_inquiries() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -1032,6 +1047,7 @@ async fn test_find_inquiries_by_status() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -1111,6 +1127,7 @@ async fn test_find_inquiries_by_execution() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -1129,6 +1146,7 @@ async fn test_find_inquiries_by_execution() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -1193,6 +1211,7 @@ async fn test_inquiry_timestamps_auto_managed() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,
@@ -1260,6 +1279,7 @@ async fn test_inquiry_complex_response_schema() {
parent: None,
enforcement: None,
executor: None,
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None,

View File

@@ -292,6 +292,7 @@ impl EnforcementProcessor {
parent: None, // TODO: Handle workflow parent-child relationships
enforcement: Some(enforcement.id),
executor: None, // Will be assigned during scheduling
worker: None,
status: attune_common::models::enums::ExecutionStatus::Requested,
result: None,
workflow_task: None, // Non-workflow execution

View File

@@ -21,13 +21,17 @@ use anyhow::Result;
use attune_common::{
models::{enums::ExecutionStatus, Execution},
mq::{
Consumer, ExecutionRequestedPayload, ExecutionStatusChangedPayload, MessageEnvelope,
MessageType, Publisher,
Consumer, ExecutionCancelRequestedPayload, ExecutionRequestedPayload,
ExecutionStatusChangedPayload, MessageEnvelope, MessageType, Publisher,
},
repositories::{
execution::{CreateExecutionInput, ExecutionRepository},
Create, FindById,
execution::{CreateExecutionInput, ExecutionRepository, UpdateExecutionInput},
workflow::{
UpdateWorkflowExecutionInput, WorkflowDefinitionRepository, WorkflowExecutionRepository,
},
Create, FindById, Update,
},
workflow::{CancellationPolicy, WorkflowDefinition},
};
use sqlx::PgPool;
@@ -116,8 +120,18 @@ impl ExecutionManager {
"Execution {} reached terminal state: {:?}, handling orchestration",
execution_id, status
);
if status == ExecutionStatus::Cancelled {
Self::handle_workflow_cancellation(pool, publisher, &execution).await?;
}
Self::handle_completion(pool, publisher, &execution).await?;
}
ExecutionStatus::Canceling => {
debug!(
"Execution {} entered canceling state; checking for workflow child cancellation",
execution_id
);
Self::handle_workflow_cancellation(pool, publisher, &execution).await?;
}
ExecutionStatus::Running => {
debug!(
"Execution {} now running (worker has updated DB)",
@@ -135,6 +149,202 @@ impl ExecutionManager {
Ok(())
}
async fn handle_workflow_cancellation(
pool: &PgPool,
publisher: &Publisher,
execution: &Execution,
) -> Result<()> {
let Some(_) = WorkflowExecutionRepository::find_by_execution(pool, execution.id).await?
else {
return Ok(());
};
let policy = Self::resolve_cancellation_policy(pool, execution.id).await;
Self::cancel_workflow_children_with_policy(pool, publisher, execution.id, policy).await
}
async fn resolve_cancellation_policy(
pool: &PgPool,
parent_execution_id: i64,
) -> CancellationPolicy {
let wf_exec =
match WorkflowExecutionRepository::find_by_execution(pool, parent_execution_id).await {
Ok(Some(wf)) => wf,
_ => return CancellationPolicy::default(),
};
let wf_def =
match WorkflowDefinitionRepository::find_by_id(pool, wf_exec.workflow_def).await {
Ok(Some(def)) => def,
_ => return CancellationPolicy::default(),
};
match serde_json::from_value::<WorkflowDefinition>(wf_def.definition) {
Ok(def) => def.cancellation_policy,
Err(e) => {
warn!(
"Failed to deserialize workflow definition for workflow_def {}: {}. Falling back to default cancellation policy.",
wf_exec.workflow_def, e
);
CancellationPolicy::default()
}
}
}
async fn cancel_workflow_children_with_policy(
pool: &PgPool,
publisher: &Publisher,
parent_execution_id: i64,
policy: CancellationPolicy,
) -> Result<()> {
let children: Vec<Execution> = sqlx::query_as::<_, Execution>(&format!(
"SELECT {} FROM execution WHERE parent = $1 AND status NOT IN ('completed', 'failed', 'timeout', 'cancelled', 'abandoned')",
attune_common::repositories::execution::SELECT_COLUMNS
))
.bind(parent_execution_id)
.fetch_all(pool)
.await?;
if children.is_empty() {
return Self::finalize_cancelled_workflow_if_idle(pool, parent_execution_id).await;
}
info!(
"Executor cascading cancellation from workflow execution {} to {} child execution(s) with policy {:?}",
parent_execution_id,
children.len(),
policy,
);
for child in &children {
let child_id = child.id;
if matches!(
child.status,
ExecutionStatus::Requested
| ExecutionStatus::Scheduling
| ExecutionStatus::Scheduled
) {
let update = UpdateExecutionInput {
status: Some(ExecutionStatus::Cancelled),
result: Some(serde_json::json!({
"error": "Cancelled: parent workflow execution was cancelled"
})),
..Default::default()
};
ExecutionRepository::update(pool, child_id, update).await?;
} else if matches!(
child.status,
ExecutionStatus::Running | ExecutionStatus::Canceling
) {
match policy {
CancellationPolicy::CancelRunning => {
if child.status != ExecutionStatus::Canceling {
let update = UpdateExecutionInput {
status: Some(ExecutionStatus::Canceling),
..Default::default()
};
ExecutionRepository::update(pool, child_id, update).await?;
}
if let Some(worker_id) = child.worker {
Self::send_cancel_to_worker(publisher, child_id, worker_id).await?;
} else {
warn!(
"Workflow child execution {} is {:?} but has no assigned worker",
child_id, child.status
);
}
}
CancellationPolicy::AllowFinish => {
info!(
"AllowFinish policy: leaving running workflow child execution {} alone",
child_id
);
}
}
}
Box::pin(Self::cancel_workflow_children_with_policy(
pool, publisher, child_id, policy,
))
.await?;
}
if let Some(wf_exec) =
WorkflowExecutionRepository::find_by_execution(pool, parent_execution_id).await?
{
if !matches!(
wf_exec.status,
ExecutionStatus::Completed | ExecutionStatus::Failed | ExecutionStatus::Cancelled
) {
let wf_update = UpdateWorkflowExecutionInput {
status: Some(ExecutionStatus::Cancelled),
error_message: Some(
"Cancelled: parent workflow execution was cancelled".to_string(),
),
current_tasks: Some(vec![]),
..Default::default()
};
WorkflowExecutionRepository::update(pool, wf_exec.id, wf_update).await?;
}
}
Self::finalize_cancelled_workflow_if_idle(pool, parent_execution_id).await
}
async fn finalize_cancelled_workflow_if_idle(
pool: &PgPool,
parent_execution_id: i64,
) -> Result<()> {
let still_running: Vec<Execution> = sqlx::query_as::<_, Execution>(&format!(
"SELECT {} FROM execution WHERE parent = $1 AND status IN ('running', 'canceling', 'scheduling', 'scheduled', 'requested')",
attune_common::repositories::execution::SELECT_COLUMNS
))
.bind(parent_execution_id)
.fetch_all(pool)
.await?;
if still_running.is_empty() {
let update = UpdateExecutionInput {
status: Some(ExecutionStatus::Cancelled),
result: Some(serde_json::json!({
"error": "Workflow cancelled",
"succeeded": false,
})),
..Default::default()
};
let _ = ExecutionRepository::update(pool, parent_execution_id, update).await?;
}
Ok(())
}
async fn send_cancel_to_worker(
publisher: &Publisher,
execution_id: i64,
worker_id: i64,
) -> Result<()> {
let payload = ExecutionCancelRequestedPayload {
execution_id,
worker_id,
};
let envelope = MessageEnvelope::new(MessageType::ExecutionCancelRequested, payload)
.with_source("executor-service")
.with_correlation_id(uuid::Uuid::new_v4());
publisher
.publish_envelope_with_routing(
&envelope,
"attune.executions",
&format!("execution.cancel.worker.{}", worker_id),
)
.await?;
Ok(())
}
/// Parse execution status from string
fn parse_execution_status(status: &str) -> Result<ExecutionStatus> {
match status.to_lowercase().as_str() {
@@ -213,6 +423,7 @@ impl ExecutionManager {
parent: Some(parent.id), // Link to parent execution
enforcement: parent.enforcement,
executor: None, // Will be assigned during scheduling
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: None, // Non-workflow execution

View File

@@ -298,6 +298,7 @@ impl RetryManager {
parent: original.parent,
enforcement: original.enforcement,
executor: None, // Will be assigned by scheduler
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: original.workflow_task.clone(),

View File

@@ -230,9 +230,11 @@ impl ExecutionScheduler {
}
};
// Update execution status to scheduled
// Persist the selected worker so later cancellation requests can be
// routed to the correct per-worker cancel queue.
let mut execution_for_update = execution;
execution_for_update.status = ExecutionStatus::Scheduled;
execution_for_update.worker = Some(worker.id);
ExecutionRepository::update(pool, execution_for_update.id, execution_for_update.into())
.await?;
@@ -529,6 +531,7 @@ impl ExecutionScheduler {
parent: Some(parent_execution.id),
enforcement: parent_execution.enforcement,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: Some(workflow_task),
@@ -689,6 +692,7 @@ impl ExecutionScheduler {
parent: Some(parent_execution.id),
enforcement: parent_execution.enforcement,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
workflow_task: Some(workflow_task),
@@ -1886,4 +1890,32 @@ mod tests {
serde_json::json!({"parameters": {"n": 5}, "context": {"rule": "test"}})
);
}
#[test]
fn test_scheduling_persists_selected_worker() {
let mut execution = attune_common::models::Execution {
id: 42,
action: Some(7),
action_ref: "core.sleep".to_string(),
config: None,
env_vars: None,
parent: None,
enforcement: None,
executor: None,
worker: None,
status: ExecutionStatus::Requested,
result: None,
started_at: None,
workflow_task: None,
created: Utc::now(),
updated: Utc::now(),
};
execution.status = ExecutionStatus::Scheduled;
execution.worker = Some(99);
let update: UpdateExecutionInput = execution.into();
assert_eq!(update.status, Some(ExecutionStatus::Scheduled));
assert_eq!(update.worker, Some(99));
}
}

View File

@@ -126,6 +126,7 @@ async fn create_test_execution(
parent: None,
enforcement: None,
executor: None,
worker: None,
status,
result: None,
workflow_task: None,

View File

@@ -121,6 +121,7 @@ async fn create_test_execution(
parent: None,
enforcement: None,
executor: None,
worker: None,
status,
result: None,
workflow_task: None,

View File

@@ -100,7 +100,7 @@ impl ActionExecutor {
/// Execute an action for the given execution, with cancellation support.
///
/// When the `cancel_token` is triggered, the running process receives
/// SIGINT → SIGTERM → SIGKILL with escalating grace periods.
/// SIGTERM → SIGKILL with a short grace period.
pub async fn execute_with_cancel(
&self,
execution_id: i64,
@@ -139,7 +139,7 @@ impl ActionExecutor {
};
// Attach the cancellation token so the process executor can monitor it
context.cancel_token = Some(cancel_token);
context.cancel_token = Some(cancel_token.clone());
// Execute the action
// Note: execute_action should rarely return Err - most failures should be
@@ -181,7 +181,16 @@ impl ActionExecutor {
execution_id, result.exit_code, result.error, is_success
);
if is_success {
let was_cancelled = cancel_token.is_cancelled()
|| result
.error
.as_deref()
.is_some_and(|e| e.contains("cancelled"));
if was_cancelled {
self.handle_execution_cancelled(execution_id, &result)
.await?;
} else if is_success {
self.handle_execution_success(execution_id, &result).await?;
} else {
self.handle_execution_failure(execution_id, Some(&result), None)
@@ -913,6 +922,51 @@ impl ActionExecutor {
Ok(())
}
async fn handle_execution_cancelled(
&self,
execution_id: i64,
result: &ExecutionResult,
) -> Result<()> {
let exec_dir = self.artifact_manager.get_execution_dir(execution_id);
let mut result_data = serde_json::json!({
"succeeded": false,
"cancelled": true,
"exit_code": result.exit_code,
"duration_ms": result.duration_ms,
"error": result.error.clone().unwrap_or_else(|| "Execution cancelled by user".to_string()),
});
if !result.stdout.is_empty() {
result_data["stdout"] = serde_json::json!(result.stdout);
}
if !result.stderr.trim().is_empty() {
let stderr_path = exec_dir.join("stderr.log");
result_data["stderr_log"] = serde_json::json!(stderr_path.to_string_lossy());
}
if result.stdout_truncated {
result_data["stdout_truncated"] = serde_json::json!(true);
result_data["stdout_bytes_truncated"] =
serde_json::json!(result.stdout_bytes_truncated);
}
if result.stderr_truncated {
result_data["stderr_truncated"] = serde_json::json!(true);
result_data["stderr_bytes_truncated"] =
serde_json::json!(result.stderr_bytes_truncated);
}
let input = UpdateExecutionInput {
status: Some(ExecutionStatus::Cancelled),
result: Some(result_data),
..Default::default()
};
ExecutionRepository::update(&self.pool, execution_id, input).await?;
Ok(())
}
/// Update execution status
async fn update_execution_status(
&self,

View File

@@ -19,7 +19,7 @@ pub use heartbeat::HeartbeatManager;
pub use registration::WorkerRegistration;
pub use runtime::{
ExecutionContext, ExecutionResult, LocalRuntime, NativeRuntime, ProcessRuntime, Runtime,
RuntimeError, RuntimeResult, ShellRuntime,
RuntimeError, RuntimeResult,
};
pub use secrets::SecretManager;
pub use service::WorkerService;

View File

@@ -1,6 +1,6 @@
//! Local Runtime Module
//!
//! Provides local execution capabilities by combining Process and Shell runtimes.
//! Provides local execution capabilities by combining Process and Native runtimes.
//! This module serves as a facade for all local process-based execution.
//!
//! The `ProcessRuntime` is used for Python (and other interpreted languages),
@@ -8,10 +8,11 @@
use super::native::NativeRuntime;
use super::process::ProcessRuntime;
use super::shell::ShellRuntime;
use super::{ExecutionContext, ExecutionResult, Runtime, RuntimeError, RuntimeResult};
use async_trait::async_trait;
use attune_common::models::runtime::{InterpreterConfig, RuntimeExecutionConfig};
use attune_common::models::runtime::{
InlineExecutionConfig, InlineExecutionStrategy, InterpreterConfig, RuntimeExecutionConfig,
};
use std::path::PathBuf;
use tracing::{debug, info};
@@ -19,7 +20,7 @@ use tracing::{debug, info};
pub struct LocalRuntime {
native: NativeRuntime,
python: ProcessRuntime,
shell: ShellRuntime,
shell: ProcessRuntime,
}
impl LocalRuntime {
@@ -34,6 +35,23 @@ impl LocalRuntime {
args: vec![],
file_extension: Some(".py".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
};
let shell_config = RuntimeExecutionConfig {
interpreter: InterpreterConfig {
binary: "/bin/bash".to_string(),
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: InlineExecutionConfig {
strategy: InlineExecutionStrategy::TempFile,
extension: Some(".sh".to_string()),
inject_shell_helpers: true,
},
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
@@ -47,7 +65,12 @@ impl LocalRuntime {
PathBuf::from("/opt/attune/packs"),
PathBuf::from("/opt/attune/runtime_envs"),
),
shell: ShellRuntime::new(),
shell: ProcessRuntime::new(
"shell".to_string(),
shell_config,
PathBuf::from("/opt/attune/packs"),
PathBuf::from("/opt/attune/runtime_envs"),
),
}
}
@@ -55,7 +78,7 @@ impl LocalRuntime {
pub fn with_runtimes(
native: NativeRuntime,
python: ProcessRuntime,
shell: ShellRuntime,
shell: ProcessRuntime,
) -> Self {
Self {
native,
@@ -76,7 +99,10 @@ impl LocalRuntime {
);
Ok(&self.python)
} else if self.shell.can_execute(context) {
debug!("Selected Shell runtime for action: {}", context.action_ref);
debug!(
"Selected Shell (ProcessRuntime) for action: {}",
context.action_ref
);
Ok(&self.shell)
} else {
Err(RuntimeError::RuntimeNotFound(format!(

View File

@@ -29,13 +29,11 @@ pub mod native;
pub mod parameter_passing;
pub mod process;
pub mod process_executor;
pub mod shell;
// Re-export runtime implementations
pub use local::LocalRuntime;
pub use native::NativeRuntime;
pub use process::ProcessRuntime;
pub use shell::ShellRuntime;
use async_trait::async_trait;
use attune_common::models::runtime::RuntimeExecutionConfig;
@@ -159,9 +157,9 @@ pub struct ExecutionContext {
/// Format for output parsing
pub output_format: OutputFormat,
/// Optional cancellation token for graceful process termination.
/// When triggered, the executor sends SIGINT → SIGTERM → SIGKILL
/// with escalating grace periods.
/// Optional cancellation token for process termination.
/// When triggered, the executor sends SIGTERM → SIGKILL
/// with a short grace period.
pub cancel_token: Option<CancellationToken>,
}

View File

@@ -19,12 +19,18 @@ use super::{
process_executor, ExecutionContext, ExecutionResult, Runtime, RuntimeError, RuntimeResult,
};
use async_trait::async_trait;
use attune_common::models::runtime::{EnvironmentConfig, RuntimeExecutionConfig};
use attune_common::models::runtime::{
EnvironmentConfig, InlineExecutionStrategy, RuntimeExecutionConfig,
};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tokio::process::Command;
use tracing::{debug, error, info, warn};
fn bash_single_quote_escape(s: &str) -> String {
s.replace('\'', "'\\''")
}
/// A generic runtime driven by `RuntimeExecutionConfig` from the database.
///
/// Each `ProcessRuntime` instance corresponds to a row in the `runtime` table.
@@ -437,6 +443,90 @@ impl ProcessRuntime {
pub fn config(&self) -> &RuntimeExecutionConfig {
&self.config
}
fn build_shell_inline_wrapper(
&self,
merged_parameters: &HashMap<String, serde_json::Value>,
code: &str,
) -> RuntimeResult<String> {
let mut script = String::new();
script.push_str("#!/bin/bash\n");
script.push_str("set -e\n\n");
script.push_str("# Action parameters\n");
for (key, value) in merged_parameters {
let value_str = match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => serde_json::to_string(value)?,
};
let escaped = bash_single_quote_escape(&value_str);
script.push_str(&format!(
"export PARAM_{}='{}'\n",
key.to_uppercase(),
escaped
));
script.push_str(&format!("export {}='{}'\n", key, escaped));
}
script.push('\n');
script.push_str("# Action code\n");
script.push_str(code);
Ok(script)
}
async fn materialize_inline_code(
&self,
execution_id: i64,
merged_parameters: &HashMap<String, serde_json::Value>,
code: &str,
effective_config: &RuntimeExecutionConfig,
) -> RuntimeResult<(PathBuf, bool)> {
let inline_dir = std::env::temp_dir().join("attune").join("inline_actions");
tokio::fs::create_dir_all(&inline_dir).await.map_err(|e| {
RuntimeError::ExecutionFailed(format!(
"Failed to create inline action directory {}: {}",
inline_dir.display(),
e
))
})?;
let extension = effective_config
.inline_execution
.extension
.as_deref()
.unwrap_or("");
let extension = if extension.is_empty() {
String::new()
} else if extension.starts_with('.') {
extension.to_string()
} else {
format!(".{}", extension)
};
let inline_path = inline_dir.join(format!("exec_{}{}", execution_id, extension));
let inline_code = if effective_config.inline_execution.inject_shell_helpers {
self.build_shell_inline_wrapper(merged_parameters, code)?
} else {
code.to_string()
};
tokio::fs::write(&inline_path, inline_code)
.await
.map_err(|e| {
RuntimeError::ExecutionFailed(format!(
"Failed to write inline action file {}: {}",
inline_path.display(),
e
))
})?;
Ok((
inline_path,
effective_config.inline_execution.inject_shell_helpers,
))
}
}
#[async_trait]
@@ -661,7 +751,7 @@ impl Runtime for ProcessRuntime {
};
let prepared_params =
parameter_passing::prepare_parameters(&merged_parameters, &mut env, param_config)?;
let parameters_stdin = prepared_params.stdin_content();
let mut parameters_stdin = prepared_params.stdin_content();
// Determine working directory: use context override, or pack dir
let working_dir = context
@@ -677,6 +767,7 @@ impl Runtime for ProcessRuntime {
});
// Build the command based on whether we have a file or inline code
let mut temp_inline_file: Option<PathBuf> = None;
let cmd = if let Some(ref code_path) = context.code_path {
// File-based execution: interpreter [args] <action_file>
debug!("Executing file: {}", code_path.display());
@@ -688,13 +779,38 @@ impl Runtime for ProcessRuntime {
&env,
)
} else if let Some(ref code) = context.code {
// Inline code execution: interpreter -c <code>
debug!("Executing inline code ({} bytes)", code.len());
let mut cmd = process_executor::build_inline_command(&interpreter, code, &env);
if let Some(dir) = working_dir {
cmd.current_dir(dir);
match effective_config.inline_execution.strategy {
InlineExecutionStrategy::Direct => {
debug!("Executing inline code directly ({} bytes)", code.len());
let mut cmd = process_executor::build_inline_command(&interpreter, code, &env);
if let Some(dir) = working_dir {
cmd.current_dir(dir);
}
cmd
}
InlineExecutionStrategy::TempFile => {
debug!("Executing inline code via temp file ({} bytes)", code.len());
let (inline_path, consumes_parameters) = self
.materialize_inline_code(
context.execution_id,
&merged_parameters,
code,
effective_config,
)
.await?;
if consumes_parameters {
parameters_stdin = None;
}
temp_inline_file = Some(inline_path.clone());
process_executor::build_action_command(
&interpreter,
&effective_config.interpreter.args,
&inline_path,
working_dir,
&env,
)
}
}
cmd
} else {
// No code_path and no inline code — try treating entry_point as a file
// relative to the pack's actions directory
@@ -737,7 +853,7 @@ impl Runtime for ProcessRuntime {
// Execute with streaming output capture (with optional cancellation support).
// Secrets are already merged into parameters — no separate secrets arg needed.
process_executor::execute_streaming_cancellable(
let result = process_executor::execute_streaming_cancellable(
cmd,
&HashMap::new(),
parameters_stdin,
@@ -747,7 +863,13 @@ impl Runtime for ProcessRuntime {
context.output_format,
context.cancel_token.clone(),
)
.await
.await;
if let Some(path) = temp_inline_file {
let _ = tokio::fs::remove_file(path).await;
}
result
}
async fn setup(&self) -> RuntimeResult<()> {
@@ -836,7 +958,8 @@ impl Runtime for ProcessRuntime {
mod tests {
use super::*;
use attune_common::models::runtime::{
DependencyConfig, EnvironmentConfig, InterpreterConfig, RuntimeExecutionConfig,
DependencyConfig, EnvironmentConfig, InlineExecutionConfig, InlineExecutionStrategy,
InterpreterConfig, RuntimeExecutionConfig,
};
use attune_common::models::{OutputFormat, ParameterDelivery, ParameterFormat};
use std::collections::HashMap;
@@ -849,6 +972,11 @@ mod tests {
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: InlineExecutionConfig {
strategy: InlineExecutionStrategy::TempFile,
extension: Some(".sh".to_string()),
inject_shell_helpers: true,
},
environment: None,
dependencies: None,
env_vars: HashMap::new(),
@@ -862,6 +990,7 @@ mod tests {
args: vec!["-u".to_string()],
file_extension: Some(".py".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: Some(EnvironmentConfig {
env_type: "virtualenv".to_string(),
dir_name: ".venv".to_string(),
@@ -1104,6 +1233,7 @@ mod tests {
args: vec![],
file_extension: Some(".py".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: HashMap::new(),
@@ -1183,6 +1313,53 @@ mod tests {
assert!(result.stdout.contains("inline shell code"));
}
#[tokio::test]
async fn test_execute_inline_code_with_merged_inputs() {
let temp_dir = TempDir::new().unwrap();
let runtime = ProcessRuntime::new(
"shell".to_string(),
make_shell_config(),
temp_dir.path().to_path_buf(),
temp_dir.path().join("runtime_envs"),
);
let context = ExecutionContext {
execution_id: 30,
action_ref: "adhoc.test_inputs".to_string(),
parameters: {
let mut map = HashMap::new();
map.insert("name".to_string(), serde_json::json!("Alice"));
map
},
env: HashMap::new(),
secrets: {
let mut map = HashMap::new();
map.insert("api_key".to_string(), serde_json::json!("secret-123"));
map
},
timeout: Some(10),
working_dir: None,
entry_point: "inline".to_string(),
code: Some("echo \"$name/$api_key/$PARAM_NAME/$PARAM_API_KEY\"".to_string()),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 1024 * 1024,
max_stderr_bytes: 1024 * 1024,
parameter_delivery: ParameterDelivery::default(),
parameter_format: ParameterFormat::default(),
output_format: OutputFormat::default(),
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert_eq!(result.exit_code, 0);
assert!(result.stdout.contains("Alice/secret-123/Alice/secret-123"));
}
#[tokio::test]
async fn test_execute_entry_point_fallback() {
let temp_dir = TempDir::new().unwrap();

View File

@@ -9,12 +9,12 @@
//!
//! When a `CancellationToken` is provided, the executor monitors it alongside
//! the running process. On cancellation:
//! 1. SIGINT is sent to the process (allows graceful shutdown)
//! 2. After a 10-second grace period, SIGTERM is sent if the process hasn't exited
//! 3. After another 5-second grace period, SIGKILL is sent as a last resort
//! 1. SIGTERM is sent to the process immediately
//! 2. After a 5-second grace period, SIGKILL is sent as a last resort
use super::{BoundedLogWriter, ExecutionResult, OutputFormat, RuntimeResult};
use std::collections::HashMap;
use std::io;
use std::path::Path;
use std::time::Instant;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
@@ -71,7 +71,7 @@ pub async fn execute_streaming(
/// - Writing parameters (with secrets merged in) to stdin
/// - Streaming stdout/stderr with bounded log collection
/// - Timeout management
/// - Graceful cancellation via SIGINT → SIGTERM → SIGKILL escalation
/// - Prompt cancellation via SIGTERM → SIGKILL escalation
/// - Output format parsing (JSON, YAML, JSONL, text)
///
/// # Arguments
@@ -96,6 +96,8 @@ pub async fn execute_streaming_cancellable(
) -> RuntimeResult<ExecutionResult> {
let start = Instant::now();
configure_child_process(&mut cmd)?;
// Spawn process with piped I/O
let mut child = cmd
.stdin(std::process::Stdio::piped())
@@ -178,67 +180,56 @@ pub async fn execute_streaming_cancellable(
// Build the wait future that handles timeout, cancellation, and normal completion.
//
// The result is a tuple: (wait_result, was_cancelled)
// - wait_result mirrors the original type: Result<Result<ExitStatus, io::Error>, Elapsed>
// - was_cancelled indicates the process was stopped by a cancel request
// The result is a tuple: (exit_status, was_cancelled, was_timed_out)
let wait_future = async {
// Inner future: wait for the child process to exit
let wait_child = child.wait();
// Apply optional timeout wrapping
let timed_wait = async {
if let Some(timeout_secs) = timeout_secs {
timeout(std::time::Duration::from_secs(timeout_secs), wait_child).await
} else {
Ok(wait_child.await)
}
};
// If we have a cancel token, race it against the (possibly-timed) wait
if let Some(ref token) = cancel_token {
tokio::select! {
result = timed_wait => (result, false),
_ = token.cancelled() => {
// Cancellation requested — escalate signals to the child process.
info!("Cancel signal received, sending SIGINT to process");
if let Some(pid) = child_pid {
send_signal(pid, libc::SIGINT);
}
// Grace period: wait up to 10s for the process to exit after SIGINT.
match timeout(std::time::Duration::from_secs(10), child.wait()).await {
Ok(status) => (Ok(status), true),
Err(_) => {
// Still alive — escalate to SIGTERM
warn!("Process did not exit after SIGINT + 10s grace period, sending SIGTERM");
if let Some(pid) = child_pid {
send_signal(pid, libc::SIGTERM);
}
// Final grace period: wait up to 5s for SIGTERM
match timeout(std::time::Duration::from_secs(5), child.wait()).await {
Ok(status) => (Ok(status), true),
Err(_) => {
// Last resort — SIGKILL
warn!("Process did not exit after SIGTERM + 5s, sending SIGKILL");
if let Some(pid) = child_pid {
send_signal(pid, libc::SIGKILL);
}
// Wait indefinitely for the SIGKILL to take effect
(Ok(child.wait().await), true)
}
}
match (cancel_token.as_ref(), timeout_secs) {
(Some(token), Some(timeout_secs)) => {
tokio::select! {
result = child.wait() => (result, false, false),
_ = token.cancelled() => {
if let Some(pid) = child_pid {
terminate_process(pid, "cancel");
}
(wait_for_terminated_child(&mut child).await, true, false)
}
_ = tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)) => {
if let Some(pid) = child_pid {
warn!("Process timed out after {} seconds, terminating", timeout_secs);
terminate_process(pid, "timeout");
}
(wait_for_terminated_child(&mut child).await, false, true)
}
}
}
} else {
(timed_wait.await, false)
(Some(token), None) => {
tokio::select! {
result = child.wait() => (result, false, false),
_ = token.cancelled() => {
if let Some(pid) = child_pid {
terminate_process(pid, "cancel");
}
(wait_for_terminated_child(&mut child).await, true, false)
}
}
}
(None, Some(timeout_secs)) => {
tokio::select! {
result = child.wait() => (result, false, false),
_ = tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)) => {
if let Some(pid) = child_pid {
warn!("Process timed out after {} seconds, terminating", timeout_secs);
terminate_process(pid, "timeout");
}
(wait_for_terminated_child(&mut child).await, false, true)
}
}
}
(None, None) => (child.wait().await, false, false),
}
};
// Wait for both streams and the process
let (stdout_writer, stderr_writer, (wait_result, was_cancelled)) =
let (stdout_writer, stderr_writer, (wait_result, was_cancelled, was_timed_out)) =
tokio::join!(stdout_task, stderr_task, wait_future);
let duration_ms = start.elapsed().as_millis() as u64;
@@ -249,31 +240,31 @@ pub async fn execute_streaming_cancellable(
// Handle process wait result
let (exit_code, process_error) = match wait_result {
Ok(Ok(status)) => (status.code().unwrap_or(-1), None),
Ok(Err(e)) => {
Ok(status) => (status.code().unwrap_or(-1), None),
Err(e) => {
warn!("Process wait failed but captured output: {}", e);
(-1, Some(format!("Process wait failed: {}", e)))
}
Err(_) => {
// Timeout occurred
return Ok(ExecutionResult {
exit_code: -1,
stdout: stdout_result.content.clone(),
stderr: stderr_result.content.clone(),
result: None,
duration_ms,
error: Some(format!(
"Execution timed out after {} seconds",
timeout_secs.unwrap()
)),
stdout_truncated: stdout_result.truncated,
stderr_truncated: stderr_result.truncated,
stdout_bytes_truncated: stdout_result.bytes_truncated,
stderr_bytes_truncated: stderr_result.bytes_truncated,
});
}
};
if was_timed_out {
return Ok(ExecutionResult {
exit_code: -1,
stdout: stdout_result.content.clone(),
stderr: stderr_result.content.clone(),
result: None,
duration_ms,
error: Some(format!(
"Execution timed out after {} seconds",
timeout_secs.unwrap()
)),
stdout_truncated: stdout_result.truncated,
stderr_truncated: stderr_result.truncated,
stdout_bytes_truncated: stdout_result.bytes_truncated,
stderr_bytes_truncated: stderr_result.bytes_truncated,
});
}
// If the process was cancelled, return a specific result
if was_cancelled {
return Ok(ExecutionResult {
@@ -361,14 +352,65 @@ pub async fn execute_streaming_cancellable(
}
/// Parse stdout content according to the specified output format.
/// Send a Unix signal to a process by PID.
///
/// Uses raw `libc::kill()` to deliver signals for graceful process termination.
/// This is safe because we only send signals to child processes we spawned.
fn send_signal(pid: u32, signal: i32) {
// Safety: we're sending a signal to a known child process PID.
// The PID is valid because we obtained it from `child.id()` before the
// child exited.
fn configure_child_process(cmd: &mut Command) -> io::Result<()> {
#[cfg(unix)]
{
// Run each action in its own process group so cancellation and timeout
// can terminate shell wrappers and any children they spawned.
unsafe {
cmd.pre_exec(|| {
if libc::setpgid(0, 0) == -1 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
});
}
}
Ok(())
}
async fn wait_for_terminated_child(
child: &mut tokio::process::Child,
) -> io::Result<std::process::ExitStatus> {
match timeout(std::time::Duration::from_secs(5), child.wait()).await {
Ok(status) => status,
Err(_) => {
warn!("Process did not exit after SIGTERM + 5s, sending SIGKILL");
if let Some(pid) = child.id() {
kill_process_group_or_process(pid, libc::SIGKILL);
}
child.wait().await
}
}
}
fn terminate_process(pid: u32, reason: &str) {
info!("Sending SIGTERM to {} process group {}", reason, pid);
kill_process_group_or_process(pid, libc::SIGTERM);
}
fn kill_process_group_or_process(pid: u32, signal: i32) {
#[cfg(unix)]
{
// Negative PID targets the process group created with setpgid(0, 0).
let pgid = -(pid as i32);
// Safety: we only signal processes we spawned.
let rc = unsafe { libc::kill(pgid, signal) };
if rc == 0 {
return;
}
let err = io::Error::last_os_error();
warn!(
"Failed to signal process group {} with signal {}: {}. Falling back to PID {}",
pid, signal, err, pid
);
}
// Safety: fallback to the direct child PID if the process group signal fails
// or on non-Unix targets where process groups are unavailable.
unsafe {
libc::kill(pid as i32, signal);
}
@@ -492,6 +534,9 @@ pub fn build_inline_command(
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use tokio::fs;
use tokio::time::{sleep, Duration};
#[test]
fn test_parse_output_text() {
@@ -621,4 +666,86 @@ mod tests {
// We can't easily inspect Command internals, but at least verify it builds without panic
let _ = cmd;
}
#[tokio::test]
async fn test_execute_streaming_cancellation_kills_shell_child_process() {
let script = NamedTempFile::new().unwrap();
fs::write(
script.path(),
"#!/bin/sh\nsleep 30\nprintf 'unexpected completion\\n'\n",
)
.await
.unwrap();
let mut perms = fs::metadata(script.path()).await.unwrap().permissions();
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
perms.set_mode(0o755);
}
fs::set_permissions(script.path(), perms).await.unwrap();
let cancel_token = CancellationToken::new();
let trigger = cancel_token.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(200)).await;
trigger.cancel();
});
let mut cmd = Command::new("/bin/sh");
cmd.arg(script.path());
let result = execute_streaming_cancellable(
cmd,
&HashMap::new(),
None,
Some(60),
1024 * 1024,
1024 * 1024,
OutputFormat::Text,
Some(cancel_token),
)
.await
.unwrap();
assert!(result
.error
.as_deref()
.is_some_and(|e| e.contains("cancelled")));
assert!(
result.duration_ms < 5_000,
"expected prompt cancellation, got {}ms",
result.duration_ms
);
assert!(!result.stdout.contains("unexpected completion"));
}
#[tokio::test]
async fn test_execute_streaming_timeout_terminates_process() {
let mut cmd = Command::new("/bin/sh");
cmd.arg("-c").arg("sleep 30");
let result = execute_streaming(
cmd,
&HashMap::new(),
None,
Some(1),
1024 * 1024,
1024 * 1024,
OutputFormat::Text,
)
.await
.unwrap();
assert_eq!(result.exit_code, -1);
assert!(result
.error
.as_deref()
.is_some_and(|e| e.contains("timed out after 1 seconds")));
assert!(
result.duration_ms < 7_000,
"expected timeout termination, got {}ms",
result.duration_ms
);
}
}

View File

@@ -1,949 +0,0 @@
//! Shell Runtime Implementation
//!
//! Executes shell scripts and commands using subprocess execution.
use super::{
parameter_passing::{self, ParameterDeliveryConfig},
BoundedLogWriter, ExecutionContext, ExecutionResult, OutputFormat, Runtime, RuntimeError,
RuntimeResult,
};
use async_trait::async_trait;
use std::collections::HashMap;
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Instant;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tokio::time::timeout;
use tracing::{debug, info, warn};
/// Escape a string for embedding inside a bash single-quoted string.
///
/// In single-quoted strings the only problematic character is `'` itself.
/// We close the current single-quote, insert an escaped single-quote, and
/// reopen: `'foo'\''bar'` → `foo'bar`.
fn bash_single_quote_escape(s: &str) -> String {
s.replace('\'', "'\\''")
}
/// Shell runtime for executing shell scripts and commands
pub struct ShellRuntime {
/// Shell interpreter path (bash, sh, zsh, etc.)
shell_path: PathBuf,
/// Base directory for storing action code
work_dir: PathBuf,
}
impl ShellRuntime {
/// Create a new Shell runtime with bash
pub fn new() -> Self {
Self {
shell_path: PathBuf::from("/bin/bash"),
work_dir: PathBuf::from("/tmp/attune/actions"),
}
}
/// Create a Shell runtime with custom shell
pub fn with_shell(shell_path: PathBuf) -> Self {
Self {
shell_path,
work_dir: PathBuf::from("/tmp/attune/actions"),
}
}
/// Create a Shell runtime with custom settings
pub fn with_config(shell_path: PathBuf, work_dir: PathBuf) -> Self {
Self {
shell_path,
work_dir,
}
}
/// Execute with streaming and bounded log collection
#[allow(clippy::too_many_arguments)]
async fn execute_with_streaming(
&self,
mut cmd: Command,
_secrets: &std::collections::HashMap<String, String>,
parameters_stdin: Option<&str>,
timeout_secs: Option<u64>,
max_stdout_bytes: usize,
max_stderr_bytes: usize,
output_format: OutputFormat,
) -> RuntimeResult<ExecutionResult> {
let start = Instant::now();
// Spawn process with piped I/O
let mut child = cmd
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
// Write to stdin - parameters (with secrets already merged in by the caller).
// If this fails, the process has already started, so we continue and capture output.
let stdin_write_error = if let Some(mut stdin) = child.stdin.take() {
let mut error = None;
// Write parameters to stdin as a single JSON line.
// Secrets are merged into the parameters map by the caller, so the
// action reads everything with a single readline().
if let Some(params_data) = parameters_stdin {
if let Err(e) = stdin.write_all(params_data.as_bytes()).await {
error = Some(format!("Failed to write parameters to stdin: {}", e));
} else if let Err(e) = stdin.write_all(b"\n").await {
error = Some(format!("Failed to write newline to stdin: {}", e));
}
}
drop(stdin);
error
} else {
None
};
// Create bounded writers
let mut stdout_writer = BoundedLogWriter::new_stdout(max_stdout_bytes);
let mut stderr_writer = BoundedLogWriter::new_stderr(max_stderr_bytes);
// Take stdout and stderr streams
let stdout = child.stdout.take().expect("stdout not captured");
let stderr = child.stderr.take().expect("stderr not captured");
// Create buffered readers
let mut stdout_reader = BufReader::new(stdout);
let mut stderr_reader = BufReader::new(stderr);
// Stream both outputs concurrently
let stdout_task = async {
let mut line = Vec::new();
loop {
line.clear();
match stdout_reader.read_until(b'\n', &mut line).await {
Ok(0) => break, // EOF
Ok(_) => {
if stdout_writer.write_all(&line).await.is_err() {
break;
}
}
Err(_) => break,
}
}
stdout_writer
};
let stderr_task = async {
let mut line = Vec::new();
loop {
line.clear();
match stderr_reader.read_until(b'\n', &mut line).await {
Ok(0) => break, // EOF
Ok(_) => {
if stderr_writer.write_all(&line).await.is_err() {
break;
}
}
Err(_) => break,
}
}
stderr_writer
};
// Wait for both streams and the process
let (stdout_writer, stderr_writer, wait_result) =
tokio::join!(stdout_task, stderr_task, async {
if let Some(timeout_secs) = timeout_secs {
timeout(std::time::Duration::from_secs(timeout_secs), child.wait()).await
} else {
Ok(child.wait().await)
}
});
let duration_ms = start.elapsed().as_millis() as u64;
// Get results from bounded writers - we have these regardless of wait() success
let stdout_result = stdout_writer.into_result();
let stderr_result = stderr_writer.into_result();
// Handle process wait result
let (exit_code, process_error) = match wait_result {
Ok(Ok(status)) => (status.code().unwrap_or(-1), None),
Ok(Err(e)) => {
// Process wait failed, but we have the output - return it with an error
warn!("Process wait failed but captured output: {}", e);
(-1, Some(format!("Process wait failed: {}", e)))
}
Err(_) => {
// Timeout occurred
return Ok(ExecutionResult {
exit_code: -1,
stdout: stdout_result.content.clone(),
stderr: stderr_result.content.clone(),
result: None,
duration_ms,
error: Some(format!(
"Execution timed out after {} seconds",
timeout_secs.unwrap()
)),
stdout_truncated: stdout_result.truncated,
stderr_truncated: stderr_result.truncated,
stdout_bytes_truncated: stdout_result.bytes_truncated,
stderr_bytes_truncated: stderr_result.bytes_truncated,
});
}
};
debug!(
"Shell execution completed: exit_code={}, duration={}ms, stdout_truncated={}, stderr_truncated={}",
exit_code, duration_ms, stdout_result.truncated, stderr_result.truncated
);
// Parse result from stdout based on output_format
let result = if exit_code == 0 && !stdout_result.content.trim().is_empty() {
match output_format {
OutputFormat::Text => {
// No parsing - text output is captured in stdout field
None
}
OutputFormat::Json => {
// Try to parse full stdout as JSON first (handles multi-line JSON),
// then fall back to last line only (for scripts that log before output)
let trimmed = stdout_result.content.trim();
serde_json::from_str(trimmed).ok().or_else(|| {
trimmed
.lines()
.last()
.and_then(|line| serde_json::from_str(line).ok())
})
}
OutputFormat::Yaml => {
// Try to parse stdout as YAML
serde_yaml_ng::from_str(stdout_result.content.trim()).ok()
}
OutputFormat::Jsonl => {
// Parse each line as JSON and collect into array
let mut items = Vec::new();
for line in stdout_result.content.trim().lines() {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(line) {
items.push(value);
}
}
if items.is_empty() {
None
} else {
Some(serde_json::Value::Array(items))
}
}
}
} else {
None
};
// Determine error message
let error = if let Some(proc_err) = process_error {
Some(proc_err)
} else if let Some(stdin_err) = stdin_write_error {
// Ignore broken pipe errors for fast-exiting successful actions
// These occur when the process exits before we finish writing secrets to stdin
let is_broken_pipe =
stdin_err.contains("Broken pipe") || stdin_err.contains("os error 32");
let is_fast_exit = duration_ms < 500;
let is_success = exit_code == 0;
if is_broken_pipe && is_fast_exit && is_success {
debug!(
"Ignoring broken pipe error for fast-exiting successful action ({}ms)",
duration_ms
);
None
} else {
Some(stdin_err)
}
} else if exit_code != 0 {
Some(if stderr_result.content.is_empty() {
format!("Command exited with code {}", exit_code)
} else {
// Use last line of stderr as error, or full stderr if short
if stderr_result.content.lines().count() > 5 {
stderr_result
.content
.lines()
.last()
.unwrap_or("")
.to_string()
} else {
stderr_result.content.clone()
}
})
} else {
None
};
Ok(ExecutionResult {
exit_code,
// Only populate stdout if result wasn't parsed (avoid duplication)
stdout: if result.is_some() {
String::new()
} else {
stdout_result.content.clone()
},
stderr: stderr_result.content.clone(),
result,
duration_ms,
error,
stdout_truncated: stdout_result.truncated,
stderr_truncated: stderr_result.truncated,
stdout_bytes_truncated: stdout_result.bytes_truncated,
stderr_bytes_truncated: stderr_result.bytes_truncated,
})
}
/// Generate shell wrapper script that injects parameters and secrets directly.
///
/// Secrets are embedded as bash associative-array entries at generation time
/// so the wrapper has **zero external runtime dependencies** (no Python, jq,
/// etc.). The generated script is written to a temp file by the caller so
/// that secrets never appear in `/proc/<pid>/cmdline`.
fn generate_wrapper_script(&self, context: &ExecutionContext) -> RuntimeResult<String> {
let mut script = String::new();
// Add shebang
script.push_str("#!/bin/bash\n");
script.push_str("set -e\n\n"); // Exit on error
// Populate secrets associative array directly from Rust — no stdin
// reading, no JSON parsing, no external interpreters.
script.push_str("# Secrets (injected at generation time, not via environment)\n");
script.push_str("declare -A ATTUNE_SECRETS\n");
for (key, value) in &context.secrets {
let escaped_key = bash_single_quote_escape(key);
// Serialize structured JSON values to string for bash; plain strings used directly.
let val_str = match value {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
let escaped_val = bash_single_quote_escape(&val_str);
script.push_str(&format!(
"ATTUNE_SECRETS['{}']='{}'\n",
escaped_key, escaped_val
));
}
script.push('\n');
// Helper function to get secrets
script.push_str("# Helper function to access secrets\n");
script.push_str("get_secret() {\n");
script.push_str(" local name=\"$1\"\n");
script.push_str(" echo \"${ATTUNE_SECRETS[$name]}\"\n");
script.push_str("}\n\n");
// Export parameters as environment variables
script.push_str("# Action parameters\n");
for (key, value) in &context.parameters {
let value_str = match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => serde_json::to_string(value)?,
};
let escaped = bash_single_quote_escape(&value_str);
// Export with PARAM_ prefix for consistency
script.push_str(&format!(
"export PARAM_{}='{}'\n",
key.to_uppercase(),
escaped
));
// Also export without prefix for easier shell script writing
script.push_str(&format!("export {}='{}'\n", key, escaped));
}
script.push('\n');
// Add the action code
script.push_str("# Action code\n");
if let Some(code) = &context.code {
script.push_str(code);
}
Ok(script)
}
/// Execute shell script from file
#[allow(clippy::too_many_arguments)]
async fn execute_shell_file(
&self,
script_path: PathBuf,
_secrets: &std::collections::HashMap<String, String>,
env: &std::collections::HashMap<String, String>,
parameters_stdin: Option<&str>,
timeout_secs: Option<u64>,
max_stdout_bytes: usize,
max_stderr_bytes: usize,
output_format: OutputFormat,
) -> RuntimeResult<ExecutionResult> {
debug!("Executing shell file: {:?}", script_path,);
// Build command
let mut cmd = Command::new(&self.shell_path);
cmd.arg(&script_path);
// Add environment variables
for (key, value) in env {
cmd.env(key, value);
}
self.execute_with_streaming(
cmd,
&std::collections::HashMap::new(),
parameters_stdin,
timeout_secs,
max_stdout_bytes,
max_stderr_bytes,
output_format,
)
.await
}
}
impl Default for ShellRuntime {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Runtime for ShellRuntime {
fn name(&self) -> &str {
"shell"
}
fn can_execute(&self, context: &ExecutionContext) -> bool {
// Check if action reference suggests shell script
let is_shell = context.action_ref.contains(".sh")
|| context.entry_point.ends_with(".sh")
|| context
.code_path
.as_ref()
.map(|p| p.extension().and_then(|e| e.to_str()) == Some("sh"))
.unwrap_or(false)
|| context.entry_point == "bash"
|| context.entry_point == "sh"
|| context.entry_point == "shell";
is_shell
}
async fn execute(&self, context: ExecutionContext) -> RuntimeResult<ExecutionResult> {
info!(
"Executing shell action: {} (execution_id: {}) with parameter delivery: {:?}, format: {:?}",
context.action_ref, context.execution_id, context.parameter_delivery, context.parameter_format
);
info!(
"Action parameters (count: {}): {:?}",
context.parameters.len(),
context.parameters
);
// Merge secrets into parameters as a single JSON document.
// Actions receive everything via one readline() on stdin.
let mut merged_parameters = context.parameters.clone();
for (key, value) in &context.secrets {
merged_parameters.insert(key.clone(), value.clone());
}
// Prepare environment and parameters according to delivery method
let mut env = context.env.clone();
let config = ParameterDeliveryConfig {
delivery: context.parameter_delivery,
format: context.parameter_format,
};
let prepared_params =
parameter_passing::prepare_parameters(&merged_parameters, &mut env, config)?;
// Get stdin content if parameters are delivered via stdin
let parameters_stdin = prepared_params.stdin_content();
if let Some(stdin_data) = parameters_stdin {
info!(
"Parameters to be sent via stdin (length: {} bytes):\n{}",
stdin_data.len(),
stdin_data
);
} else {
info!("No parameters will be sent via stdin");
}
// If code_path is provided, execute the file directly.
// Secrets are already merged into parameters — no separate secrets arg needed.
if let Some(code_path) = &context.code_path {
return self
.execute_shell_file(
code_path.clone(),
&HashMap::new(),
&env,
parameters_stdin,
context.timeout,
context.max_stdout_bytes,
context.max_stderr_bytes,
context.output_format,
)
.await;
}
// Otherwise, generate wrapper script and execute.
// Secrets and parameters are embedded directly in the wrapper script
// by generate_wrapper_script(), so we write it to a temp file (to keep
// secrets out of /proc/cmdline) and pass no secrets/params via stdin.
let script = self.generate_wrapper_script(&context)?;
// Write wrapper to a temp file so secrets are not exposed in the
// process command line (which would happen with `bash -c "..."`).
let wrapper_dir = self.work_dir.join("wrappers");
tokio::fs::create_dir_all(&wrapper_dir).await.map_err(|e| {
RuntimeError::ExecutionFailed(format!("Failed to create wrapper directory: {}", e))
})?;
let wrapper_path = wrapper_dir.join(format!("wrapper_{}.sh", context.execution_id));
tokio::fs::write(&wrapper_path, &script)
.await
.map_err(|e| {
RuntimeError::ExecutionFailed(format!("Failed to write wrapper script: {}", e))
})?;
let result = self
.execute_shell_file(
wrapper_path.clone(),
&HashMap::new(), // secrets are in the script, not stdin
&env,
None,
context.timeout,
context.max_stdout_bytes,
context.max_stderr_bytes,
context.output_format,
)
.await;
// Clean up wrapper file (best-effort)
let _ = tokio::fs::remove_file(&wrapper_path).await;
result
}
async fn setup(&self) -> RuntimeResult<()> {
info!("Setting up Shell runtime");
// Ensure work directory exists
tokio::fs::create_dir_all(&self.work_dir)
.await
.map_err(|e| RuntimeError::SetupError(format!("Failed to create work dir: {}", e)))?;
// Verify shell is available
let output = Command::new(&self.shell_path)
.arg("--version")
.output()
.await
.map_err(|e| {
RuntimeError::SetupError(format!("Shell not found at {:?}: {}", self.shell_path, e))
})?;
if !output.status.success() {
return Err(RuntimeError::SetupError(
"Shell interpreter is not working".to_string(),
));
}
let version = String::from_utf8_lossy(&output.stdout);
info!("Shell runtime ready: {}", version.trim());
Ok(())
}
async fn cleanup(&self) -> RuntimeResult<()> {
info!("Cleaning up Shell runtime");
// Could clean up temporary files here
Ok(())
}
async fn validate(&self) -> RuntimeResult<()> {
debug!("Validating Shell runtime");
// Check if shell is available
let output = Command::new(&self.shell_path)
.arg("-c")
.arg("echo 'test'")
.output()
.await
.map_err(|e| RuntimeError::SetupError(format!("Shell validation failed: {}", e)))?;
if !output.status.success() {
return Err(RuntimeError::SetupError(
"Shell interpreter validation failed".to_string(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[tokio::test]
async fn test_shell_runtime_simple() {
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 1,
action_ref: "test.simple".to_string(),
parameters: HashMap::new(),
env: HashMap::new(),
secrets: HashMap::new(),
timeout: Some(10),
working_dir: None,
entry_point: "shell".to_string(),
code: Some("echo 'Hello, World!'".to_string()),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::default(),
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(result.is_success());
assert_eq!(result.exit_code, 0);
assert!(result.stdout.contains("Hello, World!"));
}
#[tokio::test]
async fn test_shell_runtime_with_params() {
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 2,
action_ref: "test.params".to_string(),
parameters: {
let mut map = HashMap::new();
map.insert("name".to_string(), serde_json::json!("Alice"));
map
},
env: HashMap::new(),
secrets: HashMap::new(),
timeout: Some(10),
working_dir: None,
entry_point: "shell".to_string(),
code: Some("echo \"Hello, $name!\"".to_string()),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::default(),
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(result.is_success());
assert!(result.stdout.contains("Hello, Alice!"));
}
#[tokio::test]
async fn test_shell_runtime_timeout() {
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 3,
action_ref: "test.timeout".to_string(),
parameters: HashMap::new(),
env: HashMap::new(),
secrets: HashMap::new(),
timeout: Some(1),
working_dir: None,
entry_point: "shell".to_string(),
code: Some("sleep 10".to_string()),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::default(),
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(!result.is_success());
assert!(result.error.is_some());
let error_msg = result.error.unwrap();
assert!(error_msg.contains("timeout") || error_msg.contains("timed out"));
}
#[tokio::test]
async fn test_shell_runtime_error() {
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 4,
action_ref: "test.error".to_string(),
parameters: HashMap::new(),
env: HashMap::new(),
secrets: HashMap::new(),
timeout: Some(10),
working_dir: None,
entry_point: "shell".to_string(),
code: Some("exit 1".to_string()),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::default(),
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(!result.is_success());
assert_eq!(result.exit_code, 1);
}
#[tokio::test]
async fn test_shell_runtime_with_secrets() {
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 5,
action_ref: "test.secrets".to_string(),
parameters: HashMap::new(),
env: HashMap::new(),
secrets: {
let mut s = HashMap::new();
s.insert("api_key".to_string(), serde_json::json!("secret_key_12345"));
s.insert(
"db_password".to_string(),
serde_json::json!("super_secret_pass"),
);
s
},
timeout: Some(10),
working_dir: None,
entry_point: "shell".to_string(),
code: Some(
r#"
# Access secrets via get_secret function
api_key=$(get_secret 'api_key')
db_pass=$(get_secret 'db_password')
missing=$(get_secret 'nonexistent')
echo "api_key=$api_key"
echo "db_pass=$db_pass"
echo "missing=$missing"
"#
.to_string(),
),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::default(),
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(result.is_success());
assert_eq!(result.exit_code, 0);
// Verify secrets are accessible in action code
assert!(result.stdout.contains("api_key=secret_key_12345"));
assert!(result.stdout.contains("db_pass=super_secret_pass"));
assert!(result.stdout.contains("missing="));
}
#[tokio::test]
async fn test_shell_runtime_jsonl_output() {
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 6,
action_ref: "test.jsonl".to_string(),
parameters: HashMap::new(),
env: HashMap::new(),
secrets: HashMap::new(),
timeout: Some(10),
working_dir: None,
entry_point: "shell".to_string(),
code: Some(
r#"
echo '{"id": 1, "name": "Alice"}'
echo '{"id": 2, "name": "Bob"}'
echo '{"id": 3, "name": "Charlie"}'
"#
.to_string(),
),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::Jsonl,
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(result.is_success());
assert_eq!(result.exit_code, 0);
// Verify stdout is not populated when result is parsed (avoid duplication)
assert!(
result.stdout.is_empty(),
"stdout should be empty when result is parsed"
);
// Verify result is parsed as an array of JSON objects
let parsed_result = result.result.expect("Should have parsed result");
assert!(parsed_result.is_array());
let items = parsed_result.as_array().unwrap();
assert_eq!(items.len(), 3);
// Verify first item
assert_eq!(items[0]["id"], 1);
assert_eq!(items[0]["name"], "Alice");
// Verify second item
assert_eq!(items[1]["id"], 2);
assert_eq!(items[1]["name"], "Bob");
// Verify third item
assert_eq!(items[2]["id"], 3);
assert_eq!(items[2]["name"], "Charlie");
}
#[tokio::test]
async fn test_shell_runtime_multiline_json_output() {
// Regression test: scripts that embed pretty-printed JSON (e.g., http_request.sh
// embedding a multi-line response body in its "json" field) produce multi-line
// stdout. The parser must handle this by trying to parse the full stdout as JSON
// before falling back to last-line parsing.
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 7,
action_ref: "test.multiline_json".to_string(),
parameters: HashMap::new(),
env: HashMap::new(),
secrets: HashMap::new(),
timeout: Some(10),
working_dir: None,
entry_point: "shell".to_string(),
code: Some(
r#"
# Simulate http_request.sh output with embedded pretty-printed JSON
printf '{"status_code":200,"body":"hello","json":{\n "args": {\n "hello": "world"\n },\n "url": "https://example.com"\n},"success":true}\n'
"#
.to_string(),
),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::Json,
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(result.is_success());
assert_eq!(result.exit_code, 0);
// Verify result was parsed (not stored as raw stdout)
let parsed = result
.result
.expect("Multi-line JSON should be parsed successfully");
assert_eq!(parsed["status_code"], 200);
assert_eq!(parsed["success"], true);
assert_eq!(parsed["json"]["args"]["hello"], "world");
// stdout should be empty when result is successfully parsed
assert!(
result.stdout.is_empty(),
"stdout should be empty when result is parsed, got: {}",
result.stdout
);
}
#[tokio::test]
async fn test_shell_runtime_json_with_log_prefix() {
// Verify last-line fallback still works: scripts that log to stdout
// before the final JSON line should still parse correctly.
let runtime = ShellRuntime::new();
let context = ExecutionContext {
execution_id: 8,
action_ref: "test.json_with_logs".to_string(),
parameters: HashMap::new(),
env: HashMap::new(),
secrets: HashMap::new(),
timeout: Some(10),
working_dir: None,
entry_point: "shell".to_string(),
code: Some(
r#"
echo "Starting action..."
echo "Processing data..."
echo '{"result": "success", "count": 42}'
"#
.to_string(),
),
code_path: None,
runtime_name: Some("shell".to_string()),
runtime_config_override: None,
runtime_env_dir_suffix: None,
selected_runtime_version: None,
max_stdout_bytes: 10 * 1024 * 1024,
max_stderr_bytes: 10 * 1024 * 1024,
parameter_delivery: attune_common::models::ParameterDelivery::default(),
parameter_format: attune_common::models::ParameterFormat::default(),
output_format: attune_common::models::OutputFormat::Json,
cancel_token: None,
};
let result = runtime.execute(context).await.unwrap();
assert!(result.is_success());
let parsed = result.result.expect("Last-line JSON should be parsed");
assert_eq!(parsed["result"], "success");
assert_eq!(parsed["count"], 42);
}
}

View File

@@ -27,7 +27,7 @@ use attune_common::runtime_detection::runtime_in_filter;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -44,7 +44,6 @@ use crate::registration::WorkerRegistration;
use crate::runtime::local::LocalRuntime;
use crate::runtime::native::NativeRuntime;
use crate::runtime::process::ProcessRuntime;
use crate::runtime::shell::ShellRuntime;
use crate::runtime::RuntimeRegistry;
use crate::secrets::SecretManager;
use crate::version_verify;
@@ -89,8 +88,11 @@ pub struct WorkerService {
in_flight_tasks: Arc<Mutex<JoinSet<()>>>,
/// Maps execution ID → CancellationToken for running processes.
/// When a cancel request arrives, the token is triggered, causing
/// the process executor to send SIGINT → SIGTERM → SIGKILL.
/// the process executor to send SIGTERM → SIGKILL.
cancel_tokens: Arc<Mutex<HashMap<i64, CancellationToken>>>,
/// Tracks cancellation requests that arrived before the in-memory token
/// for an execution had been registered.
pending_cancellations: Arc<Mutex<HashSet<i64>>>,
}
impl WorkerService {
@@ -263,9 +265,29 @@ impl WorkerService {
if runtime_registry.list_runtimes().is_empty() {
info!("No runtimes loaded from database, registering built-in defaults");
// Shell runtime (always available)
runtime_registry.register(Box::new(ShellRuntime::new()));
info!("Registered built-in Shell runtime");
// Shell runtime (always available) via generic ProcessRuntime
let shell_runtime = ProcessRuntime::new(
"shell".to_string(),
attune_common::models::runtime::RuntimeExecutionConfig {
interpreter: attune_common::models::runtime::InterpreterConfig {
binary: "/bin/bash".to_string(),
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: attune_common::models::runtime::InlineExecutionConfig {
strategy: attune_common::models::runtime::InlineExecutionStrategy::TempFile,
extension: Some(".sh".to_string()),
inject_shell_helpers: true,
},
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
},
packs_base_dir.clone(),
runtime_envs_dir.clone(),
);
runtime_registry.register(Box::new(shell_runtime));
info!("Registered built-in shell ProcessRuntime");
// Native runtime (for compiled binaries)
runtime_registry.register(Box::new(NativeRuntime::new()));
@@ -379,6 +401,7 @@ impl WorkerService {
execution_semaphore: Arc::new(Semaphore::new(max_concurrent_tasks)),
in_flight_tasks: Arc::new(Mutex::new(JoinSet::new())),
cancel_tokens: Arc::new(Mutex::new(HashMap::new())),
pending_cancellations: Arc::new(Mutex::new(HashSet::new())),
})
}
@@ -755,6 +778,7 @@ impl WorkerService {
let semaphore = self.execution_semaphore.clone();
let in_flight = self.in_flight_tasks.clone();
let cancel_tokens = self.cancel_tokens.clone();
let pending_cancellations = self.pending_cancellations.clone();
// Spawn the consumer loop as a background task so start() can return
let handle = tokio::spawn(async move {
@@ -768,6 +792,7 @@ impl WorkerService {
let semaphore = semaphore.clone();
let in_flight = in_flight.clone();
let cancel_tokens = cancel_tokens.clone();
let pending_cancellations = pending_cancellations.clone();
async move {
let execution_id = envelope.payload.execution_id;
@@ -794,6 +819,16 @@ impl WorkerService {
let mut tokens = cancel_tokens.lock().await;
tokens.insert(execution_id, cancel_token.clone());
}
{
let pending = pending_cancellations.lock().await;
if pending.contains(&execution_id) {
info!(
"Execution {} already had a pending cancel request; cancelling immediately",
execution_id
);
cancel_token.cancel();
}
}
// Spawn the actual execution as a background task so this
// handler returns immediately, acking the message and freeing
@@ -819,6 +854,8 @@ impl WorkerService {
// Remove the cancel token now that execution is done
let mut tokens = cancel_tokens.lock().await;
tokens.remove(&execution_id);
let mut pending = pending_cancellations.lock().await;
pending.remove(&execution_id);
});
Ok(())
@@ -1060,6 +1097,7 @@ impl WorkerService {
let consumer_for_task = consumer.clone();
let cancel_tokens = self.cancel_tokens.clone();
let pending_cancellations = self.pending_cancellations.clone();
let queue_name_for_log = queue_name.clone();
let handle = tokio::spawn(async move {
@@ -1071,11 +1109,17 @@ impl WorkerService {
.consume_with_handler(
move |envelope: MessageEnvelope<ExecutionCancelRequestedPayload>| {
let cancel_tokens = cancel_tokens.clone();
let pending_cancellations = pending_cancellations.clone();
async move {
let execution_id = envelope.payload.execution_id;
info!("Received cancel request for execution {}", execution_id);
{
let mut pending = pending_cancellations.lock().await;
pending.insert(execution_id);
}
let tokens = cancel_tokens.lock().await;
if let Some(token) = tokens.get(&execution_id) {
info!("Triggering cancellation for execution {}", execution_id);

View File

@@ -9,7 +9,8 @@
//! This keeps the pack directory clean and read-only.
use attune_common::models::runtime::{
DependencyConfig, EnvironmentConfig, InterpreterConfig, RuntimeExecutionConfig,
DependencyConfig, EnvironmentConfig, InlineExecutionConfig, InterpreterConfig,
RuntimeExecutionConfig,
};
use attune_worker::runtime::process::ProcessRuntime;
use attune_worker::runtime::ExecutionContext;
@@ -26,6 +27,7 @@ fn make_python_config() -> RuntimeExecutionConfig {
args: vec!["-u".to_string()],
file_extension: Some(".py".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: Some(EnvironmentConfig {
env_type: "virtualenv".to_string(),
dir_name: ".venv".to_string(),
@@ -59,6 +61,7 @@ fn make_shell_config() -> RuntimeExecutionConfig {
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),

View File

@@ -3,9 +3,11 @@
//! Tests that verify stdout/stderr are properly truncated when they exceed
//! configured size limits, preventing OOM issues with large output.
use attune_common::models::runtime::{InterpreterConfig, RuntimeExecutionConfig};
use attune_common::models::runtime::{
InlineExecutionConfig, InlineExecutionStrategy, InterpreterConfig, RuntimeExecutionConfig,
};
use attune_worker::runtime::process::ProcessRuntime;
use attune_worker::runtime::{ExecutionContext, Runtime, ShellRuntime};
use attune_worker::runtime::{ExecutionContext, Runtime};
use std::collections::HashMap;
use std::path::PathBuf;
use tempfile::TempDir;
@@ -17,6 +19,7 @@ fn make_python_process_runtime(packs_base_dir: PathBuf) -> ProcessRuntime {
args: vec!["-u".to_string()],
file_extension: Some(".py".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
@@ -29,6 +32,30 @@ fn make_python_process_runtime(packs_base_dir: PathBuf) -> ProcessRuntime {
)
}
fn make_shell_process_runtime(packs_base_dir: PathBuf) -> ProcessRuntime {
let config = RuntimeExecutionConfig {
interpreter: InterpreterConfig {
binary: "/bin/bash".to_string(),
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: InlineExecutionConfig {
strategy: InlineExecutionStrategy::TempFile,
extension: Some(".sh".to_string()),
inject_shell_helpers: true,
},
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
};
ProcessRuntime::new(
"shell".to_string(),
config,
packs_base_dir.clone(),
packs_base_dir.join("../runtime_envs"),
)
}
fn make_python_context(
execution_id: i64,
action_ref: &str,
@@ -110,7 +137,8 @@ async fn test_python_stderr_truncation() {
#[tokio::test]
async fn test_shell_stdout_truncation() {
let runtime = ShellRuntime::new();
let tmp = TempDir::new().unwrap();
let runtime = make_shell_process_runtime(tmp.path().to_path_buf());
// Shell script that outputs more than the limit
let code = r#"
@@ -270,6 +298,7 @@ async fn test_shell_process_runtime_truncation() {
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),

View File

@@ -3,9 +3,10 @@
//! These tests verify that secrets are NOT exposed in process environment
//! or command-line arguments, ensuring secure secret passing via stdin.
use attune_common::models::runtime::{InterpreterConfig, RuntimeExecutionConfig};
use attune_common::models::runtime::{
InlineExecutionConfig, InlineExecutionStrategy, InterpreterConfig, RuntimeExecutionConfig,
};
use attune_worker::runtime::process::ProcessRuntime;
use attune_worker::runtime::shell::ShellRuntime;
use attune_worker::runtime::{ExecutionContext, Runtime};
use std::collections::HashMap;
use std::path::PathBuf;
@@ -18,6 +19,7 @@ fn make_python_process_runtime(packs_base_dir: PathBuf) -> ProcessRuntime {
args: vec!["-u".to_string()],
file_extension: Some(".py".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
@@ -34,6 +36,34 @@ fn make_python_process_runtime(packs_base_dir: PathBuf) -> ProcessRuntime {
)
}
fn make_shell_process_runtime(packs_base_dir: PathBuf) -> ProcessRuntime {
let config = RuntimeExecutionConfig {
interpreter: InterpreterConfig {
binary: "/bin/bash".to_string(),
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: InlineExecutionConfig {
strategy: InlineExecutionStrategy::TempFile,
extension: Some(".sh".to_string()),
inject_shell_helpers: true,
},
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
};
let runtime_envs_dir = packs_base_dir
.parent()
.unwrap_or(&packs_base_dir)
.join("runtime_envs");
ProcessRuntime::new(
"shell".to_string(),
config,
packs_base_dir,
runtime_envs_dir,
)
}
#[tokio::test]
async fn test_python_secrets_not_in_environ() {
let tmp = TempDir::new().unwrap();
@@ -114,7 +144,8 @@ print(json.dumps(result))
#[tokio::test]
async fn test_shell_secrets_not_in_environ() {
let runtime = ShellRuntime::new();
let tmp = TempDir::new().unwrap();
let runtime = make_shell_process_runtime(tmp.path().to_path_buf());
let context = ExecutionContext {
execution_id: 2,
@@ -151,21 +182,21 @@ if printenv | grep -q "SECRET_API_KEY"; then
exit 1
fi
# But secrets SHOULD be accessible via get_secret function
api_key=$(get_secret 'api_key')
password=$(get_secret 'password')
# Shell inline execution receives the merged input set as ordinary variables
api_key="$api_key"
password="$password"
if [ "$api_key" != "super_secret_key_do_not_expose" ]; then
echo "ERROR: Secret not accessible via get_secret"
echo "ERROR: Secret not accessible via merged inputs"
exit 1
fi
if [ "$password" != "secret_pass_123" ]; then
echo "ERROR: Password not accessible via get_secret"
echo "ERROR: Password not accessible via merged inputs"
exit 1
fi
echo "SECURITY_PASS: Secrets not in environment but accessible via get_secret"
echo "SECURITY_PASS: Secrets not in inherited environment and accessible via merged inputs"
"#
.to_string(),
),
@@ -363,7 +394,8 @@ print("ok")
#[tokio::test]
async fn test_shell_empty_secrets() {
let runtime = ShellRuntime::new();
let tmp = TempDir::new().unwrap();
let runtime = make_shell_process_runtime(tmp.path().to_path_buf());
let context = ExecutionContext {
execution_id: 6,
@@ -376,12 +408,11 @@ async fn test_shell_empty_secrets() {
entry_point: "shell".to_string(),
code: Some(
r#"
# get_secret should return empty string for non-existent secrets
result=$(get_secret 'nonexistent')
if [ -z "$result" ]; then
echo "PASS: Empty secret returns empty string"
# Unset merged inputs should expand to empty string
if [ -z "$nonexistent" ] && [ -z "$PARAM_NONEXISTENT" ]; then
echo "PASS: Missing input expands to empty string"
else
echo "FAIL: Expected empty string"
echo "FAIL: Expected empty string for missing input"
exit 1
fi
"#
@@ -440,6 +471,7 @@ echo "PASS: No secrets in environment"
args: vec![],
file_extension: Some(".sh".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),
@@ -520,6 +552,7 @@ print(json.dumps({"leaked": leaked}))
args: vec!["-u".to_string()],
file_extension: Some(".py".to_string()),
},
inline_execution: InlineExecutionConfig::default(),
environment: None,
dependencies: None,
env_vars: std::collections::HashMap::new(),

View File

@@ -109,6 +109,8 @@ services:
SOURCE_PACKS_DIR: /source/packs
TARGET_PACKS_DIR: /opt/attune/packs
LOADER_SCRIPT: /scripts/load_core_pack.py
DEFAULT_ADMIN_LOGIN: test@attune.local
DEFAULT_ADMIN_PERMISSION_SET_REF: core.admin
command: ["/bin/sh", "/init-packs.sh"]
depends_on:
migrations:

View File

@@ -26,6 +26,8 @@ TARGET_PACKS_DIR="${TARGET_PACKS_DIR:-/opt/attune/packs}"
# Python loader script
LOADER_SCRIPT="${LOADER_SCRIPT:-/scripts/load_core_pack.py}"
DEFAULT_ADMIN_LOGIN="${DEFAULT_ADMIN_LOGIN:-}"
DEFAULT_ADMIN_PERMISSION_SET_REF="${DEFAULT_ADMIN_PERMISSION_SET_REF:-core.admin}"
echo ""
echo -e "${BLUE}╔════════════════════════════════════════════════╗${NC}"
@@ -205,6 +207,63 @@ else
echo -e "${BLUE}${NC} You can manually load them later"
fi
if [ -n "$DEFAULT_ADMIN_LOGIN" ] && [ "$LOADED_COUNT" -gt 0 ]; then
echo ""
echo -e "${BLUE}Bootstrapping local admin assignment...${NC}"
if python3 - <<PY
import psycopg2
import sys
conn = psycopg2.connect(
host="${DB_HOST}",
port=${DB_PORT},
user="${DB_USER}",
password="${DB_PASSWORD}",
dbname="${DB_NAME}",
)
conn.autocommit = False
try:
with conn.cursor() as cur:
cur.execute("SET search_path TO ${DB_SCHEMA}, public")
cur.execute("SELECT id FROM identity WHERE login = %s", ("${DEFAULT_ADMIN_LOGIN}",))
identity_row = cur.fetchone()
if identity_row is None:
print(" ⚠ Default admin identity not found; skipping assignment")
conn.rollback()
sys.exit(0)
cur.execute("SELECT id FROM permission_set WHERE ref = %s", ("${DEFAULT_ADMIN_PERMISSION_SET_REF}",))
permset_row = cur.fetchone()
if permset_row is None:
print(" ⚠ Default admin permission set not found; skipping assignment")
conn.rollback()
sys.exit(0)
cur.execute(
"""
INSERT INTO permission_assignment (identity, permset)
VALUES (%s, %s)
ON CONFLICT (identity, permset) DO NOTHING
""",
(identity_row[0], permset_row[0]),
)
conn.commit()
print(" ✓ Default admin permission assignment ensured")
except Exception as exc:
conn.rollback()
print(f" ✗ Failed to ensure default admin assignment: {exc}")
sys.exit(1)
finally:
conn.close()
PY
then
:
else
exit 1
fi
fi
# Summary
echo ""
echo -e "${GREEN}╔════════════════════════════════════════════════╗${NC}"

View File

@@ -26,6 +26,7 @@ CREATE TABLE execution (
parent BIGINT, -- self-reference; no FK because execution becomes a hypertable
enforcement BIGINT, -- references enforcement(id); no FK (both are hypertables)
executor BIGINT, -- references identity(id); no FK because execution becomes a hypertable
worker BIGINT, -- references worker(id); no FK because execution becomes a hypertable
status execution_status_enum NOT NULL DEFAULT 'requested',
result JSONB,
started_at TIMESTAMPTZ, -- set when execution transitions to 'running'
@@ -49,6 +50,7 @@ CREATE INDEX idx_execution_action_ref ON execution(action_ref);
CREATE INDEX idx_execution_parent ON execution(parent);
CREATE INDEX idx_execution_enforcement ON execution(enforcement);
CREATE INDEX idx_execution_executor ON execution(executor);
CREATE INDEX idx_execution_worker ON execution(worker);
CREATE INDEX idx_execution_status ON execution(status);
CREATE INDEX idx_execution_created ON execution(created DESC);
CREATE INDEX idx_execution_updated ON execution(updated DESC);
@@ -56,6 +58,7 @@ CREATE INDEX idx_execution_status_created ON execution(status, created DESC);
CREATE INDEX idx_execution_status_updated ON execution(status, updated DESC);
CREATE INDEX idx_execution_action_status ON execution(action, status);
CREATE INDEX idx_execution_executor_created ON execution(executor, created DESC);
CREATE INDEX idx_execution_worker_created ON execution(worker, created DESC);
CREATE INDEX idx_execution_parent_created ON execution(parent, created DESC);
CREATE INDEX idx_execution_result_gin ON execution USING GIN (result);
CREATE INDEX idx_execution_env_vars_gin ON execution USING GIN (env_vars);
@@ -77,6 +80,7 @@ COMMENT ON COLUMN execution.env_vars IS 'Environment variables for this executio
COMMENT ON COLUMN execution.parent IS 'Parent execution ID for workflow hierarchies (no FK — execution is a hypertable)';
COMMENT ON COLUMN execution.enforcement IS 'Enforcement that triggered this execution (no FK — both are hypertables)';
COMMENT ON COLUMN execution.executor IS 'Identity that initiated the execution (no FK — execution is a hypertable)';
COMMENT ON COLUMN execution.worker IS 'Assigned worker handling this execution (no FK — execution is a hypertable)';
COMMENT ON COLUMN execution.status IS 'Current execution lifecycle status';
COMMENT ON COLUMN execution.result IS 'Execution output/results';
COMMENT ON COLUMN execution.retry_count IS 'Current retry attempt number (0 = first attempt, 1 = first retry, etc.)';

View File

@@ -196,7 +196,7 @@ COMMENT ON TABLE execution IS 'Executions represent action runs with workflow su
-- ----------------------------------------------------------------------------
-- execution history trigger
-- Tracked fields: status, result, executor, workflow_task, env_vars, started_at
-- Tracked fields: status, result, executor, worker, workflow_task, env_vars, started_at
-- Note: result uses _jsonb_digest_summary() to avoid storing large payloads
-- ----------------------------------------------------------------------------
@@ -214,6 +214,7 @@ BEGIN
'status', NEW.status,
'action_ref', NEW.action_ref,
'executor', NEW.executor,
'worker', NEW.worker,
'parent', NEW.parent,
'enforcement', NEW.enforcement,
'started_at', NEW.started_at
@@ -249,6 +250,12 @@ BEGIN
new_vals := new_vals || jsonb_build_object('executor', NEW.executor);
END IF;
IF OLD.worker IS DISTINCT FROM NEW.worker THEN
changed := array_append(changed, 'worker');
old_vals := old_vals || jsonb_build_object('worker', OLD.worker);
new_vals := new_vals || jsonb_build_object('worker', NEW.worker);
END IF;
IF OLD.workflow_task IS DISTINCT FROM NEW.workflow_task THEN
changed := array_append(changed, 'workflow_task');
old_vals := old_vals || jsonb_build_object('workflow_task', OLD.workflow_task);

View File

@@ -0,0 +1,36 @@
ref: core.admin
label: Admin
description: Full administrative access across Attune resources.
grants:
- resource: packs
actions: [read, create, update, delete]
- resource: actions
actions: [read, create, update, delete, execute]
- resource: rules
actions: [read, create, update, delete]
- resource: triggers
actions: [read, create, update, delete]
- resource: executions
actions: [read, create, update, delete, cancel]
- resource: events
actions: [read, create, delete]
- resource: enforcements
actions: [read, create, delete]
- resource: inquiries
actions: [read, create, update, delete, respond]
- resource: keys
actions: [read, create, update, delete]
- resource: artifacts
actions: [read, create, update, delete]
- resource: workflows
actions: [read, create, update, delete]
- resource: webhooks
actions: [read, create, update, delete]
- resource: analytics
actions: [read]
- resource: history
actions: [read]
- resource: identities
actions: [read, create, update, delete]
- resource: permissions
actions: [read, create, update, delete, manage]

View File

@@ -0,0 +1,24 @@
ref: core.editor
label: Editor
description: Create and update operational resources without full administrative control.
grants:
- resource: packs
actions: [read, create, update]
- resource: actions
actions: [read, create, update, execute]
- resource: rules
actions: [read, create, update]
- resource: triggers
actions: [read]
- resource: executions
actions: [read, create, cancel]
- resource: keys
actions: [read, update]
- resource: artifacts
actions: [read]
- resource: workflows
actions: [read, create, update]
- resource: analytics
actions: [read]
- resource: history
actions: [read]

View File

@@ -0,0 +1,20 @@
ref: core.executor
label: Executor
description: Read operational metadata and trigger executions without changing system definitions.
grants:
- resource: packs
actions: [read]
- resource: actions
actions: [read, execute]
- resource: rules
actions: [read]
- resource: triggers
actions: [read]
- resource: executions
actions: [read, create]
- resource: artifacts
actions: [read]
- resource: analytics
actions: [read]
- resource: history
actions: [read]

View File

@@ -0,0 +1,20 @@
ref: core.viewer
label: Viewer
description: Read-only access to operational metadata and execution visibility.
grants:
- resource: packs
actions: [read]
- resource: actions
actions: [read]
- resource: rules
actions: [read]
- resource: triggers
actions: [read]
- resource: executions
actions: [read]
- resource: artifacts
actions: [read]
- resource: analytics
actions: [read]
- resource: history
actions: [read]

View File

@@ -32,3 +32,7 @@ execution_config:
binary: "/bin/bash"
args: []
file_extension: ".sh"
inline_execution:
strategy: temp_file
extension: ".sh"
inject_shell_helpers: true

View File

@@ -3,8 +3,8 @@
Pack Loader for Attune
This script loads a pack from the filesystem into the database.
It reads pack.yaml, action definitions, trigger definitions, and sensor definitions
and creates all necessary database entries.
It reads pack.yaml, permission set definitions, action definitions, trigger
definitions, and sensor definitions and creates all necessary database entries.
Usage:
python3 scripts/load_core_pack.py [--database-url URL] [--pack-dir DIR] [--pack-name NAME]
@@ -147,6 +147,70 @@ class PackLoader:
print(f"✓ Pack '{ref}' loaded (ID: {self.pack_id})")
return self.pack_id
def upsert_permission_sets(self) -> Dict[str, int]:
"""Load permission set definitions from permission_sets/*.yaml."""
print("\n→ Loading permission sets...")
permission_sets_dir = self.pack_dir / "permission_sets"
if not permission_sets_dir.exists():
print(" No permission_sets directory found")
return {}
permission_set_ids = {}
cursor = self.conn.cursor()
for yaml_file in sorted(permission_sets_dir.glob("*.yaml")):
permission_set_data = self.load_yaml(yaml_file)
if not permission_set_data:
continue
ref = permission_set_data.get("ref")
if not ref:
print(
f" ⚠ Permission set YAML {yaml_file.name} missing 'ref' field, skipping"
)
continue
label = permission_set_data.get("label")
description = permission_set_data.get("description")
grants = permission_set_data.get("grants", [])
if not isinstance(grants, list):
print(
f" ⚠ Permission set '{ref}' has non-array grants, skipping"
)
continue
cursor.execute(
"""
INSERT INTO permission_set (
ref, pack, pack_ref, label, description, grants
)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (ref) DO UPDATE SET
label = EXCLUDED.label,
description = EXCLUDED.description,
grants = EXCLUDED.grants,
updated = NOW()
RETURNING id
""",
(
ref,
self.pack_id,
self.pack_ref,
label,
description,
json.dumps(grants),
),
)
permission_set_id = cursor.fetchone()[0]
permission_set_ids[ref] = permission_set_id
print(f" ✓ Permission set '{ref}' (ID: {permission_set_id})")
cursor.close()
return permission_set_ids
def upsert_triggers(self) -> Dict[str, int]:
"""Load trigger definitions"""
print("\n→ Loading triggers...")
@@ -708,11 +772,12 @@ class PackLoader:
"""Main loading process.
Components are loaded in dependency order:
1. Runtimes (no dependencies)
2. Triggers (no dependencies)
3. Actions (depend on runtime; workflow actions also create
1. Permission sets (no dependencies)
2. Runtimes (no dependencies)
3. Triggers (no dependencies)
4. Actions (depend on runtime; workflow actions also create
workflow_definition records)
4. Sensors (depend on triggers and runtime)
5. Sensors (depend on triggers and runtime)
"""
print("=" * 60)
print(f"Pack Loader - {self.pack_name}")
@@ -727,7 +792,10 @@ class PackLoader:
# Load pack metadata
self.upsert_pack()
# Load runtimes first (actions and sensors depend on them)
# Load permission sets first (authorization metadata)
permission_set_ids = self.upsert_permission_sets()
# Load runtimes (actions and sensors depend on them)
runtime_ids = self.upsert_runtimes()
# Load triggers
@@ -746,6 +814,7 @@ class PackLoader:
print(f"✓ Pack '{self.pack_name}' loaded successfully!")
print("=" * 60)
print(f" Pack ID: {self.pack_id}")
print(f" Permission sets: {len(permission_set_ids)}")
print(f" Runtimes: {len(set(runtime_ids.values()))}")
print(f" Triggers: {len(trigger_ids)}")
print(f" Actions: {len(action_ids)}")

View File

@@ -32,7 +32,7 @@ export type ApiResponse_ExecutionResponse = {
*/
enforcement?: number | null;
/**
* Executor ID (worker/executor that ran this)
* Identity ID that initiated this execution
*/
executor?: number | null;
/**
@@ -43,6 +43,10 @@ export type ApiResponse_ExecutionResponse = {
* Parent execution ID (for nested/child executions)
*/
parent?: number | null;
/**
* Worker ID currently assigned to this execution
*/
worker?: number | null;
/**
* Execution result/output
*/

View File

@@ -28,7 +28,7 @@ export type ExecutionResponse = {
*/
enforcement?: number | null;
/**
* Executor ID (worker/executor that ran this)
* Identity ID that initiated this execution
*/
executor?: number | null;
/**
@@ -39,6 +39,10 @@ export type ExecutionResponse = {
* Parent execution ID (for nested/child executions)
*/
parent?: number | null;
/**
* Worker ID currently assigned to this execution
*/
worker?: number | null;
/**
* Execution result/output
*/

View File

@@ -224,7 +224,7 @@ export class ExecutionsService {
*/
enforcement?: number | null;
/**
* Executor ID (worker/executor that ran this)
* Identity ID that initiated this execution
*/
executor?: number | null;
/**
@@ -235,6 +235,10 @@ export class ExecutionsService {
* Parent execution ID (for nested/child executions)
*/
parent?: number | null;
/**
* Worker ID currently assigned to this execution
*/
worker?: number | null;
/**
* Execution result/output
*/

View File

@@ -337,11 +337,16 @@ function TextFileDetail({
interface ProgressDetailProps {
artifactId: number;
isRunning?: boolean;
onClose: () => void;
}
function ProgressDetail({ artifactId, onClose }: ProgressDetailProps) {
const { data: artifactData, isLoading } = useArtifact(artifactId);
function ProgressDetail({
artifactId,
isRunning = false,
onClose,
}: ProgressDetailProps) {
const { data: artifactData, isLoading } = useArtifact(artifactId, isRunning);
const artifact = artifactData?.data;
const progressEntries = useMemo(() => {
@@ -707,6 +712,7 @@ export default function ExecutionArtifactsPanel({
<div className="px-3">
<ProgressDetail
artifactId={artifact.id}
isRunning={isRunning}
onClose={() => setExpandedProgressId(null)}
/>
</div>

View File

@@ -69,9 +69,10 @@ const ExecutionPreviewPanel = memo(function ExecutionPreviewPanel({
execution?.status === "running" ||
execution?.status === "scheduling" ||
execution?.status === "scheduled" ||
execution?.status === "requested";
execution?.status === "requested" ||
execution?.status === "canceling";
const isCancellable = isRunning || execution?.status === "canceling";
const isCancellable = isRunning;
const startedAt = execution?.started_at
? new Date(execution.started_at)
@@ -241,13 +242,23 @@ const ExecutionPreviewPanel = memo(function ExecutionPreviewPanel({
{execution.executor && (
<div>
<dt className="text-xs font-medium text-gray-500 uppercase tracking-wide">
Executor
Initiated By
</dt>
<dd className="mt-0.5 text-sm text-gray-900 font-mono">
#{execution.executor}
</dd>
</div>
)}
{execution.worker && (
<div>
<dt className="text-xs font-medium text-gray-500 uppercase tracking-wide">
Worker
</dt>
<dd className="mt-0.5 text-sm text-gray-900 font-mono">
#{execution.worker}
</dd>
</div>
)}
{execution.workflow_task && (
<div>
<dt className="text-xs font-medium text-gray-500 uppercase tracking-wide">

View File

@@ -42,6 +42,10 @@ const PRESET_BANNER_COLORS: Record<TransitionPreset, string> = {
const MIN_ZOOM = 0.15;
const MAX_ZOOM = 3;
const ZOOM_SENSITIVITY = 0.0015;
const CANVAS_SIDE_PADDING = 120;
const CANVAS_TOP_PADDING = 140;
const CANVAS_BOTTOM_PADDING = 120;
const CANVAS_RIGHT_PADDING = 380;
/**
* Build CSS background style for the infinite grid.
@@ -465,6 +469,11 @@ export default function WorkflowCanvas({
maxY = Math.max(maxY, t.position.y + 140);
}
minX -= CANVAS_SIDE_PADDING;
minY -= CANVAS_TOP_PADDING;
maxX += CANVAS_RIGHT_PADDING;
maxY += CANVAS_BOTTOM_PADDING;
const contentW = maxX - minX;
const contentH = maxY - minY;
const pad = 80;
@@ -492,10 +501,10 @@ export default function WorkflowCanvas({
let maxX = 4000;
let maxY = 4000;
for (const task of tasks) {
minX = Math.min(minX, task.position.x - 100);
minY = Math.min(minY, task.position.y - 100);
maxX = Math.max(maxX, task.position.x + 500);
maxY = Math.max(maxY, task.position.y + 500);
minX = Math.min(minX, task.position.x - CANVAS_SIDE_PADDING);
minY = Math.min(minY, task.position.y - CANVAS_TOP_PADDING);
maxX = Math.max(maxX, task.position.x + CANVAS_RIGHT_PADDING);
maxY = Math.max(maxY, task.position.y + CANVAS_BOTTOM_PADDING + 380);
}
return { width: maxX - minX, height: maxY - minY };
}, [tasks]);

View File

@@ -56,6 +56,14 @@ interface WorkflowEdgesProps {
const NODE_WIDTH = 240;
const NODE_HEIGHT = 96;
const SELF_LOOP_RIGHT_OFFSET = 24;
const SELF_LOOP_TOP_OFFSET = 36;
const SELF_LOOP_BOTTOM_OFFSET = 30;
const ARROW_LENGTH = 12;
const ARROW_HALF_WIDTH = 5;
const ARROW_DIRECTION_LOOKBACK_PX = 10;
const ARROW_DIRECTION_SAMPLES = 48;
const ARROW_SHAFT_OVERLAP_PX = 2;
/** Color for each edge type (alias for shared constant) */
const EDGE_COLORS = EDGE_TYPE_COLORS;
@@ -159,13 +167,14 @@ function getBestConnectionPoints(
end: { x: number; y: number };
selfLoop?: boolean;
} {
// Self-loop: right side → top
// Self-loop uses a dedicated route that stays outside the task card so the
// arrowhead and label remain readable instead of being covered by the node.
if (fromTask.id === toTask.id) {
return {
start: getNodeBottomCenter(fromTask, nodeWidth, nodeHeight),
end: {
x: fromTask.position.x + nodeWidth * 0.75,
y: fromTask.position.y,
x: fromTask.position.x + nodeWidth,
y: fromTask.position.y + nodeHeight * 0.28,
},
selfLoop: true,
};
@@ -184,17 +193,25 @@ function getBestConnectionPoints(
return { start, end };
}
/**
* Build an SVG path for a self-loop.
*/
function buildSelfLoopPath(
start: { x: number; y: number },
end: { x: number; y: number },
): string {
const loopOffset = 50;
const cp1 = { x: start.x + loopOffset, y: start.y - 20 };
const cp2 = { x: end.x + loopOffset, y: end.y - 40 };
return `M ${start.x} ${start.y} C ${cp1.x} ${cp1.y}, ${cp2.x} ${cp2.y}, ${end.x} ${end.y}`;
function buildSelfLoopRoute(
task: WorkflowTask,
nodeWidth: number,
nodeHeight: number,
): { x: number; y: number }[] {
const start = getNodeBottomCenter(task, nodeWidth, nodeHeight);
const cardRight = task.position.x + nodeWidth;
const cardTop = task.position.y;
const loopRight = cardRight + SELF_LOOP_RIGHT_OFFSET;
const loopTop = cardTop + SELF_LOOP_TOP_OFFSET;
const loopBottom = start.y + SELF_LOOP_BOTTOM_OFFSET;
return [
start,
{ x: start.x, y: loopBottom },
{ x: loopRight, y: loopBottom },
{ x: loopRight, y: loopTop },
{ x: cardRight, y: task.position.y + nodeHeight * 0.28 },
];
}
/**
@@ -296,28 +313,13 @@ function getSegmentControlPoints(
function evaluatePathAtT(
allPoints: { x: number; y: number }[],
t: number,
selfLoop?: boolean,
_selfLoop?: boolean,
): { x: number; y: number } {
if (allPoints.length < 2) {
return allPoints[0] ?? { x: 0, y: 0 };
}
// Self-loop with no waypoints (allPoints = [start, end])
if (selfLoop && allPoints.length === 2) {
const start = allPoints[0];
const end = allPoints[1];
const loopOffset = 50;
const cp1 = { x: start.x + loopOffset, y: start.y - 20 };
const cp2 = { x: end.x + loopOffset, y: end.y - 40 };
return evaluateCubicBezier(
start,
cp1,
cp2,
end,
Math.max(0, Math.min(1, t)),
);
}
const numSegments = allPoints.length - 1;
const clampedT = Math.max(0, Math.min(1, t));
const scaledT = clampedT * numSegments;
@@ -341,7 +343,7 @@ function evaluatePathAtT(
function projectOntoPath(
allPoints: { x: number; y: number }[],
mousePos: { x: number; y: number },
selfLoop?: boolean,
_selfLoop?: boolean,
): number {
if (allPoints.length < 2) return 0;
@@ -349,25 +351,6 @@ function projectOntoPath(
let bestT = 0.5;
let bestDist = Infinity;
// Self-loop with no waypoints
if (selfLoop && allPoints.length === 2) {
const start = allPoints[0];
const end = allPoints[1];
const loopOffset = 50;
const cp1 = { x: start.x + loopOffset, y: start.y - 20 };
const cp2 = { x: end.x + loopOffset, y: end.y - 40 };
for (let s = 0; s <= samplesPerSegment; s++) {
const localT = s / samplesPerSegment;
const pt = evaluateCubicBezier(start, cp1, cp2, end, localT);
const dist = Math.hypot(pt.x - mousePos.x, pt.y - mousePos.y);
if (dist < bestDist) {
bestDist = dist;
bestT = localT;
}
}
return bestT;
}
const numSegments = allPoints.length - 1;
for (let seg = 0; seg < numSegments; seg++) {
@@ -466,6 +449,151 @@ function curveSegmentMidpoint(
return evaluateCubicBezier(p1, cp1, cp2, p2, 0.5);
}
function buildArrowHeadPath(
from: { x: number; y: number },
tip: { x: number; y: number },
): {
path: string;
} {
const dx = tip.x - from.x;
const dy = tip.y - from.y;
const length = Math.hypot(dx, dy) || 1;
const ux = dx / length;
const uy = dy / length;
const baseX = tip.x - ux * ARROW_LENGTH;
const baseY = tip.y - uy * ARROW_LENGTH;
const perpX = -uy;
const perpY = ux;
return {
path: `M ${tip.x} ${tip.y} L ${baseX + perpX * ARROW_HALF_WIDTH} ${baseY + perpY * ARROW_HALF_WIDTH} L ${baseX - perpX * ARROW_HALF_WIDTH} ${baseY - perpY * ARROW_HALF_WIDTH} Z`,
};
}
function getArrowDirectionPoint(
allPoints: { x: number; y: number }[],
lookbackPx: number = ARROW_DIRECTION_LOOKBACK_PX,
): { x: number; y: number } {
if (allPoints.length < 2) {
return allPoints[0] ?? { x: 0, y: 0 };
}
const segIdx = allPoints.length - 2;
const start = allPoints[segIdx];
const end = allPoints[segIdx + 1];
const { cp1, cp2 } = getSegmentControlPoints(allPoints, segIdx);
let prev = end;
let traversed = 0;
for (let i = ARROW_DIRECTION_SAMPLES - 1; i >= 0; i--) {
const t = i / ARROW_DIRECTION_SAMPLES;
const pt = evaluateCubicBezier(start, cp1, cp2, end, t);
traversed += Math.hypot(prev.x - pt.x, prev.y - pt.y);
if (traversed >= lookbackPx) {
return pt;
}
prev = pt;
}
return start;
}
function lerpPoint(
a: { x: number; y: number },
b: { x: number; y: number },
t: number,
): { x: number; y: number } {
return {
x: a.x + (b.x - a.x) * t,
y: a.y + (b.y - a.y) * t,
};
}
function splitCubicAtT(
p0: { x: number; y: number },
p1: { x: number; y: number },
p2: { x: number; y: number },
p3: { x: number; y: number },
t: number,
): {
leftCp1: { x: number; y: number };
leftCp2: { x: number; y: number };
point: { x: number; y: number };
} {
const p01 = lerpPoint(p0, p1, t);
const p12 = lerpPoint(p1, p2, t);
const p23 = lerpPoint(p2, p3, t);
const p012 = lerpPoint(p01, p12, t);
const p123 = lerpPoint(p12, p23, t);
const point = lerpPoint(p012, p123, t);
return {
leftCp1: p01,
leftCp2: p012,
point,
};
}
function findTrimmedSegmentEnd(
allPoints: { x: number; y: number }[],
trimPx: number,
): {
segIdx: number;
t: number;
point: { x: number; y: number };
} {
const segIdx = allPoints.length - 2;
const start = allPoints[segIdx];
const end = allPoints[segIdx + 1];
const { cp1, cp2 } = getSegmentControlPoints(allPoints, segIdx);
let prev = end;
let traversed = 0;
for (let i = ARROW_DIRECTION_SAMPLES - 1; i >= 0; i--) {
const t = i / ARROW_DIRECTION_SAMPLES;
const pt = evaluateCubicBezier(start, cp1, cp2, end, t);
traversed += Math.hypot(prev.x - pt.x, prev.y - pt.y);
if (traversed >= trimPx) {
return { segIdx, t, point: pt };
}
prev = pt;
}
return { segIdx, t: 0, point: start };
}
function buildTrimmedPath(
allPoints: { x: number; y: number }[],
trimPx: number,
): string {
if (allPoints.length < 2) return "";
if (trimPx <= 0) {
return allPoints.length === 2
? buildCurvePath(allPoints[0], allPoints[1])
: buildSmoothPath(allPoints);
}
const { segIdx, t } = findTrimmedSegmentEnd(allPoints, trimPx);
const start = allPoints[segIdx];
const end = allPoints[segIdx + 1];
const { cp1, cp2 } = getSegmentControlPoints(allPoints, segIdx);
const trimmed = splitCubicAtT(start, cp1, cp2, end, t);
let d = `M ${allPoints[0].x} ${allPoints[0].y}`;
for (let i = 0; i < segIdx; i++) {
const p2 = allPoints[i + 1];
const { cp1: segCp1, cp2: segCp2 } = getSegmentControlPoints(allPoints, i);
d += ` C ${segCp1.x} ${segCp1.y}, ${segCp2.x} ${segCp2.y}, ${p2.x} ${p2.y}`;
}
d += ` C ${trimmed.leftCp1.x} ${trimmed.leftCp1.y}, ${trimmed.leftCp2.x} ${trimmed.leftCp2.y}, ${trimmed.point.x} ${trimmed.point.y}`;
return d;
}
/** Check whether two SelectedEdgeInfo match the same edge */
function edgeMatches(
sel: SelectedEdgeInfo | null | undefined,
@@ -530,9 +658,12 @@ function WorkflowEdgesInner({
let maxX = 0;
let maxY = 0;
for (const task of tasks) {
minX = Math.min(minX, task.position.x - 100);
minY = Math.min(minY, task.position.y - 100);
maxX = Math.max(maxX, task.position.x + nodeWidth + 100);
minX = Math.min(minX, task.position.x - 120);
minY = Math.min(minY, task.position.y - 140);
maxX = Math.max(
maxX,
task.position.x + nodeWidth + SELF_LOOP_RIGHT_OFFSET + 40,
);
maxY = Math.max(maxY, task.position.y + nodeHeight + 100);
}
return {
@@ -909,56 +1040,20 @@ function WorkflowEdgesInner({
height={svgBounds.height}
style={{ zIndex: 1 }}
>
<defs>
{/* Arrow markers for each edge type */}
{Object.entries(EDGE_COLORS).map(([type, color]) => (
<marker
key={`arrow-${type}`}
id={`arrow-${type}`}
viewBox="0 0 10 10"
refX={9}
refY={5}
markerWidth={8}
markerHeight={8}
orient="auto-start-reverse"
>
<path d="M 0 0 L 10 5 L 0 10 z" fill={color} opacity={0.8} />
</marker>
))}
</defs>
<g className="pointer-events-auto">
{/* Dynamic arrow markers for custom-colored edges */}
{edges.map((edge, index) => {
if (!edge.color) return null;
return (
<marker
key={`arrow-custom-${index}`}
id={`arrow-custom-${index}`}
viewBox="0 0 10 10"
refX={9}
refY={5}
markerWidth={8}
markerHeight={8}
orient="auto-start-reverse"
>
<path d="M 0 0 L 10 5 L 0 10 z" fill={edge.color} opacity={0.8} />
</marker>
);
})}
{/* Render edges */}
{edges.map((edge, index) => {
const fromTask = taskMap.get(edge.from);
const toTask = taskMap.get(edge.to);
if (!fromTask || !toTask) return null;
const isSelfLoopEdge = edge.from === edge.to;
// Build the current waypoints first so we can pass them into
// connection-point selection as an approach hint.
let currentWaypoints: NodePosition[] = edge.waypoints
? [...edge.waypoints]
: [];
let currentWaypoints: NodePosition[] =
!isSelfLoopEdge && edge.waypoints ? [...edge.waypoints] : [];
if (
!isSelfLoopEdge &&
activeDrag &&
activeDrag.edgeFrom === edge.from &&
activeDrag.edgeTo === edge.to &&
@@ -985,26 +1080,21 @@ function WorkflowEdgesInner({
const isSelected = edgeMatches(selectedEdge, edge);
// All points: start → waypoints → end
const allPoints = [start, ...currentWaypoints, end];
const pathD =
const selfLoopRoute =
selfLoop && currentWaypoints.length === 0
? buildSelfLoopPath(start, end)
: allPoints.length === 2
? buildCurvePath(start, end)
: buildSmoothPath(allPoints);
? buildSelfLoopRoute(fromTask, nodeWidth, nodeHeight)
: null;
const color =
edge.color || EDGE_COLORS[edge.type] || EDGE_COLORS.complete;
const dash = edge.lineStyle ? LINE_STYLE_DASH[edge.lineStyle] : "";
const arrowId = edge.color
? `arrow-custom-${index}`
: `arrow-${edge.type}`;
const groupOpacity = isSelected ? 1 : 0.75;
// Label position — evaluate t-parameter on the actual path
let labelPos: { x: number; y: number };
const isSelfLoopEdge = selfLoop && currentWaypoints.length === 0;
const usesDefaultSelfLoopRoute =
selfLoop && currentWaypoints.length === 0;
const allPoints = selfLoopRoute ?? [start, ...currentWaypoints, end];
if (
activeDrag &&
activeDrag.type === "label" &&
@@ -1016,15 +1106,25 @@ function WorkflowEdgesInner({
// During drag, dragPos is already snapped to the curve
labelPos = dragPos;
} else {
const t = edge.labelPosition ?? 0.5;
labelPos = evaluatePathAtT(allPoints, t, isSelfLoopEdge);
const t =
edge.labelPosition ?? (usesDefaultSelfLoopRoute ? 0.62 : 0.5);
labelPos = evaluatePathAtT(allPoints, t, usesDefaultSelfLoopRoute);
}
const arrowDirectionPoint = getArrowDirectionPoint(allPoints);
const arrowHead = buildArrowHeadPath(arrowDirectionPoint, end);
const pathD = buildTrimmedPath(
allPoints,
ARROW_LENGTH - ARROW_SHAFT_OVERLAP_PX,
);
const labelText = edge.label || "";
const labelWidth = Math.max(labelText.length * 5.5 + 12, 48);
return (
<g key={`edge-${index}-${edge.from}-${edge.to}`}>
<g
key={`edge-${index}-${edge.from}-${edge.to}`}
opacity={groupOpacity}
>
{/* Edge path */}
<path
d={pathD}
@@ -1032,9 +1132,12 @@ function WorkflowEdgesInner({
stroke={color}
strokeWidth={isSelected ? 2.5 : 2}
strokeDasharray={dash}
markerEnd={`url(#${arrowId})`}
className="transition-opacity"
opacity={isSelected ? 1 : 0.75}
/>
<path
d={arrowHead.path}
fill={color}
className="pointer-events-none transition-opacity"
/>
{/* Selection glow for selected edge */}
@@ -1078,7 +1181,7 @@ function WorkflowEdgesInner({
edge,
labelPos,
allPoints,
isSelfLoopEdge,
usesDefaultSelfLoopRoute,
)
: undefined
}
@@ -1138,7 +1241,7 @@ function WorkflowEdgesInner({
)}
{/* === Selected edge interactive elements === */}
{isSelected && (
{isSelected && !isSelfLoopEdge && (
<>
{/* Waypoint handles */}
{currentWaypoints.map((wp, wpIdx) => {

View File

@@ -147,15 +147,18 @@ export function useExecutionArtifacts(
return response;
},
enabled: !!executionId,
staleTime: isRunning ? 3000 : 10000,
refetchInterval: isRunning ? 3000 : 10000,
staleTime: isRunning ? 3000 : 30000,
refetchInterval: isRunning ? 3000 : false,
});
}
/**
* Fetch a single artifact by ID (includes data field for progress artifacts).
*
* @param isRunning - When true, polls every 3s for live updates. When false,
* uses a longer stale time and disables automatic polling.
*/
export function useArtifact(id: number | undefined) {
export function useArtifact(id: number | undefined, isRunning = false) {
return useQuery({
queryKey: ["artifacts", id],
queryFn: async () => {
@@ -169,8 +172,8 @@ export function useArtifact(id: number | undefined) {
return response;
},
enabled: !!id,
staleTime: 3000,
refetchInterval: 3000,
staleTime: isRunning ? 3000 : 30000,
refetchInterval: isRunning ? 3000 : false,
});
}

View File

@@ -205,6 +205,11 @@ export function useExecutionStream(options: UseExecutionStreamOptions = {}) {
},
);
queryClient.invalidateQueries({
queryKey: ["history", "execution", executionNotification.entity_id],
exact: false,
});
// Update execution list queries by modifying existing data.
// We need to iterate manually to access query keys for filtering.
const queries = queryClient

View File

@@ -22,6 +22,16 @@ interface ExecutionsQueryParams {
topLevelOnly?: boolean;
}
function isExecutionActive(status: string | undefined): boolean {
return (
status === "requested" ||
status === "scheduling" ||
status === "scheduled" ||
status === "running" ||
status === "canceling"
);
}
export function useExecutions(params?: ExecutionsQueryParams) {
// Check if any filters are applied
const hasFilters =
@@ -67,7 +77,9 @@ export function useExecution(id: number) {
return response;
},
enabled: !!id,
staleTime: 30000, // 30 seconds - SSE handles real-time updates
staleTime: 30000,
refetchInterval: (query) =>
isExecutionActive(query.state.data?.data?.status) ? 3000 : false,
});
}
@@ -180,11 +192,7 @@ export function useChildExecutions(parentId: number | undefined) {
const data = query.state.data;
if (!data) return false;
const hasActive = data.data.some(
(e) =>
e.status === "requested" ||
e.status === "scheduling" ||
e.status === "scheduled" ||
e.status === "running",
(e) => isExecutionActive(e.status),
);
return hasActive ? 5000 : false;
},

View File

@@ -199,10 +199,10 @@ export default function ExecutionDetailPage() {
execution.status === ExecutionStatus.RUNNING ||
execution.status === ExecutionStatus.SCHEDULING ||
execution.status === ExecutionStatus.SCHEDULED ||
execution.status === ExecutionStatus.REQUESTED;
execution.status === ExecutionStatus.REQUESTED ||
execution.status === ExecutionStatus.CANCELING;
const isCancellable =
isRunning || execution.status === ExecutionStatus.CANCELING;
const isCancellable = isRunning;
return (
<div className="p-6 max-w-7xl mx-auto">
@@ -392,13 +392,23 @@ export default function ExecutionDetailPage() {
{execution.executor && (
<div>
<dt className="text-sm font-medium text-gray-500">
Executor ID
Initiated By
</dt>
<dd className="mt-1 text-sm text-gray-900">
{execution.executor}
</dd>
</div>
)}
{execution.worker && (
<div>
<dt className="text-sm font-medium text-gray-500">
Worker ID
</dt>
<dd className="mt-1 text-sm text-gray-900">
{execution.worker}
</dd>
</div>
)}
</dl>
{/* Inline progress bar (visible when execution has progress artifacts) */}