artifacts!

This commit is contained in:
2026-03-03 13:42:41 -06:00
parent 5da940639a
commit 8299e5efcb
50 changed files with 4779 additions and 341 deletions

View File

@@ -1,5 +1,5 @@
use anyhow::{Context, Result};
use reqwest::{Client as HttpClient, Method, RequestBuilder, Response, StatusCode};
use reqwest::{multipart, Client as HttpClient, Method, RequestBuilder, Response, StatusCode};
use serde::{de::DeserializeOwned, Serialize};
use std::path::PathBuf;
use std::time::Duration;
@@ -39,7 +39,7 @@ impl ApiClient {
Self {
client: HttpClient::builder()
.timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(300)) // longer timeout for uploads
.build()
.expect("Failed to build HTTP client"),
base_url,
@@ -50,10 +50,15 @@ impl ApiClient {
}
/// Create a new API client
/// Return the base URL this client is configured to talk to.
pub fn base_url(&self) -> &str {
&self.base_url
}
#[cfg(test)]
pub fn new(base_url: String, auth_token: Option<String>) -> Self {
let client = HttpClient::builder()
.timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(300))
.build()
.expect("Failed to build HTTP client");
@@ -296,6 +301,55 @@ impl ApiClient {
anyhow::bail!("API error ({}): {}", status, error_text);
}
}
/// POST a multipart/form-data request with a file field and optional text fields.
///
/// - `file_field_name`: the multipart field name for the file
/// - `file_bytes`: raw bytes of the file content
/// - `file_name`: filename hint sent in the Content-Disposition header
/// - `mime_type`: MIME type of the file (e.g. `"application/gzip"`)
/// - `extra_fields`: additional text key/value fields to include in the form
pub async fn multipart_post<T: DeserializeOwned>(
&mut self,
path: &str,
file_field_name: &str,
file_bytes: Vec<u8>,
file_name: &str,
mime_type: &str,
extra_fields: Vec<(&str, String)>,
) -> Result<T> {
let url = format!("{}/api/v1{}", self.base_url, path);
let file_part = multipart::Part::bytes(file_bytes)
.file_name(file_name.to_string())
.mime_str(mime_type)
.context("Invalid MIME type")?;
let mut form = multipart::Form::new().part(file_field_name.to_string(), file_part);
for (key, value) in extra_fields {
form = form.text(key.to_string(), value);
}
let mut req = self.client.post(&url).multipart(form);
if let Some(token) = &self.auth_token {
req = req.bearer_auth(token);
}
let response = req.send().await.context("Failed to send multipart request to API")?;
// Handle 401 + refresh (same pattern as execute())
if response.status() == StatusCode::UNAUTHORIZED && self.refresh_token.is_some() {
if self.refresh_auth_token().await? {
return Err(anyhow::anyhow!(
"Token expired and was refreshed. Please retry your command."
));
}
}
self.handle_response(response).await
}
}
#[cfg(test)]

View File

@@ -6,6 +6,7 @@ use std::collections::HashMap;
use crate::client::ApiClient;
use crate::config::CliConfig;
use crate::output::{self, OutputFormat};
use crate::wait::{wait_for_execution, WaitOptions};
#[derive(Subcommand)]
pub enum ActionCommands {
@@ -74,6 +75,11 @@ pub enum ActionCommands {
/// Timeout in seconds when waiting (default: 300)
#[arg(long, default_value = "300", requires = "wait")]
timeout: u64,
/// Notifier WebSocket base URL (e.g. ws://localhost:8081).
/// Derived from --api-url automatically when not set.
#[arg(long, requires = "wait")]
notifier_url: Option<String>,
},
}
@@ -182,6 +188,7 @@ pub async fn handle_action_command(
params_json,
wait,
timeout,
notifier_url,
} => {
handle_execute(
action_ref,
@@ -191,6 +198,7 @@ pub async fn handle_action_command(
api_url,
wait,
timeout,
notifier_url,
output_format,
)
.await
@@ -415,6 +423,7 @@ async fn handle_execute(
api_url: &Option<String>,
wait: bool,
timeout: u64,
notifier_url: Option<String>,
output_format: OutputFormat,
) -> Result<()> {
let config = CliConfig::load_with_profile(profile.as_deref())?;
@@ -453,62 +462,61 @@ async fn handle_execute(
}
let path = "/executions/execute".to_string();
let mut execution: Execution = client.post(&path, &request).await?;
let execution: Execution = client.post(&path, &request).await?;
if wait {
if !wait {
match output_format {
OutputFormat::Json | OutputFormat::Yaml => {
output::print_output(&execution, output_format)?;
}
OutputFormat::Table => {
output::print_info(&format!(
"Waiting for execution {} to complete...",
execution.id
));
output::print_success(&format!("Execution {} started", execution.id));
output::print_key_value_table(vec![
("Execution ID", execution.id.to_string()),
("Action", execution.action_ref.clone()),
("Status", output::format_status(&execution.status)),
]);
}
_ => {}
}
// Poll for completion
let start = std::time::Instant::now();
let timeout_duration = std::time::Duration::from_secs(timeout);
loop {
if start.elapsed() > timeout_duration {
anyhow::bail!("Execution timed out after {} seconds", timeout);
}
let exec_path = format!("/executions/{}", execution.id);
execution = client.get(&exec_path).await?;
if execution.status == "succeeded"
|| execution.status == "failed"
|| execution.status == "canceled"
{
break;
}
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
}
return Ok(());
}
match output_format {
OutputFormat::Table => {
output::print_info(&format!(
"Waiting for execution {} to complete...",
execution.id
));
}
_ => {}
}
let verbose = matches!(output_format, OutputFormat::Table);
let summary = wait_for_execution(WaitOptions {
execution_id: execution.id,
timeout_secs: timeout,
api_client: &mut client,
notifier_ws_url: notifier_url,
verbose,
})
.await?;
match output_format {
OutputFormat::Json | OutputFormat::Yaml => {
output::print_output(&execution, output_format)?;
output::print_output(&summary, output_format)?;
}
OutputFormat::Table => {
output::print_success(&format!(
"Execution {} {}",
execution.id,
if wait { "completed" } else { "started" }
));
output::print_success(&format!("Execution {} completed", summary.id));
output::print_section("Execution Details");
output::print_key_value_table(vec![
("Execution ID", execution.id.to_string()),
("Action", execution.action_ref.clone()),
("Status", output::format_status(&execution.status)),
("Created", output::format_timestamp(&execution.created)),
("Updated", output::format_timestamp(&execution.updated)),
("Execution ID", summary.id.to_string()),
("Action", summary.action_ref.clone()),
("Status", output::format_status(&summary.status)),
("Created", output::format_timestamp(&summary.created)),
("Updated", output::format_timestamp(&summary.updated)),
]);
if let Some(result) = execution.result {
if let Some(result) = summary.result {
if !result.is_null() {
output::print_section("Result");
println!("{}", serde_json::to_string_pretty(&result)?);

View File

@@ -17,6 +17,14 @@ pub enum AuthCommands {
/// Password (will prompt if not provided)
#[arg(long)]
password: Option<String>,
/// API URL to log in to (saved into the profile for future use)
#[arg(long)]
url: Option<String>,
/// Save credentials into a named profile (creates it if it doesn't exist)
#[arg(long)]
save_profile: Option<String>,
},
/// Log out and clear authentication tokens
Logout,
@@ -53,8 +61,22 @@ pub async fn handle_auth_command(
output_format: OutputFormat,
) -> Result<()> {
match command {
AuthCommands::Login { username, password } => {
handle_login(username, password, profile, api_url, output_format).await
AuthCommands::Login {
username,
password,
url,
save_profile,
} => {
// --url is a convenient alias for --api-url at login time
let effective_api_url = url.or_else(|| api_url.clone());
handle_login(
username,
password,
save_profile.as_ref().or(profile.as_ref()),
&effective_api_url,
output_format,
)
.await
}
AuthCommands::Logout => handle_logout(profile, output_format).await,
AuthCommands::Whoami => handle_whoami(profile, api_url, output_format).await,
@@ -65,11 +87,44 @@ pub async fn handle_auth_command(
async fn handle_login(
username: String,
password: Option<String>,
profile: &Option<String>,
profile: Option<&String>,
api_url: &Option<String>,
output_format: OutputFormat,
) -> Result<()> {
let config = CliConfig::load_with_profile(profile.as_deref())?;
// Determine which profile name will own these credentials.
// If --save-profile / --profile was given, use that; otherwise use the
// currently-active profile.
let mut config = CliConfig::load()?;
let target_profile_name = profile
.cloned()
.unwrap_or_else(|| config.current_profile.clone());
// If a URL was provided and the target profile doesn't exist yet, create it.
if !config.profiles.contains_key(&target_profile_name) {
let url = api_url.clone().unwrap_or_else(|| "http://localhost:8080".to_string());
use crate::config::Profile;
config.set_profile(
target_profile_name.clone(),
Profile {
api_url: url,
auth_token: None,
refresh_token: None,
output_format: None,
description: None,
},
)?;
} else if let Some(url) = api_url {
// Profile exists — update its api_url if an explicit URL was provided.
if let Some(p) = config.profiles.get_mut(&target_profile_name) {
p.api_url = url.clone();
}
config.save()?;
}
// Build a temporary config view that points at the target profile so
// ApiClient uses the right base URL.
let mut login_config = CliConfig::load()?;
login_config.current_profile = target_profile_name.clone();
// Prompt for password if not provided
let password = match password {
@@ -82,7 +137,7 @@ async fn handle_login(
}
};
let mut client = ApiClient::from_config(&config, api_url);
let mut client = ApiClient::from_config(&login_config, api_url);
let login_req = LoginRequest {
login: username,
@@ -91,12 +146,17 @@ async fn handle_login(
let response: LoginResponse = client.post("/auth/login", &login_req).await?;
// Save tokens to config
// Persist tokens into the target profile.
let mut config = CliConfig::load()?;
config.set_auth(
response.access_token.clone(),
response.refresh_token.clone(),
)?;
// Ensure the profile exists (it may have just been created above and saved).
if let Some(p) = config.profiles.get_mut(&target_profile_name) {
p.auth_token = Some(response.access_token.clone());
p.refresh_token = Some(response.refresh_token.clone());
config.save()?;
} else {
// Fallback: set_auth writes to the current profile.
config.set_auth(response.access_token.clone(), response.refresh_token.clone())?;
}
match output_format {
OutputFormat::Json | OutputFormat::Yaml => {
@@ -105,6 +165,12 @@ async fn handle_login(
OutputFormat::Table => {
output::print_success("Successfully logged in");
output::print_info(&format!("Token expires in {} seconds", response.expires_in));
if target_profile_name != config.current_profile {
output::print_info(&format!(
"Credentials saved to profile '{}'",
target_profile_name
));
}
}
}

View File

@@ -1,5 +1,6 @@
use anyhow::Result;
use anyhow::{Context, Result};
use clap::Subcommand;
use flate2::{write::GzEncoder, Compression};
use serde::{Deserialize, Serialize};
use std::path::Path;
@@ -77,9 +78,9 @@ pub enum PackCommands {
#[arg(short = 'y', long)]
yes: bool,
},
/// Register a pack from a local directory
/// Register a pack from a local directory (path must be accessible by the API server)
Register {
/// Path to pack directory
/// Path to pack directory (must be a path the API server can access)
path: String,
/// Force re-registration if pack already exists
@@ -90,6 +91,22 @@ pub enum PackCommands {
#[arg(long)]
skip_tests: bool,
},
/// Upload a local pack directory to the API server and register it
///
/// This command tarballs the local directory and streams it to the API,
/// so it works regardless of whether the API is local or running in Docker.
Upload {
/// Path to the local pack directory (must contain pack.yaml)
path: String,
/// Force re-registration if a pack with the same ref already exists
#[arg(short, long)]
force: bool,
/// Skip running pack tests after upload
#[arg(long)]
skip_tests: bool,
},
/// Test a pack's test suite
Test {
/// Pack reference (name) or path to pack directory
@@ -256,6 +273,15 @@ struct RegisterPackRequest {
skip_tests: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct UploadPackResponse {
pack: Pack,
#[serde(default)]
test_result: Option<serde_json::Value>,
#[serde(default)]
tests_skipped: bool,
}
pub async fn handle_pack_command(
profile: &Option<String>,
command: PackCommands,
@@ -296,6 +322,11 @@ pub async fn handle_pack_command(
force,
skip_tests,
} => handle_register(profile, path, force, skip_tests, api_url, output_format).await,
PackCommands::Upload {
path,
force,
skip_tests,
} => handle_upload(profile, path, force, skip_tests, api_url, output_format).await,
PackCommands::Test {
pack,
verbose,
@@ -593,6 +624,160 @@ async fn handle_uninstall(
Ok(())
}
async fn handle_upload(
profile: &Option<String>,
path: String,
force: bool,
skip_tests: bool,
api_url: &Option<String>,
output_format: OutputFormat,
) -> Result<()> {
let pack_dir = Path::new(&path);
// Validate the directory exists and contains pack.yaml
if !pack_dir.exists() {
anyhow::bail!("Path does not exist: {}", path);
}
if !pack_dir.is_dir() {
anyhow::bail!("Path is not a directory: {}", path);
}
let pack_yaml_path = pack_dir.join("pack.yaml");
if !pack_yaml_path.exists() {
anyhow::bail!("No pack.yaml found in: {}", path);
}
// Read pack ref from pack.yaml so we can display it
let pack_yaml_content = std::fs::read_to_string(&pack_yaml_path)
.context("Failed to read pack.yaml")?;
let pack_yaml: serde_yaml_ng::Value =
serde_yaml_ng::from_str(&pack_yaml_content).context("Failed to parse pack.yaml")?;
let pack_ref = pack_yaml
.get("ref")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
match output_format {
OutputFormat::Table => {
output::print_info(&format!(
"Uploading pack '{}' from: {}",
pack_ref, path
));
output::print_info("Creating archive...");
}
_ => {}
}
// Build an in-memory tar.gz of the pack directory
let tar_gz_bytes = {
let buf = Vec::new();
let enc = GzEncoder::new(buf, Compression::default());
let mut tar = tar::Builder::new(enc);
// Walk the directory and add files to the archive
// We strip the leading path so the archive root is the pack directory contents
let abs_pack_dir = pack_dir
.canonicalize()
.context("Failed to resolve pack directory path")?;
append_dir_to_tar(&mut tar, &abs_pack_dir, &abs_pack_dir)?;
let encoder = tar.into_inner().context("Failed to finalise tar archive")?;
encoder.finish().context("Failed to flush gzip stream")?
};
let archive_size_kb = tar_gz_bytes.len() / 1024;
match output_format {
OutputFormat::Table => {
output::print_info(&format!(
"Archive ready ({} KB), uploading...",
archive_size_kb
));
}
_ => {}
}
let config = CliConfig::load_with_profile(profile.as_deref())?;
let mut client = ApiClient::from_config(&config, api_url);
let mut extra_fields = Vec::new();
if force {
extra_fields.push(("force", "true".to_string()));
}
if skip_tests {
extra_fields.push(("skip_tests", "true".to_string()));
}
let archive_name = format!("{}.tar.gz", pack_ref);
let response: UploadPackResponse = client
.multipart_post(
"/packs/upload",
"pack",
tar_gz_bytes,
&archive_name,
"application/gzip",
extra_fields,
)
.await?;
match output_format {
OutputFormat::Json | OutputFormat::Yaml => {
output::print_output(&response, output_format)?;
}
OutputFormat::Table => {
println!();
output::print_success(&format!(
"✓ Pack '{}' uploaded and registered successfully",
response.pack.pack_ref
));
output::print_info(&format!(" Version: {}", response.pack.version));
output::print_info(&format!(" ID: {}", response.pack.id));
if response.tests_skipped {
output::print_info(" ⚠ Tests were skipped");
} else if let Some(test_result) = &response.test_result {
if let Some(status) = test_result.get("status").and_then(|s| s.as_str()) {
if status == "passed" {
output::print_success(" ✓ All tests passed");
} else if status == "failed" {
output::print_error(" ✗ Some tests failed");
}
}
}
}
}
Ok(())
}
/// Recursively append a directory's contents to a tar archive.
/// `base` is the root directory being archived; `dir` is the current directory
/// being walked. Files are stored with paths relative to `base`.
fn append_dir_to_tar<W: std::io::Write>(
tar: &mut tar::Builder<W>,
base: &Path,
dir: &Path,
) -> Result<()> {
for entry in std::fs::read_dir(dir).context("Failed to read directory")? {
let entry = entry.context("Failed to read directory entry")?;
let entry_path = entry.path();
let relative_path = entry_path
.strip_prefix(base)
.context("Failed to compute relative path")?;
if entry_path.is_dir() {
append_dir_to_tar(tar, base, &entry_path)?;
} else if entry_path.is_file() {
tar.append_path_with_name(&entry_path, relative_path)
.with_context(|| {
format!("Failed to add {} to archive", entry_path.display())
})?;
}
// symlinks are intentionally skipped
}
Ok(())
}
async fn handle_register(
profile: &Option<String>,
path: String,
@@ -604,19 +789,39 @@ async fn handle_register(
let config = CliConfig::load_with_profile(profile.as_deref())?;
let mut client = ApiClient::from_config(&config, api_url);
// Warn if the path looks like a local filesystem path that the API server
// probably can't see (i.e. not a known container mount point).
let looks_local = !path.starts_with("/opt/attune/")
&& !path.starts_with("/app/")
&& !path.starts_with("/packs");
if looks_local {
match output_format {
OutputFormat::Table => {
output::print_info(&format!("Registering pack from: {}", path));
eprintln!(
"⚠ Warning: '{}' looks like a local path. If the API is running in \
Docker it may not be able to access this path.\n \
Use `attune pack upload {}` instead to upload the pack directly.",
path, path
);
}
_ => {}
}
} else {
match output_format {
OutputFormat::Table => {
output::print_info(&format!("Registering pack from: {}", path));
}
_ => {}
}
}
let request = RegisterPackRequest {
path: path.clone(),
force,
skip_tests,
};
match output_format {
OutputFormat::Table => {
output::print_info(&format!("Registering pack from: {}", path));
}
_ => {}
}
let response: PackInstallResponse = client.post("/packs/register", &request).await?;
match output_format {

View File

@@ -5,6 +5,7 @@ mod client;
mod commands;
mod config;
mod output;
mod wait;
use commands::{
action::{handle_action_command, ActionCommands},
@@ -112,6 +113,11 @@ enum Commands {
/// Timeout in seconds when waiting (default: 300)
#[arg(long, default_value = "300", requires = "wait")]
timeout: u64,
/// Notifier WebSocket base URL (e.g. ws://localhost:8081).
/// Derived from --api-url automatically when not set.
#[arg(long, requires = "wait")]
notifier_url: Option<String>,
},
}
@@ -193,6 +199,7 @@ async fn main() {
params_json,
wait,
timeout,
notifier_url,
} => {
// Delegate to action execute command
handle_action_command(
@@ -203,6 +210,7 @@ async fn main() {
params_json,
wait,
timeout,
notifier_url,
},
&cli.api_url,
output_format,

556
crates/cli/src/wait.rs Normal file
View File

@@ -0,0 +1,556 @@
//! Waiting for execution completion.
//!
//! Tries to connect to the notifier WebSocket first so the CLI reacts
//! *immediately* when the execution reaches a terminal state. If the
//! notifier is unreachable (not configured, different port, Docker network
//! boundary, etc.) it transparently falls back to REST polling.
//!
//! Public surface:
//! - [`WaitOptions`] caller-supplied parameters
//! - [`wait_for_execution`] the single entry point
use anyhow::Result;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::client::ApiClient;
// ── terminal status helpers ───────────────────────────────────────────────────
fn is_terminal(status: &str) -> bool {
matches!(
status,
"completed" | "succeeded" | "failed" | "canceled" | "cancelled" | "timeout" | "timed_out"
)
}
// ── public types ─────────────────────────────────────────────────────────────
/// Result returned when the wait completes.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionSummary {
pub id: i64,
pub status: String,
pub action_ref: String,
pub result: Option<serde_json::Value>,
pub created: String,
pub updated: String,
}
/// Parameters that control how we wait.
pub struct WaitOptions<'a> {
/// Execution ID to watch.
pub execution_id: i64,
/// Overall wall-clock limit (seconds). Defaults to 300 if `None`.
pub timeout_secs: u64,
/// REST API client (already authenticated).
pub api_client: &'a mut ApiClient,
/// Base URL of the *notifier* WebSocket service, e.g. `ws://localhost:8081`.
/// Derived from the API URL when not explicitly set.
pub notifier_ws_url: Option<String>,
/// If `true`, print progress lines to stderr.
pub verbose: bool,
}
// ── notifier WebSocket messages (mirrors websocket_server.rs) ────────────────
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum ClientMsg {
#[serde(rename = "subscribe")]
Subscribe { filter: String },
#[serde(rename = "ping")]
Ping,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum ServerMsg {
#[serde(rename = "welcome")]
Welcome {
client_id: String,
#[allow(dead_code)]
message: String,
},
#[serde(rename = "notification")]
Notification(NotifierNotification),
#[serde(rename = "error")]
Error { message: String },
#[serde(other)]
Unknown,
}
#[derive(Debug, Deserialize)]
struct NotifierNotification {
pub notification_type: String,
pub entity_type: String,
pub entity_id: i64,
pub payload: serde_json::Value,
}
// ── REST execution shape ──────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct RestExecution {
id: i64,
action_ref: String,
status: String,
result: Option<serde_json::Value>,
created: String,
updated: String,
}
impl From<RestExecution> for ExecutionSummary {
fn from(e: RestExecution) -> Self {
Self {
id: e.id,
status: e.status,
action_ref: e.action_ref,
result: e.result,
created: e.created,
updated: e.updated,
}
}
}
// ── entry point ───────────────────────────────────────────────────────────────
/// Wait for `execution_id` to reach a terminal status.
///
/// 1. Attempts a WebSocket connection to the notifier and subscribes to the
/// specific execution with the filter `entity:execution:<id>`.
/// 2. If the connection fails (or the notifier URL can't be derived) it falls
/// back to polling `GET /executions/<id>` every 2 seconds.
/// 3. In both cases, an overall `timeout_secs` wall-clock limit is enforced.
///
/// Returns the final [`ExecutionSummary`] on success or an error if the
/// timeout is exceeded or a fatal error occurs.
pub async fn wait_for_execution(opts: WaitOptions<'_>) -> Result<ExecutionSummary> {
let overall_deadline = Instant::now() + Duration::from_secs(opts.timeout_secs);
// Reserve at least this long for polling after WebSocket gives up.
// This ensures the polling fallback always gets a fair chance even when
// the WS path consumes most of the timeout budget.
const MIN_POLL_BUDGET: Duration = Duration::from_secs(10);
// Try WebSocket path first; fall through to polling on any connection error.
if let Some(ws_url) = resolve_ws_url(&opts) {
// Give WS at most (timeout - MIN_POLL_BUDGET) so polling always has headroom.
let ws_deadline = if overall_deadline > Instant::now() + MIN_POLL_BUDGET {
overall_deadline - MIN_POLL_BUDGET
} else {
// Timeout is very short; skip WS entirely and go straight to polling.
overall_deadline
};
match wait_via_websocket(
&ws_url,
opts.execution_id,
ws_deadline,
opts.verbose,
opts.api_client,
)
.await
{
Ok(summary) => return Ok(summary),
Err(ws_err) => {
if opts.verbose {
eprintln!(" [notifier: {}] falling back to polling", ws_err);
}
// Fall through to polling below.
}
}
} else if opts.verbose {
eprintln!(" [notifier URL not configured] using polling");
}
// Polling always uses the full overall deadline, so at minimum MIN_POLL_BUDGET
// remains (and often the full timeout if WS failed at connect time).
wait_via_polling(
opts.api_client,
opts.execution_id,
overall_deadline,
opts.verbose,
)
.await
}
// ── WebSocket path ────────────────────────────────────────────────────────────
async fn wait_via_websocket(
ws_base_url: &str,
execution_id: i64,
deadline: Instant,
verbose: bool,
api_client: &mut ApiClient,
) -> Result<ExecutionSummary> {
// Build the full WS endpoint URL.
let ws_url = format!("{}/ws", ws_base_url.trim_end_matches('/'));
let connect_timeout = Duration::from_secs(5);
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
anyhow::bail!("WS budget exhausted before connect");
}
let effective_connect_timeout = connect_timeout.min(remaining);
let connect_result =
tokio::time::timeout(effective_connect_timeout, connect_async(&ws_url)).await;
let (ws_stream, _response) = match connect_result {
Ok(Ok(pair)) => pair,
Ok(Err(e)) => anyhow::bail!("WebSocket connect failed: {}", e),
Err(_) => anyhow::bail!("WebSocket connect timed out"),
};
if verbose {
eprintln!(" [notifier] connected to {}", ws_url);
}
let (mut write, mut read) = ws_stream.split();
// Wait for the welcome message before subscribing.
tokio::time::timeout(Duration::from_secs(5), async {
while let Some(msg) = read.next().await {
if let Ok(Message::Text(txt)) = msg {
if let Ok(ServerMsg::Welcome { client_id, .. }) =
serde_json::from_str::<ServerMsg>(&txt)
{
if verbose {
eprintln!(" [notifier] session id {}", client_id);
}
return Ok(());
}
}
}
anyhow::bail!("connection closed before welcome")
})
.await
.map_err(|_| anyhow::anyhow!("timed out waiting for welcome message"))??;
// Subscribe to this specific execution.
let subscribe_msg = ClientMsg::Subscribe {
filter: format!("entity:execution:{}", execution_id),
};
let subscribe_json = serde_json::to_string(&subscribe_msg)?;
SinkExt::send(&mut write, Message::Text(subscribe_json.into())).await?;
if verbose {
eprintln!(
" [notifier] subscribed to entity:execution:{}",
execution_id
);
}
// ── Race-condition guard ──────────────────────────────────────────────
// The execution may have already completed in the window between the
// initial POST and when the WS subscription became active. Check once
// with the REST API *after* subscribing so there is no gap: either the
// notification arrives after this check (and we'll catch it in the loop
// below) or we catch the terminal state here.
{
let path = format!("/executions/{}", execution_id);
if let Ok(exec) = api_client.get::<RestExecution>(&path).await {
if is_terminal(&exec.status) {
if verbose {
eprintln!(
" [notifier] execution {} already terminal ('{}') — caught by post-subscribe check",
execution_id, exec.status
);
}
return Ok(exec.into());
}
}
}
// Periodically ping to keep the connection alive and check the deadline.
let ping_interval = Duration::from_secs(15);
let mut next_ping = Instant::now() + ping_interval;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
anyhow::bail!("timed out waiting for execution {}", execution_id);
}
// Wait up to the earlier of: next ping time or deadline.
let wait_for = remaining.min(next_ping.saturating_duration_since(Instant::now()));
let msg_result = tokio::time::timeout(wait_for, read.next()).await;
match msg_result {
// Received a message within the window.
Ok(Some(Ok(Message::Text(txt)))) => {
match serde_json::from_str::<ServerMsg>(&txt) {
Ok(ServerMsg::Notification(n)) => {
if n.entity_type == "execution" && n.entity_id == execution_id {
if verbose {
eprintln!(
" [notifier] {} for execution {} — status={:?}",
n.notification_type,
execution_id,
n.payload.get("status").and_then(|s| s.as_str()),
);
}
// Extract status from the notification payload.
// The notifier broadcasts the full execution row in
// `payload`, so we can read the status directly.
if let Some(status) = n.payload.get("status").and_then(|s| s.as_str()) {
if is_terminal(status) {
// Build a summary from the payload; fall
// back to a REST fetch for missing fields.
return build_summary_from_payload(execution_id, &n.payload);
}
}
}
// Not our execution or not yet terminal — keep waiting.
}
Ok(ServerMsg::Error { message }) => {
anyhow::bail!("notifier error: {}", message);
}
Ok(ServerMsg::Welcome { .. } | ServerMsg::Unknown) => {
// Ignore unexpected / unrecognised messages.
}
Err(e) => {
// Log parse failures at trace level — they can happen if the
// server sends a message format we don't recognise yet.
if verbose {
eprintln!(" [notifier] ignoring unrecognised message: {}", e);
}
}
}
}
// Connection closed cleanly.
Ok(Some(Ok(Message::Close(_)))) | Ok(None) => {
anyhow::bail!("notifier WebSocket closed unexpectedly");
}
// Ping/pong frames — ignore.
Ok(Some(Ok(
Message::Ping(_) | Message::Pong(_) | Message::Binary(_) | Message::Frame(_),
))) => {}
// WebSocket transport error.
Ok(Some(Err(e))) => {
anyhow::bail!("WebSocket error: {}", e);
}
// Timeout waiting for a message — time to ping.
Err(_timeout) => {
let now = Instant::now();
if now >= next_ping {
let _ = SinkExt::send(
&mut write,
Message::Text(serde_json::to_string(&ClientMsg::Ping)?.into()),
)
.await;
next_ping = now + ping_interval;
}
}
}
}
}
/// Build an [`ExecutionSummary`] from the notification payload.
/// The notifier payload matches the REST execution shape closely enough that
/// we can deserialize it directly.
fn build_summary_from_payload(
execution_id: i64,
payload: &serde_json::Value,
) -> Result<ExecutionSummary> {
// Try a full deserialize first.
if let Ok(exec) = serde_json::from_value::<RestExecution>(payload.clone()) {
return Ok(exec.into());
}
// Partial payload — assemble what we can.
Ok(ExecutionSummary {
id: execution_id,
status: payload
.get("status")
.and_then(|s| s.as_str())
.unwrap_or("unknown")
.to_string(),
action_ref: payload
.get("action_ref")
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string(),
result: payload.get("result").cloned(),
created: payload
.get("created")
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string(),
updated: payload
.get("updated")
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string(),
})
}
// ── polling fallback ──────────────────────────────────────────────────────────
const POLL_INTERVAL: Duration = Duration::from_millis(500);
const POLL_INTERVAL_MAX: Duration = Duration::from_secs(2);
/// How quickly the poll interval grows on each successive check.
const POLL_BACKOFF_FACTOR: f64 = 1.5;
async fn wait_via_polling(
client: &mut ApiClient,
execution_id: i64,
deadline: Instant,
verbose: bool,
) -> Result<ExecutionSummary> {
if verbose {
eprintln!(" [poll] watching execution {}", execution_id);
}
let mut interval = POLL_INTERVAL;
loop {
// Poll immediately first, before sleeping — catches the case where the
// execution already finished while we were connecting to the notifier.
let path = format!("/executions/{}", execution_id);
match client.get::<RestExecution>(&path).await {
Ok(exec) => {
if is_terminal(&exec.status) {
if verbose {
eprintln!(" [poll] execution {} is {}", execution_id, exec.status);
}
return Ok(exec.into());
}
if verbose {
eprintln!(
" [poll] status = {} — checking again in {:.1}s",
exec.status,
interval.as_secs_f64()
);
}
}
Err(e) => {
if verbose {
eprintln!(" [poll] request failed ({}), retrying…", e);
}
}
}
// Check deadline *after* the poll attempt so we always do at least one check.
if Instant::now() >= deadline {
anyhow::bail!("timed out waiting for execution {}", execution_id);
}
// Sleep, but wake up if we'd overshoot the deadline.
let sleep_for = interval.min(deadline.saturating_duration_since(Instant::now()));
tokio::time::sleep(sleep_for).await;
// Exponential back-off up to the cap.
interval = Duration::from_secs_f64(
(interval.as_secs_f64() * POLL_BACKOFF_FACTOR).min(POLL_INTERVAL_MAX.as_secs_f64()),
);
}
}
// ── URL resolution ────────────────────────────────────────────────────────────
/// Derive the notifier WebSocket base URL.
///
/// Priority:
/// 1. Explicit `notifier_ws_url` in [`WaitOptions`].
/// 2. Replace the API base URL scheme (`http` → `ws`) and port (`8080` → `8081`).
/// This covers the standard single-host layout where both services share the
/// same hostname.
fn resolve_ws_url(opts: &WaitOptions<'_>) -> Option<String> {
if let Some(url) = &opts.notifier_ws_url {
return Some(url.clone());
}
// Ask the client for its base URL by building a dummy request path
// and stripping the path portion — we don't have direct access to
// base_url here so we derive it from the config instead.
let api_url = opts.api_client.base_url();
// Transform http(s)://host:PORT/... → ws(s)://host:8081
let ws_url = derive_notifier_url(&api_url)?;
Some(ws_url)
}
/// Convert an HTTP API base URL into the expected notifier WebSocket URL.
///
/// - `http://localhost:8080` → `ws://localhost:8081`
/// - `https://api.example.com` → `wss://api.example.com:8081`
/// - `http://api.example.com:9000` → `ws://api.example.com:8081`
fn derive_notifier_url(api_url: &str) -> Option<String> {
let url = url::Url::parse(api_url).ok()?;
let ws_scheme = match url.scheme() {
"https" => "wss",
_ => "ws",
};
let host = url.host_str()?;
Some(format!("{}://{}:8081", ws_scheme, host))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_terminal() {
assert!(is_terminal("completed"));
assert!(is_terminal("succeeded"));
assert!(is_terminal("failed"));
assert!(is_terminal("canceled"));
assert!(is_terminal("cancelled"));
assert!(is_terminal("timeout"));
assert!(is_terminal("timed_out"));
assert!(!is_terminal("requested"));
assert!(!is_terminal("scheduled"));
assert!(!is_terminal("running"));
}
#[test]
fn test_derive_notifier_url() {
assert_eq!(
derive_notifier_url("http://localhost:8080"),
Some("ws://localhost:8081".to_string())
);
assert_eq!(
derive_notifier_url("https://api.example.com"),
Some("wss://api.example.com:8081".to_string())
);
assert_eq!(
derive_notifier_url("http://api.example.com:9000"),
Some("ws://api.example.com:8081".to_string())
);
assert_eq!(
derive_notifier_url("http://10.0.0.5:8080"),
Some("ws://10.0.0.5:8081".to_string())
);
}
#[test]
fn test_build_summary_from_full_payload() {
let payload = serde_json::json!({
"id": 42,
"action_ref": "core.echo",
"status": "completed",
"result": { "stdout": "hi" },
"created": "2026-01-01T00:00:00Z",
"updated": "2026-01-01T00:00:01Z"
});
let summary = build_summary_from_payload(42, &payload).unwrap();
assert_eq!(summary.id, 42);
assert_eq!(summary.status, "completed");
assert_eq!(summary.action_ref, "core.echo");
}
#[test]
fn test_build_summary_from_partial_payload() {
let payload = serde_json::json!({ "status": "failed" });
let summary = build_summary_from_payload(7, &payload).unwrap();
assert_eq!(summary.id, 7);
assert_eq!(summary.status, "failed");
assert_eq!(summary.action_ref, "");
}
}