[WIP] client action streaming

This commit is contained in:
2026-04-01 20:23:56 -05:00
parent 4b525f4641
commit 104dcbb1b1
14 changed files with 1152 additions and 828 deletions

View File

@@ -2,6 +2,7 @@
use axum::{
extract::{Path, Query, State},
http::HeaderMap,
http::StatusCode,
response::{
sse::{Event, KeepAlive, Sse},
@@ -13,6 +14,7 @@ use axum::{
use chrono::Utc;
use futures::stream::{Stream, StreamExt};
use std::sync::Arc;
use std::time::Duration;
use tokio_stream::wrappers::BroadcastStream;
use attune_common::models::enums::ExecutionStatus;
@@ -32,7 +34,10 @@ use attune_common::workflow::{CancellationPolicy, WorkflowDefinition};
use sqlx::Row;
use crate::{
auth::middleware::RequireAuth,
auth::{
jwt::{validate_token, Claims, JwtConfig, TokenType},
middleware::{AuthenticatedUser, RequireAuth},
},
authz::{AuthorizationCheck, AuthorizationService},
dto::{
common::{PaginatedResponse, PaginationParams},
@@ -46,6 +51,9 @@ use crate::{
};
use attune_common::rbac::{Action, AuthorizationContext, Resource};
const LOG_STREAM_POLL_INTERVAL: Duration = Duration::from_millis(250);
const LOG_STREAM_READ_CHUNK_SIZE: usize = 64 * 1024;
/// Create a new execution (manual execution)
///
/// This endpoint allows directly executing an action without a trigger or rule.
@@ -925,6 +933,398 @@ pub async fn stream_execution_updates(
Ok(Sse::new(filtered_stream).keep_alive(KeepAlive::default()))
}
#[derive(serde::Deserialize)]
pub struct StreamExecutionLogParams {
pub token: Option<String>,
pub offset: Option<u64>,
}
#[derive(Clone, Copy)]
enum ExecutionLogStream {
Stdout,
Stderr,
}
impl ExecutionLogStream {
fn parse(name: &str) -> Result<Self, ApiError> {
match name {
"stdout" => Ok(Self::Stdout),
"stderr" => Ok(Self::Stderr),
_ => Err(ApiError::BadRequest(format!(
"Unsupported log stream '{}'. Expected 'stdout' or 'stderr'.",
name
))),
}
}
fn file_name(self) -> &'static str {
match self {
Self::Stdout => "stdout.log",
Self::Stderr => "stderr.log",
}
}
}
enum ExecutionLogTailState {
WaitingForFile {
full_path: std::path::PathBuf,
execution_id: i64,
},
SendInitial {
full_path: std::path::PathBuf,
execution_id: i64,
offset: u64,
pending_utf8: Vec<u8>,
},
Tail {
full_path: std::path::PathBuf,
execution_id: i64,
offset: u64,
idle_polls: u32,
pending_utf8: Vec<u8>,
},
Finished,
}
/// Stream stdout/stderr for an execution as SSE.
///
/// This tails the worker's live log files directly from the shared artifacts
/// volume. The file may not exist yet when the worker has not emitted any
/// output, so the stream waits briefly for it to appear.
#[utoipa::path(
get,
path = "/api/v1/executions/{id}/logs/{stream}/stream",
tag = "executions",
params(
("id" = i64, Path, description = "Execution ID"),
("stream" = String, Path, description = "Log stream name: stdout or stderr"),
("token" = String, Query, description = "JWT access token for authentication"),
),
responses(
(status = 200, description = "SSE stream of execution log content", content_type = "text/event-stream"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Execution not found"),
),
)]
pub async fn stream_execution_log(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Path((id, stream_name)): Path<(i64, String)>,
Query(params): Query<StreamExecutionLogParams>,
user: Result<RequireAuth, crate::auth::middleware::AuthError>,
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
let authenticated_user =
authenticate_execution_log_stream_user(&state, &headers, user, params.token.as_deref())?;
validate_execution_log_stream_user(&authenticated_user, id)?;
let execution = ExecutionRepository::find_by_id(&state.db, id)
.await?
.ok_or_else(|| ApiError::NotFound(format!("Execution with ID {} not found", id)))?;
authorize_execution_log_stream(&state, &authenticated_user, &execution).await?;
let stream_name = ExecutionLogStream::parse(&stream_name)?;
let full_path = std::path::PathBuf::from(&state.config.artifacts_dir)
.join(format!("execution_{}", id))
.join(stream_name.file_name());
let db = state.db.clone();
let initial_state = ExecutionLogTailState::WaitingForFile {
full_path,
execution_id: id,
};
let start_offset = params.offset.unwrap_or(0);
let stream = futures::stream::unfold(initial_state, move |state| {
let db = db.clone();
async move {
match state {
ExecutionLogTailState::Finished => None,
ExecutionLogTailState::WaitingForFile {
full_path,
execution_id,
} => {
if full_path.exists() {
Some((
Ok(Event::default().event("waiting").data("Log file found")),
ExecutionLogTailState::SendInitial {
full_path,
execution_id,
offset: start_offset,
pending_utf8: Vec::new(),
},
))
} else if execution_log_execution_terminal(&db, execution_id).await {
Some((
Ok(Event::default().event("done").data("")),
ExecutionLogTailState::Finished,
))
} else {
tokio::time::sleep(LOG_STREAM_POLL_INTERVAL).await;
Some((
Ok(Event::default()
.event("waiting")
.data("Waiting for log output")),
ExecutionLogTailState::WaitingForFile {
full_path,
execution_id,
},
))
}
}
ExecutionLogTailState::SendInitial {
full_path,
execution_id,
offset,
pending_utf8,
} => {
let pending_utf8_on_empty = pending_utf8.clone();
match read_log_chunk(
&full_path,
offset,
LOG_STREAM_READ_CHUNK_SIZE,
pending_utf8,
)
.await
{
Some((content, new_offset, pending_utf8)) => Some((
Ok(Event::default()
.id(new_offset.to_string())
.event("content")
.data(content)),
ExecutionLogTailState::SendInitial {
full_path,
execution_id,
offset: new_offset,
pending_utf8,
},
)),
None => Some((
Ok(Event::default().comment("initial-catchup-complete")),
ExecutionLogTailState::Tail {
full_path,
execution_id,
offset,
idle_polls: 0,
pending_utf8: pending_utf8_on_empty,
},
)),
}
}
ExecutionLogTailState::Tail {
full_path,
execution_id,
offset,
idle_polls,
pending_utf8,
} => {
let pending_utf8_on_empty = pending_utf8.clone();
match read_log_chunk(
&full_path,
offset,
LOG_STREAM_READ_CHUNK_SIZE,
pending_utf8,
)
.await
{
Some((append, new_offset, pending_utf8)) => Some((
Ok(Event::default()
.id(new_offset.to_string())
.event("append")
.data(append)),
ExecutionLogTailState::Tail {
full_path,
execution_id,
offset: new_offset,
idle_polls: 0,
pending_utf8,
},
)),
None => {
let terminal =
execution_log_execution_terminal(&db, execution_id).await;
if terminal && idle_polls >= 2 {
Some((
Ok(Event::default().event("done").data("Execution complete")),
ExecutionLogTailState::Finished,
))
} else {
tokio::time::sleep(LOG_STREAM_POLL_INTERVAL).await;
Some((
Ok(Event::default()
.event("waiting")
.data("Waiting for log output")),
ExecutionLogTailState::Tail {
full_path,
execution_id,
offset,
idle_polls: idle_polls + 1,
pending_utf8: pending_utf8_on_empty,
},
))
}
}
}
}
}
}
});
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
}
async fn read_log_chunk(
path: &std::path::Path,
offset: u64,
max_bytes: usize,
mut pending_utf8: Vec<u8>,
) -> Option<(String, u64, Vec<u8>)> {
use tokio::io::{AsyncReadExt, AsyncSeekExt};
let mut file = tokio::fs::File::open(path).await.ok()?;
let metadata = file.metadata().await.ok()?;
if metadata.len() <= offset {
return None;
}
file.seek(std::io::SeekFrom::Start(offset)).await.ok()?;
let bytes_to_read = ((metadata.len() - offset) as usize).min(max_bytes);
let mut buf = vec![0u8; bytes_to_read];
let read = file.read(&mut buf).await.ok()?;
buf.truncate(read);
if buf.is_empty() {
return None;
}
pending_utf8.extend_from_slice(&buf);
let (content, pending_utf8) = decode_utf8_chunk(pending_utf8);
Some((content, offset + read as u64, pending_utf8))
}
async fn execution_log_execution_terminal(db: &sqlx::PgPool, execution_id: i64) -> bool {
match ExecutionRepository::find_by_id(db, execution_id).await {
Ok(Some(execution)) => matches!(
execution.status,
ExecutionStatus::Completed
| ExecutionStatus::Failed
| ExecutionStatus::Cancelled
| ExecutionStatus::Timeout
| ExecutionStatus::Abandoned
),
_ => true,
}
}
fn decode_utf8_chunk(mut bytes: Vec<u8>) -> (String, Vec<u8>) {
match std::str::from_utf8(&bytes) {
Ok(valid) => (valid.to_string(), Vec::new()),
Err(err) if err.error_len().is_none() => {
let pending = bytes.split_off(err.valid_up_to());
(String::from_utf8_lossy(&bytes).into_owned(), pending)
}
Err(_) => (String::from_utf8_lossy(&bytes).into_owned(), Vec::new()),
}
}
async fn authorize_execution_log_stream(
state: &Arc<AppState>,
user: &AuthenticatedUser,
execution: &attune_common::models::Execution,
) -> Result<(), ApiError> {
if user.claims.token_type != TokenType::Access {
return Ok(());
}
let identity_id = user
.identity_id()
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
let authz = AuthorizationService::new(state.db.clone());
let mut ctx = AuthorizationContext::new(identity_id);
ctx.target_id = Some(execution.id);
ctx.target_ref = Some(execution.action_ref.clone());
authz
.authorize(
user,
AuthorizationCheck {
resource: Resource::Executions,
action: Action::Read,
context: ctx,
},
)
.await
}
fn authenticate_execution_log_stream_user(
state: &Arc<AppState>,
headers: &HeaderMap,
user: Result<RequireAuth, crate::auth::middleware::AuthError>,
query_token: Option<&str>,
) -> Result<AuthenticatedUser, ApiError> {
match user {
Ok(RequireAuth(user)) => Ok(user),
Err(_) => {
if let Some(user) = crate::auth::oidc::cookie_authenticated_user(headers, state)? {
return Ok(user);
}
let token = query_token.ok_or(ApiError::Unauthorized(
"Missing authentication token".to_string(),
))?;
authenticate_execution_log_stream_query_token(token, &state.jwt_config)
}
}
}
fn authenticate_execution_log_stream_query_token(
token: &str,
jwt_config: &JwtConfig,
) -> Result<AuthenticatedUser, ApiError> {
let claims = validate_token(token, jwt_config)
.map_err(|_| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
Ok(AuthenticatedUser { claims })
}
fn validate_execution_log_stream_user(
user: &AuthenticatedUser,
execution_id: i64,
) -> Result<(), ApiError> {
let claims = &user.claims;
match claims.token_type {
TokenType::Access => Ok(()),
TokenType::Execution => validate_execution_token_scope(claims, execution_id),
TokenType::Sensor | TokenType::Refresh => Err(ApiError::Unauthorized(
"Invalid authentication token".to_string(),
)),
}
}
fn validate_execution_token_scope(claims: &Claims, execution_id: i64) -> Result<(), ApiError> {
if claims.scope.as_deref() != Some("execution") {
return Err(ApiError::Unauthorized(
"Invalid authentication token".to_string(),
));
}
let token_execution_id = claims
.metadata
.as_ref()
.and_then(|metadata| metadata.get("execution_id"))
.and_then(|value| value.as_i64())
.ok_or_else(|| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
if token_execution_id != execution_id {
return Err(ApiError::Forbidden(format!(
"Execution token is not valid for execution {}",
execution_id
)));
}
Ok(())
}
#[derive(serde::Deserialize)]
pub struct StreamExecutionParams {
pub execution_id: Option<i64>,
@@ -937,6 +1337,10 @@ pub fn routes() -> Router<Arc<AppState>> {
.route("/executions/execute", axum::routing::post(create_execution))
.route("/executions/stats", get(get_execution_stats))
.route("/executions/stream", get(stream_execution_updates))
.route(
"/executions/{id}/logs/{stream}/stream",
get(stream_execution_log),
)
.route("/executions/{id}", get(get_execution))
.route(
"/executions/{id}/cancel",
@@ -955,10 +1359,26 @@ pub fn routes() -> Router<Arc<AppState>> {
#[cfg(test)]
mod tests {
use super::*;
use attune_common::auth::jwt::generate_execution_token;
#[test]
fn test_execution_routes_structure() {
// Just verify the router can be constructed
let _router = routes();
}
#[test]
fn execution_token_scope_must_match_requested_execution() {
let jwt_config = JwtConfig {
secret: "test_secret_key_for_testing".to_string(),
access_token_expiration: 3600,
refresh_token_expiration: 604800,
};
let token = generate_execution_token(42, 123, "core.echo", &jwt_config, None).unwrap();
let user = authenticate_execution_log_stream_query_token(&token, &jwt_config).unwrap();
let err = validate_execution_log_stream_user(&user, 456).unwrap_err();
assert!(matches!(err, ApiError::Forbidden(_)));
}
}