[WIP] client action streaming
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::HeaderMap,
|
||||
http::StatusCode,
|
||||
response::{
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
@@ -13,6 +14,7 @@ use axum::{
|
||||
use chrono::Utc;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use attune_common::models::enums::ExecutionStatus;
|
||||
@@ -32,7 +34,10 @@ use attune_common::workflow::{CancellationPolicy, WorkflowDefinition};
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
auth::{
|
||||
jwt::{validate_token, Claims, JwtConfig, TokenType},
|
||||
middleware::{AuthenticatedUser, RequireAuth},
|
||||
},
|
||||
authz::{AuthorizationCheck, AuthorizationService},
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
@@ -46,6 +51,9 @@ use crate::{
|
||||
};
|
||||
use attune_common::rbac::{Action, AuthorizationContext, Resource};
|
||||
|
||||
const LOG_STREAM_POLL_INTERVAL: Duration = Duration::from_millis(250);
|
||||
const LOG_STREAM_READ_CHUNK_SIZE: usize = 64 * 1024;
|
||||
|
||||
/// Create a new execution (manual execution)
|
||||
///
|
||||
/// This endpoint allows directly executing an action without a trigger or rule.
|
||||
@@ -925,6 +933,398 @@ pub async fn stream_execution_updates(
|
||||
Ok(Sse::new(filtered_stream).keep_alive(KeepAlive::default()))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct StreamExecutionLogParams {
|
||||
pub token: Option<String>,
|
||||
pub offset: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ExecutionLogStream {
|
||||
Stdout,
|
||||
Stderr,
|
||||
}
|
||||
|
||||
impl ExecutionLogStream {
|
||||
fn parse(name: &str) -> Result<Self, ApiError> {
|
||||
match name {
|
||||
"stdout" => Ok(Self::Stdout),
|
||||
"stderr" => Ok(Self::Stderr),
|
||||
_ => Err(ApiError::BadRequest(format!(
|
||||
"Unsupported log stream '{}'. Expected 'stdout' or 'stderr'.",
|
||||
name
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn file_name(self) -> &'static str {
|
||||
match self {
|
||||
Self::Stdout => "stdout.log",
|
||||
Self::Stderr => "stderr.log",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum ExecutionLogTailState {
|
||||
WaitingForFile {
|
||||
full_path: std::path::PathBuf,
|
||||
execution_id: i64,
|
||||
},
|
||||
SendInitial {
|
||||
full_path: std::path::PathBuf,
|
||||
execution_id: i64,
|
||||
offset: u64,
|
||||
pending_utf8: Vec<u8>,
|
||||
},
|
||||
Tail {
|
||||
full_path: std::path::PathBuf,
|
||||
execution_id: i64,
|
||||
offset: u64,
|
||||
idle_polls: u32,
|
||||
pending_utf8: Vec<u8>,
|
||||
},
|
||||
Finished,
|
||||
}
|
||||
|
||||
/// Stream stdout/stderr for an execution as SSE.
|
||||
///
|
||||
/// This tails the worker's live log files directly from the shared artifacts
|
||||
/// volume. The file may not exist yet when the worker has not emitted any
|
||||
/// output, so the stream waits briefly for it to appear.
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/{id}/logs/{stream}/stream",
|
||||
tag = "executions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Execution ID"),
|
||||
("stream" = String, Path, description = "Log stream name: stdout or stderr"),
|
||||
("token" = String, Query, description = "JWT access token for authentication"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "SSE stream of execution log content", content_type = "text/event-stream"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Execution not found"),
|
||||
),
|
||||
)]
|
||||
pub async fn stream_execution_log(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path((id, stream_name)): Path<(i64, String)>,
|
||||
Query(params): Query<StreamExecutionLogParams>,
|
||||
user: Result<RequireAuth, crate::auth::middleware::AuthError>,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
|
||||
let authenticated_user =
|
||||
authenticate_execution_log_stream_user(&state, &headers, user, params.token.as_deref())?;
|
||||
validate_execution_log_stream_user(&authenticated_user, id)?;
|
||||
|
||||
let execution = ExecutionRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Execution with ID {} not found", id)))?;
|
||||
authorize_execution_log_stream(&state, &authenticated_user, &execution).await?;
|
||||
|
||||
let stream_name = ExecutionLogStream::parse(&stream_name)?;
|
||||
let full_path = std::path::PathBuf::from(&state.config.artifacts_dir)
|
||||
.join(format!("execution_{}", id))
|
||||
.join(stream_name.file_name());
|
||||
let db = state.db.clone();
|
||||
|
||||
let initial_state = ExecutionLogTailState::WaitingForFile {
|
||||
full_path,
|
||||
execution_id: id,
|
||||
};
|
||||
let start_offset = params.offset.unwrap_or(0);
|
||||
|
||||
let stream = futures::stream::unfold(initial_state, move |state| {
|
||||
let db = db.clone();
|
||||
async move {
|
||||
match state {
|
||||
ExecutionLogTailState::Finished => None,
|
||||
ExecutionLogTailState::WaitingForFile {
|
||||
full_path,
|
||||
execution_id,
|
||||
} => {
|
||||
if full_path.exists() {
|
||||
Some((
|
||||
Ok(Event::default().event("waiting").data("Log file found")),
|
||||
ExecutionLogTailState::SendInitial {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset: start_offset,
|
||||
pending_utf8: Vec::new(),
|
||||
},
|
||||
))
|
||||
} else if execution_log_execution_terminal(&db, execution_id).await {
|
||||
Some((
|
||||
Ok(Event::default().event("done").data("")),
|
||||
ExecutionLogTailState::Finished,
|
||||
))
|
||||
} else {
|
||||
tokio::time::sleep(LOG_STREAM_POLL_INTERVAL).await;
|
||||
Some((
|
||||
Ok(Event::default()
|
||||
.event("waiting")
|
||||
.data("Waiting for log output")),
|
||||
ExecutionLogTailState::WaitingForFile {
|
||||
full_path,
|
||||
execution_id,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
ExecutionLogTailState::SendInitial {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
pending_utf8,
|
||||
} => {
|
||||
let pending_utf8_on_empty = pending_utf8.clone();
|
||||
match read_log_chunk(
|
||||
&full_path,
|
||||
offset,
|
||||
LOG_STREAM_READ_CHUNK_SIZE,
|
||||
pending_utf8,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some((content, new_offset, pending_utf8)) => Some((
|
||||
Ok(Event::default()
|
||||
.id(new_offset.to_string())
|
||||
.event("content")
|
||||
.data(content)),
|
||||
ExecutionLogTailState::SendInitial {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset: new_offset,
|
||||
pending_utf8,
|
||||
},
|
||||
)),
|
||||
None => Some((
|
||||
Ok(Event::default().comment("initial-catchup-complete")),
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
idle_polls: 0,
|
||||
pending_utf8: pending_utf8_on_empty,
|
||||
},
|
||||
)),
|
||||
}
|
||||
}
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
idle_polls,
|
||||
pending_utf8,
|
||||
} => {
|
||||
let pending_utf8_on_empty = pending_utf8.clone();
|
||||
match read_log_chunk(
|
||||
&full_path,
|
||||
offset,
|
||||
LOG_STREAM_READ_CHUNK_SIZE,
|
||||
pending_utf8,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some((append, new_offset, pending_utf8)) => Some((
|
||||
Ok(Event::default()
|
||||
.id(new_offset.to_string())
|
||||
.event("append")
|
||||
.data(append)),
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset: new_offset,
|
||||
idle_polls: 0,
|
||||
pending_utf8,
|
||||
},
|
||||
)),
|
||||
None => {
|
||||
let terminal =
|
||||
execution_log_execution_terminal(&db, execution_id).await;
|
||||
if terminal && idle_polls >= 2 {
|
||||
Some((
|
||||
Ok(Event::default().event("done").data("Execution complete")),
|
||||
ExecutionLogTailState::Finished,
|
||||
))
|
||||
} else {
|
||||
tokio::time::sleep(LOG_STREAM_POLL_INTERVAL).await;
|
||||
Some((
|
||||
Ok(Event::default()
|
||||
.event("waiting")
|
||||
.data("Waiting for log output")),
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
idle_polls: idle_polls + 1,
|
||||
pending_utf8: pending_utf8_on_empty,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
|
||||
}
|
||||
|
||||
async fn read_log_chunk(
|
||||
path: &std::path::Path,
|
||||
offset: u64,
|
||||
max_bytes: usize,
|
||||
mut pending_utf8: Vec<u8>,
|
||||
) -> Option<(String, u64, Vec<u8>)> {
|
||||
use tokio::io::{AsyncReadExt, AsyncSeekExt};
|
||||
|
||||
let mut file = tokio::fs::File::open(path).await.ok()?;
|
||||
let metadata = file.metadata().await.ok()?;
|
||||
if metadata.len() <= offset {
|
||||
return None;
|
||||
}
|
||||
|
||||
file.seek(std::io::SeekFrom::Start(offset)).await.ok()?;
|
||||
let bytes_to_read = ((metadata.len() - offset) as usize).min(max_bytes);
|
||||
let mut buf = vec![0u8; bytes_to_read];
|
||||
let read = file.read(&mut buf).await.ok()?;
|
||||
buf.truncate(read);
|
||||
if buf.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
pending_utf8.extend_from_slice(&buf);
|
||||
let (content, pending_utf8) = decode_utf8_chunk(pending_utf8);
|
||||
|
||||
Some((content, offset + read as u64, pending_utf8))
|
||||
}
|
||||
|
||||
async fn execution_log_execution_terminal(db: &sqlx::PgPool, execution_id: i64) -> bool {
|
||||
match ExecutionRepository::find_by_id(db, execution_id).await {
|
||||
Ok(Some(execution)) => matches!(
|
||||
execution.status,
|
||||
ExecutionStatus::Completed
|
||||
| ExecutionStatus::Failed
|
||||
| ExecutionStatus::Cancelled
|
||||
| ExecutionStatus::Timeout
|
||||
| ExecutionStatus::Abandoned
|
||||
),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_utf8_chunk(mut bytes: Vec<u8>) -> (String, Vec<u8>) {
|
||||
match std::str::from_utf8(&bytes) {
|
||||
Ok(valid) => (valid.to_string(), Vec::new()),
|
||||
Err(err) if err.error_len().is_none() => {
|
||||
let pending = bytes.split_off(err.valid_up_to());
|
||||
(String::from_utf8_lossy(&bytes).into_owned(), pending)
|
||||
}
|
||||
Err(_) => (String::from_utf8_lossy(&bytes).into_owned(), Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize_execution_log_stream(
|
||||
state: &Arc<AppState>,
|
||||
user: &AuthenticatedUser,
|
||||
execution: &attune_common::models::Execution,
|
||||
) -> Result<(), ApiError> {
|
||||
if user.claims.token_type != TokenType::Access {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.target_id = Some(execution.id);
|
||||
ctx.target_ref = Some(execution.action_ref.clone());
|
||||
|
||||
authz
|
||||
.authorize(
|
||||
user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Executions,
|
||||
action: Action::Read,
|
||||
context: ctx,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn authenticate_execution_log_stream_user(
|
||||
state: &Arc<AppState>,
|
||||
headers: &HeaderMap,
|
||||
user: Result<RequireAuth, crate::auth::middleware::AuthError>,
|
||||
query_token: Option<&str>,
|
||||
) -> Result<AuthenticatedUser, ApiError> {
|
||||
match user {
|
||||
Ok(RequireAuth(user)) => Ok(user),
|
||||
Err(_) => {
|
||||
if let Some(user) = crate::auth::oidc::cookie_authenticated_user(headers, state)? {
|
||||
return Ok(user);
|
||||
}
|
||||
|
||||
let token = query_token.ok_or(ApiError::Unauthorized(
|
||||
"Missing authentication token".to_string(),
|
||||
))?;
|
||||
authenticate_execution_log_stream_query_token(token, &state.jwt_config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate_execution_log_stream_query_token(
|
||||
token: &str,
|
||||
jwt_config: &JwtConfig,
|
||||
) -> Result<AuthenticatedUser, ApiError> {
|
||||
let claims = validate_token(token, jwt_config)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
|
||||
|
||||
Ok(AuthenticatedUser { claims })
|
||||
}
|
||||
|
||||
fn validate_execution_log_stream_user(
|
||||
user: &AuthenticatedUser,
|
||||
execution_id: i64,
|
||||
) -> Result<(), ApiError> {
|
||||
let claims = &user.claims;
|
||||
|
||||
match claims.token_type {
|
||||
TokenType::Access => Ok(()),
|
||||
TokenType::Execution => validate_execution_token_scope(claims, execution_id),
|
||||
TokenType::Sensor | TokenType::Refresh => Err(ApiError::Unauthorized(
|
||||
"Invalid authentication token".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_execution_token_scope(claims: &Claims, execution_id: i64) -> Result<(), ApiError> {
|
||||
if claims.scope.as_deref() != Some("execution") {
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid authentication token".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let token_execution_id = claims
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("execution_id"))
|
||||
.and_then(|value| value.as_i64())
|
||||
.ok_or_else(|| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
|
||||
|
||||
if token_execution_id != execution_id {
|
||||
return Err(ApiError::Forbidden(format!(
|
||||
"Execution token is not valid for execution {}",
|
||||
execution_id
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct StreamExecutionParams {
|
||||
pub execution_id: Option<i64>,
|
||||
@@ -937,6 +1337,10 @@ pub fn routes() -> Router<Arc<AppState>> {
|
||||
.route("/executions/execute", axum::routing::post(create_execution))
|
||||
.route("/executions/stats", get(get_execution_stats))
|
||||
.route("/executions/stream", get(stream_execution_updates))
|
||||
.route(
|
||||
"/executions/{id}/logs/{stream}/stream",
|
||||
get(stream_execution_log),
|
||||
)
|
||||
.route("/executions/{id}", get(get_execution))
|
||||
.route(
|
||||
"/executions/{id}/cancel",
|
||||
@@ -955,10 +1359,26 @@ pub fn routes() -> Router<Arc<AppState>> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use attune_common::auth::jwt::generate_execution_token;
|
||||
|
||||
#[test]
|
||||
fn test_execution_routes_structure() {
|
||||
// Just verify the router can be constructed
|
||||
let _router = routes();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execution_token_scope_must_match_requested_execution() {
|
||||
let jwt_config = JwtConfig {
|
||||
secret: "test_secret_key_for_testing".to_string(),
|
||||
access_token_expiration: 3600,
|
||||
refresh_token_expiration: 604800,
|
||||
};
|
||||
|
||||
let token = generate_execution_token(42, 123, "core.echo", &jwt_config, None).unwrap();
|
||||
|
||||
let user = authenticate_execution_log_stream_query_token(&token, &jwt_config).unwrap();
|
||||
let err = validate_execution_log_stream_user(&user, 456).unwrap_err();
|
||||
assert!(matches!(err, ApiError::Forbidden(_)));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user