Files
attune/crates/cli/src/wait.rs
2026-03-03 13:42:41 -06:00

557 lines
20 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Waiting for execution completion.
//!
//! Tries to connect to the notifier WebSocket first so the CLI reacts
//! *immediately* when the execution reaches a terminal state. If the
//! notifier is unreachable (not configured, different port, Docker network
//! boundary, etc.) it transparently falls back to REST polling.
//!
//! Public surface:
//! - [`WaitOptions`] caller-supplied parameters
//! - [`wait_for_execution`] the single entry point
use anyhow::Result;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::client::ApiClient;
// ── terminal status helpers ───────────────────────────────────────────────────
fn is_terminal(status: &str) -> bool {
matches!(
status,
"completed" | "succeeded" | "failed" | "canceled" | "cancelled" | "timeout" | "timed_out"
)
}
// ── public types ─────────────────────────────────────────────────────────────
/// Result returned when the wait completes.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionSummary {
pub id: i64,
pub status: String,
pub action_ref: String,
pub result: Option<serde_json::Value>,
pub created: String,
pub updated: String,
}
/// Parameters that control how we wait.
pub struct WaitOptions<'a> {
/// Execution ID to watch.
pub execution_id: i64,
/// Overall wall-clock limit (seconds). Defaults to 300 if `None`.
pub timeout_secs: u64,
/// REST API client (already authenticated).
pub api_client: &'a mut ApiClient,
/// Base URL of the *notifier* WebSocket service, e.g. `ws://localhost:8081`.
/// Derived from the API URL when not explicitly set.
pub notifier_ws_url: Option<String>,
/// If `true`, print progress lines to stderr.
pub verbose: bool,
}
// ── notifier WebSocket messages (mirrors websocket_server.rs) ────────────────
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum ClientMsg {
#[serde(rename = "subscribe")]
Subscribe { filter: String },
#[serde(rename = "ping")]
Ping,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum ServerMsg {
#[serde(rename = "welcome")]
Welcome {
client_id: String,
#[allow(dead_code)]
message: String,
},
#[serde(rename = "notification")]
Notification(NotifierNotification),
#[serde(rename = "error")]
Error { message: String },
#[serde(other)]
Unknown,
}
#[derive(Debug, Deserialize)]
struct NotifierNotification {
pub notification_type: String,
pub entity_type: String,
pub entity_id: i64,
pub payload: serde_json::Value,
}
// ── REST execution shape ──────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct RestExecution {
id: i64,
action_ref: String,
status: String,
result: Option<serde_json::Value>,
created: String,
updated: String,
}
impl From<RestExecution> for ExecutionSummary {
fn from(e: RestExecution) -> Self {
Self {
id: e.id,
status: e.status,
action_ref: e.action_ref,
result: e.result,
created: e.created,
updated: e.updated,
}
}
}
// ── entry point ───────────────────────────────────────────────────────────────
/// Wait for `execution_id` to reach a terminal status.
///
/// 1. Attempts a WebSocket connection to the notifier and subscribes to the
/// specific execution with the filter `entity:execution:<id>`.
/// 2. If the connection fails (or the notifier URL can't be derived) it falls
/// back to polling `GET /executions/<id>` every 2 seconds.
/// 3. In both cases, an overall `timeout_secs` wall-clock limit is enforced.
///
/// Returns the final [`ExecutionSummary`] on success or an error if the
/// timeout is exceeded or a fatal error occurs.
pub async fn wait_for_execution(opts: WaitOptions<'_>) -> Result<ExecutionSummary> {
let overall_deadline = Instant::now() + Duration::from_secs(opts.timeout_secs);
// Reserve at least this long for polling after WebSocket gives up.
// This ensures the polling fallback always gets a fair chance even when
// the WS path consumes most of the timeout budget.
const MIN_POLL_BUDGET: Duration = Duration::from_secs(10);
// Try WebSocket path first; fall through to polling on any connection error.
if let Some(ws_url) = resolve_ws_url(&opts) {
// Give WS at most (timeout - MIN_POLL_BUDGET) so polling always has headroom.
let ws_deadline = if overall_deadline > Instant::now() + MIN_POLL_BUDGET {
overall_deadline - MIN_POLL_BUDGET
} else {
// Timeout is very short; skip WS entirely and go straight to polling.
overall_deadline
};
match wait_via_websocket(
&ws_url,
opts.execution_id,
ws_deadline,
opts.verbose,
opts.api_client,
)
.await
{
Ok(summary) => return Ok(summary),
Err(ws_err) => {
if opts.verbose {
eprintln!(" [notifier: {}] falling back to polling", ws_err);
}
// Fall through to polling below.
}
}
} else if opts.verbose {
eprintln!(" [notifier URL not configured] using polling");
}
// Polling always uses the full overall deadline, so at minimum MIN_POLL_BUDGET
// remains (and often the full timeout if WS failed at connect time).
wait_via_polling(
opts.api_client,
opts.execution_id,
overall_deadline,
opts.verbose,
)
.await
}
// ── WebSocket path ────────────────────────────────────────────────────────────
async fn wait_via_websocket(
ws_base_url: &str,
execution_id: i64,
deadline: Instant,
verbose: bool,
api_client: &mut ApiClient,
) -> Result<ExecutionSummary> {
// Build the full WS endpoint URL.
let ws_url = format!("{}/ws", ws_base_url.trim_end_matches('/'));
let connect_timeout = Duration::from_secs(5);
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
anyhow::bail!("WS budget exhausted before connect");
}
let effective_connect_timeout = connect_timeout.min(remaining);
let connect_result =
tokio::time::timeout(effective_connect_timeout, connect_async(&ws_url)).await;
let (ws_stream, _response) = match connect_result {
Ok(Ok(pair)) => pair,
Ok(Err(e)) => anyhow::bail!("WebSocket connect failed: {}", e),
Err(_) => anyhow::bail!("WebSocket connect timed out"),
};
if verbose {
eprintln!(" [notifier] connected to {}", ws_url);
}
let (mut write, mut read) = ws_stream.split();
// Wait for the welcome message before subscribing.
tokio::time::timeout(Duration::from_secs(5), async {
while let Some(msg) = read.next().await {
if let Ok(Message::Text(txt)) = msg {
if let Ok(ServerMsg::Welcome { client_id, .. }) =
serde_json::from_str::<ServerMsg>(&txt)
{
if verbose {
eprintln!(" [notifier] session id {}", client_id);
}
return Ok(());
}
}
}
anyhow::bail!("connection closed before welcome")
})
.await
.map_err(|_| anyhow::anyhow!("timed out waiting for welcome message"))??;
// Subscribe to this specific execution.
let subscribe_msg = ClientMsg::Subscribe {
filter: format!("entity:execution:{}", execution_id),
};
let subscribe_json = serde_json::to_string(&subscribe_msg)?;
SinkExt::send(&mut write, Message::Text(subscribe_json.into())).await?;
if verbose {
eprintln!(
" [notifier] subscribed to entity:execution:{}",
execution_id
);
}
// ── Race-condition guard ──────────────────────────────────────────────
// The execution may have already completed in the window between the
// initial POST and when the WS subscription became active. Check once
// with the REST API *after* subscribing so there is no gap: either the
// notification arrives after this check (and we'll catch it in the loop
// below) or we catch the terminal state here.
{
let path = format!("/executions/{}", execution_id);
if let Ok(exec) = api_client.get::<RestExecution>(&path).await {
if is_terminal(&exec.status) {
if verbose {
eprintln!(
" [notifier] execution {} already terminal ('{}') — caught by post-subscribe check",
execution_id, exec.status
);
}
return Ok(exec.into());
}
}
}
// Periodically ping to keep the connection alive and check the deadline.
let ping_interval = Duration::from_secs(15);
let mut next_ping = Instant::now() + ping_interval;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
anyhow::bail!("timed out waiting for execution {}", execution_id);
}
// Wait up to the earlier of: next ping time or deadline.
let wait_for = remaining.min(next_ping.saturating_duration_since(Instant::now()));
let msg_result = tokio::time::timeout(wait_for, read.next()).await;
match msg_result {
// Received a message within the window.
Ok(Some(Ok(Message::Text(txt)))) => {
match serde_json::from_str::<ServerMsg>(&txt) {
Ok(ServerMsg::Notification(n)) => {
if n.entity_type == "execution" && n.entity_id == execution_id {
if verbose {
eprintln!(
" [notifier] {} for execution {} — status={:?}",
n.notification_type,
execution_id,
n.payload.get("status").and_then(|s| s.as_str()),
);
}
// Extract status from the notification payload.
// The notifier broadcasts the full execution row in
// `payload`, so we can read the status directly.
if let Some(status) = n.payload.get("status").and_then(|s| s.as_str()) {
if is_terminal(status) {
// Build a summary from the payload; fall
// back to a REST fetch for missing fields.
return build_summary_from_payload(execution_id, &n.payload);
}
}
}
// Not our execution or not yet terminal — keep waiting.
}
Ok(ServerMsg::Error { message }) => {
anyhow::bail!("notifier error: {}", message);
}
Ok(ServerMsg::Welcome { .. } | ServerMsg::Unknown) => {
// Ignore unexpected / unrecognised messages.
}
Err(e) => {
// Log parse failures at trace level — they can happen if the
// server sends a message format we don't recognise yet.
if verbose {
eprintln!(" [notifier] ignoring unrecognised message: {}", e);
}
}
}
}
// Connection closed cleanly.
Ok(Some(Ok(Message::Close(_)))) | Ok(None) => {
anyhow::bail!("notifier WebSocket closed unexpectedly");
}
// Ping/pong frames — ignore.
Ok(Some(Ok(
Message::Ping(_) | Message::Pong(_) | Message::Binary(_) | Message::Frame(_),
))) => {}
// WebSocket transport error.
Ok(Some(Err(e))) => {
anyhow::bail!("WebSocket error: {}", e);
}
// Timeout waiting for a message — time to ping.
Err(_timeout) => {
let now = Instant::now();
if now >= next_ping {
let _ = SinkExt::send(
&mut write,
Message::Text(serde_json::to_string(&ClientMsg::Ping)?.into()),
)
.await;
next_ping = now + ping_interval;
}
}
}
}
}
/// Build an [`ExecutionSummary`] from the notification payload.
/// The notifier payload matches the REST execution shape closely enough that
/// we can deserialize it directly.
fn build_summary_from_payload(
execution_id: i64,
payload: &serde_json::Value,
) -> Result<ExecutionSummary> {
// Try a full deserialize first.
if let Ok(exec) = serde_json::from_value::<RestExecution>(payload.clone()) {
return Ok(exec.into());
}
// Partial payload — assemble what we can.
Ok(ExecutionSummary {
id: execution_id,
status: payload
.get("status")
.and_then(|s| s.as_str())
.unwrap_or("unknown")
.to_string(),
action_ref: payload
.get("action_ref")
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string(),
result: payload.get("result").cloned(),
created: payload
.get("created")
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string(),
updated: payload
.get("updated")
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string(),
})
}
// ── polling fallback ──────────────────────────────────────────────────────────
const POLL_INTERVAL: Duration = Duration::from_millis(500);
const POLL_INTERVAL_MAX: Duration = Duration::from_secs(2);
/// How quickly the poll interval grows on each successive check.
const POLL_BACKOFF_FACTOR: f64 = 1.5;
async fn wait_via_polling(
client: &mut ApiClient,
execution_id: i64,
deadline: Instant,
verbose: bool,
) -> Result<ExecutionSummary> {
if verbose {
eprintln!(" [poll] watching execution {}", execution_id);
}
let mut interval = POLL_INTERVAL;
loop {
// Poll immediately first, before sleeping — catches the case where the
// execution already finished while we were connecting to the notifier.
let path = format!("/executions/{}", execution_id);
match client.get::<RestExecution>(&path).await {
Ok(exec) => {
if is_terminal(&exec.status) {
if verbose {
eprintln!(" [poll] execution {} is {}", execution_id, exec.status);
}
return Ok(exec.into());
}
if verbose {
eprintln!(
" [poll] status = {} — checking again in {:.1}s",
exec.status,
interval.as_secs_f64()
);
}
}
Err(e) => {
if verbose {
eprintln!(" [poll] request failed ({}), retrying…", e);
}
}
}
// Check deadline *after* the poll attempt so we always do at least one check.
if Instant::now() >= deadline {
anyhow::bail!("timed out waiting for execution {}", execution_id);
}
// Sleep, but wake up if we'd overshoot the deadline.
let sleep_for = interval.min(deadline.saturating_duration_since(Instant::now()));
tokio::time::sleep(sleep_for).await;
// Exponential back-off up to the cap.
interval = Duration::from_secs_f64(
(interval.as_secs_f64() * POLL_BACKOFF_FACTOR).min(POLL_INTERVAL_MAX.as_secs_f64()),
);
}
}
// ── URL resolution ────────────────────────────────────────────────────────────
/// Derive the notifier WebSocket base URL.
///
/// Priority:
/// 1. Explicit `notifier_ws_url` in [`WaitOptions`].
/// 2. Replace the API base URL scheme (`http` → `ws`) and port (`8080` → `8081`).
/// This covers the standard single-host layout where both services share the
/// same hostname.
fn resolve_ws_url(opts: &WaitOptions<'_>) -> Option<String> {
if let Some(url) = &opts.notifier_ws_url {
return Some(url.clone());
}
// Ask the client for its base URL by building a dummy request path
// and stripping the path portion — we don't have direct access to
// base_url here so we derive it from the config instead.
let api_url = opts.api_client.base_url();
// Transform http(s)://host:PORT/... → ws(s)://host:8081
let ws_url = derive_notifier_url(&api_url)?;
Some(ws_url)
}
/// Convert an HTTP API base URL into the expected notifier WebSocket URL.
///
/// - `http://localhost:8080` → `ws://localhost:8081`
/// - `https://api.example.com` → `wss://api.example.com:8081`
/// - `http://api.example.com:9000` → `ws://api.example.com:8081`
fn derive_notifier_url(api_url: &str) -> Option<String> {
let url = url::Url::parse(api_url).ok()?;
let ws_scheme = match url.scheme() {
"https" => "wss",
_ => "ws",
};
let host = url.host_str()?;
Some(format!("{}://{}:8081", ws_scheme, host))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_terminal() {
assert!(is_terminal("completed"));
assert!(is_terminal("succeeded"));
assert!(is_terminal("failed"));
assert!(is_terminal("canceled"));
assert!(is_terminal("cancelled"));
assert!(is_terminal("timeout"));
assert!(is_terminal("timed_out"));
assert!(!is_terminal("requested"));
assert!(!is_terminal("scheduled"));
assert!(!is_terminal("running"));
}
#[test]
fn test_derive_notifier_url() {
assert_eq!(
derive_notifier_url("http://localhost:8080"),
Some("ws://localhost:8081".to_string())
);
assert_eq!(
derive_notifier_url("https://api.example.com"),
Some("wss://api.example.com:8081".to_string())
);
assert_eq!(
derive_notifier_url("http://api.example.com:9000"),
Some("ws://api.example.com:8081".to_string())
);
assert_eq!(
derive_notifier_url("http://10.0.0.5:8080"),
Some("ws://10.0.0.5:8081".to_string())
);
}
#[test]
fn test_build_summary_from_full_payload() {
let payload = serde_json::json!({
"id": 42,
"action_ref": "core.echo",
"status": "completed",
"result": { "stdout": "hi" },
"created": "2026-01-01T00:00:00Z",
"updated": "2026-01-01T00:00:01Z"
});
let summary = build_summary_from_payload(42, &payload).unwrap();
assert_eq!(summary.id, 42);
assert_eq!(summary.status, "completed");
assert_eq!(summary.action_ref, "core.echo");
}
#[test]
fn test_build_summary_from_partial_payload() {
let payload = serde_json::json!({ "status": "failed" });
let summary = build_summary_from_payload(7, &payload).unwrap();
assert_eq!(summary.id, 7);
assert_eq!(summary.status, "failed");
assert_eq!(summary.action_ref, "");
}
}