Compare commits
60 Commits
a7ed135af2
...
ha-executo
| Author | SHA1 | Date | |
|---|---|---|---|
| f93e9229d2 | |||
| 8e91440f23 | |||
| 8278030699 | |||
| b34617ded1 | |||
| b6446cc574 | |||
| cf82de87ea | |||
| a4c303ec84 | |||
| a0f59114a3 | |||
| 104dcbb1b1 | |||
| b342005e17 | |||
| 4b525f4641 | |||
|
|
7ef2b59b23 | ||
| 3a13bf754a | |||
| f4ef823f43 | |||
| ab7d31de2f | |||
| 938c271ff5 | |||
| da8055cb79 | |||
| 03a239d22b | |||
| ba83958337 | |||
| c11bc1a2bf | |||
| eb82755137 | |||
| 058f392616 | |||
| 0264a66b5a | |||
| 542e72a454 | |||
| a118563366 | |||
| a057ad5db5 | |||
| 8e273ec683 | |||
| 16f1c2f079 | |||
| 62307e8c65 | |||
| 2ebb03b868 | |||
| af5175b96a | |||
| 8af8c1af9c | |||
| d4c6240485 | |||
| 4d5a3b1bf5 | |||
| 8ba7e3bb84 | |||
| 0782675a2b | |||
| 5a18c73572 | |||
| 1c16f65476 | |||
| ae8029f9c4 | |||
| 882ba0da84 | |||
| ee4fc31b9d | |||
| c791495572 | |||
| 35182ccb28 | |||
| 16e6b69fc7 | |||
| a7962eec09 | |||
| 2182be1008 | |||
| 43b27044bb | |||
| 4df621c5c8 | |||
| 57fa3bf7cf | |||
| 1d59ff5de4 | |||
| f96861d417 | |||
| 643023b6d5 | |||
| feb070c165 | |||
| 6a86dd7ca6 | |||
| 6307888722 | |||
| 9b0ff4a6d2 | |||
| 5c0ff6f271 | |||
| 1645ad84ee | |||
| 765afc7d76 | |||
| b5d6bb2243 |
0
.codex_write_test
Normal file
0
.codex_write_test
Normal file
@@ -50,8 +50,8 @@ web/node_modules/
|
||||
web/dist/
|
||||
web/.vite/
|
||||
|
||||
# SQLx offline data (generated at build time)
|
||||
#.sqlx/
|
||||
# SQLx offline data (generated when using `cargo sqlx prepare`)
|
||||
# .sqlx/
|
||||
|
||||
# Configuration files (copied selectively)
|
||||
config.development.yaml
|
||||
@@ -61,6 +61,7 @@ config.example.yaml
|
||||
|
||||
# Scripts (not needed in runtime)
|
||||
scripts/
|
||||
!scripts/load_core_pack.py
|
||||
|
||||
# Cargo lock (workspace handles this)
|
||||
# Uncomment if you want deterministic builds:
|
||||
|
||||
@@ -9,19 +9,32 @@ on:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
RUST_MIN_STACK: 16777216
|
||||
RUST_MIN_STACK: 67108864
|
||||
CARGO_INCREMENTAL: 0
|
||||
CARGO_NET_RETRY: 10
|
||||
RUSTUP_MAX_RETRIES: 10
|
||||
# Gitea Actions runner tool cache. Actions like setup-node/setup-python can reuse this.
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
|
||||
jobs:
|
||||
rust-fmt:
|
||||
name: Rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Rust toolchain
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.rustup/toolchains
|
||||
~/.rustup/update-hashes
|
||||
key: rustup-rustfmt-${{ runner.os }}-stable-v1
|
||||
restore-keys: |
|
||||
rustup-${{ runner.os }}-stable-v1
|
||||
rustup-
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
@@ -32,11 +45,22 @@ jobs:
|
||||
|
||||
rust-clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Rust toolchain
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.rustup/toolchains
|
||||
~/.rustup/update-hashes
|
||||
key: rustup-clippy-${{ runner.os }}-stable-v1
|
||||
restore-keys: |
|
||||
rustup-${{ runner.os }}-stable-v1
|
||||
rustup-
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
@@ -67,11 +91,22 @@ jobs:
|
||||
|
||||
rust-test:
|
||||
name: Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Rust toolchain
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.rustup/toolchains
|
||||
~/.rustup/update-hashes
|
||||
key: rustup-test-${{ runner.os }}-stable-v1
|
||||
restore-keys: |
|
||||
rustup-${{ runner.os }}-stable-v1
|
||||
rustup-
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
@@ -100,11 +135,22 @@ jobs:
|
||||
|
||||
rust-audit:
|
||||
name: Cargo Audit & Deny
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Rust toolchain
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.rustup/toolchains
|
||||
~/.rustup/update-hashes
|
||||
key: rustup-audit-${{ runner.os }}-stable-v1
|
||||
restore-keys: |
|
||||
rustup-${{ runner.os }}-stable-v1
|
||||
rustup-
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
@@ -142,7 +188,7 @@ jobs:
|
||||
|
||||
web-blocking:
|
||||
name: Web Blocking Checks
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
@@ -171,7 +217,7 @@ jobs:
|
||||
|
||||
security-blocking:
|
||||
name: Security Blocking Checks
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -204,7 +250,7 @@ jobs:
|
||||
|
||||
web-advisory:
|
||||
name: Web Advisory Checks
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
continue-on-error: true
|
||||
defaults:
|
||||
run:
|
||||
@@ -233,7 +279,7 @@ jobs:
|
||||
|
||||
security-advisory:
|
||||
name: Security Advisory Checks
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: build-amd64
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
1130
.gitea/workflows/publish.yml
Normal file
1130
.gitea/workflows/publish.yml
Normal file
File diff suppressed because it is too large
Load Diff
15
.githooks/pre-commit
Executable file
15
.githooks/pre-commit
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
repo_root="$(git rev-parse --show-toplevel)"
|
||||
cd "$repo_root"
|
||||
|
||||
echo "Formatting Rust code..."
|
||||
cargo fmt --all
|
||||
|
||||
echo "Refreshing staged Rust files..."
|
||||
git add --all '*.rs'
|
||||
|
||||
echo "Running pre-commit checks..."
|
||||
make pre-commit
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -11,6 +11,7 @@ target/
|
||||
# Configuration files (keep *.example.yaml)
|
||||
config.yaml
|
||||
config.*.yaml
|
||||
!docker/distributable/config.docker.yaml
|
||||
!config.example.yaml
|
||||
!config.development.yaml
|
||||
!config.test.yaml
|
||||
@@ -35,6 +36,7 @@ logs/
|
||||
# Build artifacts
|
||||
dist/
|
||||
build/
|
||||
artifacts/
|
||||
|
||||
# Testing
|
||||
coverage/
|
||||
@@ -78,3 +80,8 @@ docker-compose.override.yml
|
||||
*.pid
|
||||
|
||||
packs.examples/
|
||||
packs.external/
|
||||
codex/
|
||||
|
||||
# Compiled pack binaries (built via Docker or build-pack-binaries.sh)
|
||||
packs/core/sensors/attune-core-timer-sensor
|
||||
|
||||
@@ -2,6 +2,8 @@ target/
|
||||
web/dist/
|
||||
web/node_modules/
|
||||
web/src/api/
|
||||
packs/
|
||||
packs.dev/
|
||||
packs.external/
|
||||
tests/
|
||||
docs/
|
||||
*.md
|
||||
|
||||
788
Cargo.lock
generated
788
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
12
Cargo.toml
12
Cargo.toml
@@ -21,7 +21,7 @@ repository = "https://git.rdrx.app/attune-system/attune"
|
||||
[workspace.dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1.50", features = ["full"] }
|
||||
tokio-util = "0.7"
|
||||
tokio-util = { version = "0.7", features = ["io"] }
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
|
||||
# Web framework
|
||||
@@ -52,17 +52,17 @@ config = "0.15"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
# UUID
|
||||
uuid = { version = "1.21", features = ["v4", "serde"] }
|
||||
uuid = { version = "1.22", features = ["v4", "serde"] }
|
||||
|
||||
# Validation
|
||||
validator = { version = "0.20", features = ["derive"] }
|
||||
|
||||
# CLI
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
clap = { version = "4.6", features = ["derive"] }
|
||||
|
||||
# Message queue / PubSub
|
||||
# RabbitMQ
|
||||
lapin = "4.1"
|
||||
lapin = "4.3"
|
||||
# Redis
|
||||
redis = { version = "1.0", features = ["tokio-comp", "connection-manager"] }
|
||||
|
||||
@@ -101,7 +101,7 @@ tar = "0.4"
|
||||
flate2 = "1.1"
|
||||
|
||||
# WebSocket client
|
||||
tokio-tungstenite = { version = "0.28", features = ["native-tls"] }
|
||||
tokio-tungstenite = { version = "0.28", features = ["rustls-tls-native-roots"] }
|
||||
|
||||
# URL parsing
|
||||
url = "2.5"
|
||||
@@ -114,7 +114,7 @@ futures = "0.3"
|
||||
semver = { version = "1.0", features = ["serde"] }
|
||||
|
||||
# Temp files
|
||||
tempfile = "3.26"
|
||||
tempfile = "3.27"
|
||||
|
||||
# Testing
|
||||
mockall = "0.14"
|
||||
|
||||
134
Makefile
134
Makefile
@@ -3,7 +3,12 @@
|
||||
docker-up docker-down docker-cache-warm docker-stop-system-services dev watch generate-agents-index \
|
||||
docker-build-workers docker-build-worker-base docker-build-worker-python \
|
||||
docker-build-worker-node docker-build-worker-full deny ci-rust ci-web-blocking ci-web-advisory \
|
||||
ci-security-blocking ci-security-advisory ci-blocking ci-advisory
|
||||
ci-security-blocking ci-security-advisory ci-blocking ci-advisory \
|
||||
fmt-check pre-commit install-git-hooks \
|
||||
build-agent docker-build-agent docker-build-agent-arm64 docker-build-agent-all \
|
||||
run-agent run-agent-release \
|
||||
docker-up-agent docker-down-agent \
|
||||
docker-build-pack-binaries docker-build-pack-binaries-arm64 docker-build-pack-binaries-all
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -25,8 +30,12 @@ help:
|
||||
@echo ""
|
||||
@echo "Code Quality:"
|
||||
@echo " make fmt - Format all code"
|
||||
@echo " make fmt-check - Verify formatting without changing files"
|
||||
@echo " make clippy - Run linter"
|
||||
@echo " make lint - Run both fmt and clippy"
|
||||
@echo " make deny - Run cargo-deny checks"
|
||||
@echo " make pre-commit - Run the git pre-commit checks locally"
|
||||
@echo " make install-git-hooks - Configure git to use the repo hook scripts"
|
||||
@echo ""
|
||||
@echo "Running Services:"
|
||||
@echo " make run-api - Run API service"
|
||||
@@ -55,6 +64,21 @@ help:
|
||||
@echo " make docker-up - Start services with docker compose"
|
||||
@echo " make docker-down - Stop services"
|
||||
@echo ""
|
||||
@echo "Agent (Universal Worker):"
|
||||
@echo " make build-agent - Build statically-linked agent binary (musl)"
|
||||
@echo " make docker-build-agent - Build agent Docker image (amd64, default)"
|
||||
@echo " make docker-build-agent-arm64 - Build agent Docker image (arm64)"
|
||||
@echo " make docker-build-agent-all - Build agent Docker images (amd64 + arm64)"
|
||||
@echo " make run-agent - Run agent in development mode"
|
||||
@echo " make run-agent-release - Run agent in release mode"
|
||||
@echo " make docker-up-agent - Start all services + agent workers (ruby, etc.)"
|
||||
@echo " make docker-down-agent - Stop agent stack"
|
||||
@echo ""
|
||||
@echo "Pack Binaries:"
|
||||
@echo " make docker-build-pack-binaries - Build pack binaries Docker image (amd64, default)"
|
||||
@echo " make docker-build-pack-binaries-arm64 - Build pack binaries Docker image (arm64)"
|
||||
@echo " make docker-build-pack-binaries-all - Build pack binaries Docker images (amd64 + arm64)"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make watch - Watch and rebuild on changes"
|
||||
@echo " make install-tools - Install development tools"
|
||||
@@ -64,7 +88,7 @@ help:
|
||||
@echo ""
|
||||
|
||||
# Increase rustc stack size to prevent SIGSEGV during compilation
|
||||
export RUST_MIN_STACK := 16777216
|
||||
export RUST_MIN_STACK:=67108864
|
||||
|
||||
# Building
|
||||
build:
|
||||
@@ -111,6 +135,9 @@ check:
|
||||
fmt:
|
||||
cargo fmt --all
|
||||
|
||||
fmt-check:
|
||||
cargo fmt --all -- --check
|
||||
|
||||
clippy:
|
||||
cargo clippy --all-features -- -D warnings
|
||||
|
||||
@@ -219,38 +246,86 @@ docker-build-api:
|
||||
docker-build-web:
|
||||
docker compose build web
|
||||
|
||||
# Build worker images
|
||||
docker-build-workers: docker-build-worker-base docker-build-worker-python docker-build-worker-node docker-build-worker-full
|
||||
@echo "✅ All worker images built successfully"
|
||||
# Agent binary (statically-linked for injection into any container)
|
||||
AGENT_RUST_TARGET ?= x86_64-unknown-linux-musl
|
||||
|
||||
docker-build-worker-base:
|
||||
@echo "Building base worker (shell only)..."
|
||||
DOCKER_BUILDKIT=1 docker build --target worker-base -t attune-worker:base -f docker/Dockerfile.worker .
|
||||
@echo "✅ Base worker image built: attune-worker:base"
|
||||
# Pack binaries (statically-linked for packs volume)
|
||||
PACK_BINARIES_RUST_TARGET ?= x86_64-unknown-linux-musl
|
||||
|
||||
docker-build-worker-python:
|
||||
@echo "Building Python worker (shell + python)..."
|
||||
DOCKER_BUILDKIT=1 docker build --target worker-python -t attune-worker:python -f docker/Dockerfile.worker .
|
||||
@echo "✅ Python worker image built: attune-worker:python"
|
||||
build-agent:
|
||||
@echo "Installing musl target (if not already installed)..."
|
||||
rustup target add $(AGENT_RUST_TARGET) 2>/dev/null || true
|
||||
@echo "Building statically-linked worker and sensor agent binaries..."
|
||||
SQLX_OFFLINE=true cargo build --release --target $(AGENT_RUST_TARGET) --bin attune-agent --bin attune-sensor-agent
|
||||
strip target/$(AGENT_RUST_TARGET)/release/attune-agent
|
||||
strip target/$(AGENT_RUST_TARGET)/release/attune-sensor-agent
|
||||
@echo "✅ Agent binaries built:"
|
||||
@echo " - target/$(AGENT_RUST_TARGET)/release/attune-agent"
|
||||
@echo " - target/$(AGENT_RUST_TARGET)/release/attune-sensor-agent"
|
||||
@ls -lh target/$(AGENT_RUST_TARGET)/release/attune-agent
|
||||
@ls -lh target/$(AGENT_RUST_TARGET)/release/attune-sensor-agent
|
||||
|
||||
docker-build-worker-node:
|
||||
@echo "Building Node.js worker (shell + node)..."
|
||||
DOCKER_BUILDKIT=1 docker build --target worker-node -t attune-worker:node -f docker/Dockerfile.worker .
|
||||
@echo "✅ Node.js worker image built: attune-worker:node"
|
||||
docker-build-agent:
|
||||
@echo "Building agent Docker image ($(AGENT_RUST_TARGET))..."
|
||||
DOCKER_BUILDKIT=1 docker buildx build --build-arg RUST_TARGET=$(AGENT_RUST_TARGET) --target agent-init -f docker/Dockerfile.agent -t attune-agent:latest .
|
||||
@echo "✅ Agent image built: attune-agent:latest ($(AGENT_RUST_TARGET))"
|
||||
|
||||
docker-build-worker-full:
|
||||
@echo "Building full worker (all runtimes)..."
|
||||
DOCKER_BUILDKIT=1 docker build --target worker-full -t attune-worker:full -f docker/Dockerfile.worker .
|
||||
@echo "✅ Full worker image built: attune-worker:full"
|
||||
docker-build-agent-arm64:
|
||||
@echo "Building arm64 agent Docker image..."
|
||||
DOCKER_BUILDKIT=1 docker buildx build --build-arg RUST_TARGET=aarch64-unknown-linux-musl --target agent-init -f docker/Dockerfile.agent -t attune-agent:arm64 .
|
||||
@echo "✅ Agent image built: attune-agent:arm64"
|
||||
|
||||
docker-build-agent-all:
|
||||
@echo "Building agent Docker images for all architectures..."
|
||||
$(MAKE) docker-build-agent
|
||||
$(MAKE) docker-build-agent-arm64
|
||||
@echo "✅ All agent images built: attune-agent:latest (amd64), attune-agent:arm64"
|
||||
|
||||
run-agent:
|
||||
cargo run --bin attune-agent
|
||||
|
||||
run-agent-release:
|
||||
cargo run --bin attune-agent --release
|
||||
|
||||
# Pack binaries (statically-linked for packs volume)
|
||||
docker-build-pack-binaries:
|
||||
@echo "Building pack binaries Docker image ($(PACK_BINARIES_RUST_TARGET))..."
|
||||
DOCKER_BUILDKIT=1 docker buildx build --build-arg RUST_TARGET=$(PACK_BINARIES_RUST_TARGET) --target pack-binaries-init -f docker/Dockerfile.pack-binaries -t attune-pack-builder:latest .
|
||||
@echo "✅ Pack binaries image built: attune-pack-builder:latest ($(PACK_BINARIES_RUST_TARGET))"
|
||||
|
||||
docker-build-pack-binaries-arm64:
|
||||
@echo "Building arm64 pack binaries Docker image..."
|
||||
DOCKER_BUILDKIT=1 docker buildx build --build-arg RUST_TARGET=aarch64-unknown-linux-musl --target pack-binaries-init -f docker/Dockerfile.pack-binaries -t attune-pack-builder:arm64 .
|
||||
@echo "✅ Pack binaries image built: attune-pack-builder:arm64"
|
||||
|
||||
docker-build-pack-binaries-all:
|
||||
@echo "Building pack binaries Docker images for all architectures..."
|
||||
$(MAKE) docker-build-pack-binaries
|
||||
$(MAKE) docker-build-pack-binaries-arm64
|
||||
@echo "✅ All pack binary images built: attune-pack-builder:latest (amd64), attune-pack-builder:arm64"
|
||||
|
||||
run-sensor-agent:
|
||||
cargo run --bin attune-sensor-agent
|
||||
|
||||
run-sensor-agent-release:
|
||||
cargo run --bin attune-sensor-agent --release
|
||||
|
||||
docker-up:
|
||||
@echo "Starting all services with Docker Compose..."
|
||||
docker compose up -d
|
||||
|
||||
docker-up-agent:
|
||||
@echo "Starting all services + agent-based workers..."
|
||||
docker compose -f docker-compose.yaml -f docker-compose.agent.yaml up -d
|
||||
|
||||
docker-down:
|
||||
@echo "Stopping all services..."
|
||||
docker compose down
|
||||
|
||||
docker-down-agent:
|
||||
@echo "Stopping all services (including agent workers)..."
|
||||
docker compose -f docker-compose.yaml -f docker-compose.agent.yaml down
|
||||
|
||||
docker-down-volumes:
|
||||
@echo "Stopping all services and removing volumes (WARNING: deletes data)..."
|
||||
docker compose down -v
|
||||
@@ -341,6 +416,11 @@ ci-web-blocking:
|
||||
cd web && npm run typecheck
|
||||
cd web && npm run build
|
||||
|
||||
ci-web-pre-commit:
|
||||
cd web && npm ci
|
||||
cd web && npm run lint
|
||||
cd web && npm run typecheck
|
||||
|
||||
ci-web-advisory:
|
||||
cd web && npm ci
|
||||
cd web && npm run knip
|
||||
@@ -381,9 +461,15 @@ licenses:
|
||||
cargo license --json > licenses.json
|
||||
@echo "License information saved to licenses.json"
|
||||
|
||||
# All-in-one check before committing
|
||||
pre-commit: fmt clippy test
|
||||
@echo "✅ All checks passed! Ready to commit."
|
||||
# Blocking checks run by the git pre-commit hook after formatting.
|
||||
# Keep the local web step fast; full production builds stay in CI.
|
||||
pre-commit: deny ci-web-pre-commit ci-security-blocking
|
||||
@echo "✅ Pre-commit checks passed."
|
||||
|
||||
install-git-hooks:
|
||||
git config core.hooksPath .githooks
|
||||
chmod +x .githooks/pre-commit
|
||||
@echo "✅ Git hooks configured to use .githooks/"
|
||||
|
||||
# CI simulation
|
||||
ci: ci-blocking ci-advisory
|
||||
|
||||
6
charts/attune/Chart.yaml
Normal file
6
charts/attune/Chart.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
apiVersion: v2
|
||||
name: attune
|
||||
description: Helm chart for deploying the Attune automation platform
|
||||
type: application
|
||||
version: 0.1.0
|
||||
appVersion: "0.1.0"
|
||||
26
charts/attune/templates/NOTES.txt
Normal file
26
charts/attune/templates/NOTES.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
1. Set `global.imageRegistry`, `global.imageNamespace`, and `global.imageTag` so the chart pulls the images published by the Gitea workflow.
|
||||
2. Set `web.config.apiUrl` and `web.config.wsUrl` to browser-reachable endpoints before exposing the web UI.
|
||||
3. The shared `packs`, `runtime_envs`, and `artifacts` PVCs default to `ReadWriteMany`; your cluster storage class must support RWX or you need to override those claims.
|
||||
{{- if .Values.agentWorkers }}
|
||||
|
||||
Agent-based workers enabled:
|
||||
{{- range .Values.agentWorkers }}
|
||||
- {{ .name }}: image={{ .image }}, replicas={{ .replicas | default 1 }}
|
||||
{{- if .runtimes }} runtimes={{ join "," .runtimes }}{{ else }} runtimes=auto-detect{{ end }}
|
||||
{{- end }}
|
||||
|
||||
Each agent worker uses an init container to copy the statically-linked
|
||||
attune-agent binary into the worker pod via an emptyDir volume. The agent
|
||||
auto-detects available runtimes in the container and registers with Attune.
|
||||
|
||||
The default sensor deployment also uses the same injection pattern, copying
|
||||
`attune-sensor-agent` into the pod before starting a stock runtime image.
|
||||
|
||||
To add more agent workers, append entries to `agentWorkers` in your values:
|
||||
|
||||
agentWorkers:
|
||||
- name: my-runtime
|
||||
image: my-org/my-image:latest
|
||||
replicas: 1
|
||||
runtimes: [] # auto-detect
|
||||
{{- end }}
|
||||
113
charts/attune/templates/_helpers.tpl
Normal file
113
charts/attune/templates/_helpers.tpl
Normal file
@@ -0,0 +1,113 @@
|
||||
{{- define "attune.name" -}}
|
||||
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.fullname" -}}
|
||||
{{- if .Values.fullnameOverride -}}
|
||||
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" -}}
|
||||
{{- else -}}
|
||||
{{- printf "%s-%s" .Release.Name (include "attune.name" .) | trunc 63 | trimSuffix "-" -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.chart" -}}
|
||||
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.labels" -}}
|
||||
helm.sh/chart: {{ include "attune.chart" . }}
|
||||
app.kubernetes.io/name: {{ include "attune.name" . }}
|
||||
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.selectorLabels" -}}
|
||||
app.kubernetes.io/name: {{ include "attune.name" . }}
|
||||
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.componentLabels" -}}
|
||||
{{ include "attune.selectorLabels" .root }}
|
||||
app.kubernetes.io/component: {{ .component }}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.image" -}}
|
||||
{{- $root := .root -}}
|
||||
{{- $image := .image -}}
|
||||
{{- $registry := $root.Values.global.imageRegistry -}}
|
||||
{{- $namespace := $root.Values.global.imageNamespace -}}
|
||||
{{- $repository := $image.repository -}}
|
||||
{{- $tag := default $root.Values.global.imageTag $image.tag -}}
|
||||
{{- if and $registry $namespace -}}
|
||||
{{- printf "%s/%s/%s:%s" $registry $namespace $repository $tag -}}
|
||||
{{- else if $registry -}}
|
||||
{{- printf "%s/%s:%s" $registry $repository $tag -}}
|
||||
{{- else -}}
|
||||
{{- printf "%s:%s" $repository $tag -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.secretName" -}}
|
||||
{{- if .Values.security.existingSecret -}}
|
||||
{{- .Values.security.existingSecret -}}
|
||||
{{- else -}}
|
||||
{{- printf "%s-secrets" (include "attune.fullname" .) -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.postgresqlServiceName" -}}
|
||||
{{- if .Values.database.host -}}
|
||||
{{- .Values.database.host -}}
|
||||
{{- else -}}
|
||||
{{- printf "%s-postgresql" (include "attune.fullname" .) -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.rabbitmqServiceName" -}}
|
||||
{{- if .Values.rabbitmq.host -}}
|
||||
{{- .Values.rabbitmq.host -}}
|
||||
{{- else -}}
|
||||
{{- printf "%s-rabbitmq" (include "attune.fullname" .) -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.redisServiceName" -}}
|
||||
{{- if .Values.redis.host -}}
|
||||
{{- .Values.redis.host -}}
|
||||
{{- else -}}
|
||||
{{- printf "%s-redis" (include "attune.fullname" .) -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.databaseUrl" -}}
|
||||
{{- if .Values.database.url -}}
|
||||
{{- .Values.database.url -}}
|
||||
{{- else -}}
|
||||
{{- printf "postgresql://%s:%s@%s:%v/%s" .Values.database.username .Values.database.password (include "attune.postgresqlServiceName" .) .Values.database.port .Values.database.database -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.rabbitmqUrl" -}}
|
||||
{{- if .Values.rabbitmq.url -}}
|
||||
{{- .Values.rabbitmq.url -}}
|
||||
{{- else -}}
|
||||
{{- printf "amqp://%s:%s@%s:%v" .Values.rabbitmq.username .Values.rabbitmq.password (include "attune.rabbitmqServiceName" .) .Values.rabbitmq.port -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.redisUrl" -}}
|
||||
{{- if .Values.redis.url -}}
|
||||
{{- .Values.redis.url -}}
|
||||
{{- else -}}
|
||||
{{- printf "redis://%s:%v" (include "attune.redisServiceName" .) .Values.redis.port -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.apiServiceName" -}}
|
||||
{{- printf "%s-api" (include "attune.fullname" .) -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "attune.notifierServiceName" -}}
|
||||
{{- printf "%s-notifier" (include "attune.fullname" .) -}}
|
||||
{{- end -}}
|
||||
137
charts/attune/templates/agent-workers.yaml
Normal file
137
charts/attune/templates/agent-workers.yaml
Normal file
@@ -0,0 +1,137 @@
|
||||
{{- range .Values.agentWorkers }}
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" $ }}-agent-worker-{{ .name }}
|
||||
labels:
|
||||
{{- include "attune.labels" $ | nindent 4 }}
|
||||
app.kubernetes.io/component: agent-worker-{{ .name }}
|
||||
spec:
|
||||
replicas: {{ .replicas | default 1 }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.selectorLabels" $ | nindent 6 }}
|
||||
app.kubernetes.io/component: agent-worker-{{ .name }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.selectorLabels" $ | nindent 8 }}
|
||||
app.kubernetes.io/component: agent-worker-{{ .name }}
|
||||
spec:
|
||||
{{- if $.Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml $.Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .runtimeClassName }}
|
||||
runtimeClassName: {{ .runtimeClassName }}
|
||||
{{- end }}
|
||||
{{- if .nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml .nodeSelector | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml .tolerations | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .stopGracePeriod }}
|
||||
terminationGracePeriodSeconds: {{ .stopGracePeriod }}
|
||||
{{- else }}
|
||||
terminationGracePeriodSeconds: 45
|
||||
{{- end }}
|
||||
initContainers:
|
||||
- name: agent-loader
|
||||
image: {{ include "attune.image" (dict "root" $ "image" $.Values.images.agent) }}
|
||||
imagePullPolicy: {{ $.Values.images.agent.pullPolicy }}
|
||||
command: ["cp", "/usr/local/bin/attune-agent", "/opt/attune/agent/attune-agent"]
|
||||
volumeMounts:
|
||||
- name: agent-bin
|
||||
mountPath: /opt/attune/agent
|
||||
- name: wait-for-schema
|
||||
image: postgres:16-alpine
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until PGPASSWORD="$DB_PASSWORD" psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -tAc "SELECT to_regclass('${DB_SCHEMA}.identity')" | grep -q identity; do
|
||||
echo "waiting for schema";
|
||||
sleep 2;
|
||||
done
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" $ }}
|
||||
- name: wait-for-packs
|
||||
image: busybox:1.36
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until [ -f /opt/attune/packs/core/pack.yaml ]; do
|
||||
echo "waiting for packs";
|
||||
sleep 2;
|
||||
done
|
||||
volumeMounts:
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
containers:
|
||||
- name: worker
|
||||
image: {{ .image }}
|
||||
{{- if .imagePullPolicy }}
|
||||
imagePullPolicy: {{ .imagePullPolicy }}
|
||||
{{- end }}
|
||||
command: ["/opt/attune/agent/attune-agent"]
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" $ }}
|
||||
env:
|
||||
- name: ATTUNE_CONFIG
|
||||
value: /opt/attune/config.yaml
|
||||
- name: ATTUNE__DATABASE__SCHEMA
|
||||
value: {{ $.Values.database.schema | quote }}
|
||||
- name: ATTUNE_WORKER_TYPE
|
||||
value: container
|
||||
- name: ATTUNE_WORKER_NAME
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
- name: ATTUNE_API_URL
|
||||
value: http://{{ include "attune.apiServiceName" $ }}:{{ $.Values.api.service.port }}
|
||||
- name: RUST_LOG
|
||||
value: {{ .logLevel | default "info" }}
|
||||
{{- if .runtimes }}
|
||||
- name: ATTUNE_WORKER_RUNTIMES
|
||||
value: {{ join "," .runtimes | quote }}
|
||||
{{- end }}
|
||||
{{- if .env }}
|
||||
{{- toYaml .env | nindent 12 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
{{- toYaml (.resources | default dict) | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: agent-bin
|
||||
mountPath: /opt/attune/agent
|
||||
readOnly: true
|
||||
- name: config
|
||||
mountPath: /opt/attune/config.yaml
|
||||
subPath: config.yaml
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
readOnly: true
|
||||
- name: runtime-envs
|
||||
mountPath: /opt/attune/runtime_envs
|
||||
- name: artifacts
|
||||
mountPath: /opt/attune/artifacts
|
||||
volumes:
|
||||
- name: agent-bin
|
||||
emptyDir: {}
|
||||
- name: config
|
||||
configMap:
|
||||
name: {{ include "attune.fullname" $ }}-config
|
||||
- name: packs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" $ }}-packs
|
||||
- name: runtime-envs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" $ }}-runtime-envs
|
||||
- name: artifacts
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" $ }}-artifacts
|
||||
{{- end }}
|
||||
542
charts/attune/templates/applications.yaml
Normal file
542
charts/attune/templates/applications.yaml
Normal file
@@ -0,0 +1,542 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "attune.apiServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
type: {{ .Values.api.service.type }}
|
||||
selector:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "api") | nindent 4 }}
|
||||
ports:
|
||||
- name: http
|
||||
port: {{ .Values.api.service.port }}
|
||||
targetPort: http
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "attune.apiServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
replicas: {{ .Values.api.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "api") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "api") | nindent 8 }}
|
||||
spec:
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
initContainers:
|
||||
- name: wait-for-schema
|
||||
image: postgres:16-alpine
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until PGPASSWORD="$DB_PASSWORD" psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -tAc "SELECT to_regclass('${DB_SCHEMA}.identity')" | grep -q identity; do
|
||||
echo "waiting for schema";
|
||||
sleep 2;
|
||||
done
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
- name: wait-for-packs
|
||||
image: busybox:1.36
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until [ -f /opt/attune/packs/core/pack.yaml ]; do
|
||||
echo "waiting for packs";
|
||||
sleep 2;
|
||||
done
|
||||
volumeMounts:
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
containers:
|
||||
- name: api
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.api) }}
|
||||
imagePullPolicy: {{ .Values.images.api.pullPolicy }}
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
env:
|
||||
- name: ATTUNE_CONFIG
|
||||
value: /opt/attune/config.yaml
|
||||
- name: ATTUNE__DATABASE__SCHEMA
|
||||
value: {{ .Values.database.schema | quote }}
|
||||
- name: ATTUNE__WORKER__WORKER_TYPE
|
||||
value: container
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: 8080
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 20
|
||||
periodSeconds: 15
|
||||
resources:
|
||||
{{- toYaml .Values.api.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /opt/attune/config.yaml
|
||||
subPath: config.yaml
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
- name: runtime-envs
|
||||
mountPath: /opt/attune/runtime_envs
|
||||
- name: artifacts
|
||||
mountPath: /opt/attune/artifacts
|
||||
volumes:
|
||||
- name: config
|
||||
configMap:
|
||||
name: {{ include "attune.fullname" . }}-config
|
||||
- name: packs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-packs
|
||||
- name: runtime-envs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-runtime-envs
|
||||
- name: artifacts
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-artifacts
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-executor
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
replicas: {{ .Values.executor.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "executor") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "executor") | nindent 8 }}
|
||||
spec:
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
initContainers:
|
||||
- name: wait-for-schema
|
||||
image: postgres:16-alpine
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until PGPASSWORD="$DB_PASSWORD" psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -tAc "SELECT to_regclass('${DB_SCHEMA}.identity')" | grep -q identity; do
|
||||
echo "waiting for schema";
|
||||
sleep 2;
|
||||
done
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
- name: wait-for-packs
|
||||
image: busybox:1.36
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until [ -f /opt/attune/packs/core/pack.yaml ]; do
|
||||
echo "waiting for packs";
|
||||
sleep 2;
|
||||
done
|
||||
volumeMounts:
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
containers:
|
||||
- name: executor
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.executor) }}
|
||||
imagePullPolicy: {{ .Values.images.executor.pullPolicy }}
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
env:
|
||||
- name: ATTUNE_CONFIG
|
||||
value: /opt/attune/config.yaml
|
||||
- name: ATTUNE__DATABASE__SCHEMA
|
||||
value: {{ .Values.database.schema | quote }}
|
||||
- name: ATTUNE__WORKER__WORKER_TYPE
|
||||
value: container
|
||||
resources:
|
||||
{{- toYaml .Values.executor.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /opt/attune/config.yaml
|
||||
subPath: config.yaml
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
- name: artifacts
|
||||
mountPath: /opt/attune/artifacts
|
||||
volumes:
|
||||
- name: config
|
||||
configMap:
|
||||
name: {{ include "attune.fullname" . }}-config
|
||||
- name: packs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-packs
|
||||
- name: artifacts
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-artifacts
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-worker
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
replicas: {{ .Values.worker.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "worker") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "worker") | nindent 8 }}
|
||||
spec:
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
initContainers:
|
||||
- name: wait-for-schema
|
||||
image: postgres:16-alpine
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until PGPASSWORD="$DB_PASSWORD" psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -tAc "SELECT to_regclass('${DB_SCHEMA}.identity')" | grep -q identity; do
|
||||
echo "waiting for schema";
|
||||
sleep 2;
|
||||
done
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
- name: wait-for-packs
|
||||
image: busybox:1.36
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until [ -f /opt/attune/packs/core/pack.yaml ]; do
|
||||
echo "waiting for packs";
|
||||
sleep 2;
|
||||
done
|
||||
volumeMounts:
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
containers:
|
||||
- name: worker
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.worker) }}
|
||||
imagePullPolicy: {{ .Values.images.worker.pullPolicy }}
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
env:
|
||||
- name: ATTUNE_CONFIG
|
||||
value: /opt/attune/config.yaml
|
||||
- name: ATTUNE__DATABASE__SCHEMA
|
||||
value: {{ .Values.database.schema | quote }}
|
||||
- name: ATTUNE_WORKER_RUNTIMES
|
||||
value: {{ .Values.worker.runtimes | quote }}
|
||||
- name: ATTUNE_WORKER_TYPE
|
||||
value: container
|
||||
- name: ATTUNE_WORKER_NAME
|
||||
value: {{ .Values.worker.name | quote }}
|
||||
- name: ATTUNE_API_URL
|
||||
value: http://{{ include "attune.apiServiceName" . }}:{{ .Values.api.service.port }}
|
||||
resources:
|
||||
{{- toYaml .Values.worker.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /opt/attune/config.yaml
|
||||
subPath: config.yaml
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
- name: runtime-envs
|
||||
mountPath: /opt/attune/runtime_envs
|
||||
- name: artifacts
|
||||
mountPath: /opt/attune/artifacts
|
||||
volumes:
|
||||
- name: config
|
||||
configMap:
|
||||
name: {{ include "attune.fullname" . }}-config
|
||||
- name: packs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-packs
|
||||
- name: runtime-envs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-runtime-envs
|
||||
- name: artifacts
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-artifacts
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-sensor
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
replicas: {{ .Values.sensor.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "sensor") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "sensor") | nindent 8 }}
|
||||
spec:
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
terminationGracePeriodSeconds: 45
|
||||
initContainers:
|
||||
- name: sensor-agent-loader
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.agent) }}
|
||||
imagePullPolicy: {{ .Values.images.agent.pullPolicy }}
|
||||
command: ["cp", "/usr/local/bin/attune-sensor-agent", "/opt/attune/agent/attune-sensor-agent"]
|
||||
volumeMounts:
|
||||
- name: agent-bin
|
||||
mountPath: /opt/attune/agent
|
||||
- name: wait-for-schema
|
||||
image: postgres:16-alpine
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until PGPASSWORD="$DB_PASSWORD" psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -tAc "SELECT to_regclass('${DB_SCHEMA}.identity')" | grep -q identity; do
|
||||
echo "waiting for schema";
|
||||
sleep 2;
|
||||
done
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
- name: wait-for-packs
|
||||
image: busybox:1.36
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until [ -f /opt/attune/packs/core/pack.yaml ]; do
|
||||
echo "waiting for packs";
|
||||
sleep 2;
|
||||
done
|
||||
volumeMounts:
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
containers:
|
||||
- name: sensor
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.sensor) }}
|
||||
imagePullPolicy: {{ .Values.images.sensor.pullPolicy }}
|
||||
command: ["/opt/attune/agent/attune-sensor-agent"]
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
env:
|
||||
- name: ATTUNE_CONFIG
|
||||
value: /opt/attune/config.yaml
|
||||
- name: ATTUNE__DATABASE__SCHEMA
|
||||
value: {{ .Values.database.schema | quote }}
|
||||
- name: ATTUNE__WORKER__WORKER_TYPE
|
||||
value: container
|
||||
- name: ATTUNE_SENSOR_RUNTIMES
|
||||
value: {{ .Values.sensor.runtimes | quote }}
|
||||
- name: ATTUNE_API_URL
|
||||
value: http://{{ include "attune.apiServiceName" . }}:{{ .Values.api.service.port }}
|
||||
- name: ATTUNE_MQ_URL
|
||||
value: {{ include "attune.rabbitmqUrl" . | quote }}
|
||||
- name: ATTUNE_PACKS_BASE_DIR
|
||||
value: /opt/attune/packs
|
||||
- name: RUST_LOG
|
||||
value: {{ .Values.sensor.logLevel | quote }}
|
||||
resources:
|
||||
{{- toYaml .Values.sensor.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: agent-bin
|
||||
mountPath: /opt/attune/agent
|
||||
readOnly: true
|
||||
- name: config
|
||||
mountPath: /opt/attune/config.yaml
|
||||
subPath: config.yaml
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
readOnly: true
|
||||
- name: runtime-envs
|
||||
mountPath: /opt/attune/runtime_envs
|
||||
volumes:
|
||||
- name: agent-bin
|
||||
emptyDir: {}
|
||||
- name: config
|
||||
configMap:
|
||||
name: {{ include "attune.fullname" . }}-config
|
||||
- name: packs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-packs
|
||||
- name: runtime-envs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-runtime-envs
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "attune.notifierServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
type: {{ .Values.notifier.service.type }}
|
||||
selector:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "notifier") | nindent 4 }}
|
||||
ports:
|
||||
- name: ws
|
||||
port: {{ .Values.notifier.service.port }}
|
||||
targetPort: ws
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "attune.notifierServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
replicas: {{ .Values.notifier.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "notifier") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "notifier") | nindent 8 }}
|
||||
spec:
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
initContainers:
|
||||
- name: wait-for-schema
|
||||
image: postgres:16-alpine
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until PGPASSWORD="$DB_PASSWORD" psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -tAc "SELECT to_regclass('${DB_SCHEMA}.identity')" | grep -q identity; do
|
||||
echo "waiting for schema";
|
||||
sleep 2;
|
||||
done
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
containers:
|
||||
- name: notifier
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.notifier) }}
|
||||
imagePullPolicy: {{ .Values.images.notifier.pullPolicy }}
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
env:
|
||||
- name: ATTUNE_CONFIG
|
||||
value: /opt/attune/config.yaml
|
||||
- name: ATTUNE__DATABASE__SCHEMA
|
||||
value: {{ .Values.database.schema | quote }}
|
||||
- name: ATTUNE__WORKER__WORKER_TYPE
|
||||
value: container
|
||||
ports:
|
||||
- name: ws
|
||||
containerPort: 8081
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: ws
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: ws
|
||||
initialDelaySeconds: 20
|
||||
periodSeconds: 15
|
||||
resources:
|
||||
{{- toYaml .Values.notifier.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /opt/attune/config.yaml
|
||||
subPath: config.yaml
|
||||
volumes:
|
||||
- name: config
|
||||
configMap:
|
||||
name: {{ include "attune.fullname" . }}-config
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-web
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
type: {{ .Values.web.service.type }}
|
||||
selector:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "web") | nindent 4 }}
|
||||
ports:
|
||||
- name: http
|
||||
port: {{ .Values.web.service.port }}
|
||||
targetPort: http
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-web
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
replicas: {{ .Values.web.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "web") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "web") | nindent 8 }}
|
||||
spec:
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: web
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.web) }}
|
||||
imagePullPolicy: {{ .Values.images.web.pullPolicy }}
|
||||
env:
|
||||
- name: API_URL
|
||||
value: {{ .Values.web.config.apiUrl | quote }}
|
||||
- name: WS_URL
|
||||
value: {{ .Values.web.config.wsUrl | quote }}
|
||||
- name: ENVIRONMENT
|
||||
value: {{ .Values.web.config.environment | quote }}
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: 80
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 20
|
||||
periodSeconds: 15
|
||||
resources:
|
||||
{{- toYaml .Values.web.resources | nindent 12 }}
|
||||
9
charts/attune/templates/configmap.yaml
Normal file
9
charts/attune/templates/configmap.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-config
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
data:
|
||||
config.yaml: |
|
||||
{{ .Files.Get "files/config.docker.yaml" | indent 4 }}
|
||||
225
charts/attune/templates/infrastructure.yaml
Normal file
225
charts/attune/templates/infrastructure.yaml
Normal file
@@ -0,0 +1,225 @@
|
||||
{{- if .Values.database.postgresql.enabled }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "attune.postgresqlServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
selector:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "postgresql") | nindent 4 }}
|
||||
ports:
|
||||
- name: postgres
|
||||
port: {{ .Values.database.port }}
|
||||
targetPort: postgres
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: {{ include "attune.postgresqlServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
serviceName: {{ include "attune.postgresqlServiceName" . }}
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "postgresql") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "postgresql") | nindent 8 }}
|
||||
spec:
|
||||
containers:
|
||||
- name: postgresql
|
||||
image: "{{ .Values.database.postgresql.image.repository }}:{{ .Values.database.postgresql.image.tag }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
env:
|
||||
- name: POSTGRES_USER
|
||||
value: {{ .Values.database.username | quote }}
|
||||
- name: POSTGRES_PASSWORD
|
||||
value: {{ .Values.database.password | quote }}
|
||||
- name: POSTGRES_DB
|
||||
value: {{ .Values.database.database | quote }}
|
||||
- name: PGDATA
|
||||
value: /var/lib/postgresql/data/pgdata
|
||||
ports:
|
||||
- name: postgres
|
||||
containerPort: 5432
|
||||
livenessProbe:
|
||||
exec:
|
||||
command: ["pg_isready", "-U", "{{ .Values.database.username }}"]
|
||||
initialDelaySeconds: 20
|
||||
periodSeconds: 10
|
||||
readinessProbe:
|
||||
exec:
|
||||
command: ["pg_isready", "-U", "{{ .Values.database.username }}"]
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
resources:
|
||||
{{- toYaml .Values.database.postgresql.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: data
|
||||
mountPath: /var/lib/postgresql/data
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: data
|
||||
spec:
|
||||
accessModes:
|
||||
{{- toYaml .Values.database.postgresql.persistence.accessModes | nindent 10 }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.database.postgresql.persistence.size }}
|
||||
{{- if .Values.database.postgresql.persistence.storageClassName }}
|
||||
storageClassName: {{ .Values.database.postgresql.persistence.storageClassName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- if .Values.rabbitmq.enabled }}
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "attune.rabbitmqServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
selector:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "rabbitmq") | nindent 4 }}
|
||||
ports:
|
||||
- name: amqp
|
||||
port: {{ .Values.rabbitmq.port }}
|
||||
targetPort: amqp
|
||||
- name: management
|
||||
port: {{ .Values.rabbitmq.managementPort }}
|
||||
targetPort: management
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: {{ include "attune.rabbitmqServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
serviceName: {{ include "attune.rabbitmqServiceName" . }}
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "rabbitmq") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "rabbitmq") | nindent 8 }}
|
||||
spec:
|
||||
containers:
|
||||
- name: rabbitmq
|
||||
image: "{{ .Values.rabbitmq.image.repository }}:{{ .Values.rabbitmq.image.tag }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
env:
|
||||
- name: RABBITMQ_DEFAULT_USER
|
||||
value: {{ .Values.rabbitmq.username | quote }}
|
||||
- name: RABBITMQ_DEFAULT_PASS
|
||||
value: {{ .Values.rabbitmq.password | quote }}
|
||||
- name: RABBITMQ_DEFAULT_VHOST
|
||||
value: /
|
||||
ports:
|
||||
- name: amqp
|
||||
containerPort: 5672
|
||||
- name: management
|
||||
containerPort: 15672
|
||||
livenessProbe:
|
||||
exec:
|
||||
command: ["rabbitmq-diagnostics", "-q", "ping"]
|
||||
initialDelaySeconds: 20
|
||||
periodSeconds: 15
|
||||
readinessProbe:
|
||||
exec:
|
||||
command: ["rabbitmq-diagnostics", "-q", "ping"]
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
resources:
|
||||
{{- toYaml .Values.rabbitmq.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: data
|
||||
mountPath: /var/lib/rabbitmq
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: data
|
||||
spec:
|
||||
accessModes:
|
||||
{{- toYaml .Values.rabbitmq.persistence.accessModes | nindent 10 }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.rabbitmq.persistence.size }}
|
||||
{{- if .Values.rabbitmq.persistence.storageClassName }}
|
||||
storageClassName: {{ .Values.rabbitmq.persistence.storageClassName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- if .Values.redis.enabled }}
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "attune.redisServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
selector:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "redis") | nindent 4 }}
|
||||
ports:
|
||||
- name: redis
|
||||
port: {{ .Values.redis.port }}
|
||||
targetPort: redis
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: {{ include "attune.redisServiceName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
serviceName: {{ include "attune.redisServiceName" . }}
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "redis") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "redis") | nindent 8 }}
|
||||
spec:
|
||||
containers:
|
||||
- name: redis
|
||||
image: "{{ .Values.redis.image.repository }}:{{ .Values.redis.image.tag }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: ["redis-server", "--appendonly", "yes"]
|
||||
ports:
|
||||
- name: redis
|
||||
containerPort: 6379
|
||||
livenessProbe:
|
||||
exec:
|
||||
command: ["redis-cli", "ping"]
|
||||
initialDelaySeconds: 15
|
||||
periodSeconds: 10
|
||||
readinessProbe:
|
||||
exec:
|
||||
command: ["redis-cli", "ping"]
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
resources:
|
||||
{{- toYaml .Values.redis.resources | nindent 12 }}
|
||||
volumeMounts:
|
||||
- name: data
|
||||
mountPath: /data
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: data
|
||||
spec:
|
||||
accessModes:
|
||||
{{- toYaml .Values.redis.persistence.accessModes | nindent 10 }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.redis.persistence.size }}
|
||||
{{- if .Values.redis.persistence.storageClassName }}
|
||||
storageClassName: {{ .Values.redis.persistence.storageClassName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
35
charts/attune/templates/ingress.yaml
Normal file
35
charts/attune/templates/ingress.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
{{- if .Values.web.ingress.enabled }}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-web
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
{{- with .Values.web.ingress.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- if .Values.web.ingress.className }}
|
||||
ingressClassName: {{ .Values.web.ingress.className }}
|
||||
{{- end }}
|
||||
rules:
|
||||
{{- range .Values.web.ingress.hosts }}
|
||||
- host: {{ .host | quote }}
|
||||
http:
|
||||
paths:
|
||||
{{- range .paths }}
|
||||
- path: {{ .path }}
|
||||
pathType: {{ .pathType }}
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "attune.fullname" $ }}-web
|
||||
port:
|
||||
number: {{ $.Values.web.service.port }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- with .Values.web.ingress.tls }}
|
||||
tls:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
154
charts/attune/templates/jobs.yaml
Normal file
154
charts/attune/templates/jobs.yaml
Normal file
@@ -0,0 +1,154 @@
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-migrations
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: migrations
|
||||
annotations:
|
||||
helm.sh/hook: post-install,post-upgrade
|
||||
helm.sh/hook-weight: "-20"
|
||||
helm.sh/hook-delete-policy: before-hook-creation,hook-succeeded
|
||||
spec:
|
||||
ttlSecondsAfterFinished: {{ .Values.jobs.migrations.ttlSecondsAfterFinished }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "migrations") | nindent 8 }}
|
||||
spec:
|
||||
restartPolicy: OnFailure
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: migrations
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.migrations) }}
|
||||
imagePullPolicy: {{ .Values.images.migrations.pullPolicy }}
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
env:
|
||||
- name: MIGRATIONS_DIR
|
||||
value: /migrations
|
||||
resources:
|
||||
{{- toYaml .Values.jobs.migrations.resources | nindent 12 }}
|
||||
---
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-init-user
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: init-user
|
||||
annotations:
|
||||
helm.sh/hook: post-install,post-upgrade
|
||||
helm.sh/hook-weight: "-10"
|
||||
helm.sh/hook-delete-policy: before-hook-creation,hook-succeeded
|
||||
spec:
|
||||
ttlSecondsAfterFinished: {{ .Values.jobs.initUser.ttlSecondsAfterFinished }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "init-user") | nindent 8 }}
|
||||
spec:
|
||||
restartPolicy: OnFailure
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: init-user
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.initUser) }}
|
||||
imagePullPolicy: {{ .Values.images.initUser.pullPolicy }}
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until PGPASSWORD="$DB_PASSWORD" psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -tAc "SELECT to_regclass('${DB_SCHEMA}.identity')" | grep -q identity; do
|
||||
echo "waiting for database schema";
|
||||
sleep 2;
|
||||
done
|
||||
exec /init-user.sh
|
||||
resources:
|
||||
{{- toYaml .Values.jobs.initUser.resources | nindent 12 }}
|
||||
---
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-init-packs
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: init-packs
|
||||
annotations:
|
||||
helm.sh/hook: post-install,post-upgrade
|
||||
helm.sh/hook-weight: "0"
|
||||
helm.sh/hook-delete-policy: before-hook-creation,hook-succeeded
|
||||
spec:
|
||||
ttlSecondsAfterFinished: {{ .Values.jobs.initPacks.ttlSecondsAfterFinished }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "attune.componentLabels" (dict "root" . "component" "init-packs") | nindent 8 }}
|
||||
spec:
|
||||
restartPolicy: OnFailure
|
||||
{{- if .Values.global.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml .Values.global.imagePullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: init-packs
|
||||
image: {{ include "attune.image" (dict "root" . "image" .Values.images.initPacks) }}
|
||||
imagePullPolicy: {{ .Values.images.initPacks.pullPolicy }}
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
command: ["/bin/sh", "-ec"]
|
||||
args:
|
||||
- |
|
||||
until python3 - <<'PY'
|
||||
import os
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=os.environ["DB_HOST"],
|
||||
port=os.environ["DB_PORT"],
|
||||
user=os.environ["DB_USER"],
|
||||
password=os.environ["DB_PASSWORD"],
|
||||
dbname=os.environ["DB_NAME"],
|
||||
)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SET search_path TO %s, public" % os.environ["DB_SCHEMA"])
|
||||
cur.execute("SELECT to_regclass(%s)", (f"{os.environ['DB_SCHEMA']}.identity",))
|
||||
value = cur.fetchone()[0]
|
||||
raise SystemExit(0 if value else 1)
|
||||
finally:
|
||||
conn.close()
|
||||
PY
|
||||
do
|
||||
echo "waiting for database schema";
|
||||
sleep 2;
|
||||
done
|
||||
exec /init-packs.sh
|
||||
volumeMounts:
|
||||
- name: packs
|
||||
mountPath: /opt/attune/packs
|
||||
- name: runtime-envs
|
||||
mountPath: /opt/attune/runtime_envs
|
||||
- name: artifacts
|
||||
mountPath: /opt/attune/artifacts
|
||||
resources:
|
||||
{{- toYaml .Values.jobs.initPacks.resources | nindent 12 }}
|
||||
volumes:
|
||||
- name: packs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-packs
|
||||
- name: runtime-envs
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-runtime-envs
|
||||
- name: artifacts
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ include "attune.fullname" . }}-artifacts
|
||||
53
charts/attune/templates/pvc.yaml
Normal file
53
charts/attune/templates/pvc.yaml
Normal file
@@ -0,0 +1,53 @@
|
||||
{{- if .Values.sharedStorage.packs.enabled }}
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-packs
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
accessModes:
|
||||
{{- toYaml .Values.sharedStorage.packs.accessModes | nindent 4 }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.sharedStorage.packs.size }}
|
||||
{{- if .Values.sharedStorage.packs.storageClassName }}
|
||||
storageClassName: {{ .Values.sharedStorage.packs.storageClassName }}
|
||||
{{- end }}
|
||||
---
|
||||
{{- end }}
|
||||
{{- if .Values.sharedStorage.runtimeEnvs.enabled }}
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-runtime-envs
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
accessModes:
|
||||
{{- toYaml .Values.sharedStorage.runtimeEnvs.accessModes | nindent 4 }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.sharedStorage.runtimeEnvs.size }}
|
||||
{{- if .Values.sharedStorage.runtimeEnvs.storageClassName }}
|
||||
storageClassName: {{ .Values.sharedStorage.runtimeEnvs.storageClassName }}
|
||||
{{- end }}
|
||||
---
|
||||
{{- end }}
|
||||
{{- if .Values.sharedStorage.artifacts.enabled }}
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ include "attune.fullname" . }}-artifacts
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
spec:
|
||||
accessModes:
|
||||
{{- toYaml .Values.sharedStorage.artifacts.accessModes | nindent 4 }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.sharedStorage.artifacts.size }}
|
||||
{{- if .Values.sharedStorage.artifacts.storageClassName }}
|
||||
storageClassName: {{ .Values.sharedStorage.artifacts.storageClassName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
31
charts/attune/templates/secret.yaml
Normal file
31
charts/attune/templates/secret.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
{{- if not .Values.security.existingSecret }}
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: {{ include "attune.secretName" . }}
|
||||
labels:
|
||||
{{- include "attune.labels" . | nindent 4 }}
|
||||
type: Opaque
|
||||
stringData:
|
||||
ATTUNE__SECURITY__JWT_SECRET: {{ .Values.security.jwtSecret | quote }}
|
||||
ATTUNE__SECURITY__ENCRYPTION_KEY: {{ .Values.security.encryptionKey | quote }}
|
||||
ATTUNE__DATABASE__URL: {{ include "attune.databaseUrl" . | quote }}
|
||||
ATTUNE__MESSAGE_QUEUE__URL: {{ include "attune.rabbitmqUrl" . | quote }}
|
||||
ATTUNE__REDIS__URL: {{ include "attune.redisUrl" . | quote }}
|
||||
DB_HOST: {{ include "attune.postgresqlServiceName" . | quote }}
|
||||
DB_PORT: {{ .Values.database.port | quote }}
|
||||
DB_USER: {{ .Values.database.username | quote }}
|
||||
DB_PASSWORD: {{ .Values.database.password | quote }}
|
||||
DB_NAME: {{ .Values.database.database | quote }}
|
||||
DB_SCHEMA: {{ .Values.database.schema | quote }}
|
||||
TEST_LOGIN: {{ .Values.bootstrap.testUser.login | quote }}
|
||||
TEST_DISPLAY_NAME: {{ .Values.bootstrap.testUser.displayName | quote }}
|
||||
TEST_PASSWORD: {{ .Values.bootstrap.testUser.password | quote }}
|
||||
DEFAULT_ADMIN_LOGIN: {{ .Values.bootstrap.testUser.login | quote }}
|
||||
DEFAULT_ADMIN_PERMISSION_SET_REF: "core.admin"
|
||||
SOURCE_PACKS_DIR: "/source/packs"
|
||||
TARGET_PACKS_DIR: "/opt/attune/packs"
|
||||
RUNTIME_ENVS_DIR: "/opt/attune/runtime_envs"
|
||||
ARTIFACTS_DIR: "/opt/attune/artifacts"
|
||||
LOADER_SCRIPT: "/scripts/load_core_pack.py"
|
||||
{{- end }}
|
||||
253
charts/attune/values.yaml
Normal file
253
charts/attune/values.yaml
Normal file
@@ -0,0 +1,253 @@
|
||||
nameOverride: ""
|
||||
fullnameOverride: ""
|
||||
|
||||
global:
|
||||
imageRegistry: ""
|
||||
imageNamespace: ""
|
||||
imageTag: edge
|
||||
imagePullSecrets: []
|
||||
|
||||
security:
|
||||
existingSecret: ""
|
||||
jwtSecret: change-me-in-production
|
||||
encryptionKey: change-me-in-production-32-bytes-minimum
|
||||
|
||||
database:
|
||||
schema: public
|
||||
username: attune
|
||||
password: attune
|
||||
database: attune
|
||||
host: ""
|
||||
port: 5432
|
||||
url: ""
|
||||
postgresql:
|
||||
enabled: true
|
||||
image:
|
||||
repository: timescale/timescaledb
|
||||
tag: 2.17.2-pg16
|
||||
persistence:
|
||||
enabled: true
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 20Gi
|
||||
storageClassName: ""
|
||||
resources: {}
|
||||
|
||||
rabbitmq:
|
||||
username: attune
|
||||
password: attune
|
||||
host: ""
|
||||
port: 5672
|
||||
url: ""
|
||||
managementPort: 15672
|
||||
enabled: true
|
||||
image:
|
||||
repository: rabbitmq
|
||||
tag: 3.13-management-alpine
|
||||
persistence:
|
||||
enabled: true
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 8Gi
|
||||
storageClassName: ""
|
||||
resources: {}
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
host: ""
|
||||
port: 6379
|
||||
url: ""
|
||||
image:
|
||||
repository: redis
|
||||
tag: 7-alpine
|
||||
persistence:
|
||||
enabled: true
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 8Gi
|
||||
storageClassName: ""
|
||||
resources: {}
|
||||
|
||||
bootstrap:
|
||||
testUser:
|
||||
login: test@attune.local
|
||||
displayName: Test User
|
||||
password: TestPass123!
|
||||
|
||||
sharedStorage:
|
||||
packs:
|
||||
enabled: true
|
||||
accessModes:
|
||||
- ReadWriteMany
|
||||
size: 2Gi
|
||||
storageClassName: ""
|
||||
runtimeEnvs:
|
||||
enabled: true
|
||||
accessModes:
|
||||
- ReadWriteMany
|
||||
size: 10Gi
|
||||
storageClassName: ""
|
||||
artifacts:
|
||||
enabled: true
|
||||
accessModes:
|
||||
- ReadWriteMany
|
||||
size: 20Gi
|
||||
storageClassName: ""
|
||||
|
||||
images:
|
||||
api:
|
||||
repository: attune-api
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
executor:
|
||||
repository: attune-executor
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
worker:
|
||||
repository: attune-worker
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
sensor:
|
||||
repository: nikolaik/python-nodejs
|
||||
tag: python3.12-nodejs22-slim
|
||||
pullPolicy: IfNotPresent
|
||||
notifier:
|
||||
repository: attune-notifier
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
web:
|
||||
repository: attune-web
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
migrations:
|
||||
repository: attune-migrations
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
initUser:
|
||||
repository: attune-init-user
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
initPacks:
|
||||
repository: attune-init-packs
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
agent:
|
||||
repository: attune-agent
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
|
||||
jobs:
|
||||
migrations:
|
||||
ttlSecondsAfterFinished: 300
|
||||
resources: {}
|
||||
initUser:
|
||||
ttlSecondsAfterFinished: 300
|
||||
resources: {}
|
||||
initPacks:
|
||||
ttlSecondsAfterFinished: 300
|
||||
resources: {}
|
||||
|
||||
api:
|
||||
replicaCount: 1
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 8080
|
||||
resources: {}
|
||||
|
||||
executor:
|
||||
replicaCount: 1
|
||||
resources: {}
|
||||
|
||||
worker:
|
||||
replicaCount: 1
|
||||
runtimes: shell,python,node,native
|
||||
name: worker-full-01
|
||||
resources: {}
|
||||
|
||||
sensor:
|
||||
replicaCount: 1
|
||||
runtimes: shell,python,node,native
|
||||
logLevel: debug
|
||||
resources: {}
|
||||
|
||||
notifier:
|
||||
replicaCount: 1
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 8081
|
||||
resources: {}
|
||||
|
||||
web:
|
||||
replicaCount: 1
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 80
|
||||
config:
|
||||
environment: kubernetes
|
||||
apiUrl: http://localhost:8080
|
||||
wsUrl: ws://localhost:8081
|
||||
resources: {}
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
annotations: {}
|
||||
hosts:
|
||||
- host: attune.local
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
tls: []
|
||||
|
||||
# Agent-based workers
|
||||
# These deploy the universal worker agent into any container image.
|
||||
# The agent auto-detects available runtimes (python, ruby, node, etc.)
|
||||
# and registers with the Attune platform.
|
||||
#
|
||||
# Each entry creates a separate Deployment with an init container that
|
||||
# copies the statically-linked agent binary into the worker container.
|
||||
#
|
||||
# Supported fields per worker:
|
||||
# name (required) - Unique name for this worker (used in resource names)
|
||||
# image (required) - Container image with your desired runtime(s)
|
||||
# replicas (optional) - Number of pod replicas (default: 1)
|
||||
# runtimes (optional) - List of runtimes to expose; [] = auto-detect
|
||||
# resources (optional) - Kubernetes resource requests/limits
|
||||
# env (optional) - Extra environment variables (list of {name, value})
|
||||
# imagePullPolicy (optional) - Pull policy for the worker image
|
||||
# logLevel (optional) - RUST_LOG level (default: "info")
|
||||
# runtimeClassName (optional) - Kubernetes RuntimeClass (e.g., "nvidia" for GPU)
|
||||
# nodeSelector (optional) - Node selector map for pod scheduling
|
||||
# tolerations (optional) - Tolerations list for pod scheduling
|
||||
# stopGracePeriod (optional) - Termination grace period in seconds (default: 45)
|
||||
#
|
||||
# Examples:
|
||||
# agentWorkers:
|
||||
# - name: ruby
|
||||
# image: ruby:3.3
|
||||
# replicas: 2
|
||||
# runtimes: [] # auto-detect
|
||||
# resources: {}
|
||||
#
|
||||
# - name: python-gpu
|
||||
# image: nvidia/cuda:12.3.1-runtime-ubuntu22.04
|
||||
# replicas: 1
|
||||
# runtimes: [python, shell]
|
||||
# runtimeClassName: nvidia
|
||||
# nodeSelector:
|
||||
# gpu: "true"
|
||||
# tolerations:
|
||||
# - key: nvidia.com/gpu
|
||||
# operator: Exists
|
||||
# effect: NoSchedule
|
||||
# resources:
|
||||
# limits:
|
||||
# nvidia.com/gpu: 1
|
||||
#
|
||||
# - name: custom
|
||||
# image: my-org/my-custom-image:latest
|
||||
# replicas: 1
|
||||
# runtimes: []
|
||||
# env:
|
||||
# - name: MY_CUSTOM_VAR
|
||||
# value: my-value
|
||||
agentWorkers: []
|
||||
@@ -47,6 +47,21 @@ security:
|
||||
encryption_key: test-encryption-key-32-chars-okay
|
||||
enable_auth: true
|
||||
allow_self_registration: true
|
||||
oidc:
|
||||
enabled: false
|
||||
discovery_url: https://auth.rdrx.app/.well-known/openid-configuration
|
||||
client_id: 31d194737840d32bd3afe6474826976bae346d77247a158c4dc43887278eb605
|
||||
client_secret: null
|
||||
redirect_uri: http://localhost:3000/auth/callback
|
||||
post_logout_redirect_uri: http://localhost:3000/login
|
||||
scopes:
|
||||
- groups
|
||||
ldap:
|
||||
enabled: false
|
||||
url: ldap://localhost:389
|
||||
bind_dn_template: "uid={login},ou=users,dc=example,dc=com"
|
||||
provider_name: ldap
|
||||
provider_label: Development LDAP
|
||||
|
||||
# Packs directory (where pack action files are located)
|
||||
packs_base_dir: ./packs
|
||||
@@ -110,3 +125,8 @@ executor:
|
||||
scheduled_timeout: 120 # 2 minutes (faster feedback in dev)
|
||||
timeout_check_interval: 30 # Check every 30 seconds
|
||||
enable_timeout_monitor: true
|
||||
|
||||
# Agent binary distribution (optional - for local development)
|
||||
# Binary is built via: make build-agent
|
||||
# agent:
|
||||
# binary_dir: ./target/x86_64-unknown-linux-musl/release
|
||||
|
||||
@@ -86,6 +86,48 @@ security:
|
||||
# Enable authentication
|
||||
enable_auth: true
|
||||
|
||||
# Login page defaults for the web UI. Users can still override with:
|
||||
# /login?auth=direct
|
||||
# /login?auth=<provider_name>
|
||||
login_page:
|
||||
show_local_login: true
|
||||
show_oidc_login: true
|
||||
show_ldap_login: true
|
||||
|
||||
# Optional OIDC browser login configuration
|
||||
oidc:
|
||||
enabled: false
|
||||
discovery_url: https://auth.example.com/.well-known/openid-configuration
|
||||
client_id: your-confidential-client-id
|
||||
provider_name: sso
|
||||
provider_label: Example SSO
|
||||
provider_icon_url: https://auth.example.com/assets/logo.svg
|
||||
client_secret: your-confidential-client-secret
|
||||
redirect_uri: http://localhost:3000/auth/callback
|
||||
post_logout_redirect_uri: http://localhost:3000/login
|
||||
scopes:
|
||||
- groups
|
||||
|
||||
# Optional LDAP authentication configuration
|
||||
ldap:
|
||||
enabled: false
|
||||
url: ldap://ldap.example.com:389
|
||||
# Direct-bind mode: construct DN from template
|
||||
# bind_dn_template: "uid={login},ou=users,dc=example,dc=com"
|
||||
# Search-and-bind mode: search for user with a service account
|
||||
user_search_base: "ou=users,dc=example,dc=com"
|
||||
user_filter: "(uid={login})"
|
||||
search_bind_dn: "cn=readonly,dc=example,dc=com"
|
||||
search_bind_password: "readonly-password"
|
||||
login_attr: uid
|
||||
email_attr: mail
|
||||
display_name_attr: cn
|
||||
group_attr: memberOf
|
||||
starttls: false
|
||||
danger_skip_tls_verify: false
|
||||
provider_name: ldap
|
||||
provider_label: Company LDAP
|
||||
|
||||
# Worker configuration (optional, for worker services)
|
||||
# Uncomment and configure if running worker processes
|
||||
# worker:
|
||||
|
||||
@@ -62,6 +62,8 @@ pack_registry:
|
||||
enabled: true
|
||||
default_registry: https://registry.attune.example.com
|
||||
cache_ttl: 300
|
||||
allowed_source_hosts:
|
||||
- registry.attune.example.com
|
||||
|
||||
# Test worker configuration
|
||||
# worker:
|
||||
|
||||
@@ -27,6 +27,8 @@ futures = { workspace = true }
|
||||
|
||||
# Web framework
|
||||
axum = { workspace = true, features = ["multipart"] }
|
||||
axum-extra = { version = "0.10", features = ["cookie"] }
|
||||
cookie = "0.18"
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
|
||||
@@ -67,6 +69,9 @@ jsonschema = { workspace = true }
|
||||
|
||||
# HTTP client
|
||||
reqwest = { workspace = true }
|
||||
openidconnect = "4.0"
|
||||
ldap3 = { version = "0.12", default-features = false, features = ["sync", "tls-rustls-ring"] }
|
||||
url = { workspace = true }
|
||||
|
||||
# Archive/compression
|
||||
tar = { workspace = true }
|
||||
@@ -84,10 +89,12 @@ hmac = "0.12"
|
||||
sha1 = "0.10"
|
||||
sha2 = { workspace = true }
|
||||
hex = "0.4"
|
||||
subtle = "2.6"
|
||||
|
||||
# OpenAPI/Swagger
|
||||
utoipa = { workspace = true, features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "9.0", features = ["axum"] }
|
||||
jsonwebtoken = { workspace = true, features = ["rust_crypto"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mockall = { workspace = true }
|
||||
|
||||
504
crates/api/src/auth/ldap.rs
Normal file
504
crates/api/src/auth/ldap.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
//! LDAP authentication helpers for username/password login.
|
||||
|
||||
use attune_common::{
|
||||
config::LdapConfig,
|
||||
repositories::{
|
||||
identity::{
|
||||
CreateIdentityInput, IdentityRepository, IdentityRoleAssignmentRepository,
|
||||
UpdateIdentityInput,
|
||||
},
|
||||
Create, Update,
|
||||
},
|
||||
};
|
||||
use ldap3::{dn_escape, ldap_escape, Ldap, LdapConnAsync, LdapConnSettings, Scope, SearchEntry};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::{
|
||||
auth::jwt::{generate_access_token, generate_refresh_token},
|
||||
dto::TokenResponse,
|
||||
middleware::error::ApiError,
|
||||
state::SharedState,
|
||||
};
|
||||
|
||||
/// Claims extracted from the LDAP directory for an authenticated user.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LdapUserClaims {
|
||||
/// The LDAP server URL the user was authenticated against.
|
||||
pub server_url: String,
|
||||
/// The user's full distinguished name.
|
||||
pub dn: String,
|
||||
/// Login attribute value (uid, sAMAccountName, etc.).
|
||||
pub login: Option<String>,
|
||||
/// Email address.
|
||||
pub email: Option<String>,
|
||||
/// Display name (cn).
|
||||
pub display_name: Option<String>,
|
||||
/// Group memberships (memberOf values).
|
||||
pub groups: Vec<String>,
|
||||
}
|
||||
|
||||
/// The result of a successful LDAP authentication.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LdapAuthenticatedIdentity {
|
||||
pub token_response: TokenResponse,
|
||||
}
|
||||
|
||||
/// Authenticate a user against the configured LDAP directory.
|
||||
///
|
||||
/// This performs a bind (either direct or search+bind) to verify
|
||||
/// the user's credentials, then fetches their attributes and upserts
|
||||
/// the identity in the database.
|
||||
pub async fn authenticate(
|
||||
state: &SharedState,
|
||||
login: &str,
|
||||
password: &str,
|
||||
) -> Result<LdapAuthenticatedIdentity, ApiError> {
|
||||
let ldap_config = ldap_config(state)?;
|
||||
|
||||
// Connect and authenticate
|
||||
let claims = if ldap_config.bind_dn_template.is_some() {
|
||||
direct_bind(&ldap_config, login, password).await?
|
||||
} else {
|
||||
search_and_bind(&ldap_config, login, password).await?
|
||||
};
|
||||
|
||||
// Upsert identity in DB and issue JWT tokens
|
||||
let identity = upsert_identity(state, &claims).await?;
|
||||
if identity.frozen {
|
||||
return Err(ApiError::Forbidden(
|
||||
"Identity is frozen and cannot authenticate".to_string(),
|
||||
));
|
||||
}
|
||||
let access_token = generate_access_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
let refresh_token = generate_refresh_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
|
||||
let token_response = TokenResponse::new(
|
||||
access_token,
|
||||
refresh_token,
|
||||
state.jwt_config.access_token_expiration,
|
||||
)
|
||||
.with_user(
|
||||
identity.id,
|
||||
identity.login.clone(),
|
||||
identity.display_name.clone(),
|
||||
);
|
||||
|
||||
Ok(LdapAuthenticatedIdentity { token_response })
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn ldap_config(state: &SharedState) -> Result<LdapConfig, ApiError> {
|
||||
let config = state
|
||||
.config
|
||||
.security
|
||||
.ldap
|
||||
.clone()
|
||||
.filter(|ldap| ldap.enabled)
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotImplemented("LDAP authentication is not configured".to_string())
|
||||
})?;
|
||||
|
||||
// Reject partial service-account configuration: having exactly one of
|
||||
// search_bind_dn / search_bind_password is almost certainly a config
|
||||
// error and would silently fall back to anonymous search, which is a
|
||||
// very different security posture than the admin intended.
|
||||
let has_dn = config.search_bind_dn.is_some();
|
||||
let has_pw = config.search_bind_password.is_some();
|
||||
if has_dn != has_pw {
|
||||
let missing = if has_dn {
|
||||
"search_bind_password"
|
||||
} else {
|
||||
"search_bind_dn"
|
||||
};
|
||||
return Err(ApiError::InternalServerError(format!(
|
||||
"LDAP misconfiguration: search_bind_dn and search_bind_password must both be set \
|
||||
or both be omitted (missing {missing})"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Build an `LdapConnSettings` from the config.
|
||||
fn conn_settings(config: &LdapConfig) -> LdapConnSettings {
|
||||
let mut settings = LdapConnSettings::new();
|
||||
if config.starttls {
|
||||
settings = settings.set_starttls(true);
|
||||
}
|
||||
if config.danger_skip_tls_verify {
|
||||
settings = settings.set_no_tls_verify(true);
|
||||
}
|
||||
settings
|
||||
}
|
||||
|
||||
/// Open a new LDAP connection.
|
||||
async fn connect(config: &LdapConfig) -> Result<Ldap, ApiError> {
|
||||
let settings = conn_settings(config);
|
||||
let url = config.url.as_deref().unwrap_or_default();
|
||||
let (conn, ldap) = LdapConnAsync::with_settings(settings, url)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Failed to connect to LDAP server: {err}"))
|
||||
})?;
|
||||
// Drive the connection in the background
|
||||
ldap3::drive!(conn);
|
||||
Ok(ldap)
|
||||
}
|
||||
|
||||
/// Direct-bind authentication: construct the DN from the template and bind.
|
||||
async fn direct_bind(
|
||||
config: &LdapConfig,
|
||||
login: &str,
|
||||
password: &str,
|
||||
) -> Result<LdapUserClaims, ApiError> {
|
||||
let template = config.bind_dn_template.as_deref().unwrap_or_default();
|
||||
// Escape the login value for safe interpolation into a Distinguished Name
|
||||
// (RFC 4514). Without this, characters like `,`, `+`, `"`, `\`, `<`, `>`,
|
||||
// `;`, `=`, NUL, `#` (leading), or space (leading/trailing) in the username
|
||||
// would alter the DN structure.
|
||||
let escaped_login = dn_escape(login);
|
||||
let bind_dn = template.replace("{login}", &escaped_login);
|
||||
|
||||
let mut ldap = connect(config).await?;
|
||||
|
||||
// Bind as the user
|
||||
let result = ldap
|
||||
.simple_bind(&bind_dn, password)
|
||||
.await
|
||||
.map_err(|err| ApiError::InternalServerError(format!("LDAP bind failed: {err}")))?;
|
||||
|
||||
if result.rc != 0 {
|
||||
let _ = ldap.unbind().await;
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid LDAP credentials".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch user attributes
|
||||
let claims = fetch_user_attributes(config, &mut ldap, &bind_dn).await?;
|
||||
|
||||
let _ = ldap.unbind().await;
|
||||
Ok(claims)
|
||||
}
|
||||
|
||||
/// Search-and-bind authentication:
|
||||
/// 1. Bind as the service account (or anonymous)
|
||||
/// 2. Search for the user entry (must match exactly one)
|
||||
/// 3. Re-bind as the user with their DN + password
|
||||
async fn search_and_bind(
|
||||
config: &LdapConfig,
|
||||
login: &str,
|
||||
password: &str,
|
||||
) -> Result<LdapUserClaims, ApiError> {
|
||||
let search_base = config.user_search_base.as_deref().ok_or_else(|| {
|
||||
ApiError::InternalServerError(
|
||||
"LDAP user_search_base is required when bind_dn_template is not set".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut ldap = connect(config).await?;
|
||||
|
||||
// Step 1: Bind as service account or anonymous.
|
||||
// Partial config (only one of dn/password) is already rejected by
|
||||
// ldap_config(), so this match is exhaustive over valid states.
|
||||
if let (Some(bind_dn), Some(bind_pw)) = (
|
||||
config.search_bind_dn.as_deref(),
|
||||
config.search_bind_password.as_deref(),
|
||||
) {
|
||||
let result = ldap.simple_bind(bind_dn, bind_pw).await.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("LDAP service bind failed: {err}"))
|
||||
})?;
|
||||
if result.rc != 0 {
|
||||
let _ = ldap.unbind().await;
|
||||
return Err(ApiError::InternalServerError(
|
||||
"LDAP service account bind failed — check search_bind_dn and search_bind_password"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
// If no service account, we proceed with an anonymous connection (already connected)
|
||||
|
||||
// Step 2: Search for the user.
|
||||
// Escape the login value for safe interpolation into an LDAP search filter
|
||||
// (RFC 4515). Without this, characters like `(`, `)`, `*`, `\`, and NUL in
|
||||
// the username could broaden the filter, match unintended entries, or break
|
||||
// the search entirely.
|
||||
let escaped_login = ldap_escape(login);
|
||||
let filter = config.user_filter.replace("{login}", &escaped_login);
|
||||
let attrs = vec![
|
||||
config.login_attr.as_str(),
|
||||
config.email_attr.as_str(),
|
||||
config.display_name_attr.as_str(),
|
||||
config.group_attr.as_str(),
|
||||
"dn",
|
||||
];
|
||||
|
||||
let (results, _result) = ldap
|
||||
.search(search_base, Scope::Subtree, &filter, attrs)
|
||||
.await
|
||||
.map_err(|err| ApiError::InternalServerError(format!("LDAP user search failed: {err}")))?
|
||||
.success()
|
||||
.map_err(|err| ApiError::InternalServerError(format!("LDAP search error: {err}")))?;
|
||||
|
||||
// The search must return exactly one entry. Zero means the user was not
|
||||
// found; more than one means the filter or directory layout is ambiguous
|
||||
// and we must not guess which identity to authenticate.
|
||||
let result_count = results.len();
|
||||
if result_count == 0 {
|
||||
let _ = ldap.unbind().await;
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid LDAP credentials".to_string(),
|
||||
));
|
||||
}
|
||||
if result_count > 1 {
|
||||
let _ = ldap.unbind().await;
|
||||
return Err(ApiError::InternalServerError(format!(
|
||||
"LDAP user search returned {result_count} entries (expected exactly 1) — \
|
||||
tighten the user_filter or user_search_base to ensure uniqueness"
|
||||
)));
|
||||
}
|
||||
|
||||
// SAFETY: result_count == 1 guaranteed by the checks above.
|
||||
let entry = results
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("checked result_count == 1");
|
||||
let search_entry = SearchEntry::construct(entry);
|
||||
let user_dn = search_entry.dn.clone();
|
||||
|
||||
// Step 3: Re-bind as the user
|
||||
let result = ldap
|
||||
.simple_bind(&user_dn, password)
|
||||
.await
|
||||
.map_err(|err| ApiError::InternalServerError(format!("LDAP user bind failed: {err}")))?;
|
||||
if result.rc != 0 {
|
||||
let _ = ldap.unbind().await;
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid LDAP credentials".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let claims = extract_claims(config, &search_entry);
|
||||
let _ = ldap.unbind().await;
|
||||
Ok(claims)
|
||||
}
|
||||
|
||||
/// Fetch the user's LDAP attributes after a successful bind.
|
||||
async fn fetch_user_attributes(
|
||||
config: &LdapConfig,
|
||||
ldap: &mut Ldap,
|
||||
user_dn: &str,
|
||||
) -> Result<LdapUserClaims, ApiError> {
|
||||
let attrs = vec![
|
||||
config.login_attr.as_str(),
|
||||
config.email_attr.as_str(),
|
||||
config.display_name_attr.as_str(),
|
||||
config.group_attr.as_str(),
|
||||
];
|
||||
|
||||
let (results, _result) = ldap
|
||||
.search(user_dn, Scope::Base, "(objectClass=*)", attrs)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!(
|
||||
"LDAP attribute fetch failed for DN {user_dn}: {err}"
|
||||
))
|
||||
})?
|
||||
.success()
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("LDAP attribute search error: {err}"))
|
||||
})?;
|
||||
|
||||
let entry = results.into_iter().next().ok_or_else(|| {
|
||||
ApiError::InternalServerError(format!("LDAP entry not found for DN: {user_dn}"))
|
||||
})?;
|
||||
let search_entry = SearchEntry::construct(entry);
|
||||
|
||||
Ok(extract_claims(config, &search_entry))
|
||||
}
|
||||
|
||||
/// Extract user claims from an LDAP search entry.
|
||||
fn extract_claims(config: &LdapConfig, entry: &SearchEntry) -> LdapUserClaims {
|
||||
let first_attr =
|
||||
|name: &str| -> Option<String> { entry.attrs.get(name).and_then(|v| v.first()).cloned() };
|
||||
|
||||
let groups = entry
|
||||
.attrs
|
||||
.get(&config.group_attr)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
LdapUserClaims {
|
||||
server_url: config.url.clone().unwrap_or_default(),
|
||||
dn: entry.dn.clone(),
|
||||
login: first_attr(&config.login_attr),
|
||||
email: first_attr(&config.email_attr),
|
||||
display_name: first_attr(&config.display_name_attr),
|
||||
groups,
|
||||
}
|
||||
}
|
||||
|
||||
/// Upsert an identity row for the LDAP-authenticated user.
|
||||
async fn upsert_identity(
|
||||
state: &SharedState,
|
||||
claims: &LdapUserClaims,
|
||||
) -> Result<attune_common::models::identity::Identity, ApiError> {
|
||||
let existing =
|
||||
IdentityRepository::find_by_ldap_dn(&state.db, &claims.server_url, &claims.dn).await?;
|
||||
let desired_login = derive_login(claims);
|
||||
let display_name = claims.display_name.clone();
|
||||
let attributes = json!({ "ldap": claims });
|
||||
|
||||
match existing {
|
||||
Some(identity) => {
|
||||
let updated = UpdateIdentityInput {
|
||||
display_name,
|
||||
password_hash: None,
|
||||
attributes: Some(attributes),
|
||||
frozen: None,
|
||||
};
|
||||
let identity = IdentityRepository::update(&state.db, identity.id, updated)
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
sync_roles(&state.db, identity.id, "ldap", &claims.groups).await?;
|
||||
Ok(identity)
|
||||
}
|
||||
None => {
|
||||
// Avoid login collisions
|
||||
let login = match IdentityRepository::find_by_login(&state.db, &desired_login).await? {
|
||||
Some(_) => fallback_dn_login(claims),
|
||||
None => desired_login,
|
||||
};
|
||||
|
||||
let identity = IdentityRepository::create(
|
||||
&state.db,
|
||||
CreateIdentityInput {
|
||||
login,
|
||||
display_name,
|
||||
password_hash: None,
|
||||
attributes,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
sync_roles(&state.db, identity.id, "ldap", &claims.groups).await?;
|
||||
Ok(identity)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn sync_roles(
|
||||
db: &sqlx::PgPool,
|
||||
identity_id: i64,
|
||||
source: &str,
|
||||
roles: &[String],
|
||||
) -> Result<(), ApiError> {
|
||||
IdentityRoleAssignmentRepository::replace_managed_roles(db, identity_id, source, roles)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Derive the login name from LDAP claims.
|
||||
fn derive_login(claims: &LdapUserClaims) -> String {
|
||||
claims
|
||||
.login
|
||||
.clone()
|
||||
.or_else(|| claims.email.clone())
|
||||
.unwrap_or_else(|| fallback_dn_login(claims))
|
||||
}
|
||||
|
||||
/// Generate a deterministic fallback login from the LDAP server URL + DN.
|
||||
fn fallback_dn_login(claims: &LdapUserClaims) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(claims.server_url.as_bytes());
|
||||
hasher.update(b":");
|
||||
hasher.update(claims.dn.as_bytes());
|
||||
let digest = hex::encode(hasher.finalize());
|
||||
format!("ldap:{}", &digest[..24])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn direct_bind_dn_escapes_special_characters() {
|
||||
// Simulate what direct_bind does with the template
|
||||
let template = "uid={login},ou=users,dc=example,dc=com";
|
||||
let malicious_login = "admin,ou=admins,dc=evil,dc=com";
|
||||
let escaped = dn_escape(malicious_login);
|
||||
let bind_dn = template.replace("{login}", &escaped);
|
||||
// The commas in the login value must be escaped so they don't
|
||||
// introduce additional RDN components.
|
||||
assert!(
|
||||
bind_dn.contains("\\2c"),
|
||||
"commas in login must be escaped in DN: {bind_dn}"
|
||||
);
|
||||
assert!(
|
||||
bind_dn.starts_with("uid=admin\\2cou\\3dadmins\\2cdc\\3devil\\2cdc\\3dcom,ou=users"),
|
||||
"DN structure must be preserved: {bind_dn}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_filter_escapes_special_characters() {
|
||||
let filter_template = "(uid={login})";
|
||||
let malicious_login = "admin)(|(uid=*))";
|
||||
let escaped = ldap_escape(malicious_login);
|
||||
let filter = filter_template.replace("{login}", &escaped);
|
||||
// The parentheses and asterisk must be escaped so they don't
|
||||
// alter the filter structure.
|
||||
assert!(
|
||||
!filter.contains(")("),
|
||||
"parentheses in login must be escaped in filter: {filter}"
|
||||
);
|
||||
assert!(
|
||||
filter.contains("\\28"),
|
||||
"open-paren must be hex-escaped: {filter}"
|
||||
);
|
||||
assert!(
|
||||
filter.contains("\\29"),
|
||||
"close-paren must be hex-escaped: {filter}"
|
||||
);
|
||||
assert!(
|
||||
filter.contains("\\2a"),
|
||||
"asterisk must be hex-escaped: {filter}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dn_escape_preserves_safe_usernames() {
|
||||
let safe = "jdoe";
|
||||
let escaped = dn_escape(safe);
|
||||
assert_eq!(escaped.as_ref(), "jdoe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_escape_preserves_safe_usernames() {
|
||||
let safe = "jdoe";
|
||||
let escaped = ldap_escape(safe);
|
||||
assert_eq!(escaped.as_ref(), "jdoe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_dn_login_is_deterministic() {
|
||||
let claims = LdapUserClaims {
|
||||
server_url: "ldap://ldap.example.com".to_string(),
|
||||
dn: "uid=test,ou=users,dc=example,dc=com".to_string(),
|
||||
login: None,
|
||||
email: None,
|
||||
display_name: None,
|
||||
groups: vec![],
|
||||
};
|
||||
let a = fallback_dn_login(&claims);
|
||||
let b = fallback_dn_login(&claims);
|
||||
assert_eq!(a, b);
|
||||
assert!(a.starts_with("ldap:"));
|
||||
assert_eq!(a.len(), "ldap:".len() + 24);
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{header::AUTHORIZATION, StatusCode},
|
||||
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
@@ -14,6 +14,8 @@ use attune_common::auth::jwt::{
|
||||
extract_token_from_header, validate_token, Claims, JwtConfig, TokenType,
|
||||
};
|
||||
|
||||
use super::oidc::{cookie_authenticated_user, ACCESS_COOKIE_NAME};
|
||||
|
||||
/// Authentication middleware state
|
||||
#[derive(Clone)]
|
||||
pub struct AuthMiddleware {
|
||||
@@ -50,21 +52,7 @@ pub async fn require_auth(
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AuthError> {
|
||||
// Extract Authorization header
|
||||
let auth_header = request
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or(AuthError::MissingToken)?;
|
||||
|
||||
// Extract token from Bearer scheme
|
||||
let token = extract_token_from_header(auth_header).ok_or(AuthError::InvalidToken)?;
|
||||
|
||||
// Validate token
|
||||
let claims = validate_token(token, &auth.jwt_config).map_err(|e| match e {
|
||||
super::jwt::JwtError::Expired => AuthError::ExpiredToken,
|
||||
_ => AuthError::InvalidToken,
|
||||
})?;
|
||||
let claims = extract_claims(request.headers(), &auth.jwt_config)?;
|
||||
|
||||
// Add claims to request extensions
|
||||
request
|
||||
@@ -90,22 +78,13 @@ impl axum::extract::FromRequestParts<crate::state::SharedState> for RequireAuth
|
||||
return Ok(RequireAuth(user.clone()));
|
||||
}
|
||||
|
||||
// Otherwise, extract and validate token directly from header
|
||||
// Extract Authorization header
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or(AuthError::MissingToken)?;
|
||||
|
||||
// Extract token from Bearer scheme
|
||||
let token = extract_token_from_header(auth_header).ok_or(AuthError::InvalidToken)?;
|
||||
|
||||
// Validate token using jwt_config from app state
|
||||
let claims = validate_token(token, &state.jwt_config).map_err(|e| match e {
|
||||
super::jwt::JwtError::Expired => AuthError::ExpiredToken,
|
||||
_ => AuthError::InvalidToken,
|
||||
})?;
|
||||
let claims = if let Some(user) =
|
||||
cookie_authenticated_user(&parts.headers, state).map_err(map_cookie_auth_error)?
|
||||
{
|
||||
user.claims
|
||||
} else {
|
||||
extract_claims(&parts.headers, &state.jwt_config)?
|
||||
};
|
||||
|
||||
// Allow access, sensor, and execution-scoped tokens
|
||||
if claims.token_type != TokenType::Access
|
||||
@@ -119,6 +98,33 @@ impl axum::extract::FromRequestParts<crate::state::SharedState> for RequireAuth
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_claims(headers: &HeaderMap, jwt_config: &JwtConfig) -> Result<Claims, AuthError> {
|
||||
if let Some(auth_header) = headers.get(AUTHORIZATION).and_then(|h| h.to_str().ok()) {
|
||||
let token = extract_token_from_header(auth_header).ok_or(AuthError::InvalidToken)?;
|
||||
return validate_token(token, jwt_config).map_err(|e| match e {
|
||||
super::jwt::JwtError::Expired => AuthError::ExpiredToken,
|
||||
_ => AuthError::InvalidToken,
|
||||
});
|
||||
}
|
||||
|
||||
if headers
|
||||
.get(axum::http::header::COOKIE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.is_some_and(|cookies| cookies.contains(ACCESS_COOKIE_NAME))
|
||||
{
|
||||
return Err(AuthError::InvalidToken);
|
||||
}
|
||||
|
||||
Err(AuthError::MissingToken)
|
||||
}
|
||||
|
||||
fn map_cookie_auth_error(error: crate::middleware::error::ApiError) -> AuthError {
|
||||
match error {
|
||||
crate::middleware::error::ApiError::Unauthorized(_) => AuthError::InvalidToken,
|
||||
_ => AuthError::InvalidToken,
|
||||
}
|
||||
}
|
||||
|
||||
/// Authentication errors
|
||||
#[derive(Debug)]
|
||||
pub enum AuthError {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
//! Authentication and authorization module
|
||||
|
||||
pub mod jwt;
|
||||
pub mod ldap;
|
||||
pub mod middleware;
|
||||
pub mod oidc;
|
||||
pub mod password;
|
||||
|
||||
pub use jwt::{generate_token, validate_token, Claims};
|
||||
|
||||
803
crates/api/src/auth/oidc.rs
Normal file
803
crates/api/src/auth/oidc.rs
Normal file
@@ -0,0 +1,803 @@
|
||||
//! OpenID Connect helpers for browser login.
|
||||
|
||||
use attune_common::{
|
||||
config::OidcConfig,
|
||||
repositories::{
|
||||
identity::{
|
||||
CreateIdentityInput, IdentityRepository, IdentityRoleAssignmentRepository,
|
||||
UpdateIdentityInput,
|
||||
},
|
||||
Create, Update,
|
||||
},
|
||||
};
|
||||
use axum::{
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||
use cookie::time::Duration as CookieDuration;
|
||||
use jsonwebtoken::{
|
||||
decode, decode_header,
|
||||
jwk::{AlgorithmParameters, JwkSet},
|
||||
Algorithm, DecodingKey, Validation,
|
||||
};
|
||||
use openidconnect::{
|
||||
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata, CoreUserInfoClaims},
|
||||
reqwest::Client as OidcHttpClient,
|
||||
AuthorizationCode, ClientId, ClientSecret, CsrfToken, LocalizedClaim, Nonce,
|
||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope,
|
||||
TokenResponse as OidcTokenResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use sha2::{Digest, Sha256};
|
||||
use url::{form_urlencoded::byte_serialize, Url};
|
||||
|
||||
use crate::{
|
||||
auth::jwt::{generate_access_token, generate_refresh_token, validate_token},
|
||||
dto::{CurrentUserResponse, TokenResponse},
|
||||
middleware::error::ApiError,
|
||||
state::SharedState,
|
||||
};
|
||||
|
||||
pub const ACCESS_COOKIE_NAME: &str = "attune_access_token";
|
||||
pub const REFRESH_COOKIE_NAME: &str = "attune_refresh_token";
|
||||
pub const OIDC_ID_TOKEN_COOKIE_NAME: &str = "attune_oidc_id_token";
|
||||
pub const OIDC_STATE_COOKIE_NAME: &str = "attune_oidc_state";
|
||||
pub const OIDC_NONCE_COOKIE_NAME: &str = "attune_oidc_nonce";
|
||||
pub const OIDC_PKCE_COOKIE_NAME: &str = "attune_oidc_pkce_verifier";
|
||||
pub const OIDC_REDIRECT_COOKIE_NAME: &str = "attune_oidc_redirect_to";
|
||||
|
||||
const LOGIN_CALLBACK_PATH: &str = "/login/callback";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OidcDiscoveryDocument {
|
||||
#[serde(flatten)]
|
||||
pub metadata: CoreProviderMetadata,
|
||||
#[serde(default)]
|
||||
pub end_session_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OidcIdentityClaims {
|
||||
pub issuer: String,
|
||||
pub sub: String,
|
||||
pub email: Option<String>,
|
||||
pub email_verified: Option<bool>,
|
||||
pub name: Option<String>,
|
||||
pub preferred_username: Option<String>,
|
||||
pub groups: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct VerifiedIdTokenClaims {
|
||||
iss: String,
|
||||
sub: String,
|
||||
#[serde(default)]
|
||||
nonce: Option<String>,
|
||||
#[serde(default)]
|
||||
email: Option<String>,
|
||||
#[serde(default)]
|
||||
email_verified: Option<bool>,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
preferred_username: Option<String>,
|
||||
#[serde(default)]
|
||||
groups: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OidcAuthenticatedIdentity {
|
||||
pub current_user: CurrentUserResponse,
|
||||
pub token_response: TokenResponse,
|
||||
pub id_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OidcLoginRedirect {
|
||||
pub authorization_url: String,
|
||||
pub cookies: Vec<Cookie<'static>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OidcLogoutRedirect {
|
||||
pub redirect_url: String,
|
||||
pub cookies: Vec<Cookie<'static>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcCallbackQuery {
|
||||
pub code: Option<String>,
|
||||
pub state: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub error_description: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn build_login_redirect(
|
||||
state: &SharedState,
|
||||
redirect_to: Option<&str>,
|
||||
) -> Result<OidcLoginRedirect, ApiError> {
|
||||
let oidc = oidc_config(state)?;
|
||||
let discovery = fetch_discovery_document(&oidc).await?;
|
||||
let _http_client = OidcHttpClient::builder()
|
||||
.redirect(openidconnect::reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Failed to build OIDC HTTP client: {err}"))
|
||||
})?;
|
||||
let redirect_uri_str = oidc.redirect_uri.clone().unwrap_or_default();
|
||||
let redirect_uri = RedirectUrl::new(redirect_uri_str).map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Invalid OIDC redirect URI: {err}"))
|
||||
})?;
|
||||
let client_secret = oidc.client_secret.clone().ok_or_else(|| {
|
||||
ApiError::InternalServerError("OIDC client secret is missing".to_string())
|
||||
})?;
|
||||
let client_id = oidc.client_id.clone().unwrap_or_default();
|
||||
let client = CoreClient::from_provider_metadata(
|
||||
discovery.metadata.clone(),
|
||||
ClientId::new(client_id),
|
||||
Some(ClientSecret::new(client_secret)),
|
||||
)
|
||||
.set_redirect_uri(redirect_uri);
|
||||
|
||||
let redirect_target = sanitize_redirect_target(redirect_to);
|
||||
let pkce = PkceCodeChallenge::new_random_sha256();
|
||||
let (auth_url, csrf_state, nonce) = client
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random,
|
||||
)
|
||||
.add_scope(Scope::new("openid".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.add_scope(Scope::new("profile".to_string()))
|
||||
.add_scopes(
|
||||
oidc.scopes
|
||||
.iter()
|
||||
.filter(|scope| !matches!(scope.as_str(), "openid" | "email" | "profile"))
|
||||
.cloned()
|
||||
.map(Scope::new),
|
||||
)
|
||||
.set_pkce_challenge(pkce.0)
|
||||
.url();
|
||||
|
||||
Ok(OidcLoginRedirect {
|
||||
authorization_url: auth_url.to_string(),
|
||||
cookies: vec![
|
||||
build_cookie(
|
||||
state,
|
||||
OIDC_STATE_COOKIE_NAME,
|
||||
csrf_state.secret().to_string(),
|
||||
600,
|
||||
true,
|
||||
),
|
||||
build_cookie(
|
||||
state,
|
||||
OIDC_NONCE_COOKIE_NAME,
|
||||
nonce.secret().to_string(),
|
||||
600,
|
||||
true,
|
||||
),
|
||||
build_cookie(
|
||||
state,
|
||||
OIDC_PKCE_COOKIE_NAME,
|
||||
pkce.1.secret().to_string(),
|
||||
600,
|
||||
true,
|
||||
),
|
||||
build_cookie(
|
||||
state,
|
||||
OIDC_REDIRECT_COOKIE_NAME,
|
||||
redirect_target,
|
||||
600,
|
||||
false,
|
||||
),
|
||||
],
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn handle_callback(
|
||||
state: &SharedState,
|
||||
headers: &HeaderMap,
|
||||
query: &OidcCallbackQuery,
|
||||
) -> Result<OidcAuthenticatedIdentity, ApiError> {
|
||||
if let Some(error) = &query.error {
|
||||
let description = query
|
||||
.error_description
|
||||
.as_deref()
|
||||
.unwrap_or("OpenID Connect login failed");
|
||||
return Err(ApiError::Unauthorized(format!("{error}: {description}")));
|
||||
}
|
||||
|
||||
let code = query
|
||||
.code
|
||||
.as_ref()
|
||||
.ok_or_else(|| ApiError::BadRequest("Missing authorization code".to_string()))?;
|
||||
let returned_state = query
|
||||
.state
|
||||
.as_ref()
|
||||
.ok_or_else(|| ApiError::BadRequest("Missing OIDC state".to_string()))?;
|
||||
|
||||
let expected_state = get_cookie_value(headers, OIDC_STATE_COOKIE_NAME)
|
||||
.ok_or_else(|| ApiError::Unauthorized("Missing OIDC state cookie".to_string()))?;
|
||||
let expected_nonce = get_cookie_value(headers, OIDC_NONCE_COOKIE_NAME)
|
||||
.ok_or_else(|| ApiError::Unauthorized("Missing OIDC nonce cookie".to_string()))?;
|
||||
let pkce_verifier = get_cookie_value(headers, OIDC_PKCE_COOKIE_NAME)
|
||||
.ok_or_else(|| ApiError::Unauthorized("Missing OIDC PKCE verifier cookie".to_string()))?;
|
||||
|
||||
if returned_state != &expected_state {
|
||||
return Err(ApiError::Unauthorized(
|
||||
"OIDC state validation failed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let oidc = oidc_config(state)?;
|
||||
let discovery = fetch_discovery_document(&oidc).await?;
|
||||
let http_client = OidcHttpClient::builder()
|
||||
.redirect(openidconnect::reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Failed to build OIDC HTTP client: {err}"))
|
||||
})?;
|
||||
let redirect_uri_str = oidc.redirect_uri.clone().unwrap_or_default();
|
||||
let redirect_uri = RedirectUrl::new(redirect_uri_str).map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Invalid OIDC redirect URI: {err}"))
|
||||
})?;
|
||||
let client_secret = oidc.client_secret.clone().ok_or_else(|| {
|
||||
ApiError::InternalServerError("OIDC client secret is missing".to_string())
|
||||
})?;
|
||||
let client_id = oidc.client_id.clone().unwrap_or_default();
|
||||
let client = CoreClient::from_provider_metadata(
|
||||
discovery.metadata.clone(),
|
||||
ClientId::new(client_id),
|
||||
Some(ClientSecret::new(client_secret)),
|
||||
)
|
||||
.set_redirect_uri(redirect_uri);
|
||||
|
||||
let token_response = client
|
||||
.exchange_code(AuthorizationCode::new(code.clone()))
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("OIDC token request is misconfigured: {err}"))
|
||||
})?
|
||||
.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier))
|
||||
.request_async(&http_client)
|
||||
.await
|
||||
.map_err(|err| ApiError::Unauthorized(format!("OIDC token exchange failed: {err}")))?;
|
||||
|
||||
let id_token = token_response.id_token().ok_or_else(|| {
|
||||
ApiError::Unauthorized("OIDC provider did not return an ID token".to_string())
|
||||
})?;
|
||||
|
||||
let raw_id_token = id_token.to_string();
|
||||
let claims = verify_id_token(&raw_id_token, &discovery, &oidc, &expected_nonce).await?;
|
||||
|
||||
let mut oidc_claims = OidcIdentityClaims {
|
||||
issuer: claims.iss,
|
||||
sub: claims.sub,
|
||||
email: claims.email,
|
||||
email_verified: claims.email_verified,
|
||||
name: claims.name,
|
||||
preferred_username: claims.preferred_username,
|
||||
groups: claims.groups,
|
||||
};
|
||||
|
||||
if let Ok(userinfo_request) = client.user_info(token_response.access_token().to_owned(), None) {
|
||||
if let Ok(userinfo) = userinfo_request.request_async(&http_client).await {
|
||||
merge_userinfo_claims(&mut oidc_claims, &userinfo);
|
||||
}
|
||||
}
|
||||
|
||||
let identity = upsert_identity(state, &oidc_claims).await?;
|
||||
if identity.frozen {
|
||||
return Err(ApiError::Forbidden(
|
||||
"Identity is frozen and cannot authenticate".to_string(),
|
||||
));
|
||||
}
|
||||
let access_token = generate_access_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
let refresh_token = generate_refresh_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
|
||||
let token_response = TokenResponse::new(
|
||||
access_token,
|
||||
refresh_token,
|
||||
state.jwt_config.access_token_expiration,
|
||||
)
|
||||
.with_user(
|
||||
identity.id,
|
||||
identity.login.clone(),
|
||||
identity.display_name.clone(),
|
||||
);
|
||||
|
||||
Ok(OidcAuthenticatedIdentity {
|
||||
current_user: CurrentUserResponse {
|
||||
id: identity.id,
|
||||
login: identity.login.clone(),
|
||||
display_name: identity.display_name.clone(),
|
||||
},
|
||||
id_token: raw_id_token,
|
||||
token_response,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn build_logout_redirect(
|
||||
state: &SharedState,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<OidcLogoutRedirect, ApiError> {
|
||||
let oidc = oidc_config(state)?;
|
||||
let discovery = fetch_discovery_document(&oidc).await?;
|
||||
let post_logout_redirect_uri = oidc
|
||||
.post_logout_redirect_uri
|
||||
.clone()
|
||||
.unwrap_or_else(|| "/login".to_string());
|
||||
|
||||
let redirect_url = if let Some(end_session_endpoint) = discovery.end_session_endpoint {
|
||||
let mut url = Url::parse(&end_session_endpoint).map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Invalid end_session_endpoint: {err}"))
|
||||
})?;
|
||||
{
|
||||
let mut pairs = url.query_pairs_mut();
|
||||
if let Some(id_token_hint) = get_cookie_value(headers, OIDC_ID_TOKEN_COOKIE_NAME) {
|
||||
pairs.append_pair("id_token_hint", &id_token_hint);
|
||||
}
|
||||
pairs.append_pair("post_logout_redirect_uri", &post_logout_redirect_uri);
|
||||
pairs.append_pair("client_id", oidc.client_id.as_deref().unwrap_or_default());
|
||||
}
|
||||
String::from(url)
|
||||
} else {
|
||||
post_logout_redirect_uri
|
||||
};
|
||||
|
||||
Ok(OidcLogoutRedirect {
|
||||
redirect_url,
|
||||
cookies: clear_auth_cookies(state),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_auth_cookies(state: &SharedState) -> Vec<Cookie<'static>> {
|
||||
[
|
||||
ACCESS_COOKIE_NAME,
|
||||
REFRESH_COOKIE_NAME,
|
||||
OIDC_ID_TOKEN_COOKIE_NAME,
|
||||
OIDC_STATE_COOKIE_NAME,
|
||||
OIDC_NONCE_COOKIE_NAME,
|
||||
OIDC_PKCE_COOKIE_NAME,
|
||||
OIDC_REDIRECT_COOKIE_NAME,
|
||||
]
|
||||
.into_iter()
|
||||
.map(|name| remove_cookie(state, name))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn build_auth_cookies(
|
||||
state: &SharedState,
|
||||
token_response: &TokenResponse,
|
||||
id_token: &str,
|
||||
) -> Vec<Cookie<'static>> {
|
||||
let mut cookies = vec![
|
||||
build_cookie(
|
||||
state,
|
||||
ACCESS_COOKIE_NAME,
|
||||
token_response.access_token.clone(),
|
||||
state.jwt_config.access_token_expiration,
|
||||
true,
|
||||
),
|
||||
build_cookie(
|
||||
state,
|
||||
REFRESH_COOKIE_NAME,
|
||||
token_response.refresh_token.clone(),
|
||||
state.jwt_config.refresh_token_expiration,
|
||||
true,
|
||||
),
|
||||
];
|
||||
|
||||
if !id_token.is_empty() {
|
||||
cookies.push(build_cookie(
|
||||
state,
|
||||
OIDC_ID_TOKEN_COOKIE_NAME,
|
||||
id_token.to_string(),
|
||||
state.jwt_config.refresh_token_expiration,
|
||||
true,
|
||||
));
|
||||
}
|
||||
|
||||
cookies
|
||||
}
|
||||
|
||||
pub fn apply_cookies_to_headers(
|
||||
headers: &mut HeaderMap,
|
||||
cookies: &[Cookie<'static>],
|
||||
) -> Result<(), ApiError> {
|
||||
for cookie in cookies {
|
||||
let value = HeaderValue::from_str(&cookie.to_string()).map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Failed to serialize cookie header: {err}"))
|
||||
})?;
|
||||
headers.append(header::SET_COOKIE, value);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn oidc_callback_redirect_response(
|
||||
state: &SharedState,
|
||||
token_response: &TokenResponse,
|
||||
redirect_to: Option<String>,
|
||||
id_token: &str,
|
||||
) -> Result<Response, ApiError> {
|
||||
let redirect_target = sanitize_redirect_target(redirect_to.as_deref());
|
||||
let redirect_url = format!(
|
||||
"{LOGIN_CALLBACK_PATH}#access_token={}&refresh_token={}&expires_in={}&redirect_to={}",
|
||||
encode_fragment_value(&token_response.access_token),
|
||||
encode_fragment_value(&token_response.refresh_token),
|
||||
token_response.expires_in,
|
||||
encode_fragment_value(&redirect_target),
|
||||
);
|
||||
|
||||
let mut response = Redirect::temporary(&redirect_url).into_response();
|
||||
let mut cookies = build_auth_cookies(state, token_response, id_token);
|
||||
cookies.push(remove_cookie(state, OIDC_STATE_COOKIE_NAME));
|
||||
cookies.push(remove_cookie(state, OIDC_NONCE_COOKIE_NAME));
|
||||
cookies.push(remove_cookie(state, OIDC_PKCE_COOKIE_NAME));
|
||||
cookies.push(remove_cookie(state, OIDC_REDIRECT_COOKIE_NAME));
|
||||
apply_cookies_to_headers(response.headers_mut(), &cookies)?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn cookie_authenticated_user(
|
||||
headers: &HeaderMap,
|
||||
state: &SharedState,
|
||||
) -> Result<Option<crate::auth::middleware::AuthenticatedUser>, ApiError> {
|
||||
let Some(token) = get_cookie_value(headers, ACCESS_COOKIE_NAME) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let claims = validate_token(&token, &state.jwt_config).map_err(ApiError::from)?;
|
||||
Ok(Some(crate::auth::middleware::AuthenticatedUser { claims }))
|
||||
}
|
||||
|
||||
pub fn get_cookie_value(headers: &HeaderMap, name: &str) -> Option<String> {
|
||||
headers
|
||||
.get_all(header::COOKIE)
|
||||
.iter()
|
||||
.filter_map(|value| value.to_str().ok())
|
||||
.flat_map(|value| value.split(';'))
|
||||
.filter_map(|part| {
|
||||
let mut pieces = part.trim().splitn(2, '=');
|
||||
let key = pieces.next()?.trim();
|
||||
let value = pieces.next()?.trim();
|
||||
if key == name {
|
||||
Some(value.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.next()
|
||||
}
|
||||
|
||||
fn oidc_config(state: &SharedState) -> Result<OidcConfig, ApiError> {
|
||||
state
|
||||
.config
|
||||
.security
|
||||
.oidc
|
||||
.clone()
|
||||
.filter(|oidc| oidc.enabled)
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotImplemented("OIDC authentication is not configured".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_discovery_document(oidc: &OidcConfig) -> Result<OidcDiscoveryDocument, ApiError> {
|
||||
let discovery_url = oidc.discovery_url.as_deref().unwrap_or_default();
|
||||
let discovery = reqwest::get(discovery_url).await.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Failed to fetch OIDC discovery document: {err}"))
|
||||
})?;
|
||||
|
||||
if !discovery.status().is_success() {
|
||||
return Err(ApiError::InternalServerError(format!(
|
||||
"OIDC discovery request failed with status {}",
|
||||
discovery.status()
|
||||
)));
|
||||
}
|
||||
|
||||
discovery
|
||||
.json::<OidcDiscoveryDocument>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Failed to parse OIDC discovery document: {err}"))
|
||||
})
|
||||
}
|
||||
|
||||
async fn upsert_identity(
|
||||
state: &SharedState,
|
||||
oidc_claims: &OidcIdentityClaims,
|
||||
) -> Result<attune_common::models::identity::Identity, ApiError> {
|
||||
let existing_by_subject =
|
||||
IdentityRepository::find_by_oidc_subject(&state.db, &oidc_claims.issuer, &oidc_claims.sub)
|
||||
.await?;
|
||||
let desired_login = derive_login(oidc_claims);
|
||||
let display_name = derive_display_name(oidc_claims);
|
||||
let attributes = json!({
|
||||
"oidc": oidc_claims,
|
||||
});
|
||||
|
||||
match existing_by_subject {
|
||||
Some(identity) => {
|
||||
let updated = UpdateIdentityInput {
|
||||
display_name,
|
||||
password_hash: None,
|
||||
attributes: Some(attributes.clone()),
|
||||
frozen: None,
|
||||
};
|
||||
let identity = IdentityRepository::update(&state.db, identity.id, updated)
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
sync_roles(&state.db, identity.id, "oidc", &oidc_claims.groups).await?;
|
||||
Ok(identity)
|
||||
}
|
||||
None => {
|
||||
let login = match IdentityRepository::find_by_login(&state.db, &desired_login).await? {
|
||||
Some(_) => fallback_subject_login(oidc_claims),
|
||||
None => desired_login,
|
||||
};
|
||||
|
||||
let identity = IdentityRepository::create(
|
||||
&state.db,
|
||||
CreateIdentityInput {
|
||||
login,
|
||||
display_name,
|
||||
password_hash: None,
|
||||
attributes,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
sync_roles(&state.db, identity.id, "oidc", &oidc_claims.groups).await?;
|
||||
Ok(identity)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn sync_roles(
|
||||
db: &sqlx::PgPool,
|
||||
identity_id: i64,
|
||||
source: &str,
|
||||
roles: &[String],
|
||||
) -> Result<(), ApiError> {
|
||||
IdentityRoleAssignmentRepository::replace_managed_roles(db, identity_id, source, roles)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn derive_login(oidc_claims: &OidcIdentityClaims) -> String {
|
||||
oidc_claims
|
||||
.email
|
||||
.clone()
|
||||
.or_else(|| oidc_claims.preferred_username.clone())
|
||||
.unwrap_or_else(|| fallback_subject_login(oidc_claims))
|
||||
}
|
||||
|
||||
async fn verify_id_token(
|
||||
raw_id_token: &str,
|
||||
discovery: &OidcDiscoveryDocument,
|
||||
oidc: &OidcConfig,
|
||||
expected_nonce: &str,
|
||||
) -> Result<VerifiedIdTokenClaims, ApiError> {
|
||||
let header = decode_header(raw_id_token).map_err(|err| {
|
||||
ApiError::Unauthorized(format!("OIDC ID token header decode failed: {err}"))
|
||||
})?;
|
||||
|
||||
let algorithm = match header.alg {
|
||||
Algorithm::RS256 => Algorithm::RS256,
|
||||
Algorithm::RS384 => Algorithm::RS384,
|
||||
Algorithm::RS512 => Algorithm::RS512,
|
||||
other => {
|
||||
return Err(ApiError::Unauthorized(format!(
|
||||
"OIDC ID token uses unsupported signing algorithm: {other:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let jwks = reqwest::get(discovery.metadata.jwks_uri().url().as_str())
|
||||
.await
|
||||
.map_err(|err| ApiError::InternalServerError(format!("Failed to fetch OIDC JWKS: {err}")))?
|
||||
.json::<JwkSet>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ApiError::InternalServerError(format!("Failed to parse OIDC JWKS: {err}"))
|
||||
})?;
|
||||
|
||||
let jwk = jwks
|
||||
.keys
|
||||
.iter()
|
||||
.find(|jwk| {
|
||||
jwk.common.key_id == header.kid
|
||||
&& matches!(
|
||||
jwk.common.public_key_use,
|
||||
Some(jsonwebtoken::jwk::PublicKeyUse::Signature)
|
||||
)
|
||||
&& matches!(
|
||||
jwk.algorithm,
|
||||
AlgorithmParameters::RSA(_) | AlgorithmParameters::EllipticCurve(_)
|
||||
)
|
||||
})
|
||||
.ok_or_else(|| ApiError::Unauthorized("OIDC signing key not found in JWKS".to_string()))?;
|
||||
|
||||
let decoding_key = DecodingKey::from_jwk(jwk)
|
||||
.map_err(|err| ApiError::Unauthorized(format!("OIDC JWK decode failed: {err}")))?;
|
||||
|
||||
let issuer = discovery.metadata.issuer().to_string();
|
||||
let mut validation = Validation::new(algorithm);
|
||||
validation.set_issuer(&[issuer.as_str()]);
|
||||
validation.set_audience(&[oidc.client_id.as_deref().unwrap_or_default()]);
|
||||
validation.set_required_spec_claims(&["exp", "iat", "iss", "sub", "aud"]);
|
||||
validation.validate_nbf = false;
|
||||
|
||||
let token = decode::<VerifiedIdTokenClaims>(raw_id_token, &decoding_key, &validation)
|
||||
.map_err(|err| ApiError::Unauthorized(format!("OIDC ID token validation failed: {err}")))?;
|
||||
|
||||
if token.claims.nonce.as_deref() != Some(expected_nonce) {
|
||||
return Err(ApiError::Unauthorized(
|
||||
"OIDC nonce validation failed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(token.claims)
|
||||
}
|
||||
|
||||
fn derive_display_name(oidc_claims: &OidcIdentityClaims) -> Option<String> {
|
||||
oidc_claims
|
||||
.name
|
||||
.clone()
|
||||
.or_else(|| oidc_claims.preferred_username.clone())
|
||||
.or_else(|| oidc_claims.email.clone())
|
||||
}
|
||||
|
||||
fn fallback_subject_login(oidc_claims: &OidcIdentityClaims) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(oidc_claims.issuer.as_bytes());
|
||||
hasher.update(b":");
|
||||
hasher.update(oidc_claims.sub.as_bytes());
|
||||
let digest = hex::encode(hasher.finalize());
|
||||
format!("oidc:{}", &digest[..24])
|
||||
}
|
||||
|
||||
fn extract_groups_from_claims<T>(claims: &T) -> Vec<String>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let Ok(json) = serde_json::to_value(claims) else {
|
||||
return Vec::new();
|
||||
};
|
||||
match json.get("groups") {
|
||||
Some(JsonValue::Array(values)) => values
|
||||
.iter()
|
||||
.filter_map(|value| value.as_str().map(ToString::to_string))
|
||||
.collect(),
|
||||
Some(JsonValue::String(value)) => vec![value.to_string()],
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_userinfo_claims(oidc_claims: &mut OidcIdentityClaims, userinfo: &CoreUserInfoClaims) {
|
||||
if oidc_claims.email.is_none() {
|
||||
oidc_claims.email = userinfo.email().map(|email| email.as_str().to_string());
|
||||
}
|
||||
if oidc_claims.name.is_none() {
|
||||
oidc_claims.name = userinfo.name().and_then(first_localized_claim);
|
||||
}
|
||||
if oidc_claims.preferred_username.is_none() {
|
||||
oidc_claims.preferred_username = userinfo
|
||||
.preferred_username()
|
||||
.map(|username| username.as_str().to_string());
|
||||
}
|
||||
if oidc_claims.groups.is_empty() {
|
||||
oidc_claims.groups = extract_groups_from_claims(userinfo.additional_claims());
|
||||
}
|
||||
}
|
||||
|
||||
fn first_localized_claim<T>(claim: &LocalizedClaim<T>) -> Option<String>
|
||||
where
|
||||
T: std::ops::Deref<Target = String>,
|
||||
{
|
||||
claim
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(_, value)| value.as_str().to_string())
|
||||
}
|
||||
|
||||
fn build_cookie(
|
||||
state: &SharedState,
|
||||
name: &'static str,
|
||||
value: String,
|
||||
max_age_seconds: i64,
|
||||
http_only: bool,
|
||||
) -> Cookie<'static> {
|
||||
let mut cookie = Cookie::build((name, value))
|
||||
.path("/")
|
||||
.same_site(SameSite::Lax)
|
||||
.http_only(http_only)
|
||||
.max_age(CookieDuration::seconds(max_age_seconds))
|
||||
.build();
|
||||
|
||||
if should_use_secure_cookies(state) {
|
||||
cookie.set_secure(true);
|
||||
}
|
||||
|
||||
cookie
|
||||
}
|
||||
|
||||
fn remove_cookie(state: &SharedState, name: &'static str) -> Cookie<'static> {
|
||||
let mut cookie = Cookie::build((name, String::new()))
|
||||
.path("/")
|
||||
.same_site(SameSite::Lax)
|
||||
.http_only(true)
|
||||
.max_age(CookieDuration::seconds(0))
|
||||
.build();
|
||||
cookie.make_removal();
|
||||
if should_use_secure_cookies(state) {
|
||||
cookie.set_secure(true);
|
||||
}
|
||||
cookie
|
||||
}
|
||||
|
||||
fn should_use_secure_cookies(state: &SharedState) -> bool {
|
||||
state.config.is_production()
|
||||
|| state
|
||||
.config
|
||||
.security
|
||||
.oidc
|
||||
.as_ref()
|
||||
.and_then(|oidc| oidc.redirect_uri.as_deref())
|
||||
.map(|uri| uri.starts_with("https://"))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn sanitize_redirect_target(redirect_to: Option<&str>) -> String {
|
||||
let fallback = "/".to_string();
|
||||
let Some(redirect_to) = redirect_to else {
|
||||
return fallback;
|
||||
};
|
||||
if redirect_to.starts_with('/') && !redirect_to.starts_with("//") {
|
||||
redirect_to.to_string()
|
||||
} else {
|
||||
fallback
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unauthorized_redirect(location: &str) -> Response {
|
||||
let mut response = Redirect::to(location).into_response();
|
||||
*response.status_mut() = StatusCode::FOUND;
|
||||
response
|
||||
}
|
||||
|
||||
fn encode_fragment_value(value: &str) -> String {
|
||||
byte_serialize(value.as_bytes()).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn sanitize_redirect_target_rejects_external_urls() {
|
||||
assert_eq!(sanitize_redirect_target(Some("https://example.com")), "/");
|
||||
assert_eq!(sanitize_redirect_target(Some("//example.com")), "/");
|
||||
assert_eq!(
|
||||
sanitize_redirect_target(Some("/executions/42")),
|
||||
"/executions/42"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_groups_from_claims_accepts_array_and_string() {
|
||||
let array_claims = serde_json::json!({ "groups": ["admins", "operators"] });
|
||||
let string_claims = serde_json::json!({ "groups": "admins" });
|
||||
|
||||
assert_eq!(
|
||||
extract_groups_from_claims(&array_claims),
|
||||
vec!["admins".to_string(), "operators".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
extract_groups_from_claims(&string_claims),
|
||||
vec!["admins".to_string()]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
use attune_common::{
|
||||
rbac::{Action, AuthorizationContext, Grant, Resource},
|
||||
repositories::{
|
||||
identity::{IdentityRepository, PermissionSetRepository},
|
||||
identity::{IdentityRepository, IdentityRoleAssignmentRepository, PermissionSetRepository},
|
||||
FindById,
|
||||
},
|
||||
};
|
||||
@@ -95,8 +95,16 @@ impl AuthorizationService {
|
||||
}
|
||||
|
||||
async fn load_effective_grants(&self, identity_id: i64) -> Result<Vec<Grant>, ApiError> {
|
||||
let permission_sets =
|
||||
let mut permission_sets =
|
||||
PermissionSetRepository::find_by_identity(&self.db, identity_id).await?;
|
||||
let roles =
|
||||
IdentityRoleAssignmentRepository::find_role_names_by_identity(&self.db, identity_id)
|
||||
.await?;
|
||||
let role_permission_sets = PermissionSetRepository::find_by_roles(&self.db, &roles).await?;
|
||||
permission_sets.extend(role_permission_sets);
|
||||
|
||||
let mut seen_permission_sets = std::collections::HashSet::new();
|
||||
permission_sets.retain(|permission_set| seen_permission_sets.insert(permission_set.id));
|
||||
|
||||
let mut grants = Vec::new();
|
||||
for permission_set in permission_sets {
|
||||
@@ -126,10 +134,6 @@ fn resource_name(resource: Resource) -> &'static str {
|
||||
Resource::Inquiries => "inquiries",
|
||||
Resource::Keys => "keys",
|
||||
Resource::Artifacts => "artifacts",
|
||||
Resource::Workflows => "workflows",
|
||||
Resource::Webhooks => "webhooks",
|
||||
Resource::Analytics => "analytics",
|
||||
Resource::History => "history",
|
||||
Resource::Identities => "identities",
|
||||
Resource::Permissions => "permissions",
|
||||
}
|
||||
@@ -145,5 +149,6 @@ fn action_name(action: Action) -> &'static str {
|
||||
Action::Cancel => "cancel",
|
||||
Action::Respond => "respond",
|
||||
Action::Manage => "manage",
|
||||
Action::Decrypt => "decrypt",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,9 +25,8 @@ pub struct CreateActionRequest {
|
||||
pub label: String,
|
||||
|
||||
/// Action description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Posts a message to a Slack channel")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Entry point for action execution (e.g., path to script, function name)
|
||||
#[validate(length(min = 1, max = 1024))]
|
||||
@@ -63,7 +62,6 @@ pub struct UpdateActionRequest {
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Action description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Posts a message to a Slack channel with enhanced features")]
|
||||
pub description: Option<String>,
|
||||
|
||||
@@ -76,9 +74,8 @@ pub struct UpdateActionRequest {
|
||||
#[schema(example = 1)]
|
||||
pub runtime: Option<i64>,
|
||||
|
||||
/// Optional semver version constraint for the runtime (e.g., ">=3.12", ">=3.12,<4.0", "~18.0")
|
||||
#[schema(example = ">=3.12", nullable = true)]
|
||||
pub runtime_version_constraint: Option<Option<String>>,
|
||||
/// Optional semver version constraint patch for the runtime.
|
||||
pub runtime_version_constraint: Option<RuntimeVersionConstraintPatch>,
|
||||
|
||||
/// Parameter schema (StackStorm-style with inline required/secret)
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
@@ -89,6 +86,14 @@ pub struct UpdateActionRequest {
|
||||
pub out_schema: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Explicit patch operation for a nullable runtime version constraint.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum RuntimeVersionConstraintPatch {
|
||||
Set(String),
|
||||
Clear,
|
||||
}
|
||||
|
||||
/// Response DTO for action information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct ActionResponse {
|
||||
@@ -114,7 +119,7 @@ pub struct ActionResponse {
|
||||
|
||||
/// Action description
|
||||
#[schema(example = "Posts a message to a Slack channel")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Entry point
|
||||
#[schema(example = "/actions/slack/post_message.py")]
|
||||
@@ -176,7 +181,7 @@ pub struct ActionSummary {
|
||||
|
||||
/// Action description
|
||||
#[schema(example = "Posts a message to a Slack channel")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Entry point
|
||||
#[schema(example = "/actions/slack/post_message.py")]
|
||||
@@ -314,7 +319,7 @@ mod tests {
|
||||
r#ref: "".to_string(), // Invalid: empty
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
description: Some("Test description".to_string()),
|
||||
entrypoint: "/actions/test.py".to_string(),
|
||||
runtime: None,
|
||||
runtime_version_constraint: None,
|
||||
@@ -331,7 +336,7 @@ mod tests {
|
||||
r#ref: "test.action".to_string(),
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
description: Some("Test description".to_string()),
|
||||
entrypoint: "/actions/test.py".to_string(),
|
||||
runtime: None,
|
||||
runtime_version_constraint: None,
|
||||
|
||||
@@ -97,19 +97,41 @@ pub struct UpdateArtifactRequest {
|
||||
pub retention_limit: Option<i32>,
|
||||
|
||||
/// Updated name
|
||||
pub name: Option<String>,
|
||||
pub name: Option<ArtifactStringPatch>,
|
||||
|
||||
/// Updated description
|
||||
pub description: Option<String>,
|
||||
pub description: Option<ArtifactStringPatch>,
|
||||
|
||||
/// Updated content type
|
||||
pub content_type: Option<String>,
|
||||
pub content_type: Option<ArtifactStringPatch>,
|
||||
|
||||
/// Updated execution ID (re-links artifact to a different execution)
|
||||
pub execution: Option<i64>,
|
||||
/// Updated execution patch (set a new execution ID or clear the link)
|
||||
pub execution: Option<ArtifactExecutionPatch>,
|
||||
|
||||
/// Updated structured data (replaces existing data entirely)
|
||||
pub data: Option<JsonValue>,
|
||||
pub data: Option<ArtifactJsonPatch>,
|
||||
}
|
||||
|
||||
/// Explicit patch operation for a nullable execution link.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum ArtifactExecutionPatch {
|
||||
Set(i64),
|
||||
Clear,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum ArtifactStringPatch {
|
||||
Set(String),
|
||||
Clear,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum ArtifactJsonPatch {
|
||||
Set(JsonValue),
|
||||
Clear,
|
||||
}
|
||||
|
||||
/// Request DTO for appending to a progress-type artifact
|
||||
|
||||
@@ -136,3 +136,63 @@ pub struct CurrentUserResponse {
|
||||
#[schema(example = "Administrator")]
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Public authentication settings for the login page.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct AuthSettingsResponse {
|
||||
/// Whether authentication is enabled for the server.
|
||||
#[schema(example = true)]
|
||||
pub authentication_enabled: bool,
|
||||
|
||||
/// Whether local username/password login is configured.
|
||||
#[schema(example = true)]
|
||||
pub local_password_enabled: bool,
|
||||
|
||||
/// Whether local username/password login should be shown by default.
|
||||
#[schema(example = true)]
|
||||
pub local_password_visible_by_default: bool,
|
||||
|
||||
/// Whether OIDC login is configured and enabled.
|
||||
#[schema(example = false)]
|
||||
pub oidc_enabled: bool,
|
||||
|
||||
/// Whether OIDC login should be shown by default.
|
||||
#[schema(example = false)]
|
||||
pub oidc_visible_by_default: bool,
|
||||
|
||||
/// Provider name for `?auth=<provider>`.
|
||||
#[schema(example = "sso")]
|
||||
pub oidc_provider_name: Option<String>,
|
||||
|
||||
/// User-facing provider label for the login button.
|
||||
#[schema(example = "Example SSO")]
|
||||
pub oidc_provider_label: Option<String>,
|
||||
|
||||
/// Optional icon URL shown beside the provider label.
|
||||
#[schema(example = "https://auth.example.com/assets/logo.svg")]
|
||||
pub oidc_provider_icon_url: Option<String>,
|
||||
|
||||
/// Whether LDAP login is configured and enabled.
|
||||
#[schema(example = false)]
|
||||
pub ldap_enabled: bool,
|
||||
|
||||
/// Whether LDAP login should be shown by default.
|
||||
#[schema(example = false)]
|
||||
pub ldap_visible_by_default: bool,
|
||||
|
||||
/// Provider name for `?auth=<provider>`.
|
||||
#[schema(example = "ldap")]
|
||||
pub ldap_provider_name: Option<String>,
|
||||
|
||||
/// User-facing provider label for the login button.
|
||||
#[schema(example = "Company LDAP")]
|
||||
pub ldap_provider_label: Option<String>,
|
||||
|
||||
/// Optional icon URL shown beside the provider label.
|
||||
#[schema(example = "https://ldap.example.com/assets/logo.svg")]
|
||||
pub ldap_provider_icon_url: Option<String>,
|
||||
|
||||
/// Whether unauthenticated self-service registration is allowed.
|
||||
#[schema(example = false)]
|
||||
pub self_registration_enabled: bool,
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ pub mod key;
|
||||
pub mod pack;
|
||||
pub mod permission;
|
||||
pub mod rule;
|
||||
pub mod runtime;
|
||||
pub mod trigger;
|
||||
pub mod webhook;
|
||||
pub mod workflow;
|
||||
@@ -29,8 +30,8 @@ pub use artifact::{
|
||||
CreateVersionJsonRequest, SetDataRequest, UpdateArtifactRequest,
|
||||
};
|
||||
pub use auth::{
|
||||
ChangePasswordRequest, CurrentUserResponse, LoginRequest, RefreshTokenRequest, RegisterRequest,
|
||||
TokenResponse,
|
||||
AuthSettingsResponse, ChangePasswordRequest, CurrentUserResponse, LoginRequest,
|
||||
RefreshTokenRequest, RegisterRequest, TokenResponse,
|
||||
};
|
||||
pub use common::{
|
||||
ApiResponse, PaginatedResponse, PaginationMeta, PaginationParams, SuccessResponse,
|
||||
@@ -50,11 +51,13 @@ pub use inquiry::{
|
||||
pub use key::{CreateKeyRequest, KeyQueryParams, KeyResponse, KeySummary, UpdateKeyRequest};
|
||||
pub use pack::{CreatePackRequest, PackResponse, PackSummary, UpdatePackRequest};
|
||||
pub use permission::{
|
||||
CreateIdentityRequest, CreatePermissionAssignmentRequest, IdentityResponse, IdentitySummary,
|
||||
PermissionAssignmentResponse, PermissionSetQueryParams, PermissionSetSummary,
|
||||
UpdateIdentityRequest,
|
||||
CreateIdentityRequest, CreateIdentityRoleAssignmentRequest, CreatePermissionAssignmentRequest,
|
||||
CreatePermissionSetRoleAssignmentRequest, IdentityResponse, IdentityRoleAssignmentResponse,
|
||||
IdentitySummary, PermissionAssignmentResponse, PermissionSetQueryParams,
|
||||
PermissionSetRoleAssignmentResponse, PermissionSetSummary, UpdateIdentityRequest,
|
||||
};
|
||||
pub use rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest};
|
||||
pub use runtime::{CreateRuntimeRequest, RuntimeResponse, RuntimeSummary, UpdateRuntimeRequest};
|
||||
pub use trigger::{
|
||||
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary, TriggerResponse,
|
||||
TriggerSummary, UpdateSensorRequest, UpdateTriggerRequest,
|
||||
|
||||
@@ -129,7 +129,7 @@ pub struct UpdatePackRequest {
|
||||
|
||||
/// Pack description
|
||||
#[schema(example = "Enhanced Slack integration with new features")]
|
||||
pub description: Option<String>,
|
||||
pub description: Option<PackDescriptionPatch>,
|
||||
|
||||
/// Pack version
|
||||
#[validate(length(min = 1, max = 50))]
|
||||
@@ -165,6 +165,13 @@ pub struct UpdatePackRequest {
|
||||
pub is_standard: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum PackDescriptionPatch {
|
||||
Set(String),
|
||||
Clear,
|
||||
}
|
||||
|
||||
/// Response DTO for pack information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PackResponse {
|
||||
|
||||
@@ -14,10 +14,32 @@ pub struct IdentitySummary {
|
||||
pub id: i64,
|
||||
pub login: String,
|
||||
pub display_name: Option<String>,
|
||||
pub frozen: bool,
|
||||
pub attributes: JsonValue,
|
||||
pub roles: Vec<String>,
|
||||
}
|
||||
|
||||
pub type IdentityResponse = IdentitySummary;
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct IdentityRoleAssignmentResponse {
|
||||
pub id: i64,
|
||||
pub identity_id: i64,
|
||||
pub role: String,
|
||||
pub source: String,
|
||||
pub managed: bool,
|
||||
pub created: chrono::DateTime<chrono::Utc>,
|
||||
pub updated: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct IdentityResponse {
|
||||
pub id: i64,
|
||||
pub login: String,
|
||||
pub display_name: Option<String>,
|
||||
pub frozen: bool,
|
||||
pub attributes: JsonValue,
|
||||
pub roles: Vec<IdentityRoleAssignmentResponse>,
|
||||
pub direct_permissions: Vec<PermissionAssignmentResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PermissionSetSummary {
|
||||
@@ -27,6 +49,7 @@ pub struct PermissionSetSummary {
|
||||
pub label: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub grants: JsonValue,
|
||||
pub roles: Vec<PermissionSetRoleAssignmentResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
@@ -38,6 +61,15 @@ pub struct PermissionAssignmentResponse {
|
||||
pub created: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct PermissionSetRoleAssignmentResponse {
|
||||
pub id: i64,
|
||||
pub permission_set_id: i64,
|
||||
pub permission_set_ref: Option<String>,
|
||||
pub role: String,
|
||||
pub created: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
||||
pub struct CreatePermissionAssignmentRequest {
|
||||
pub identity_id: Option<i64>,
|
||||
@@ -45,6 +77,18 @@ pub struct CreatePermissionAssignmentRequest {
|
||||
pub permission_set_ref: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateIdentityRoleAssignmentRequest {
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreatePermissionSetRoleAssignmentRequest {
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateIdentityRequest {
|
||||
#[validate(length(min = 3, max = 255))]
|
||||
@@ -62,4 +106,5 @@ pub struct UpdateIdentityRequest {
|
||||
pub display_name: Option<String>,
|
||||
pub password: Option<String>,
|
||||
pub attributes: Option<JsonValue>,
|
||||
pub frozen: Option<bool>,
|
||||
}
|
||||
|
||||
@@ -25,9 +25,8 @@ pub struct CreateRuleRequest {
|
||||
pub label: String,
|
||||
|
||||
/// Rule description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Send Slack notification when an error occurs")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Action reference to execute when rule matches
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
@@ -69,7 +68,6 @@ pub struct UpdateRuleRequest {
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Rule description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Enhanced error notification with filtering")]
|
||||
pub description: Option<String>,
|
||||
|
||||
@@ -115,7 +113,7 @@ pub struct RuleResponse {
|
||||
|
||||
/// Rule description
|
||||
#[schema(example = "Send Slack notification when an error occurs")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Action ID (null if the referenced action has been deleted)
|
||||
#[schema(example = 1)]
|
||||
@@ -183,7 +181,7 @@ pub struct RuleSummary {
|
||||
|
||||
/// Rule description
|
||||
#[schema(example = "Send Slack notification when an error occurs")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Action reference
|
||||
#[schema(example = "slack.post_message")]
|
||||
@@ -297,7 +295,7 @@ mod tests {
|
||||
r#ref: "".to_string(), // Invalid: empty
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Rule".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
description: Some("Test description".to_string()),
|
||||
action_ref: "test.action".to_string(),
|
||||
trigger_ref: "test.trigger".to_string(),
|
||||
conditions: default_empty_object(),
|
||||
@@ -315,7 +313,7 @@ mod tests {
|
||||
r#ref: "test.rule".to_string(),
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Rule".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
description: Some("Test description".to_string()),
|
||||
action_ref: "test.action".to_string(),
|
||||
trigger_ref: "test.trigger".to_string(),
|
||||
conditions: serde_json::json!({
|
||||
|
||||
181
crates/api/src/dto/runtime.rs
Normal file
181
crates/api/src/dto/runtime.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
//! Runtime DTOs for API requests and responses
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
/// Request DTO for creating a runtime.
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct CreateRuntimeRequest {
|
||||
/// Unique reference identifier (e.g. "core.python", "core.nodejs")
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "core.python")]
|
||||
pub r#ref: String,
|
||||
|
||||
/// Optional pack reference this runtime belongs to
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "core", nullable = true)]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
/// Optional human-readable description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Python runtime with virtualenv support", nullable = true)]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Display name
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Python")]
|
||||
pub name: String,
|
||||
|
||||
/// Distribution metadata used for verification and platform support
|
||||
#[serde(default)]
|
||||
#[schema(value_type = Object, example = json!({"linux": {"supported": true}}))]
|
||||
pub distributions: JsonValue,
|
||||
|
||||
/// Optional installation metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(value_type = Object, nullable = true, example = json!({"method": "system"}))]
|
||||
pub installation: Option<JsonValue>,
|
||||
|
||||
/// Runtime execution configuration
|
||||
#[serde(default)]
|
||||
#[schema(value_type = Object, example = json!({"interpreter": {"command": "python3"}}))]
|
||||
pub execution_config: JsonValue,
|
||||
}
|
||||
|
||||
/// Request DTO for updating a runtime.
|
||||
#[derive(Debug, Clone, Deserialize, Validate, ToSchema)]
|
||||
pub struct UpdateRuntimeRequest {
|
||||
/// Optional human-readable description patch.
|
||||
pub description: Option<NullableStringPatch>,
|
||||
|
||||
/// Display name
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
#[schema(example = "Python 3")]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Distribution metadata used for verification and platform support
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub distributions: Option<JsonValue>,
|
||||
|
||||
/// Optional installation metadata patch.
|
||||
pub installation: Option<NullableJsonPatch>,
|
||||
|
||||
/// Runtime execution configuration
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub execution_config: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Explicit patch operation for nullable string fields.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum NullableStringPatch {
|
||||
#[schema(title = "SetString")]
|
||||
Set(String),
|
||||
Clear,
|
||||
}
|
||||
|
||||
/// Explicit patch operation for nullable JSON fields.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum NullableJsonPatch {
|
||||
#[schema(title = "SetJson")]
|
||||
Set(JsonValue),
|
||||
Clear,
|
||||
}
|
||||
|
||||
/// Full runtime response.
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct RuntimeResponse {
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
#[schema(example = "core.python")]
|
||||
pub r#ref: String,
|
||||
|
||||
#[schema(example = 1, nullable = true)]
|
||||
pub pack: Option<i64>,
|
||||
|
||||
#[schema(example = "core", nullable = true)]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
#[schema(example = "Python runtime with virtualenv support", nullable = true)]
|
||||
pub description: Option<String>,
|
||||
|
||||
#[schema(example = "Python")]
|
||||
pub name: String,
|
||||
|
||||
#[schema(value_type = Object)]
|
||||
pub distributions: JsonValue,
|
||||
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub installation: Option<JsonValue>,
|
||||
|
||||
#[schema(value_type = Object)]
|
||||
pub execution_config: JsonValue,
|
||||
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Runtime summary for list views.
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct RuntimeSummary {
|
||||
#[schema(example = 1)]
|
||||
pub id: i64,
|
||||
|
||||
#[schema(example = "core.python")]
|
||||
pub r#ref: String,
|
||||
|
||||
#[schema(example = "core", nullable = true)]
|
||||
pub pack_ref: Option<String>,
|
||||
|
||||
#[schema(example = "Python runtime with virtualenv support", nullable = true)]
|
||||
pub description: Option<String>,
|
||||
|
||||
#[schema(example = "Python")]
|
||||
pub name: String,
|
||||
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<attune_common::models::runtime::Runtime> for RuntimeResponse {
|
||||
fn from(runtime: attune_common::models::runtime::Runtime) -> Self {
|
||||
Self {
|
||||
id: runtime.id,
|
||||
r#ref: runtime.r#ref,
|
||||
pack: runtime.pack,
|
||||
pack_ref: runtime.pack_ref,
|
||||
description: runtime.description,
|
||||
name: runtime.name,
|
||||
distributions: runtime.distributions,
|
||||
installation: runtime.installation,
|
||||
execution_config: runtime.execution_config,
|
||||
created: runtime.created,
|
||||
updated: runtime.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<attune_common::models::runtime::Runtime> for RuntimeSummary {
|
||||
fn from(runtime: attune_common::models::runtime::Runtime) -> Self {
|
||||
Self {
|
||||
id: runtime.id,
|
||||
r#ref: runtime.r#ref,
|
||||
pack_ref: runtime.pack_ref,
|
||||
description: runtime.description,
|
||||
name: runtime.name,
|
||||
created: runtime.created,
|
||||
updated: runtime.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -54,21 +54,35 @@ pub struct UpdateTriggerRequest {
|
||||
|
||||
/// Trigger description
|
||||
#[schema(example = "Updated webhook trigger description")]
|
||||
pub description: Option<String>,
|
||||
pub description: Option<TriggerStringPatch>,
|
||||
|
||||
/// Parameter schema (StackStorm-style with inline required/secret)
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
pub param_schema: Option<TriggerJsonPatch>,
|
||||
|
||||
/// Output schema
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub out_schema: Option<JsonValue>,
|
||||
pub out_schema: Option<TriggerJsonPatch>,
|
||||
|
||||
/// Whether the trigger is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum TriggerStringPatch {
|
||||
Set(String),
|
||||
Clear,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum TriggerJsonPatch {
|
||||
Set(JsonValue),
|
||||
Clear,
|
||||
}
|
||||
|
||||
/// Response DTO for trigger information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct TriggerResponse {
|
||||
@@ -189,9 +203,8 @@ pub struct CreateSensorRequest {
|
||||
pub label: String,
|
||||
|
||||
/// Sensor description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Monitors CPU usage and generates events")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Entry point for sensor execution (e.g., path to script, function name)
|
||||
#[validate(length(min = 1, max = 1024))]
|
||||
@@ -233,7 +246,6 @@ pub struct UpdateSensorRequest {
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Sensor description
|
||||
#[validate(length(min = 1))]
|
||||
#[schema(example = "Enhanced CPU monitoring with alerts")]
|
||||
pub description: Option<String>,
|
||||
|
||||
@@ -244,13 +256,20 @@ pub struct UpdateSensorRequest {
|
||||
|
||||
/// Parameter schema (StackStorm-style with inline required/secret)
|
||||
#[schema(value_type = Object, nullable = true)]
|
||||
pub param_schema: Option<JsonValue>,
|
||||
pub param_schema: Option<SensorJsonPatch>,
|
||||
|
||||
/// Whether the sensor is enabled
|
||||
#[schema(example = false)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
pub enum SensorJsonPatch {
|
||||
Set(JsonValue),
|
||||
Clear,
|
||||
}
|
||||
|
||||
/// Response DTO for sensor information
|
||||
#[derive(Debug, Clone, Serialize, ToSchema)]
|
||||
pub struct SensorResponse {
|
||||
@@ -276,7 +295,7 @@ pub struct SensorResponse {
|
||||
|
||||
/// Sensor description
|
||||
#[schema(example = "Monitors CPU usage and generates events")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Entry point
|
||||
#[schema(example = "/sensors/monitoring/cpu_monitor.py")]
|
||||
@@ -336,7 +355,7 @@ pub struct SensorSummary {
|
||||
|
||||
/// Sensor description
|
||||
#[schema(example = "Monitors CPU usage and generates events")]
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Trigger reference
|
||||
#[schema(example = "monitoring.cpu_threshold")]
|
||||
@@ -478,7 +497,7 @@ mod tests {
|
||||
r#ref: "test.sensor".to_string(),
|
||||
pack_ref: "test-pack".to_string(),
|
||||
label: "Test Sensor".to_string(),
|
||||
description: "Test description".to_string(),
|
||||
description: Some("Test description".to_string()),
|
||||
entrypoint: "/sensors/test.py".to_string(),
|
||||
runtime_ref: "python3".to_string(),
|
||||
trigger_ref: "test.trigger".to_string(),
|
||||
|
||||
@@ -48,10 +48,6 @@ pub struct SaveWorkflowFileRequest {
|
||||
/// Tags for categorization
|
||||
#[schema(example = json!(["deployment", "automation"]))]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Request DTO for creating a new workflow
|
||||
@@ -96,10 +92,6 @@ pub struct CreateWorkflowRequest {
|
||||
/// Tags for categorization and search
|
||||
#[schema(example = json!(["incident", "slack", "approval"]))]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Request DTO for updating a workflow
|
||||
@@ -134,10 +126,6 @@ pub struct UpdateWorkflowRequest {
|
||||
/// Tags
|
||||
#[schema(example = json!(["incident", "slack", "approval", "automation"]))]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Response DTO for workflow information
|
||||
@@ -187,10 +175,6 @@ pub struct WorkflowResponse {
|
||||
#[schema(example = json!(["incident", "slack", "approval"]))]
|
||||
pub tags: Vec<String>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
@@ -231,10 +215,6 @@ pub struct WorkflowSummary {
|
||||
#[schema(example = json!(["incident", "slack", "approval"]))]
|
||||
pub tags: Vec<String>,
|
||||
|
||||
/// Whether the workflow is enabled
|
||||
#[schema(example = true)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Creation timestamp
|
||||
#[schema(example = "2024-01-13T10:30:00Z")]
|
||||
pub created: DateTime<Utc>,
|
||||
@@ -259,7 +239,6 @@ impl From<attune_common::models::workflow::WorkflowDefinition> for WorkflowRespo
|
||||
out_schema: workflow.out_schema,
|
||||
definition: workflow.definition,
|
||||
tags: workflow.tags,
|
||||
enabled: workflow.enabled,
|
||||
created: workflow.created,
|
||||
updated: workflow.updated,
|
||||
}
|
||||
@@ -277,7 +256,6 @@ impl From<attune_common::models::workflow::WorkflowDefinition> for WorkflowSumma
|
||||
description: workflow.description,
|
||||
version: workflow.version,
|
||||
tags: workflow.tags,
|
||||
enabled: workflow.enabled,
|
||||
created: workflow.created,
|
||||
updated: workflow.updated,
|
||||
}
|
||||
@@ -291,10 +269,6 @@ pub struct WorkflowSearchParams {
|
||||
#[param(example = "incident,approval")]
|
||||
pub tags: Option<String>,
|
||||
|
||||
/// Filter by enabled status
|
||||
#[param(example = true)]
|
||||
pub enabled: Option<bool>,
|
||||
|
||||
/// Search term for label/description (case-insensitive)
|
||||
#[param(example = "incident")]
|
||||
pub search: Option<String>,
|
||||
@@ -320,7 +294,6 @@ mod tests {
|
||||
out_schema: None,
|
||||
definition: serde_json::json!({"tasks": []}),
|
||||
tags: None,
|
||||
enabled: None,
|
||||
};
|
||||
|
||||
assert!(req.validate().is_err());
|
||||
@@ -338,7 +311,6 @@ mod tests {
|
||||
out_schema: None,
|
||||
definition: serde_json::json!({"tasks": []}),
|
||||
tags: Some(vec!["test".to_string()]),
|
||||
enabled: Some(true),
|
||||
};
|
||||
|
||||
assert!(req.validate().is_ok());
|
||||
@@ -354,7 +326,6 @@ mod tests {
|
||||
out_schema: None,
|
||||
definition: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
};
|
||||
|
||||
// Should be valid even with all None values
|
||||
@@ -365,7 +336,6 @@ mod tests {
|
||||
fn test_workflow_search_params() {
|
||||
let params = WorkflowSearchParams {
|
||||
tags: Some("incident,approval".to_string()),
|
||||
enabled: Some(true),
|
||||
search: Some("response".to_string()),
|
||||
pack_ref: Some("core".to_string()),
|
||||
};
|
||||
|
||||
@@ -115,8 +115,9 @@ async fn mq_reconnect_loop(state: Arc<AppState>, mq_url: String) {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Install HMAC-only JWT crypto provider (must be before any token operations)
|
||||
attune_common::auth::install_crypto_provider();
|
||||
// Install a JWT crypto provider that supports both Attune's HS tokens
|
||||
// and external RS256 OIDC identity tokens.
|
||||
let _ = jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER.install_default();
|
||||
|
||||
// Initialize tracing subscriber
|
||||
tracing_subscriber::fmt()
|
||||
|
||||
@@ -10,8 +10,8 @@ use crate::dto::{
|
||||
ActionResponse, ActionSummary, CreateActionRequest, QueueStatsResponse, UpdateActionRequest,
|
||||
},
|
||||
auth::{
|
||||
ChangePasswordRequest, CurrentUserResponse, LoginRequest, RefreshTokenRequest,
|
||||
RegisterRequest, TokenResponse,
|
||||
AuthSettingsResponse, ChangePasswordRequest, CurrentUserResponse, LoginRequest,
|
||||
RefreshTokenRequest, RegisterRequest, TokenResponse,
|
||||
},
|
||||
common::{ApiResponse, PaginatedResponse, PaginationMeta, SuccessResponse},
|
||||
event::{EnforcementResponse, EnforcementSummary, EventResponse, EventSummary},
|
||||
@@ -27,10 +27,14 @@ use crate::dto::{
|
||||
UpdatePackRequest, WorkflowSyncResult,
|
||||
},
|
||||
permission::{
|
||||
CreateIdentityRequest, CreatePermissionAssignmentRequest, IdentityResponse,
|
||||
IdentitySummary, PermissionAssignmentResponse, PermissionSetSummary, UpdateIdentityRequest,
|
||||
CreateIdentityRequest, CreateIdentityRoleAssignmentRequest,
|
||||
CreatePermissionAssignmentRequest, CreatePermissionSetRoleAssignmentRequest,
|
||||
IdentityResponse, IdentityRoleAssignmentResponse, IdentitySummary,
|
||||
PermissionAssignmentResponse, PermissionSetRoleAssignmentResponse, PermissionSetSummary,
|
||||
UpdateIdentityRequest,
|
||||
},
|
||||
rule::{CreateRuleRequest, RuleResponse, RuleSummary, UpdateRuleRequest},
|
||||
runtime::{CreateRuntimeRequest, RuntimeResponse, RuntimeSummary, UpdateRuntimeRequest},
|
||||
trigger::{
|
||||
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary, TriggerResponse,
|
||||
TriggerSummary, UpdateSensorRequest, UpdateTriggerRequest,
|
||||
@@ -67,7 +71,9 @@ use crate::dto::{
|
||||
crate::routes::health::liveness,
|
||||
|
||||
// Authentication
|
||||
crate::routes::auth::auth_settings,
|
||||
crate::routes::auth::login,
|
||||
crate::routes::auth::ldap_login,
|
||||
crate::routes::auth::register,
|
||||
crate::routes::auth::refresh_token,
|
||||
crate::routes::auth::get_current_user,
|
||||
@@ -96,6 +102,14 @@ use crate::dto::{
|
||||
crate::routes::actions::delete_action,
|
||||
crate::routes::actions::get_queue_stats,
|
||||
|
||||
// Runtimes
|
||||
crate::routes::runtimes::list_runtimes,
|
||||
crate::routes::runtimes::list_runtimes_by_pack,
|
||||
crate::routes::runtimes::get_runtime,
|
||||
crate::routes::runtimes::create_runtime,
|
||||
crate::routes::runtimes::update_runtime,
|
||||
crate::routes::runtimes::delete_runtime,
|
||||
|
||||
// Triggers
|
||||
crate::routes::triggers::list_triggers,
|
||||
crate::routes::triggers::list_enabled_triggers,
|
||||
@@ -174,6 +188,12 @@ use crate::dto::{
|
||||
crate::routes::permissions::list_identity_permissions,
|
||||
crate::routes::permissions::create_permission_assignment,
|
||||
crate::routes::permissions::delete_permission_assignment,
|
||||
crate::routes::permissions::create_identity_role_assignment,
|
||||
crate::routes::permissions::delete_identity_role_assignment,
|
||||
crate::routes::permissions::create_permission_set_role_assignment,
|
||||
crate::routes::permissions::delete_permission_set_role_assignment,
|
||||
crate::routes::permissions::freeze_identity,
|
||||
crate::routes::permissions::unfreeze_identity,
|
||||
|
||||
// Workflows
|
||||
crate::routes::workflows::list_workflows,
|
||||
@@ -188,15 +208,21 @@ use crate::dto::{
|
||||
crate::routes::webhooks::disable_webhook,
|
||||
crate::routes::webhooks::regenerate_webhook_key,
|
||||
crate::routes::webhooks::receive_webhook,
|
||||
|
||||
// Agent
|
||||
crate::routes::agent::download_agent_binary,
|
||||
crate::routes::agent::agent_info,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
// Common types
|
||||
ApiResponse<TokenResponse>,
|
||||
ApiResponse<AuthSettingsResponse>,
|
||||
ApiResponse<CurrentUserResponse>,
|
||||
ApiResponse<PackResponse>,
|
||||
ApiResponse<PackInstallResponse>,
|
||||
ApiResponse<ActionResponse>,
|
||||
ApiResponse<RuntimeResponse>,
|
||||
ApiResponse<TriggerResponse>,
|
||||
ApiResponse<SensorResponse>,
|
||||
ApiResponse<RuleResponse>,
|
||||
@@ -211,6 +237,7 @@ use crate::dto::{
|
||||
ApiResponse<QueueStatsResponse>,
|
||||
PaginatedResponse<PackSummary>,
|
||||
PaginatedResponse<ActionSummary>,
|
||||
PaginatedResponse<RuntimeSummary>,
|
||||
PaginatedResponse<TriggerSummary>,
|
||||
PaginatedResponse<SensorSummary>,
|
||||
PaginatedResponse<RuleSummary>,
|
||||
@@ -226,6 +253,7 @@ use crate::dto::{
|
||||
|
||||
// Auth DTOs
|
||||
LoginRequest,
|
||||
crate::routes::auth::LdapLoginRequest,
|
||||
RegisterRequest,
|
||||
RefreshTokenRequest,
|
||||
ChangePasswordRequest,
|
||||
@@ -258,6 +286,16 @@ use crate::dto::{
|
||||
PermissionSetSummary,
|
||||
PermissionAssignmentResponse,
|
||||
CreatePermissionAssignmentRequest,
|
||||
CreateIdentityRoleAssignmentRequest,
|
||||
IdentityRoleAssignmentResponse,
|
||||
CreatePermissionSetRoleAssignmentRequest,
|
||||
PermissionSetRoleAssignmentResponse,
|
||||
|
||||
// Runtime DTOs
|
||||
CreateRuntimeRequest,
|
||||
UpdateRuntimeRequest,
|
||||
RuntimeResponse,
|
||||
RuntimeSummary,
|
||||
IdentitySummary,
|
||||
|
||||
// Action DTOs
|
||||
@@ -320,6 +358,10 @@ use crate::dto::{
|
||||
WebhookReceiverRequest,
|
||||
WebhookReceiverResponse,
|
||||
ApiResponse<WebhookReceiverResponse>,
|
||||
|
||||
// Agent DTOs
|
||||
crate::routes::agent::AgentBinaryInfo,
|
||||
crate::routes::agent::AgentArchInfo,
|
||||
)
|
||||
),
|
||||
modifiers(&SecurityAddon),
|
||||
@@ -338,6 +380,7 @@ use crate::dto::{
|
||||
(name = "secrets", description = "Secret management endpoints"),
|
||||
(name = "workflows", description = "Workflow management endpoints"),
|
||||
(name = "webhooks", description = "Webhook management and receiver endpoints"),
|
||||
(name = "agent", description = "Agent binary distribution endpoints"),
|
||||
)
|
||||
)]
|
||||
pub struct ApiDoc;
|
||||
@@ -420,18 +463,57 @@ mod tests {
|
||||
// We have 57 unique paths with 81 total operations (HTTP methods)
|
||||
// This test ensures we don't accidentally remove endpoints
|
||||
assert!(
|
||||
path_count >= 57,
|
||||
"Expected at least 57 unique API paths, found {}",
|
||||
path_count >= 59,
|
||||
"Expected at least 59 unique API paths, found {}",
|
||||
path_count
|
||||
);
|
||||
|
||||
assert!(
|
||||
operation_count >= 81,
|
||||
"Expected at least 81 API operations, found {}",
|
||||
operation_count >= 83,
|
||||
"Expected at least 83 API operations, found {}",
|
||||
operation_count
|
||||
);
|
||||
|
||||
println!("Total API paths: {}", path_count);
|
||||
println!("Total API operations: {}", operation_count);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_endpoints_registered() {
|
||||
let doc = ApiDoc::openapi();
|
||||
|
||||
let expected_auth_paths = vec![
|
||||
"/auth/settings",
|
||||
"/auth/login",
|
||||
"/auth/ldap/login",
|
||||
"/auth/register",
|
||||
"/auth/refresh",
|
||||
"/auth/me",
|
||||
"/auth/change-password",
|
||||
];
|
||||
|
||||
for path in &expected_auth_paths {
|
||||
assert!(
|
||||
doc.paths.paths.contains_key(*path),
|
||||
"Expected auth endpoint {} to be registered in OpenAPI spec, but it was missing. \
|
||||
Registered paths: {:?}",
|
||||
path,
|
||||
doc.paths.paths.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ldap_login_request_schema_registered() {
|
||||
let doc = ApiDoc::openapi();
|
||||
|
||||
let components = doc.components.as_ref().expect("components should exist");
|
||||
|
||||
assert!(
|
||||
components.schemas.contains_key("LdapLoginRequest"),
|
||||
"Expected LdapLoginRequest schema to be registered in OpenAPI components. \
|
||||
Registered schemas: {:?}",
|
||||
components.schemas.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ use attune_common::repositories::{
|
||||
action::{ActionRepository, ActionSearchFilters, CreateActionInput, UpdateActionInput},
|
||||
pack::PackRepository,
|
||||
queue_stats::QueueStatsRepository,
|
||||
Create, Delete, FindByRef, Update,
|
||||
Create, Delete, FindByRef, Patch, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -24,7 +24,7 @@ use crate::{
|
||||
dto::{
|
||||
action::{
|
||||
ActionResponse, ActionSummary, CreateActionRequest, QueueStatsResponse,
|
||||
UpdateActionRequest,
|
||||
RuntimeVersionConstraintPatch, UpdateActionRequest,
|
||||
},
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
ApiResponse, SuccessResponse,
|
||||
@@ -277,10 +277,13 @@ pub async fn update_action(
|
||||
// Create update input
|
||||
let update_input = UpdateActionInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
description: request.description.map(Patch::Set),
|
||||
entrypoint: request.entrypoint,
|
||||
runtime: request.runtime,
|
||||
runtime_version_constraint: request.runtime_version_constraint,
|
||||
runtime_version_constraint: request.runtime_version_constraint.map(|patch| match patch {
|
||||
RuntimeVersionConstraintPatch::Set(value) => Patch::Set(value),
|
||||
RuntimeVersionConstraintPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
parameter_delivery: None,
|
||||
|
||||
482
crates/api/src/routes/agent.rs
Normal file
482
crates/api/src/routes/agent.rs
Normal file
@@ -0,0 +1,482 @@
|
||||
//! Agent binary download endpoints
|
||||
//!
|
||||
//! Provides endpoints for downloading the attune-agent binary for injection
|
||||
//! into arbitrary containers. This supports deployments where shared Docker
|
||||
//! volumes are impractical (Kubernetes, ECS, remote Docker hosts).
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Query, State},
|
||||
http::{header, HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use subtle::ConstantTimeEq;
|
||||
use tokio::fs;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Query parameters for the binary download endpoint
|
||||
#[derive(Debug, Deserialize, IntoParams)]
|
||||
pub struct BinaryDownloadParams {
|
||||
/// Target architecture (x86_64, aarch64). Defaults to x86_64.
|
||||
#[param(example = "x86_64")]
|
||||
pub arch: Option<String>,
|
||||
/// Optional bootstrap token for authentication
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
/// Agent binary metadata
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct AgentBinaryInfo {
|
||||
/// Available architectures
|
||||
pub architectures: Vec<AgentArchInfo>,
|
||||
/// Agent version (from build)
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
/// Per-architecture binary info
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct AgentArchInfo {
|
||||
/// Architecture name
|
||||
pub arch: String,
|
||||
/// Binary size in bytes
|
||||
pub size_bytes: u64,
|
||||
/// Whether this binary is available
|
||||
pub available: bool,
|
||||
}
|
||||
|
||||
/// Validate that the architecture name is safe (no path traversal) and normalize it.
|
||||
fn validate_arch(arch: &str) -> Result<&str, (StatusCode, Json<serde_json::Value>)> {
|
||||
match arch {
|
||||
"x86_64" | "aarch64" => Ok(arch),
|
||||
// Accept arm64 as an alias for aarch64
|
||||
"arm64" => Ok("aarch64"),
|
||||
_ => Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": "Invalid architecture",
|
||||
"message": format!("Unsupported architecture '{}'. Supported: x86_64, aarch64", arch),
|
||||
})),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate bootstrap token if configured.
|
||||
///
|
||||
/// If the agent config has a `bootstrap_token` set, the request must provide it
|
||||
/// via the `X-Agent-Token` header or the `token` query parameter. If no token
|
||||
/// is configured, access is unrestricted.
|
||||
fn validate_token(
|
||||
config: &attune_common::config::Config,
|
||||
headers: &HeaderMap,
|
||||
query_token: &Option<String>,
|
||||
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
|
||||
let expected_token = config
|
||||
.agent
|
||||
.as_ref()
|
||||
.and_then(|ac| ac.bootstrap_token.as_ref());
|
||||
|
||||
let expected_token = match expected_token {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
use std::sync::Once;
|
||||
static WARN_ONCE: Once = Once::new();
|
||||
WARN_ONCE.call_once(|| {
|
||||
tracing::warn!(
|
||||
"Agent binary download endpoint has no bootstrap_token configured. \
|
||||
Anyone with network access to the API can download the agent binary. \
|
||||
Set agent.bootstrap_token in config to restrict access."
|
||||
);
|
||||
});
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Check X-Agent-Token header first, then query param
|
||||
let provided_token = headers
|
||||
.get("x-agent-token")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| query_token.clone());
|
||||
|
||||
match provided_token {
|
||||
Some(ref t) if bool::from(t.as_bytes().ct_eq(expected_token.as_bytes())) => Ok(()),
|
||||
Some(_) => Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({
|
||||
"error": "Invalid token",
|
||||
"message": "The provided bootstrap token is invalid",
|
||||
})),
|
||||
)),
|
||||
None => Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({
|
||||
"error": "Token required",
|
||||
"message": "A bootstrap token is required. Provide via X-Agent-Token header or token query parameter.",
|
||||
})),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Download the agent binary
|
||||
///
|
||||
/// Returns the statically-linked attune-agent binary for the requested architecture.
|
||||
/// The binary can be injected into any container to turn it into an Attune worker.
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/agent/binary",
|
||||
params(BinaryDownloadParams),
|
||||
responses(
|
||||
(status = 200, description = "Agent binary", content_type = "application/octet-stream"),
|
||||
(status = 400, description = "Invalid architecture"),
|
||||
(status = 401, description = "Invalid or missing bootstrap token"),
|
||||
(status = 404, description = "Agent binary not found"),
|
||||
(status = 503, description = "Agent binary distribution not configured"),
|
||||
),
|
||||
tag = "agent"
|
||||
)]
|
||||
pub async fn download_agent_binary(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(params): Query<BinaryDownloadParams>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Validate bootstrap token if configured
|
||||
validate_token(&state.config, &headers, ¶ms.token)?;
|
||||
|
||||
let agent_config = state.config.agent.as_ref().ok_or_else(|| {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({
|
||||
"error": "Not configured",
|
||||
"message": "Agent binary distribution is not configured. Set agent.binary_dir in config.",
|
||||
})),
|
||||
)
|
||||
})?;
|
||||
|
||||
let arch = params.arch.as_deref().unwrap_or("x86_64");
|
||||
let arch = validate_arch(arch)?;
|
||||
|
||||
let binary_dir = std::path::Path::new(&agent_config.binary_dir);
|
||||
|
||||
// Try arch-specific binary first, then fall back to generic name.
|
||||
// IMPORTANT: The generic `attune-agent` binary is only safe to serve for
|
||||
// x86_64 requests, because the current build pipeline produces an
|
||||
// x86_64-unknown-linux-musl binary. Serving it for aarch64/arm64 would
|
||||
// give the caller an incompatible executable (exec format error).
|
||||
let arch_specific = binary_dir.join(format!("attune-agent-{}", arch));
|
||||
let generic = binary_dir.join("attune-agent");
|
||||
|
||||
let binary_path = if arch_specific.exists() {
|
||||
arch_specific
|
||||
} else if arch == "x86_64" && generic.exists() {
|
||||
tracing::debug!(
|
||||
"Arch-specific binary not found at {:?}, falling back to generic {:?} (safe for x86_64)",
|
||||
arch_specific,
|
||||
generic
|
||||
);
|
||||
generic
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Agent binary not found. Checked: {:?} and {:?}",
|
||||
arch_specific,
|
||||
generic
|
||||
);
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({
|
||||
"error": "Not found",
|
||||
"message": format!(
|
||||
"Agent binary not found for architecture '{}'. Ensure the agent binary is built and placed in '{}'.",
|
||||
arch,
|
||||
agent_config.binary_dir
|
||||
),
|
||||
})),
|
||||
));
|
||||
};
|
||||
|
||||
// Get file metadata for Content-Length
|
||||
let metadata = fs::metadata(&binary_path).await.map_err(|e| {
|
||||
tracing::error!(
|
||||
"Failed to read agent binary metadata at {:?}: {}",
|
||||
binary_path,
|
||||
e
|
||||
);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": "Internal error",
|
||||
"message": "Failed to read agent binary",
|
||||
})),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Open file for streaming
|
||||
let file = fs::File::open(&binary_path).await.map_err(|e| {
|
||||
tracing::error!("Failed to open agent binary at {:?}: {}", binary_path, e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": "Internal error",
|
||||
"message": "Failed to open agent binary",
|
||||
})),
|
||||
)
|
||||
})?;
|
||||
|
||||
let stream = ReaderStream::new(file);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let headers_response = [
|
||||
(header::CONTENT_TYPE, "application/octet-stream".to_string()),
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
"attachment; filename=\"attune-agent\"".to_string(),
|
||||
),
|
||||
(header::CONTENT_LENGTH, metadata.len().to_string()),
|
||||
(header::CACHE_CONTROL, "public, max-age=3600".to_string()),
|
||||
];
|
||||
|
||||
tracing::info!(
|
||||
arch = arch,
|
||||
size_bytes = metadata.len(),
|
||||
path = ?binary_path,
|
||||
"Serving agent binary download"
|
||||
);
|
||||
|
||||
Ok((headers_response, body))
|
||||
}
|
||||
|
||||
/// Get agent binary metadata
|
||||
///
|
||||
/// Returns information about available agent binaries, including
|
||||
/// supported architectures and binary sizes.
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/agent/info",
|
||||
responses(
|
||||
(status = 200, description = "Agent binary info", body = AgentBinaryInfo),
|
||||
(status = 503, description = "Agent binary distribution not configured"),
|
||||
),
|
||||
tag = "agent"
|
||||
)]
|
||||
pub async fn agent_info(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
|
||||
let agent_config = state.config.agent.as_ref().ok_or_else(|| {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({
|
||||
"error": "Not configured",
|
||||
"message": "Agent binary distribution is not configured.",
|
||||
})),
|
||||
)
|
||||
})?;
|
||||
|
||||
let binary_dir = std::path::Path::new(&agent_config.binary_dir);
|
||||
let architectures = ["x86_64", "aarch64"];
|
||||
|
||||
let mut arch_infos = Vec::new();
|
||||
for arch in &architectures {
|
||||
let arch_specific = binary_dir.join(format!("attune-agent-{}", arch));
|
||||
let generic = binary_dir.join("attune-agent");
|
||||
|
||||
// Only fall back to the generic binary for x86_64, since the build
|
||||
// pipeline currently produces x86_64-only generic binaries.
|
||||
let (available, size_bytes) = if arch_specific.exists() {
|
||||
match fs::metadata(&arch_specific).await {
|
||||
Ok(m) => (true, m.len()),
|
||||
Err(_) => (false, 0),
|
||||
}
|
||||
} else if *arch == "x86_64" && generic.exists() {
|
||||
match fs::metadata(&generic).await {
|
||||
Ok(m) => (true, m.len()),
|
||||
Err(_) => (false, 0),
|
||||
}
|
||||
} else {
|
||||
(false, 0)
|
||||
};
|
||||
|
||||
arch_infos.push(AgentArchInfo {
|
||||
arch: arch.to_string(),
|
||||
size_bytes,
|
||||
available,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Json(AgentBinaryInfo {
|
||||
architectures: arch_infos,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Create agent routes
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/agent/binary", get(download_agent_binary))
|
||||
.route("/agent/info", get(agent_info))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use attune_common::config::AgentConfig;
|
||||
use axum::http::{HeaderMap, HeaderValue};
|
||||
|
||||
// ── validate_arch tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_validate_arch_valid_x86_64() {
|
||||
let result = validate_arch("x86_64");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "x86_64");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_arch_valid_aarch64() {
|
||||
let result = validate_arch("aarch64");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "aarch64");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_arch_arm64_alias() {
|
||||
// "arm64" is an alias for "aarch64"
|
||||
let result = validate_arch("arm64");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "aarch64");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_arch_invalid() {
|
||||
let result = validate_arch("mips");
|
||||
assert!(result.is_err());
|
||||
let (status, body) = result.unwrap_err();
|
||||
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||
assert_eq!(body.0["error"], "Invalid architecture");
|
||||
}
|
||||
|
||||
// ── validate_token tests ────────────────────────────────────────
|
||||
|
||||
/// Helper: build a minimal Config with the given agent config.
|
||||
/// Only the `agent` field is relevant for `validate_token`.
|
||||
fn test_config(agent: Option<AgentConfig>) -> attune_common::config::Config {
|
||||
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
|
||||
let config_path = format!("{}/../../config.test.yaml", manifest_dir);
|
||||
let mut config = attune_common::config::Config::load_from_file(&config_path)
|
||||
.expect("Failed to load test config");
|
||||
config.agent = agent;
|
||||
config
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_no_config() {
|
||||
// When no agent config is set at all, no token is required.
|
||||
let config = test_config(None);
|
||||
let headers = HeaderMap::new();
|
||||
let query_token = None;
|
||||
|
||||
let result = validate_token(&config, &headers, &query_token);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_no_bootstrap_token_configured() {
|
||||
// Agent config exists but bootstrap_token is None → no token required.
|
||||
let config = test_config(Some(AgentConfig {
|
||||
binary_dir: "/tmp/test".to_string(),
|
||||
bootstrap_token: None,
|
||||
}));
|
||||
let headers = HeaderMap::new();
|
||||
let query_token = None;
|
||||
|
||||
let result = validate_token(&config, &headers, &query_token);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_valid_from_header() {
|
||||
let config = test_config(Some(AgentConfig {
|
||||
binary_dir: "/tmp/test".to_string(),
|
||||
bootstrap_token: Some("s3cret-bootstrap".to_string()),
|
||||
}));
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"x-agent-token",
|
||||
HeaderValue::from_static("s3cret-bootstrap"),
|
||||
);
|
||||
let query_token = None;
|
||||
|
||||
let result = validate_token(&config, &headers, &query_token);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_valid_from_query() {
|
||||
let config = test_config(Some(AgentConfig {
|
||||
binary_dir: "/tmp/test".to_string(),
|
||||
bootstrap_token: Some("s3cret-bootstrap".to_string()),
|
||||
}));
|
||||
let headers = HeaderMap::new();
|
||||
let query_token = Some("s3cret-bootstrap".to_string());
|
||||
|
||||
let result = validate_token(&config, &headers, &query_token);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_invalid() {
|
||||
let config = test_config(Some(AgentConfig {
|
||||
binary_dir: "/tmp/test".to_string(),
|
||||
bootstrap_token: Some("correct-token".to_string()),
|
||||
}));
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-agent-token", HeaderValue::from_static("wrong-token"));
|
||||
let query_token = None;
|
||||
|
||||
let result = validate_token(&config, &headers, &query_token);
|
||||
assert!(result.is_err());
|
||||
let (status, body) = result.unwrap_err();
|
||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(body.0["error"], "Invalid token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_missing_when_required() {
|
||||
// bootstrap_token is configured but caller provides nothing.
|
||||
let config = test_config(Some(AgentConfig {
|
||||
binary_dir: "/tmp/test".to_string(),
|
||||
bootstrap_token: Some("required-token".to_string()),
|
||||
}));
|
||||
let headers = HeaderMap::new();
|
||||
let query_token = None;
|
||||
|
||||
let result = validate_token(&config, &headers, &query_token);
|
||||
assert!(result.is_err());
|
||||
let (status, body) = result.unwrap_err();
|
||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(body.0["error"], "Token required");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_header_takes_precedence_over_query() {
|
||||
// When both header and query provide a token, the header value is
|
||||
// checked first (it appears first in the or_else chain). Provide a
|
||||
// valid token in the header and an invalid one in the query — should
|
||||
// succeed because the header matches.
|
||||
let config = test_config(Some(AgentConfig {
|
||||
binary_dir: "/tmp/test".to_string(),
|
||||
bootstrap_token: Some("the-real-token".to_string()),
|
||||
}));
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-agent-token", HeaderValue::from_static("the-real-token"));
|
||||
let query_token = Some("wrong-token".to_string());
|
||||
|
||||
let result = validate_token(&config, &headers, &query_token);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
@@ -36,15 +36,17 @@ use attune_common::repositories::{
|
||||
ArtifactRepository, ArtifactSearchFilters, ArtifactVersionRepository, CreateArtifactInput,
|
||||
CreateArtifactVersionInput, UpdateArtifactInput,
|
||||
},
|
||||
Create, Delete, FindById, FindByRef, Update,
|
||||
Create, Delete, FindById, FindByRef, Patch, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
auth::{jwt::TokenType, middleware::AuthenticatedUser, middleware::RequireAuth},
|
||||
authz::{AuthorizationCheck, AuthorizationService},
|
||||
dto::{
|
||||
artifact::{
|
||||
AllocateFileVersionByRefRequest, AppendProgressRequest, ArtifactQueryParams,
|
||||
ArtifactResponse, ArtifactSummary, ArtifactVersionResponse, ArtifactVersionSummary,
|
||||
AllocateFileVersionByRefRequest, AppendProgressRequest, ArtifactExecutionPatch,
|
||||
ArtifactJsonPatch, ArtifactQueryParams, ArtifactResponse, ArtifactStringPatch,
|
||||
ArtifactSummary, ArtifactVersionResponse, ArtifactVersionSummary,
|
||||
CreateArtifactRequest, CreateFileVersionRequest, CreateVersionJsonRequest,
|
||||
SetDataRequest, UpdateArtifactRequest,
|
||||
},
|
||||
@@ -54,6 +56,7 @@ use crate::{
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
use attune_common::rbac::{Action, AuthorizationContext, Resource};
|
||||
|
||||
// ============================================================================
|
||||
// Artifact CRUD
|
||||
@@ -71,7 +74,7 @@ use crate::{
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_artifacts(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(query): Query<ArtifactQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -87,8 +90,16 @@ pub async fn list_artifacts(
|
||||
};
|
||||
|
||||
let result = ArtifactRepository::search(&state.db, &filters).await?;
|
||||
let mut rows = result.rows;
|
||||
|
||||
let items: Vec<ArtifactSummary> = result.rows.into_iter().map(ArtifactSummary::from).collect();
|
||||
if let Some((identity_id, grants)) = ensure_can_read_any_artifact(&state, &user).await? {
|
||||
rows.retain(|artifact| {
|
||||
let ctx = artifact_authorization_context(identity_id, artifact);
|
||||
AuthorizationService::is_allowed(&grants, Resource::Artifacts, Action::Read, &ctx)
|
||||
});
|
||||
}
|
||||
|
||||
let items: Vec<ArtifactSummary> = rows.into_iter().map(ArtifactSummary::from).collect();
|
||||
|
||||
let pagination = PaginationParams {
|
||||
page: query.page,
|
||||
@@ -112,7 +123,7 @@ pub async fn list_artifacts(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -120,6 +131,10 @@ pub async fn get_artifact(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(ArtifactResponse::from(artifact))),
|
||||
@@ -139,7 +154,7 @@ pub async fn get_artifact(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_artifact_by_ref(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(artifact_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -147,6 +162,10 @@ pub async fn get_artifact_by_ref(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact '{}' not found", artifact_ref)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact '{}' not found", artifact_ref)))?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(ArtifactResponse::from(artifact))),
|
||||
@@ -167,7 +186,7 @@ pub async fn get_artifact_by_ref(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(request): Json<CreateArtifactRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -199,6 +218,16 @@ pub async fn create_artifact(
|
||||
}
|
||||
});
|
||||
|
||||
authorize_artifact_create(
|
||||
&state,
|
||||
&user,
|
||||
&request.r#ref,
|
||||
request.scope,
|
||||
&request.owner,
|
||||
visibility,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let input = CreateArtifactInput {
|
||||
r#ref: request.r#ref,
|
||||
scope: request.scope,
|
||||
@@ -239,16 +268,18 @@ pub async fn create_artifact(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn update_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<UpdateArtifactRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify artifact exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Update, &artifact).await?;
|
||||
|
||||
let input = UpdateArtifactInput {
|
||||
r#ref: None, // Ref is immutable after creation
|
||||
scope: request.scope,
|
||||
@@ -257,12 +288,27 @@ pub async fn update_artifact(
|
||||
visibility: request.visibility,
|
||||
retention_policy: request.retention_policy,
|
||||
retention_limit: request.retention_limit,
|
||||
name: request.name,
|
||||
description: request.description,
|
||||
content_type: request.content_type,
|
||||
name: request.name.map(|patch| match patch {
|
||||
ArtifactStringPatch::Set(value) => Patch::Set(value),
|
||||
ArtifactStringPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
description: request.description.map(|patch| match patch {
|
||||
ArtifactStringPatch::Set(value) => Patch::Set(value),
|
||||
ArtifactStringPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
content_type: request.content_type.map(|patch| match patch {
|
||||
ArtifactStringPatch::Set(value) => Patch::Set(value),
|
||||
ArtifactStringPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
size_bytes: None, // Managed by version creation trigger
|
||||
execution: request.execution.map(Some),
|
||||
data: request.data,
|
||||
execution: request.execution.map(|patch| match patch {
|
||||
ArtifactExecutionPatch::Set(value) => Patch::Set(value),
|
||||
ArtifactExecutionPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
data: request.data.map(|patch| match patch {
|
||||
ArtifactJsonPatch::Set(value) => Patch::Set(value),
|
||||
ArtifactJsonPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
};
|
||||
|
||||
let updated = ArtifactRepository::update(&state.db, id, input).await?;
|
||||
@@ -289,7 +335,7 @@ pub async fn update_artifact(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_artifact(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -297,6 +343,8 @@ pub async fn delete_artifact(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Delete, &artifact).await?;
|
||||
|
||||
// Before deleting DB rows, clean up any file-backed versions on disk
|
||||
let file_versions =
|
||||
ArtifactVersionRepository::find_file_versions_by_artifact(&state.db, id).await?;
|
||||
@@ -339,11 +387,17 @@ pub async fn delete_artifact(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_artifacts_by_execution(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(execution_id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let artifacts = ArtifactRepository::find_by_execution(&state.db, execution_id).await?;
|
||||
let mut artifacts = ArtifactRepository::find_by_execution(&state.db, execution_id).await?;
|
||||
if let Some((identity_id, grants)) = ensure_can_read_any_artifact(&state, &user).await? {
|
||||
artifacts.retain(|artifact| {
|
||||
let ctx = artifact_authorization_context(identity_id, artifact);
|
||||
AuthorizationService::is_allowed(&grants, Resource::Artifacts, Action::Read, &ctx)
|
||||
});
|
||||
}
|
||||
let items: Vec<ArtifactSummary> = artifacts.into_iter().map(ArtifactSummary::from).collect();
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::new(items))))
|
||||
@@ -371,7 +425,7 @@ pub async fn list_artifacts_by_execution(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn append_progress(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<AppendProgressRequest>,
|
||||
@@ -380,6 +434,8 @@ pub async fn append_progress(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Update, &artifact).await?;
|
||||
|
||||
if artifact.r#type != ArtifactType::Progress {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"Artifact '{}' is type {:?}, not progress. Use version endpoints for file artifacts.",
|
||||
@@ -414,16 +470,18 @@ pub async fn append_progress(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn set_artifact_data(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<SetDataRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Update, &artifact).await?;
|
||||
|
||||
let updated = ArtifactRepository::set_data(&state.db, id, &request.data).await?;
|
||||
|
||||
Ok((
|
||||
@@ -452,15 +510,19 @@ pub async fn set_artifact_data(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_versions(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify artifact exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let versions = ArtifactVersionRepository::list_by_artifact(&state.db, id).await?;
|
||||
let items: Vec<ArtifactVersionSummary> = versions
|
||||
.into_iter()
|
||||
@@ -486,15 +548,19 @@ pub async fn list_versions(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path((id, version)): Path<(i64, i32)>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Verify artifact exists
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let ver = ArtifactVersionRepository::find_by_version(&state.db, id, version)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
@@ -520,14 +586,18 @@ pub async fn get_version(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_latest_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
let ver = ArtifactVersionRepository::find_latest(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("No versions found for artifact {}", id)))?;
|
||||
@@ -552,15 +622,17 @@ pub async fn get_latest_version(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_version_json(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<CreateVersionJsonRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Update, &artifact).await?;
|
||||
|
||||
let input = CreateArtifactVersionInput {
|
||||
artifact: id,
|
||||
content_type: Some(
|
||||
@@ -608,7 +680,7 @@ pub async fn create_version_json(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_version_file(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
Json(request): Json<CreateFileVersionRequest>,
|
||||
@@ -617,6 +689,8 @@ pub async fn create_version_file(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Update, &artifact).await?;
|
||||
|
||||
// Validate this is a file-type artifact
|
||||
if !is_file_backed_type(artifact.r#type) {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
@@ -710,15 +784,17 @@ pub async fn create_version_file(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn upload_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
mut multipart: Multipart,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
ArtifactRepository::find_by_id(&state.db, id)
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Update, &artifact).await?;
|
||||
|
||||
let mut file_data: Option<Vec<u8>> = None;
|
||||
let mut content_type: Option<String> = None;
|
||||
let mut meta: Option<serde_json::Value> = None;
|
||||
@@ -838,7 +914,7 @@ pub async fn upload_version(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn download_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path((id, version)): Path<(i64, i32)>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -846,6 +922,10 @@ pub async fn download_version(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
// First try without content (cheaper query) to check for file_path
|
||||
let ver = ArtifactVersionRepository::find_by_version(&state.db, id, version)
|
||||
.await?
|
||||
@@ -888,7 +968,7 @@ pub async fn download_version(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn download_latest(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -896,6 +976,10 @@ pub async fn download_latest(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
// First try without content (cheaper query) to check for file_path
|
||||
let ver = ArtifactVersionRepository::find_latest(&state.db, id)
|
||||
.await?
|
||||
@@ -939,7 +1023,7 @@ pub async fn download_latest(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_version(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path((id, version)): Path<(i64, i32)>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
@@ -948,6 +1032,8 @@ pub async fn delete_version(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Delete, &artifact).await?;
|
||||
|
||||
// Find the version by artifact + version number
|
||||
let ver = ArtifactVersionRepository::find_by_version(&state.db, id, version)
|
||||
.await?
|
||||
@@ -1026,7 +1112,7 @@ pub async fn delete_version(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn upload_version_by_ref(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(artifact_ref): Path<String>,
|
||||
mut multipart: Multipart,
|
||||
@@ -1141,6 +1227,8 @@ pub async fn upload_version_by_ref(
|
||||
// Upsert: find existing artifact or create a new one
|
||||
let artifact = match ArtifactRepository::find_by_ref(&state.db, &artifact_ref).await? {
|
||||
Some(existing) => {
|
||||
authorize_artifact_action(&state, &user, Action::Update, &existing).await?;
|
||||
|
||||
// Update execution link if a new execution ID was provided
|
||||
if execution_id.is_some() && execution_id != existing.execution {
|
||||
let update_input = UpdateArtifactInput {
|
||||
@@ -1155,7 +1243,7 @@ pub async fn upload_version_by_ref(
|
||||
description: None,
|
||||
content_type: None,
|
||||
size_bytes: None,
|
||||
execution: execution_id.map(Some),
|
||||
execution: execution_id.map(Patch::Set),
|
||||
data: None,
|
||||
};
|
||||
ArtifactRepository::update(&state.db, existing.id, update_input).await?
|
||||
@@ -1195,6 +1283,16 @@ pub async fn upload_version_by_ref(
|
||||
}
|
||||
};
|
||||
|
||||
authorize_artifact_create(
|
||||
&state,
|
||||
&user,
|
||||
&artifact_ref,
|
||||
a_scope,
|
||||
owner.as_deref().unwrap_or_default(),
|
||||
a_visibility,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Parse retention
|
||||
let a_retention_policy: RetentionPolicyType = match &retention_policy {
|
||||
Some(rp) if !rp.is_empty() => {
|
||||
@@ -1281,7 +1379,7 @@ pub async fn upload_version_by_ref(
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn allocate_file_version_by_ref(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(artifact_ref): Path<String>,
|
||||
Json(request): Json<AllocateFileVersionByRefRequest>,
|
||||
@@ -1289,6 +1387,8 @@ pub async fn allocate_file_version_by_ref(
|
||||
// Upsert: find existing artifact or create a new one
|
||||
let artifact = match ArtifactRepository::find_by_ref(&state.db, &artifact_ref).await? {
|
||||
Some(existing) => {
|
||||
authorize_artifact_action(&state, &user, Action::Update, &existing).await?;
|
||||
|
||||
// Update execution link if a new execution ID was provided
|
||||
if request.execution.is_some() && request.execution != existing.execution {
|
||||
let update_input = UpdateArtifactInput {
|
||||
@@ -1303,7 +1403,7 @@ pub async fn allocate_file_version_by_ref(
|
||||
description: None,
|
||||
content_type: None,
|
||||
size_bytes: None,
|
||||
execution: request.execution.map(Some),
|
||||
execution: request.execution.map(Patch::Set),
|
||||
data: None,
|
||||
};
|
||||
ArtifactRepository::update(&state.db, existing.id, update_input).await?
|
||||
@@ -1331,6 +1431,16 @@ pub async fn allocate_file_version_by_ref(
|
||||
.unwrap_or(RetentionPolicyType::Versions);
|
||||
let a_retention_limit = request.retention_limit.unwrap_or(10);
|
||||
|
||||
authorize_artifact_create(
|
||||
&state,
|
||||
&user,
|
||||
&artifact_ref,
|
||||
a_scope,
|
||||
request.owner.as_deref().unwrap_or_default(),
|
||||
a_visibility,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let create_input = CreateArtifactInput {
|
||||
r#ref: artifact_ref.clone(),
|
||||
scope: a_scope,
|
||||
@@ -1421,6 +1531,105 @@ pub async fn allocate_file_version_by_ref(
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
async fn authorize_artifact_action(
|
||||
state: &Arc<AppState>,
|
||||
user: &AuthenticatedUser,
|
||||
action: Action,
|
||||
artifact: &attune_common::models::artifact::Artifact,
|
||||
) -> Result<(), ApiError> {
|
||||
if user.claims.token_type != TokenType::Access {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
authz
|
||||
.authorize(
|
||||
user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Artifacts,
|
||||
action,
|
||||
context: artifact_authorization_context(identity_id, artifact),
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn authorize_artifact_create(
|
||||
state: &Arc<AppState>,
|
||||
user: &AuthenticatedUser,
|
||||
artifact_ref: &str,
|
||||
scope: OwnerType,
|
||||
owner: &str,
|
||||
visibility: ArtifactVisibility,
|
||||
) -> Result<(), ApiError> {
|
||||
if user.claims.token_type != TokenType::Access {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.target_ref = Some(artifact_ref.to_string());
|
||||
ctx.owner_type = Some(scope);
|
||||
ctx.owner_ref = Some(owner.to_string());
|
||||
ctx.visibility = Some(visibility);
|
||||
|
||||
authz
|
||||
.authorize(
|
||||
user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Artifacts,
|
||||
action: Action::Create,
|
||||
context: ctx,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn ensure_can_read_any_artifact(
|
||||
state: &Arc<AppState>,
|
||||
user: &AuthenticatedUser,
|
||||
) -> Result<Option<(i64, Vec<attune_common::rbac::Grant>)>, ApiError> {
|
||||
if user.claims.token_type != TokenType::Access {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
let grants = authz.effective_grants(user).await?;
|
||||
|
||||
let can_read_any_artifact = grants
|
||||
.iter()
|
||||
.any(|g| g.resource == Resource::Artifacts && g.actions.contains(&Action::Read));
|
||||
if !can_read_any_artifact {
|
||||
return Err(ApiError::Forbidden(
|
||||
"Insufficient permissions: artifacts:read".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some((identity_id, grants)))
|
||||
}
|
||||
|
||||
fn artifact_authorization_context(
|
||||
identity_id: i64,
|
||||
artifact: &attune_common::models::artifact::Artifact,
|
||||
) -> AuthorizationContext {
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.target_id = Some(artifact.id);
|
||||
ctx.target_ref = Some(artifact.r#ref.clone());
|
||||
ctx.owner_type = Some(artifact.scope);
|
||||
ctx.owner_ref = Some(artifact.owner.clone());
|
||||
ctx.visibility = Some(artifact.visibility);
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Returns true for artifact types that should use file-backed storage on disk.
|
||||
fn is_file_backed_type(artifact_type: ArtifactType) -> bool {
|
||||
matches!(
|
||||
@@ -1759,14 +1968,19 @@ pub async fn stream_artifact(
|
||||
let token = params.token.as_ref().ok_or(ApiError::Unauthorized(
|
||||
"Missing authentication token".to_string(),
|
||||
))?;
|
||||
validate_token(token, &state.jwt_config)
|
||||
let claims = validate_token(token, &state.jwt_config)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
|
||||
let user = AuthenticatedUser { claims };
|
||||
|
||||
// --- resolve artifact + latest version ---------------------------------
|
||||
let artifact = ArtifactRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
authorize_artifact_action(&state, &user, Action::Read, &artifact)
|
||||
.await
|
||||
.map_err(|_| ApiError::NotFound(format!("Artifact with ID {} not found", id)))?;
|
||||
|
||||
if !is_file_backed_type(artifact.r#type) {
|
||||
return Err(ApiError::BadRequest(format!(
|
||||
"Artifact '{}' is type {:?} which is not file-backed. \
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
//! Authentication routes
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
extract::{Query, State},
|
||||
http::HeaderMap,
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
@@ -21,11 +23,16 @@ use crate::{
|
||||
TokenType,
|
||||
},
|
||||
middleware::RequireAuth,
|
||||
oidc::{
|
||||
apply_cookies_to_headers, build_login_redirect, build_logout_redirect,
|
||||
cookie_authenticated_user, get_cookie_value, oidc_callback_redirect_response,
|
||||
OidcCallbackQuery, REFRESH_COOKIE_NAME,
|
||||
},
|
||||
verify_password,
|
||||
},
|
||||
dto::{
|
||||
ApiResponse, ChangePasswordRequest, CurrentUserResponse, LoginRequest, RefreshTokenRequest,
|
||||
RegisterRequest, SuccessResponse, TokenResponse,
|
||||
ApiResponse, AuthSettingsResponse, ChangePasswordRequest, CurrentUserResponse,
|
||||
LoginRequest, RefreshTokenRequest, RegisterRequest, SuccessResponse, TokenResponse,
|
||||
},
|
||||
middleware::error::ApiError,
|
||||
state::SharedState,
|
||||
@@ -63,7 +70,12 @@ pub struct SensorTokenResponse {
|
||||
/// Create authentication routes
|
||||
pub fn routes() -> Router<SharedState> {
|
||||
Router::new()
|
||||
.route("/settings", get(auth_settings))
|
||||
.route("/login", post(login))
|
||||
.route("/oidc/login", get(oidc_login))
|
||||
.route("/callback", get(oidc_callback))
|
||||
.route("/ldap/login", post(ldap_login))
|
||||
.route("/logout", get(logout))
|
||||
.route("/register", post(register))
|
||||
.route("/refresh", post(refresh_token))
|
||||
.route("/me", get(get_current_user))
|
||||
@@ -72,6 +84,63 @@ pub fn routes() -> Router<SharedState> {
|
||||
.route("/internal/sensor-token", post(create_sensor_token_internal))
|
||||
}
|
||||
|
||||
/// Authentication settings endpoint
|
||||
///
|
||||
/// GET /auth/settings
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/auth/settings",
|
||||
tag = "auth",
|
||||
responses(
|
||||
(status = 200, description = "Authentication settings", body = inline(ApiResponse<AuthSettingsResponse>))
|
||||
)
|
||||
)]
|
||||
pub async fn auth_settings(
|
||||
State(state): State<SharedState>,
|
||||
) -> Result<Json<ApiResponse<AuthSettingsResponse>>, ApiError> {
|
||||
let oidc = state
|
||||
.config
|
||||
.security
|
||||
.oidc
|
||||
.as_ref()
|
||||
.filter(|oidc| oidc.enabled);
|
||||
|
||||
let ldap = state
|
||||
.config
|
||||
.security
|
||||
.ldap
|
||||
.as_ref()
|
||||
.filter(|ldap| ldap.enabled);
|
||||
|
||||
let response = AuthSettingsResponse {
|
||||
authentication_enabled: state.config.security.enable_auth,
|
||||
local_password_enabled: state.config.security.enable_auth,
|
||||
local_password_visible_by_default: state.config.security.enable_auth
|
||||
&& state.config.security.login_page.show_local_login,
|
||||
oidc_enabled: oidc.is_some(),
|
||||
oidc_visible_by_default: oidc.is_some() && state.config.security.login_page.show_oidc_login,
|
||||
oidc_provider_name: oidc.map(|oidc| oidc.provider_name.clone()),
|
||||
oidc_provider_label: oidc.map(|oidc| {
|
||||
oidc.provider_label
|
||||
.clone()
|
||||
.unwrap_or_else(|| oidc.provider_name.clone())
|
||||
}),
|
||||
oidc_provider_icon_url: oidc.and_then(|oidc| oidc.provider_icon_url.clone()),
|
||||
ldap_enabled: ldap.is_some(),
|
||||
ldap_visible_by_default: ldap.is_some() && state.config.security.login_page.show_ldap_login,
|
||||
ldap_provider_name: ldap.map(|ldap| ldap.provider_name.clone()),
|
||||
ldap_provider_label: ldap.map(|ldap| {
|
||||
ldap.provider_label
|
||||
.clone()
|
||||
.unwrap_or_else(|| ldap.provider_name.clone())
|
||||
}),
|
||||
ldap_provider_icon_url: ldap.and_then(|ldap| ldap.provider_icon_url.clone()),
|
||||
self_registration_enabled: state.config.security.allow_self_registration,
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Login endpoint
|
||||
///
|
||||
/// POST /auth/login
|
||||
@@ -100,6 +169,12 @@ pub async fn login(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::Unauthorized("Invalid login or password".to_string()))?;
|
||||
|
||||
if identity.frozen {
|
||||
return Err(ApiError::Forbidden(
|
||||
"Identity is frozen and cannot authenticate".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check if identity has a password set
|
||||
let password_hash = identity
|
||||
.password_hash
|
||||
@@ -221,15 +296,22 @@ pub async fn register(
|
||||
)]
|
||||
pub async fn refresh_token(
|
||||
State(state): State<SharedState>,
|
||||
Json(payload): Json<RefreshTokenRequest>,
|
||||
) -> Result<Json<ApiResponse<TokenResponse>>, ApiError> {
|
||||
// Validate request
|
||||
payload
|
||||
.validate()
|
||||
.map_err(|e| ApiError::ValidationError(format!("Invalid refresh token request: {}", e)))?;
|
||||
headers: HeaderMap,
|
||||
payload: Option<Json<RefreshTokenRequest>>,
|
||||
) -> Result<Response, ApiError> {
|
||||
let browser_cookie_refresh = payload.is_none();
|
||||
let refresh_token = if let Some(Json(payload)) = payload {
|
||||
payload.validate().map_err(|e| {
|
||||
ApiError::ValidationError(format!("Invalid refresh token request: {}", e))
|
||||
})?;
|
||||
payload.refresh_token
|
||||
} else {
|
||||
get_cookie_value(&headers, REFRESH_COOKIE_NAME)
|
||||
.ok_or_else(|| ApiError::Unauthorized("Missing refresh token".to_string()))?
|
||||
};
|
||||
|
||||
// Validate refresh token
|
||||
let claims = validate_token(&payload.refresh_token, &state.jwt_config)
|
||||
let claims = validate_token(&refresh_token, &state.jwt_config)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid or expired refresh token".to_string()))?;
|
||||
|
||||
// Ensure it's a refresh token
|
||||
@@ -248,6 +330,12 @@ pub async fn refresh_token(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::Unauthorized("Identity not found".to_string()))?;
|
||||
|
||||
if identity.frozen {
|
||||
return Err(ApiError::Forbidden(
|
||||
"Identity is frozen and cannot authenticate".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Generate new tokens
|
||||
let access_token = generate_access_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
let refresh_token = generate_refresh_token(identity.id, &identity.login, &state.jwt_config)?;
|
||||
@@ -257,8 +345,18 @@ pub async fn refresh_token(
|
||||
refresh_token,
|
||||
state.jwt_config.access_token_expiration,
|
||||
);
|
||||
let response_body = Json(ApiResponse::new(response.clone()));
|
||||
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
if browser_cookie_refresh {
|
||||
let mut http_response = response_body.into_response();
|
||||
apply_cookies_to_headers(
|
||||
http_response.headers_mut(),
|
||||
&crate::auth::oidc::build_auth_cookies(&state, &response, ""),
|
||||
)?;
|
||||
return Ok(http_response);
|
||||
}
|
||||
|
||||
Ok(response_body.into_response())
|
||||
}
|
||||
|
||||
/// Get current user endpoint
|
||||
@@ -279,15 +377,27 @@ pub async fn refresh_token(
|
||||
)]
|
||||
pub async fn get_current_user(
|
||||
State(state): State<SharedState>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
headers: HeaderMap,
|
||||
user: Result<RequireAuth, crate::auth::middleware::AuthError>,
|
||||
) -> Result<Json<ApiResponse<CurrentUserResponse>>, ApiError> {
|
||||
let identity_id = user.identity_id()?;
|
||||
let authenticated_user = match user {
|
||||
Ok(RequireAuth(user)) => user,
|
||||
Err(_) => cookie_authenticated_user(&headers, &state)?
|
||||
.ok_or_else(|| ApiError::Unauthorized("Unauthorized".to_string()))?,
|
||||
};
|
||||
let identity_id = authenticated_user.identity_id()?;
|
||||
|
||||
// Fetch identity from database
|
||||
let identity = IdentityRepository::find_by_id(&state.db, identity_id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound("Identity not found".to_string()))?;
|
||||
|
||||
if identity.frozen {
|
||||
return Err(ApiError::Forbidden(
|
||||
"Identity is frozen and cannot authenticate".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let response = CurrentUserResponse {
|
||||
id: identity.id,
|
||||
login: identity.login,
|
||||
@@ -297,6 +407,106 @@ pub async fn get_current_user(
|
||||
Ok(Json(ApiResponse::new(response)))
|
||||
}
|
||||
|
||||
/// Request body for LDAP login.
|
||||
#[derive(Debug, Serialize, Deserialize, Validate, ToSchema)]
|
||||
pub struct LdapLoginRequest {
|
||||
/// User login name (uid, sAMAccountName, etc.)
|
||||
#[validate(length(min = 1, max = 255))]
|
||||
pub login: String,
|
||||
/// User password
|
||||
#[validate(length(min = 1, max = 512))]
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcLoginParams {
|
||||
pub redirect_to: Option<String>,
|
||||
}
|
||||
|
||||
/// Begin browser OIDC login by redirecting to the provider.
|
||||
pub async fn oidc_login(
|
||||
State(state): State<SharedState>,
|
||||
Query(params): Query<OidcLoginParams>,
|
||||
) -> Result<Response, ApiError> {
|
||||
let login_redirect = build_login_redirect(&state, params.redirect_to.as_deref()).await?;
|
||||
let mut response = Redirect::temporary(&login_redirect.authorization_url).into_response();
|
||||
apply_cookies_to_headers(response.headers_mut(), &login_redirect.cookies)?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Handle the OIDC authorization code callback.
|
||||
pub async fn oidc_callback(
|
||||
State(state): State<SharedState>,
|
||||
headers: HeaderMap,
|
||||
Query(query): Query<OidcCallbackQuery>,
|
||||
) -> Result<Response, ApiError> {
|
||||
let redirect_to = get_cookie_value(&headers, crate::auth::oidc::OIDC_REDIRECT_COOKIE_NAME);
|
||||
let authenticated = crate::auth::oidc::handle_callback(&state, &headers, &query).await?;
|
||||
oidc_callback_redirect_response(
|
||||
&state,
|
||||
&authenticated.token_response,
|
||||
redirect_to,
|
||||
&authenticated.id_token,
|
||||
)
|
||||
}
|
||||
|
||||
/// Authenticate via LDAP directory.
|
||||
///
|
||||
/// POST /auth/ldap/login
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/ldap/login",
|
||||
tag = "auth",
|
||||
request_body = LdapLoginRequest,
|
||||
responses(
|
||||
(status = 200, description = "Successfully authenticated via LDAP", body = inline(ApiResponse<TokenResponse>)),
|
||||
(status = 401, description = "Invalid LDAP credentials"),
|
||||
(status = 501, description = "LDAP not configured")
|
||||
)
|
||||
)]
|
||||
pub async fn ldap_login(
|
||||
State(state): State<SharedState>,
|
||||
Json(payload): Json<LdapLoginRequest>,
|
||||
) -> Result<Json<ApiResponse<TokenResponse>>, ApiError> {
|
||||
payload
|
||||
.validate()
|
||||
.map_err(|e| ApiError::ValidationError(format!("Invalid LDAP login request: {e}")))?;
|
||||
|
||||
let authenticated =
|
||||
crate::auth::ldap::authenticate(&state, &payload.login, &payload.password).await?;
|
||||
|
||||
Ok(Json(ApiResponse::new(authenticated.token_response)))
|
||||
}
|
||||
|
||||
/// Logout the current browser session and optionally redirect through the provider logout flow.
|
||||
pub async fn logout(
|
||||
State(state): State<SharedState>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Response, ApiError> {
|
||||
let oidc_enabled = state
|
||||
.config
|
||||
.security
|
||||
.oidc
|
||||
.as_ref()
|
||||
.is_some_and(|oidc| oidc.enabled);
|
||||
|
||||
let response = if oidc_enabled {
|
||||
let logout_redirect = build_logout_redirect(&state, &headers).await?;
|
||||
let mut response = Redirect::temporary(&logout_redirect.redirect_url).into_response();
|
||||
apply_cookies_to_headers(response.headers_mut(), &logout_redirect.cookies)?;
|
||||
response
|
||||
} else {
|
||||
let mut response = Redirect::temporary("/login").into_response();
|
||||
apply_cookies_to_headers(
|
||||
response.headers_mut(),
|
||||
&crate::auth::oidc::clear_auth_cookies(&state),
|
||||
)?;
|
||||
response
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Change password endpoint
|
||||
///
|
||||
/// POST /auth/change-password
|
||||
@@ -359,6 +569,7 @@ pub async fn change_password(
|
||||
display_name: None,
|
||||
password_hash: Some(new_password_hash),
|
||||
attributes: None,
|
||||
frozen: None,
|
||||
};
|
||||
|
||||
IdentityRepository::update(&state.db, identity_id, update_input).await?;
|
||||
|
||||
@@ -82,6 +82,17 @@ pub async fn create_event(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(payload): Json<CreateEventRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Only sensor and execution tokens may create events directly.
|
||||
// User sessions must go through the webhook receiver instead.
|
||||
use crate::auth::jwt::TokenType;
|
||||
if user.0.claims.token_type == TokenType::Access {
|
||||
return Err(ApiError::Forbidden(
|
||||
"Events may only be created by sensor services. To fire an event as a user, \
|
||||
enable webhooks on the trigger and POST to its webhook URL."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate request
|
||||
payload
|
||||
.validate()
|
||||
@@ -128,7 +139,6 @@ pub async fn create_event(
|
||||
};
|
||||
|
||||
// Determine source (sensor) from authenticated user if it's a sensor token
|
||||
use crate::auth::jwt::TokenType;
|
||||
let (source_id, source_ref) = match user.0.claims.token_type {
|
||||
TokenType::Sensor => {
|
||||
// Extract sensor reference from login
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::HeaderMap,
|
||||
http::StatusCode,
|
||||
response::{
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
@@ -13,6 +14,7 @@ use axum::{
|
||||
use chrono::Utc;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use attune_common::models::enums::ExecutionStatus;
|
||||
@@ -32,7 +34,10 @@ use attune_common::workflow::{CancellationPolicy, WorkflowDefinition};
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
auth::{
|
||||
jwt::{validate_token, Claims, JwtConfig, TokenType},
|
||||
middleware::{AuthenticatedUser, RequireAuth},
|
||||
},
|
||||
authz::{AuthorizationCheck, AuthorizationService},
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
@@ -46,6 +51,9 @@ use crate::{
|
||||
};
|
||||
use attune_common::rbac::{Action, AuthorizationContext, Resource};
|
||||
|
||||
const LOG_STREAM_POLL_INTERVAL: Duration = Duration::from_millis(250);
|
||||
const LOG_STREAM_READ_CHUNK_SIZE: usize = 64 * 1024;
|
||||
|
||||
/// Create a new execution (manual execution)
|
||||
///
|
||||
/// This endpoint allows directly executing an action without a trigger or rule.
|
||||
@@ -93,19 +101,6 @@ pub async fn create_execution(
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut execution_ctx = AuthorizationContext::new(identity_id);
|
||||
execution_ctx.pack_ref = Some(action.pack_ref.clone());
|
||||
authz
|
||||
.authorize(
|
||||
&user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Executions,
|
||||
action: Action::Create,
|
||||
context: execution_ctx,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Create execution input
|
||||
@@ -938,6 +933,398 @@ pub async fn stream_execution_updates(
|
||||
Ok(Sse::new(filtered_stream).keep_alive(KeepAlive::default()))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct StreamExecutionLogParams {
|
||||
pub token: Option<String>,
|
||||
pub offset: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ExecutionLogStream {
|
||||
Stdout,
|
||||
Stderr,
|
||||
}
|
||||
|
||||
impl ExecutionLogStream {
|
||||
fn parse(name: &str) -> Result<Self, ApiError> {
|
||||
match name {
|
||||
"stdout" => Ok(Self::Stdout),
|
||||
"stderr" => Ok(Self::Stderr),
|
||||
_ => Err(ApiError::BadRequest(format!(
|
||||
"Unsupported log stream '{}'. Expected 'stdout' or 'stderr'.",
|
||||
name
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn file_name(self) -> &'static str {
|
||||
match self {
|
||||
Self::Stdout => "stdout.log",
|
||||
Self::Stderr => "stderr.log",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum ExecutionLogTailState {
|
||||
WaitingForFile {
|
||||
full_path: std::path::PathBuf,
|
||||
execution_id: i64,
|
||||
},
|
||||
SendInitial {
|
||||
full_path: std::path::PathBuf,
|
||||
execution_id: i64,
|
||||
offset: u64,
|
||||
pending_utf8: Vec<u8>,
|
||||
},
|
||||
Tail {
|
||||
full_path: std::path::PathBuf,
|
||||
execution_id: i64,
|
||||
offset: u64,
|
||||
idle_polls: u32,
|
||||
pending_utf8: Vec<u8>,
|
||||
},
|
||||
Finished,
|
||||
}
|
||||
|
||||
/// Stream stdout/stderr for an execution as SSE.
|
||||
///
|
||||
/// This tails the worker's live log files directly from the shared artifacts
|
||||
/// volume. The file may not exist yet when the worker has not emitted any
|
||||
/// output, so the stream waits briefly for it to appear.
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/executions/{id}/logs/{stream}/stream",
|
||||
tag = "executions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Execution ID"),
|
||||
("stream" = String, Path, description = "Log stream name: stdout or stderr"),
|
||||
("token" = String, Query, description = "JWT access token for authentication"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "SSE stream of execution log content", content_type = "text/event-stream"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "Execution not found"),
|
||||
),
|
||||
)]
|
||||
pub async fn stream_execution_log(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path((id, stream_name)): Path<(i64, String)>,
|
||||
Query(params): Query<StreamExecutionLogParams>,
|
||||
user: Result<RequireAuth, crate::auth::middleware::AuthError>,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
|
||||
let authenticated_user =
|
||||
authenticate_execution_log_stream_user(&state, &headers, user, params.token.as_deref())?;
|
||||
validate_execution_log_stream_user(&authenticated_user, id)?;
|
||||
|
||||
let execution = ExecutionRepository::find_by_id(&state.db, id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Execution with ID {} not found", id)))?;
|
||||
authorize_execution_log_stream(&state, &authenticated_user, &execution).await?;
|
||||
|
||||
let stream_name = ExecutionLogStream::parse(&stream_name)?;
|
||||
let full_path = std::path::PathBuf::from(&state.config.artifacts_dir)
|
||||
.join(format!("execution_{}", id))
|
||||
.join(stream_name.file_name());
|
||||
let db = state.db.clone();
|
||||
|
||||
let initial_state = ExecutionLogTailState::WaitingForFile {
|
||||
full_path,
|
||||
execution_id: id,
|
||||
};
|
||||
let start_offset = params.offset.unwrap_or(0);
|
||||
|
||||
let stream = futures::stream::unfold(initial_state, move |state| {
|
||||
let db = db.clone();
|
||||
async move {
|
||||
match state {
|
||||
ExecutionLogTailState::Finished => None,
|
||||
ExecutionLogTailState::WaitingForFile {
|
||||
full_path,
|
||||
execution_id,
|
||||
} => {
|
||||
if full_path.exists() {
|
||||
Some((
|
||||
Ok(Event::default().event("waiting").data("Log file found")),
|
||||
ExecutionLogTailState::SendInitial {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset: start_offset,
|
||||
pending_utf8: Vec::new(),
|
||||
},
|
||||
))
|
||||
} else if execution_log_execution_terminal(&db, execution_id).await {
|
||||
Some((
|
||||
Ok(Event::default().event("done").data("")),
|
||||
ExecutionLogTailState::Finished,
|
||||
))
|
||||
} else {
|
||||
tokio::time::sleep(LOG_STREAM_POLL_INTERVAL).await;
|
||||
Some((
|
||||
Ok(Event::default()
|
||||
.event("waiting")
|
||||
.data("Waiting for log output")),
|
||||
ExecutionLogTailState::WaitingForFile {
|
||||
full_path,
|
||||
execution_id,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
ExecutionLogTailState::SendInitial {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
pending_utf8,
|
||||
} => {
|
||||
let pending_utf8_on_empty = pending_utf8.clone();
|
||||
match read_log_chunk(
|
||||
&full_path,
|
||||
offset,
|
||||
LOG_STREAM_READ_CHUNK_SIZE,
|
||||
pending_utf8,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some((content, new_offset, pending_utf8)) => Some((
|
||||
Ok(Event::default()
|
||||
.id(new_offset.to_string())
|
||||
.event("content")
|
||||
.data(content)),
|
||||
ExecutionLogTailState::SendInitial {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset: new_offset,
|
||||
pending_utf8,
|
||||
},
|
||||
)),
|
||||
None => Some((
|
||||
Ok(Event::default().comment("initial-catchup-complete")),
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
idle_polls: 0,
|
||||
pending_utf8: pending_utf8_on_empty,
|
||||
},
|
||||
)),
|
||||
}
|
||||
}
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
idle_polls,
|
||||
pending_utf8,
|
||||
} => {
|
||||
let pending_utf8_on_empty = pending_utf8.clone();
|
||||
match read_log_chunk(
|
||||
&full_path,
|
||||
offset,
|
||||
LOG_STREAM_READ_CHUNK_SIZE,
|
||||
pending_utf8,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some((append, new_offset, pending_utf8)) => Some((
|
||||
Ok(Event::default()
|
||||
.id(new_offset.to_string())
|
||||
.event("append")
|
||||
.data(append)),
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset: new_offset,
|
||||
idle_polls: 0,
|
||||
pending_utf8,
|
||||
},
|
||||
)),
|
||||
None => {
|
||||
let terminal =
|
||||
execution_log_execution_terminal(&db, execution_id).await;
|
||||
if terminal && idle_polls >= 2 {
|
||||
Some((
|
||||
Ok(Event::default().event("done").data("Execution complete")),
|
||||
ExecutionLogTailState::Finished,
|
||||
))
|
||||
} else {
|
||||
tokio::time::sleep(LOG_STREAM_POLL_INTERVAL).await;
|
||||
Some((
|
||||
Ok(Event::default()
|
||||
.event("waiting")
|
||||
.data("Waiting for log output")),
|
||||
ExecutionLogTailState::Tail {
|
||||
full_path,
|
||||
execution_id,
|
||||
offset,
|
||||
idle_polls: idle_polls + 1,
|
||||
pending_utf8: pending_utf8_on_empty,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
|
||||
}
|
||||
|
||||
async fn read_log_chunk(
|
||||
path: &std::path::Path,
|
||||
offset: u64,
|
||||
max_bytes: usize,
|
||||
mut pending_utf8: Vec<u8>,
|
||||
) -> Option<(String, u64, Vec<u8>)> {
|
||||
use tokio::io::{AsyncReadExt, AsyncSeekExt};
|
||||
|
||||
let mut file = tokio::fs::File::open(path).await.ok()?;
|
||||
let metadata = file.metadata().await.ok()?;
|
||||
if metadata.len() <= offset {
|
||||
return None;
|
||||
}
|
||||
|
||||
file.seek(std::io::SeekFrom::Start(offset)).await.ok()?;
|
||||
let bytes_to_read = ((metadata.len() - offset) as usize).min(max_bytes);
|
||||
let mut buf = vec![0u8; bytes_to_read];
|
||||
let read = file.read(&mut buf).await.ok()?;
|
||||
buf.truncate(read);
|
||||
if buf.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
pending_utf8.extend_from_slice(&buf);
|
||||
let (content, pending_utf8) = decode_utf8_chunk(pending_utf8);
|
||||
|
||||
Some((content, offset + read as u64, pending_utf8))
|
||||
}
|
||||
|
||||
async fn execution_log_execution_terminal(db: &sqlx::PgPool, execution_id: i64) -> bool {
|
||||
match ExecutionRepository::find_by_id(db, execution_id).await {
|
||||
Ok(Some(execution)) => matches!(
|
||||
execution.status,
|
||||
ExecutionStatus::Completed
|
||||
| ExecutionStatus::Failed
|
||||
| ExecutionStatus::Cancelled
|
||||
| ExecutionStatus::Timeout
|
||||
| ExecutionStatus::Abandoned
|
||||
),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_utf8_chunk(mut bytes: Vec<u8>) -> (String, Vec<u8>) {
|
||||
match std::str::from_utf8(&bytes) {
|
||||
Ok(valid) => (valid.to_string(), Vec::new()),
|
||||
Err(err) if err.error_len().is_none() => {
|
||||
let pending = bytes.split_off(err.valid_up_to());
|
||||
(String::from_utf8_lossy(&bytes).into_owned(), pending)
|
||||
}
|
||||
Err(_) => (String::from_utf8_lossy(&bytes).into_owned(), Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize_execution_log_stream(
|
||||
state: &Arc<AppState>,
|
||||
user: &AuthenticatedUser,
|
||||
execution: &attune_common::models::Execution,
|
||||
) -> Result<(), ApiError> {
|
||||
if user.claims.token_type != TokenType::Access {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.target_id = Some(execution.id);
|
||||
ctx.target_ref = Some(execution.action_ref.clone());
|
||||
|
||||
authz
|
||||
.authorize(
|
||||
user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Executions,
|
||||
action: Action::Read,
|
||||
context: ctx,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn authenticate_execution_log_stream_user(
|
||||
state: &Arc<AppState>,
|
||||
headers: &HeaderMap,
|
||||
user: Result<RequireAuth, crate::auth::middleware::AuthError>,
|
||||
query_token: Option<&str>,
|
||||
) -> Result<AuthenticatedUser, ApiError> {
|
||||
match user {
|
||||
Ok(RequireAuth(user)) => Ok(user),
|
||||
Err(_) => {
|
||||
if let Some(user) = crate::auth::oidc::cookie_authenticated_user(headers, state)? {
|
||||
return Ok(user);
|
||||
}
|
||||
|
||||
let token = query_token.ok_or(ApiError::Unauthorized(
|
||||
"Missing authentication token".to_string(),
|
||||
))?;
|
||||
authenticate_execution_log_stream_query_token(token, &state.jwt_config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate_execution_log_stream_query_token(
|
||||
token: &str,
|
||||
jwt_config: &JwtConfig,
|
||||
) -> Result<AuthenticatedUser, ApiError> {
|
||||
let claims = validate_token(token, jwt_config)
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
|
||||
|
||||
Ok(AuthenticatedUser { claims })
|
||||
}
|
||||
|
||||
fn validate_execution_log_stream_user(
|
||||
user: &AuthenticatedUser,
|
||||
execution_id: i64,
|
||||
) -> Result<(), ApiError> {
|
||||
let claims = &user.claims;
|
||||
|
||||
match claims.token_type {
|
||||
TokenType::Access => Ok(()),
|
||||
TokenType::Execution => validate_execution_token_scope(claims, execution_id),
|
||||
TokenType::Sensor | TokenType::Refresh => Err(ApiError::Unauthorized(
|
||||
"Invalid authentication token".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_execution_token_scope(claims: &Claims, execution_id: i64) -> Result<(), ApiError> {
|
||||
if claims.scope.as_deref() != Some("execution") {
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Invalid authentication token".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let token_execution_id = claims
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("execution_id"))
|
||||
.and_then(|value| value.as_i64())
|
||||
.ok_or_else(|| ApiError::Unauthorized("Invalid authentication token".to_string()))?;
|
||||
|
||||
if token_execution_id != execution_id {
|
||||
return Err(ApiError::Forbidden(format!(
|
||||
"Execution token is not valid for execution {}",
|
||||
execution_id
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct StreamExecutionParams {
|
||||
pub execution_id: Option<i64>,
|
||||
@@ -950,6 +1337,10 @@ pub fn routes() -> Router<Arc<AppState>> {
|
||||
.route("/executions/execute", axum::routing::post(create_execution))
|
||||
.route("/executions/stats", get(get_execution_stats))
|
||||
.route("/executions/stream", get(stream_execution_updates))
|
||||
.route(
|
||||
"/executions/{id}/logs/{stream}/stream",
|
||||
get(stream_execution_log),
|
||||
)
|
||||
.route("/executions/{id}", get(get_execution))
|
||||
.route(
|
||||
"/executions/{id}/cancel",
|
||||
@@ -968,10 +1359,26 @@ pub fn routes() -> Router<Arc<AppState>> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use attune_common::auth::jwt::generate_execution_token;
|
||||
|
||||
#[test]
|
||||
fn test_execution_routes_structure() {
|
||||
// Just verify the router can be constructed
|
||||
let _router = routes();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execution_token_scope_must_match_requested_execution() {
|
||||
let jwt_config = JwtConfig {
|
||||
secret: "test_secret_key_for_testing".to_string(),
|
||||
access_token_expiration: 3600,
|
||||
refresh_token_expiration: 604800,
|
||||
};
|
||||
|
||||
let token = generate_execution_token(42, 123, "core.echo", &jwt_config, None).unwrap();
|
||||
|
||||
let user = authenticate_execution_log_stream_query_token(&token, &jwt_config).unwrap();
|
||||
let err = validate_execution_log_stream_user(&user, 456).unwrap_err();
|
||||
assert!(matches!(err, ApiError::Forbidden(_)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,12 +120,16 @@ pub async fn get_key(
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
|
||||
|
||||
if user.0.claims.token_type == TokenType::Access {
|
||||
// For encrypted keys, track whether this caller is permitted to see the value.
|
||||
// Non-Access tokens (sensor, execution) always get full access.
|
||||
let can_decrypt = if user.0.claims.token_type == TokenType::Access {
|
||||
let identity_id = user
|
||||
.0
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
|
||||
// Basic read check — hide behind 404 to prevent enumeration.
|
||||
authz
|
||||
.authorize(
|
||||
&user.0,
|
||||
@@ -136,28 +140,55 @@ pub async fn get_key(
|
||||
},
|
||||
)
|
||||
.await
|
||||
// Hide unauthorized records behind 404 to reduce enumeration leakage.
|
||||
.map_err(|_| ApiError::NotFound(format!("Key '{}' not found", key_ref)))?;
|
||||
}
|
||||
|
||||
// Decrypt value if encrypted
|
||||
// For encrypted keys, separately check Keys::Decrypt.
|
||||
// Failing this is not an error — we just return the value as null.
|
||||
if key.encrypted {
|
||||
authz
|
||||
.authorize(
|
||||
&user.0,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Keys,
|
||||
action: Action::Decrypt,
|
||||
context: key_authorization_context(identity_id, &key),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.is_ok()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
// Decrypt value if encrypted and caller has permission.
|
||||
// If they lack Keys::Decrypt, return null rather than the ciphertext.
|
||||
if key.encrypted {
|
||||
let encryption_key = state
|
||||
.config
|
||||
.security
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
ApiError::InternalServerError("Encryption key not configured on server".to_string())
|
||||
})?;
|
||||
if can_decrypt {
|
||||
let encryption_key =
|
||||
state
|
||||
.config
|
||||
.security
|
||||
.encryption_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
ApiError::InternalServerError(
|
||||
"Encryption key not configured on server".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let decrypted_value = attune_common::crypto::decrypt_json(&key.value, encryption_key)
|
||||
.map_err(|e| {
|
||||
let decrypted_value = attune_common::crypto::decrypt_json(&key.value, encryption_key)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to decrypt key '{}': {}", key_ref, e);
|
||||
ApiError::InternalServerError(format!("Failed to decrypt key: {}", e))
|
||||
})?;
|
||||
|
||||
key.value = decrypted_value;
|
||||
key.value = decrypted_value;
|
||||
} else {
|
||||
key.value = serde_json::Value::Null;
|
||||
}
|
||||
}
|
||||
|
||||
let response = ApiResponse::new(KeyResponse::from(key));
|
||||
@@ -195,6 +226,7 @@ pub async fn create_key(
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.owner_identity_id = request.owner_identity;
|
||||
ctx.owner_type = Some(request.owner_type);
|
||||
ctx.owner_ref = requested_key_owner_ref(&request);
|
||||
ctx.encrypted = Some(request.encrypted);
|
||||
ctx.target_ref = Some(request.r#ref.clone());
|
||||
|
||||
@@ -541,6 +573,38 @@ fn key_authorization_context(identity_id: i64, key: &Key) -> AuthorizationContex
|
||||
ctx.target_ref = Some(key.r#ref.clone());
|
||||
ctx.owner_identity_id = key.owner_identity;
|
||||
ctx.owner_type = Some(key.owner_type);
|
||||
ctx.owner_ref = key_owner_ref(
|
||||
key.owner_type,
|
||||
key.owner.as_deref(),
|
||||
key.owner_pack_ref.as_deref(),
|
||||
key.owner_action_ref.as_deref(),
|
||||
key.owner_sensor_ref.as_deref(),
|
||||
);
|
||||
ctx.encrypted = Some(key.encrypted);
|
||||
ctx
|
||||
}
|
||||
|
||||
fn requested_key_owner_ref(request: &CreateKeyRequest) -> Option<String> {
|
||||
key_owner_ref(
|
||||
request.owner_type,
|
||||
request.owner.as_deref(),
|
||||
request.owner_pack_ref.as_deref(),
|
||||
request.owner_action_ref.as_deref(),
|
||||
request.owner_sensor_ref.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
fn key_owner_ref(
|
||||
owner_type: OwnerType,
|
||||
owner: Option<&str>,
|
||||
owner_pack_ref: Option<&str>,
|
||||
owner_action_ref: Option<&str>,
|
||||
owner_sensor_ref: Option<&str>,
|
||||
) -> Option<String> {
|
||||
match owner_type {
|
||||
OwnerType::Pack => owner_pack_ref.map(str::to_string),
|
||||
OwnerType::Action => owner_action_ref.map(str::to_string),
|
||||
OwnerType::Sensor => owner_sensor_ref.map(str::to_string),
|
||||
_ => owner.map(str::to_string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! API route modules
|
||||
|
||||
pub mod actions;
|
||||
pub mod agent;
|
||||
pub mod analytics;
|
||||
pub mod artifacts;
|
||||
pub mod auth;
|
||||
@@ -13,11 +14,13 @@ pub mod keys;
|
||||
pub mod packs;
|
||||
pub mod permissions;
|
||||
pub mod rules;
|
||||
pub mod runtimes;
|
||||
pub mod triggers;
|
||||
pub mod webhooks;
|
||||
pub mod workflows;
|
||||
|
||||
pub use actions::routes as action_routes;
|
||||
pub use agent::routes as agent_routes;
|
||||
pub use analytics::routes as analytics_routes;
|
||||
pub use artifacts::routes as artifact_routes;
|
||||
pub use auth::routes as auth_routes;
|
||||
@@ -30,6 +33,7 @@ pub use keys::routes as key_routes;
|
||||
pub use packs::routes as pack_routes;
|
||||
pub use permissions::routes as permission_routes;
|
||||
pub use rules::routes as rule_routes;
|
||||
pub use runtimes::routes as runtime_routes;
|
||||
pub use triggers::routes as trigger_routes;
|
||||
pub use webhooks::routes as webhook_routes;
|
||||
pub use workflows::routes as workflow_routes;
|
||||
|
||||
@@ -16,7 +16,8 @@ use attune_common::mq::{MessageEnvelope, MessageType, PackRegisteredPayload};
|
||||
use attune_common::rbac::{Action, AuthorizationContext, Resource};
|
||||
use attune_common::repositories::{
|
||||
pack::{CreatePackInput, UpdatePackInput},
|
||||
Create, Delete, FindById, FindByRef, PackRepository, PackTestRepository, Pagination, Update,
|
||||
Create, Delete, FindById, FindByRef, PackRepository, PackTestRepository, Pagination, Patch,
|
||||
Update,
|
||||
};
|
||||
use attune_common::workflow::{PackWorkflowService, PackWorkflowServiceConfig};
|
||||
|
||||
@@ -28,9 +29,10 @@ use crate::{
|
||||
pack::{
|
||||
BuildPackEnvsRequest, BuildPackEnvsResponse, CreatePackRequest, DownloadPacksRequest,
|
||||
DownloadPacksResponse, GetPackDependenciesRequest, GetPackDependenciesResponse,
|
||||
InstallPackRequest, PackInstallResponse, PackResponse, PackSummary,
|
||||
PackWorkflowSyncResponse, PackWorkflowValidationResponse, RegisterPackRequest,
|
||||
RegisterPacksRequest, RegisterPacksResponse, UpdatePackRequest, WorkflowSyncResult,
|
||||
InstallPackRequest, PackDescriptionPatch, PackInstallResponse, PackResponse,
|
||||
PackSummary, PackWorkflowSyncResponse, PackWorkflowValidationResponse,
|
||||
RegisterPackRequest, RegisterPacksRequest, RegisterPacksResponse, UpdatePackRequest,
|
||||
WorkflowSyncResult,
|
||||
},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
@@ -258,7 +260,10 @@ pub async fn update_pack(
|
||||
// Create update input
|
||||
let update_input = UpdatePackInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
description: request.description.map(|patch| match patch {
|
||||
PackDescriptionPatch::Set(value) => Patch::Set(value),
|
||||
PackDescriptionPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
version: request.version,
|
||||
conf_schema: request.conf_schema,
|
||||
config: request.config,
|
||||
@@ -876,7 +881,10 @@ async fn register_pack_internal(
|
||||
// Update existing pack in place — preserves pack ID and all child entity IDs
|
||||
let update_input = UpdatePackInput {
|
||||
label: Some(label),
|
||||
description: Some(description.unwrap_or_default()),
|
||||
description: Some(match description {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
version: Some(version.clone()),
|
||||
conf_schema: Some(conf_schema),
|
||||
config: None, // preserve user-set config
|
||||
|
||||
@@ -9,12 +9,14 @@ use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::{
|
||||
models::identity::{Identity, PermissionSet},
|
||||
models::identity::{Identity, IdentityRoleAssignment},
|
||||
rbac::{Action, AuthorizationContext, Resource},
|
||||
repositories::{
|
||||
identity::{
|
||||
CreateIdentityInput, CreatePermissionAssignmentInput, IdentityRepository,
|
||||
PermissionAssignmentRepository, PermissionSetRepository, UpdateIdentityInput,
|
||||
CreateIdentityInput, CreateIdentityRoleAssignmentInput,
|
||||
CreatePermissionAssignmentInput, CreatePermissionSetRoleAssignmentInput,
|
||||
IdentityRepository, IdentityRoleAssignmentRepository, PermissionAssignmentRepository,
|
||||
PermissionSetRepository, PermissionSetRoleAssignmentRepository, UpdateIdentityInput,
|
||||
},
|
||||
Create, Delete, FindById, FindByRef, List, Update,
|
||||
},
|
||||
@@ -26,9 +28,12 @@ use crate::{
|
||||
authz::{AuthorizationCheck, AuthorizationService},
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
ApiResponse, CreateIdentityRequest, CreatePermissionAssignmentRequest, IdentityResponse,
|
||||
IdentitySummary, PermissionAssignmentResponse, PermissionSetQueryParams,
|
||||
PermissionSetSummary, SuccessResponse, UpdateIdentityRequest,
|
||||
ApiResponse, CreateIdentityRequest, CreateIdentityRoleAssignmentRequest,
|
||||
CreatePermissionAssignmentRequest, CreatePermissionSetRoleAssignmentRequest,
|
||||
IdentityResponse, IdentityRoleAssignmentResponse, IdentitySummary,
|
||||
PermissionAssignmentResponse, PermissionSetQueryParams,
|
||||
PermissionSetRoleAssignmentResponse, PermissionSetSummary, SuccessResponse,
|
||||
UpdateIdentityRequest,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
@@ -58,16 +63,22 @@ pub async fn list_identities(
|
||||
let page_items = if start >= identities.len() {
|
||||
Vec::new()
|
||||
} else {
|
||||
identities[start..end]
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(IdentitySummary::from)
|
||||
.collect()
|
||||
identities[start..end].to_vec()
|
||||
};
|
||||
|
||||
let mut summaries = Vec::with_capacity(page_items.len());
|
||||
for identity in page_items {
|
||||
let role_assignments =
|
||||
IdentityRoleAssignmentRepository::find_by_identity(&state.db, identity.id).await?;
|
||||
let roles = role_assignments.into_iter().map(|ra| ra.role).collect();
|
||||
let mut summary = IdentitySummary::from(identity);
|
||||
summary.roles = roles;
|
||||
summaries.push(summary);
|
||||
}
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(PaginatedResponse::new(page_items, &query, total)),
|
||||
Json(PaginatedResponse::new(summaries, &query, total)),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -94,10 +105,42 @@ pub async fn get_identity(
|
||||
let identity = IdentityRepository::find_by_id(&state.db, identity_id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Identity '{}' not found", identity_id)))?;
|
||||
let roles = IdentityRoleAssignmentRepository::find_by_identity(&state.db, identity_id).await?;
|
||||
let assignments =
|
||||
PermissionAssignmentRepository::find_by_identity(&state.db, identity_id).await?;
|
||||
let permission_sets = PermissionSetRepository::find_by_identity(&state.db, identity_id).await?;
|
||||
let permission_set_refs = permission_sets
|
||||
.into_iter()
|
||||
.map(|ps| (ps.id, ps.r#ref))
|
||||
.collect::<std::collections::HashMap<_, _>>();
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(IdentityResponse::from(identity))),
|
||||
Json(ApiResponse::new(IdentityResponse {
|
||||
id: identity.id,
|
||||
login: identity.login,
|
||||
display_name: identity.display_name,
|
||||
frozen: identity.frozen,
|
||||
attributes: identity.attributes,
|
||||
roles: roles
|
||||
.into_iter()
|
||||
.map(IdentityRoleAssignmentResponse::from)
|
||||
.collect(),
|
||||
direct_permissions: assignments
|
||||
.into_iter()
|
||||
.filter_map(|assignment| {
|
||||
permission_set_refs.get(&assignment.permset).cloned().map(
|
||||
|permission_set_ref| PermissionAssignmentResponse {
|
||||
id: assignment.id,
|
||||
identity_id: assignment.identity,
|
||||
permission_set_id: assignment.permset,
|
||||
permission_set_ref,
|
||||
created: assignment.created,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
})),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -180,6 +223,7 @@ pub async fn update_identity(
|
||||
display_name: request.display_name,
|
||||
password_hash,
|
||||
attributes: request.attributes,
|
||||
frozen: request.frozen,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
@@ -257,10 +301,33 @@ pub async fn list_permission_sets(
|
||||
permission_sets.retain(|ps| ps.pack_ref.as_deref() == Some(pack_ref.as_str()));
|
||||
}
|
||||
|
||||
let response: Vec<PermissionSetSummary> = permission_sets
|
||||
.into_iter()
|
||||
.map(PermissionSetSummary::from)
|
||||
.collect();
|
||||
let mut response = Vec::with_capacity(permission_sets.len());
|
||||
for permission_set in permission_sets {
|
||||
let permission_set_ref = permission_set.r#ref.clone();
|
||||
let roles = PermissionSetRoleAssignmentRepository::find_by_permission_set(
|
||||
&state.db,
|
||||
permission_set.id,
|
||||
)
|
||||
.await?;
|
||||
response.push(PermissionSetSummary {
|
||||
id: permission_set.id,
|
||||
r#ref: permission_set.r#ref,
|
||||
pack_ref: permission_set.pack_ref,
|
||||
label: permission_set.label,
|
||||
description: permission_set.description,
|
||||
grants: permission_set.grants,
|
||||
roles: roles
|
||||
.into_iter()
|
||||
.map(|assignment| PermissionSetRoleAssignmentResponse {
|
||||
id: assignment.id,
|
||||
permission_set_id: assignment.permset,
|
||||
permission_set_ref: Some(permission_set_ref.clone()),
|
||||
role: assignment.role,
|
||||
created: assignment.created,
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -412,6 +479,229 @@ pub async fn delete_permission_assignment(
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/identities/{id}/roles",
|
||||
tag = "permissions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Identity ID")
|
||||
),
|
||||
request_body = CreateIdentityRoleAssignmentRequest,
|
||||
responses(
|
||||
(status = 201, description = "Identity role assignment created", body = inline(ApiResponse<IdentityRoleAssignmentResponse>)),
|
||||
(status = 404, description = "Identity not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_identity_role_assignment(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(identity_id): Path<i64>,
|
||||
Json(request): Json<CreateIdentityRoleAssignmentRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
authorize_permissions(&state, &user, Resource::Permissions, Action::Manage).await?;
|
||||
request.validate()?;
|
||||
|
||||
IdentityRepository::find_by_id(&state.db, identity_id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Identity '{}' not found", identity_id)))?;
|
||||
|
||||
let assignment = IdentityRoleAssignmentRepository::create(
|
||||
&state.db,
|
||||
CreateIdentityRoleAssignmentInput {
|
||||
identity: identity_id,
|
||||
role: request.role,
|
||||
source: "manual".to_string(),
|
||||
managed: false,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(ApiResponse::new(IdentityRoleAssignmentResponse::from(
|
||||
assignment,
|
||||
))),
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/identities/roles/{id}",
|
||||
tag = "permissions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Identity role assignment ID")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Identity role assignment deleted", body = inline(ApiResponse<SuccessResponse>)),
|
||||
(status = 404, description = "Identity role assignment not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_identity_role_assignment(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(assignment_id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
authorize_permissions(&state, &user, Resource::Permissions, Action::Manage).await?;
|
||||
|
||||
let assignment = IdentityRoleAssignmentRepository::find_by_id(&state.db, assignment_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!(
|
||||
"Identity role assignment '{}' not found",
|
||||
assignment_id
|
||||
))
|
||||
})?;
|
||||
|
||||
if assignment.managed {
|
||||
return Err(ApiError::BadRequest(
|
||||
"Managed role assignments must be updated through the identity provider sync"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
IdentityRoleAssignmentRepository::delete(&state.db, assignment_id).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(SuccessResponse::new(
|
||||
"Identity role assignment deleted successfully",
|
||||
))),
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/permissions/sets/{id}/roles",
|
||||
tag = "permissions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Permission set ID")
|
||||
),
|
||||
request_body = CreatePermissionSetRoleAssignmentRequest,
|
||||
responses(
|
||||
(status = 201, description = "Permission set role assignment created", body = inline(ApiResponse<PermissionSetRoleAssignmentResponse>)),
|
||||
(status = 404, description = "Permission set not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_permission_set_role_assignment(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(permission_set_id): Path<i64>,
|
||||
Json(request): Json<CreatePermissionSetRoleAssignmentRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
authorize_permissions(&state, &user, Resource::Permissions, Action::Manage).await?;
|
||||
request.validate()?;
|
||||
|
||||
let permission_set = PermissionSetRepository::find_by_id(&state.db, permission_set_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!("Permission set '{}' not found", permission_set_id))
|
||||
})?;
|
||||
|
||||
let assignment = PermissionSetRoleAssignmentRepository::create(
|
||||
&state.db,
|
||||
CreatePermissionSetRoleAssignmentInput {
|
||||
permset: permission_set_id,
|
||||
role: request.role,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(ApiResponse::new(PermissionSetRoleAssignmentResponse {
|
||||
id: assignment.id,
|
||||
permission_set_id: assignment.permset,
|
||||
permission_set_ref: Some(permission_set.r#ref),
|
||||
role: assignment.role,
|
||||
created: assignment.created,
|
||||
})),
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/permissions/sets/roles/{id}",
|
||||
tag = "permissions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Permission set role assignment ID")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Permission set role assignment deleted", body = inline(ApiResponse<SuccessResponse>)),
|
||||
(status = 404, description = "Permission set role assignment not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_permission_set_role_assignment(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(assignment_id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
authorize_permissions(&state, &user, Resource::Permissions, Action::Manage).await?;
|
||||
|
||||
PermissionSetRoleAssignmentRepository::find_by_id(&state.db, assignment_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ApiError::NotFound(format!(
|
||||
"Permission set role assignment '{}' not found",
|
||||
assignment_id
|
||||
))
|
||||
})?;
|
||||
|
||||
PermissionSetRoleAssignmentRepository::delete(&state.db, assignment_id).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(SuccessResponse::new(
|
||||
"Permission set role assignment deleted successfully",
|
||||
))),
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/identities/{id}/freeze",
|
||||
tag = "permissions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Identity ID")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Identity frozen", body = inline(ApiResponse<SuccessResponse>)),
|
||||
(status = 404, description = "Identity not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn freeze_identity(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(identity_id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
set_identity_frozen(&state, &user, identity_id, true).await
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/identities/{id}/unfreeze",
|
||||
tag = "permissions",
|
||||
params(
|
||||
("id" = i64, Path, description = "Identity ID")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Identity unfrozen", body = inline(ApiResponse<SuccessResponse>)),
|
||||
(status = 404, description = "Identity not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn unfreeze_identity(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(identity_id): Path<i64>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
set_identity_frozen(&state, &user, identity_id, false).await
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/identities", get(list_identities).post(create_identity))
|
||||
@@ -421,11 +711,29 @@ pub fn routes() -> Router<Arc<AppState>> {
|
||||
.put(update_identity)
|
||||
.delete(delete_identity),
|
||||
)
|
||||
.route(
|
||||
"/identities/{id}/roles",
|
||||
post(create_identity_role_assignment),
|
||||
)
|
||||
.route(
|
||||
"/identities/{id}/permissions",
|
||||
get(list_identity_permissions),
|
||||
)
|
||||
.route("/identities/{id}/freeze", post(freeze_identity))
|
||||
.route("/identities/{id}/unfreeze", post(unfreeze_identity))
|
||||
.route(
|
||||
"/identities/roles/{id}",
|
||||
delete(delete_identity_role_assignment),
|
||||
)
|
||||
.route("/permissions/sets", get(list_permission_sets))
|
||||
.route(
|
||||
"/permissions/sets/{id}/roles",
|
||||
post(create_permission_set_role_assignment),
|
||||
)
|
||||
.route(
|
||||
"/permissions/sets/roles/{id}",
|
||||
delete(delete_permission_set_role_assignment),
|
||||
)
|
||||
.route(
|
||||
"/permissions/assignments",
|
||||
post(create_permission_assignment),
|
||||
@@ -488,20 +796,82 @@ impl From<Identity> for IdentitySummary {
|
||||
id: value.id,
|
||||
login: value.login,
|
||||
display_name: value.display_name,
|
||||
frozen: value.frozen,
|
||||
attributes: value.attributes,
|
||||
roles: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PermissionSet> for PermissionSetSummary {
|
||||
fn from(value: PermissionSet) -> Self {
|
||||
impl From<IdentityRoleAssignment> for IdentityRoleAssignmentResponse {
|
||||
fn from(value: IdentityRoleAssignment) -> Self {
|
||||
Self {
|
||||
id: value.id,
|
||||
r#ref: value.r#ref,
|
||||
pack_ref: value.pack_ref,
|
||||
label: value.label,
|
||||
description: value.description,
|
||||
grants: value.grants,
|
||||
identity_id: value.identity,
|
||||
role: value.role,
|
||||
source: value.source,
|
||||
managed: value.managed,
|
||||
created: value.created,
|
||||
updated: value.updated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Identity> for IdentityResponse {
|
||||
fn from(value: Identity) -> Self {
|
||||
Self {
|
||||
id: value.id,
|
||||
login: value.login,
|
||||
display_name: value.display_name,
|
||||
frozen: value.frozen,
|
||||
attributes: value.attributes,
|
||||
roles: Vec::new(),
|
||||
direct_permissions: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_identity_frozen(
|
||||
state: &Arc<AppState>,
|
||||
user: &crate::auth::middleware::AuthenticatedUser,
|
||||
identity_id: i64,
|
||||
frozen: bool,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
authorize_permissions(state, user, Resource::Identities, Action::Update).await?;
|
||||
|
||||
let caller_identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
if caller_identity_id == identity_id && frozen {
|
||||
return Err(ApiError::BadRequest(
|
||||
"Refusing to freeze the currently authenticated identity".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
IdentityRepository::find_by_id(&state.db, identity_id)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Identity '{}' not found", identity_id)))?;
|
||||
|
||||
IdentityRepository::update(
|
||||
&state.db,
|
||||
identity_id,
|
||||
UpdateIdentityInput {
|
||||
display_name: None,
|
||||
password_hash: None,
|
||||
attributes: None,
|
||||
frozen: Some(frozen),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let message = if frozen {
|
||||
"Identity frozen successfully"
|
||||
} else {
|
||||
"Identity unfrozen successfully"
|
||||
};
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(SuccessResponse::new(message))),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ use attune_common::repositories::{
|
||||
pack::PackRepository,
|
||||
rule::{CreateRuleInput, RuleRepository, RuleSearchFilters, UpdateRuleInput},
|
||||
trigger::TriggerRepository,
|
||||
Create, Delete, FindByRef, Update,
|
||||
Create, Delete, FindByRef, Patch, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -474,7 +474,7 @@ pub async fn update_rule(
|
||||
// Create update input
|
||||
let update_input = UpdateRuleInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
description: request.description.map(Patch::Set),
|
||||
conditions: request.conditions,
|
||||
action_params: request.action_params,
|
||||
trigger_params: request.trigger_params,
|
||||
|
||||
307
crates/api/src/routes/runtimes.rs
Normal file
307
crates/api/src/routes/runtimes.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
//! Runtime management API routes
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use validator::Validate;
|
||||
|
||||
use attune_common::repositories::{
|
||||
pack::PackRepository,
|
||||
runtime::{CreateRuntimeInput, RuntimeRepository, UpdateRuntimeInput},
|
||||
Create, Delete, FindByRef, List, Patch, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
runtime::{
|
||||
CreateRuntimeRequest, NullableJsonPatch, NullableStringPatch, RuntimeResponse,
|
||||
RuntimeSummary, UpdateRuntimeRequest,
|
||||
},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
middleware::{ApiError, ApiResult},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/runtimes",
|
||||
tag = "runtimes",
|
||||
params(PaginationParams),
|
||||
responses(
|
||||
(status = 200, description = "List of runtimes", body = PaginatedResponse<RuntimeSummary>)
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_runtimes(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let all_runtimes = RuntimeRepository::list(&state.db).await?;
|
||||
let total = all_runtimes.len() as u64;
|
||||
let rows: Vec<_> = all_runtimes
|
||||
.into_iter()
|
||||
.skip(pagination.offset() as usize)
|
||||
.take(pagination.limit() as usize)
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(
|
||||
rows.into_iter().map(RuntimeSummary::from).collect(),
|
||||
&pagination,
|
||||
total,
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/packs/{pack_ref}/runtimes",
|
||||
tag = "runtimes",
|
||||
params(
|
||||
("pack_ref" = String, Path, description = "Pack reference identifier"),
|
||||
PaginationParams
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of runtimes for a pack", body = PaginatedResponse<RuntimeSummary>),
|
||||
(status = 404, description = "Pack not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn list_runtimes_by_pack(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(pack_ref): Path<String>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let pack = PackRepository::find_by_ref(&state.db, &pack_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref)))?;
|
||||
|
||||
let all_runtimes = RuntimeRepository::find_by_pack(&state.db, pack.id).await?;
|
||||
let total = all_runtimes.len() as u64;
|
||||
let rows: Vec<_> = all_runtimes
|
||||
.into_iter()
|
||||
.skip(pagination.offset() as usize)
|
||||
.take(pagination.limit() as usize)
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(
|
||||
rows.into_iter().map(RuntimeSummary::from).collect(),
|
||||
&pagination,
|
||||
total,
|
||||
);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/runtimes/{ref}",
|
||||
tag = "runtimes",
|
||||
params(("ref" = String, Path, description = "Runtime reference identifier")),
|
||||
responses(
|
||||
(status = 200, description = "Runtime details", body = ApiResponse<RuntimeResponse>),
|
||||
(status = 404, description = "Runtime not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn get_runtime(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(runtime_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let runtime = RuntimeRepository::find_by_ref(&state.db, &runtime_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Runtime '{}' not found", runtime_ref)))?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::new(RuntimeResponse::from(runtime))),
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/runtimes",
|
||||
tag = "runtimes",
|
||||
request_body = CreateRuntimeRequest,
|
||||
responses(
|
||||
(status = 201, description = "Runtime created successfully", body = ApiResponse<RuntimeResponse>),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Pack not found"),
|
||||
(status = 409, description = "Runtime with same ref already exists")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn create_runtime(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Json(request): Json<CreateRuntimeRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
request.validate()?;
|
||||
|
||||
if RuntimeRepository::find_by_ref(&state.db, &request.r#ref)
|
||||
.await?
|
||||
.is_some()
|
||||
{
|
||||
return Err(ApiError::Conflict(format!(
|
||||
"Runtime with ref '{}' already exists",
|
||||
request.r#ref
|
||||
)));
|
||||
}
|
||||
|
||||
let (pack_id, pack_ref) = if let Some(ref pack_ref_str) = request.pack_ref {
|
||||
let pack = PackRepository::find_by_ref(&state.db, pack_ref_str)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Pack '{}' not found", pack_ref_str)))?;
|
||||
(Some(pack.id), Some(pack.r#ref))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let runtime = RuntimeRepository::create(
|
||||
&state.db,
|
||||
CreateRuntimeInput {
|
||||
r#ref: request.r#ref,
|
||||
pack: pack_id,
|
||||
pack_ref,
|
||||
description: request.description,
|
||||
name: request.name,
|
||||
aliases: vec![],
|
||||
distributions: request.distributions,
|
||||
installation: request.installation,
|
||||
execution_config: request.execution_config,
|
||||
auto_detected: false,
|
||||
detection_config: serde_json::json!({}),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(ApiResponse::with_message(
|
||||
RuntimeResponse::from(runtime),
|
||||
"Runtime created successfully",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/runtimes/{ref}",
|
||||
tag = "runtimes",
|
||||
params(("ref" = String, Path, description = "Runtime reference identifier")),
|
||||
request_body = UpdateRuntimeRequest,
|
||||
responses(
|
||||
(status = 200, description = "Runtime updated successfully", body = ApiResponse<RuntimeResponse>),
|
||||
(status = 400, description = "Validation error"),
|
||||
(status = 404, description = "Runtime not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn update_runtime(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(runtime_ref): Path<String>,
|
||||
Json(request): Json<UpdateRuntimeRequest>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
request.validate()?;
|
||||
|
||||
let existing_runtime = RuntimeRepository::find_by_ref(&state.db, &runtime_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Runtime '{}' not found", runtime_ref)))?;
|
||||
|
||||
let runtime = RuntimeRepository::update(
|
||||
&state.db,
|
||||
existing_runtime.id,
|
||||
UpdateRuntimeInput {
|
||||
description: request.description.map(|patch| match patch {
|
||||
NullableStringPatch::Set(value) => Patch::Set(value),
|
||||
NullableStringPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
name: request.name,
|
||||
distributions: request.distributions,
|
||||
installation: request.installation.map(|patch| match patch {
|
||||
NullableJsonPatch::Set(value) => Patch::Set(value),
|
||||
NullableJsonPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
execution_config: request.execution_config,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(ApiResponse::with_message(
|
||||
RuntimeResponse::from(runtime),
|
||||
"Runtime updated successfully",
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/runtimes/{ref}",
|
||||
tag = "runtimes",
|
||||
params(("ref" = String, Path, description = "Runtime reference identifier")),
|
||||
responses(
|
||||
(status = 200, description = "Runtime deleted successfully", body = SuccessResponse),
|
||||
(status = 404, description = "Runtime not found")
|
||||
),
|
||||
security(("bearer_auth" = []))
|
||||
)]
|
||||
pub async fn delete_runtime(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Path(runtime_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let runtime = RuntimeRepository::find_by_ref(&state.db, &runtime_ref)
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Runtime '{}' not found", runtime_ref)))?;
|
||||
|
||||
let deleted = RuntimeRepository::delete(&state.db, runtime.id).await?;
|
||||
if !deleted {
|
||||
return Err(ApiError::NotFound(format!(
|
||||
"Runtime '{}' not found",
|
||||
runtime_ref
|
||||
)));
|
||||
}
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(SuccessResponse::new(format!(
|
||||
"Runtime '{}' deleted successfully",
|
||||
runtime_ref
|
||||
))),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/runtimes", get(list_runtimes).post(create_runtime))
|
||||
.route(
|
||||
"/runtimes/{ref}",
|
||||
get(get_runtime).put(update_runtime).delete(delete_runtime),
|
||||
)
|
||||
.route("/packs/{pack_ref}/runtimes", get(list_runtimes_by_pack))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_runtime_routes_structure() {
|
||||
let _router = routes();
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,7 @@ use attune_common::repositories::{
|
||||
CreateSensorInput, CreateTriggerInput, SensorRepository, SensorSearchFilters,
|
||||
TriggerRepository, TriggerSearchFilters, UpdateSensorInput, UpdateTriggerInput,
|
||||
},
|
||||
Create, Delete, FindByRef, Update,
|
||||
Create, Delete, FindByRef, Patch, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -25,8 +25,9 @@ use crate::{
|
||||
dto::{
|
||||
common::{PaginatedResponse, PaginationParams},
|
||||
trigger::{
|
||||
CreateSensorRequest, CreateTriggerRequest, SensorResponse, SensorSummary,
|
||||
TriggerResponse, TriggerSummary, UpdateSensorRequest, UpdateTriggerRequest,
|
||||
CreateSensorRequest, CreateTriggerRequest, SensorJsonPatch, SensorResponse,
|
||||
SensorSummary, TriggerJsonPatch, TriggerResponse, TriggerStringPatch, TriggerSummary,
|
||||
UpdateSensorRequest, UpdateTriggerRequest,
|
||||
},
|
||||
ApiResponse, SuccessResponse,
|
||||
},
|
||||
@@ -274,10 +275,19 @@ pub async fn update_trigger(
|
||||
// Create update input
|
||||
let update_input = UpdateTriggerInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
description: request.description.map(|patch| match patch {
|
||||
TriggerStringPatch::Set(value) => Patch::Set(value),
|
||||
TriggerStringPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
enabled: request.enabled,
|
||||
param_schema: request.param_schema,
|
||||
out_schema: request.out_schema,
|
||||
param_schema: request.param_schema.map(|patch| match patch {
|
||||
TriggerJsonPatch::Set(value) => Patch::Set(value),
|
||||
TriggerJsonPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
out_schema: request.out_schema.map(|patch| match patch {
|
||||
TriggerJsonPatch::Set(value) => Patch::Set(value),
|
||||
TriggerJsonPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
};
|
||||
|
||||
let trigger = TriggerRepository::update(&state.db, existing_trigger.id, update_input).await?;
|
||||
@@ -714,7 +724,7 @@ pub async fn update_sensor(
|
||||
// Create update input
|
||||
let update_input = UpdateSensorInput {
|
||||
label: request.label,
|
||||
description: request.description,
|
||||
description: request.description.map(Patch::Set),
|
||||
entrypoint: request.entrypoint,
|
||||
runtime: None,
|
||||
runtime_ref: None,
|
||||
@@ -722,7 +732,10 @@ pub async fn update_sensor(
|
||||
trigger: None,
|
||||
trigger_ref: None,
|
||||
enabled: request.enabled,
|
||||
param_schema: request.param_schema,
|
||||
param_schema: request.param_schema.map(|patch| match patch {
|
||||
SensorJsonPatch::Set(value) => Patch::Set(value),
|
||||
SensorJsonPatch::Clear => Patch::Clear,
|
||||
}),
|
||||
config: None,
|
||||
};
|
||||
|
||||
|
||||
@@ -20,8 +20,11 @@ use attune_common::{
|
||||
},
|
||||
};
|
||||
|
||||
use attune_common::rbac::{Action, AuthorizationContext, Resource};
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
authz::{AuthorizationCheck, AuthorizationService},
|
||||
dto::{
|
||||
trigger::TriggerResponse,
|
||||
webhook::{WebhookReceiverRequest, WebhookReceiverResponse},
|
||||
@@ -170,7 +173,7 @@ fn get_webhook_config_array(
|
||||
)]
|
||||
pub async fn enable_webhook(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// First, find the trigger by ref to get its ID
|
||||
@@ -179,6 +182,26 @@ pub async fn enable_webhook(
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.target_ref = Some(trigger.r#ref.clone());
|
||||
ctx.pack_ref = trigger.pack_ref.clone();
|
||||
authz
|
||||
.authorize(
|
||||
&user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Triggers,
|
||||
action: Action::Update,
|
||||
context: ctx,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Enable webhooks for this trigger
|
||||
let _webhook_info = TriggerRepository::enable_webhook(&state.db, trigger.id)
|
||||
.await
|
||||
@@ -213,7 +236,7 @@ pub async fn enable_webhook(
|
||||
)]
|
||||
pub async fn disable_webhook(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// First, find the trigger by ref to get its ID
|
||||
@@ -222,6 +245,26 @@ pub async fn disable_webhook(
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.target_ref = Some(trigger.r#ref.clone());
|
||||
ctx.pack_ref = trigger.pack_ref.clone();
|
||||
authz
|
||||
.authorize(
|
||||
&user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Triggers,
|
||||
action: Action::Update,
|
||||
context: ctx,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Disable webhooks for this trigger
|
||||
TriggerRepository::disable_webhook(&state.db, trigger.id)
|
||||
.await
|
||||
@@ -257,7 +300,7 @@ pub async fn disable_webhook(
|
||||
)]
|
||||
pub async fn regenerate_webhook_key(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
RequireAuth(user): RequireAuth,
|
||||
Path(trigger_ref): Path<String>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// First, find the trigger by ref to get its ID
|
||||
@@ -266,6 +309,26 @@ pub async fn regenerate_webhook_key(
|
||||
.map_err(|e| ApiError::InternalServerError(e.to_string()))?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("Trigger '{}' not found", trigger_ref)))?;
|
||||
|
||||
if user.claims.token_type == crate::auth::jwt::TokenType::Access {
|
||||
let identity_id = user
|
||||
.identity_id()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid user identity".to_string()))?;
|
||||
let authz = AuthorizationService::new(state.db.clone());
|
||||
let mut ctx = AuthorizationContext::new(identity_id);
|
||||
ctx.target_ref = Some(trigger.r#ref.clone());
|
||||
ctx.pack_ref = trigger.pack_ref.clone();
|
||||
authz
|
||||
.authorize(
|
||||
&user,
|
||||
AuthorizationCheck {
|
||||
resource: Resource::Triggers,
|
||||
action: Action::Update,
|
||||
context: ctx,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Check if webhooks are enabled
|
||||
if !trigger.webhook_enabled {
|
||||
return Err(ApiError::BadRequest(
|
||||
|
||||
@@ -18,7 +18,7 @@ use attune_common::repositories::{
|
||||
CreateWorkflowDefinitionInput, UpdateWorkflowDefinitionInput, WorkflowDefinitionRepository,
|
||||
WorkflowSearchFilters,
|
||||
},
|
||||
Create, Delete, FindByRef, Update,
|
||||
Create, Delete, FindByRef, Patch, Update,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -66,7 +66,6 @@ pub async fn list_workflows(
|
||||
let filters = WorkflowSearchFilters {
|
||||
pack: None,
|
||||
pack_ref: search_params.pack_ref.clone(),
|
||||
enabled: search_params.enabled,
|
||||
tags,
|
||||
search: search_params.search.clone(),
|
||||
limit: pagination.limit(),
|
||||
@@ -113,7 +112,6 @@ pub async fn list_workflows_by_pack(
|
||||
let filters = WorkflowSearchFilters {
|
||||
pack: None,
|
||||
pack_ref: Some(pack_ref),
|
||||
enabled: None,
|
||||
tags: None,
|
||||
search: None,
|
||||
limit: pagination.limit(),
|
||||
@@ -208,7 +206,6 @@ pub async fn create_workflow(
|
||||
out_schema: request.out_schema.clone(),
|
||||
definition: request.definition,
|
||||
tags: request.tags.clone().unwrap_or_default(),
|
||||
enabled: request.enabled.unwrap_or(true),
|
||||
};
|
||||
|
||||
let workflow = WorkflowDefinitionRepository::create(&state.db, workflow_input).await?;
|
||||
@@ -220,7 +217,7 @@ pub async fn create_workflow(
|
||||
pack.id,
|
||||
&pack.r#ref,
|
||||
&request.label,
|
||||
&request.description.clone().unwrap_or_default(),
|
||||
request.description.as_deref(),
|
||||
"workflow",
|
||||
request.param_schema.as_ref(),
|
||||
request.out_schema.as_ref(),
|
||||
@@ -275,7 +272,6 @@ pub async fn update_workflow(
|
||||
out_schema: request.out_schema.clone(),
|
||||
definition: request.definition,
|
||||
tags: request.tags,
|
||||
enabled: request.enabled,
|
||||
};
|
||||
|
||||
let workflow =
|
||||
@@ -408,7 +404,6 @@ pub async fn save_workflow_file(
|
||||
out_schema: request.out_schema.clone(),
|
||||
definition: definition_json,
|
||||
tags: request.tags.clone().unwrap_or_default(),
|
||||
enabled: request.enabled.unwrap_or(true),
|
||||
};
|
||||
|
||||
let workflow = WorkflowDefinitionRepository::create(&state.db, workflow_input).await?;
|
||||
@@ -421,7 +416,7 @@ pub async fn save_workflow_file(
|
||||
pack.id,
|
||||
&pack.r#ref,
|
||||
&request.label,
|
||||
&request.description.clone().unwrap_or_default(),
|
||||
request.description.as_deref(),
|
||||
&entrypoint,
|
||||
request.param_schema.as_ref(),
|
||||
request.out_schema.as_ref(),
|
||||
@@ -489,7 +484,6 @@ pub async fn update_workflow_file(
|
||||
out_schema: request.out_schema.clone(),
|
||||
definition: Some(definition_json),
|
||||
tags: request.tags,
|
||||
enabled: request.enabled,
|
||||
};
|
||||
|
||||
let workflow =
|
||||
@@ -505,7 +499,7 @@ pub async fn update_workflow_file(
|
||||
pack.id,
|
||||
&pack.r#ref,
|
||||
&request.label,
|
||||
&request.description.unwrap_or_default(),
|
||||
request.description.as_deref(),
|
||||
&entrypoint,
|
||||
request.param_schema.as_ref(),
|
||||
request.out_schema.as_ref(),
|
||||
@@ -647,7 +641,6 @@ fn build_action_yaml(pack_ref: &str, request: &SaveWorkflowFileRequest) -> Strin
|
||||
lines.push(format!("description: \"{}\"", desc.replace('"', "\\\"")));
|
||||
}
|
||||
}
|
||||
lines.push("enabled: true".to_string());
|
||||
lines.push(format!(
|
||||
"workflow_file: workflows/{}.workflow.yaml",
|
||||
request.name
|
||||
@@ -709,7 +702,7 @@ async fn create_companion_action(
|
||||
pack_id: i64,
|
||||
pack_ref: &str,
|
||||
label: &str,
|
||||
description: &str,
|
||||
description: Option<&str>,
|
||||
entrypoint: &str,
|
||||
param_schema: Option<&serde_json::Value>,
|
||||
out_schema: Option<&serde_json::Value>,
|
||||
@@ -720,7 +713,7 @@ async fn create_companion_action(
|
||||
pack: pack_id,
|
||||
pack_ref: pack_ref.to_string(),
|
||||
label: label.to_string(),
|
||||
description: description.to_string(),
|
||||
description: description.map(|s| s.to_string()),
|
||||
entrypoint: entrypoint.to_string(),
|
||||
runtime: None,
|
||||
runtime_version_constraint: None,
|
||||
@@ -794,7 +787,7 @@ async fn update_companion_action(
|
||||
if let Some(action) = existing_action {
|
||||
let update_input = UpdateActionInput {
|
||||
label: label.map(|s| s.to_string()),
|
||||
description: description.map(|s| s.to_string()),
|
||||
description: description.map(|s| Patch::Set(s.to_string())),
|
||||
entrypoint: None,
|
||||
runtime: None,
|
||||
runtime_version_constraint: None,
|
||||
@@ -845,7 +838,7 @@ async fn ensure_companion_action(
|
||||
pack_id: i64,
|
||||
pack_ref: &str,
|
||||
label: &str,
|
||||
description: &str,
|
||||
description: Option<&str>,
|
||||
entrypoint: &str,
|
||||
param_schema: Option<&serde_json::Value>,
|
||||
out_schema: Option<&serde_json::Value>,
|
||||
@@ -860,7 +853,10 @@ async fn ensure_companion_action(
|
||||
// Update existing companion action
|
||||
let update_input = UpdateActionInput {
|
||||
label: Some(label.to_string()),
|
||||
description: Some(description.to_string()),
|
||||
description: Some(match description {
|
||||
Some(description) => Patch::Set(description.to_string()),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
entrypoint: Some(entrypoint.to_string()),
|
||||
runtime: None,
|
||||
runtime_version_constraint: None,
|
||||
|
||||
@@ -47,6 +47,7 @@ impl Server {
|
||||
let api_v1 = Router::new()
|
||||
.merge(routes::pack_routes())
|
||||
.merge(routes::action_routes())
|
||||
.merge(routes::runtime_routes())
|
||||
.merge(routes::rule_routes())
|
||||
.merge(routes::execution_routes())
|
||||
.merge(routes::trigger_routes())
|
||||
@@ -59,6 +60,7 @@ impl Server {
|
||||
.merge(routes::history_routes())
|
||||
.merge(routes::analytics_routes())
|
||||
.merge(routes::artifact_routes())
|
||||
.merge(routes::agent_routes())
|
||||
.with_state(self.state.clone());
|
||||
|
||||
// Auth routes at root level (not versioned for frontend compatibility)
|
||||
|
||||
@@ -362,7 +362,7 @@ mod tests {
|
||||
pack: 1,
|
||||
pack_ref: "test".to_string(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test action".to_string(),
|
||||
description: Some("Test action".to_string()),
|
||||
entrypoint: "test.sh".to_string(),
|
||||
runtime: Some(1),
|
||||
runtime_version_constraint: None,
|
||||
|
||||
138
crates/api/tests/agent_tests.rs
Normal file
138
crates/api/tests/agent_tests.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
//! Integration tests for agent binary distribution endpoints
|
||||
//!
|
||||
//! The agent endpoints (`/api/v1/agent/binary` and `/api/v1/agent/info`) are
|
||||
//! intentionally unauthenticated — the agent needs to download its binary
|
||||
//! before it has JWT credentials. An optional `bootstrap_token` can restrict
|
||||
//! access, but that is validated inside the handler, not via RequireAuth
|
||||
//! middleware.
|
||||
//!
|
||||
//! The test configuration (`config.test.yaml`) does NOT include an `agent`
|
||||
//! section, so both endpoints return 503 Service Unavailable. This is the
|
||||
//! correct behaviour: the endpoints are reachable (no 401/404 from middleware)
|
||||
//! but the feature is not configured.
|
||||
|
||||
use axum::http::StatusCode;
|
||||
|
||||
#[allow(dead_code)]
|
||||
mod helpers;
|
||||
use helpers::TestContext;
|
||||
|
||||
// ── /api/v1/agent/info ──────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_agent_info_not_configured() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/api/v1/agent/info", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// Agent config is not set in config.test.yaml, so the handler returns 503.
|
||||
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
assert_eq!(body["error"], "Not configured");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_agent_info_no_auth_required() {
|
||||
// Verify that the endpoint is reachable WITHOUT any JWT token.
|
||||
// If RequireAuth middleware were applied, this would return 401.
|
||||
// Instead we expect 503 (not configured) — proving the endpoint
|
||||
// is publicly accessible.
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/api/v1/agent/info", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// Must NOT be 401 Unauthorized — the endpoint has no auth middleware.
|
||||
assert_ne!(
|
||||
response.status(),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"agent/info should not require authentication"
|
||||
);
|
||||
// Should be 503 because agent config is absent.
|
||||
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
// ── /api/v1/agent/binary ────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_agent_binary_not_configured() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/api/v1/agent/binary", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// Agent config is not set in config.test.yaml, so the handler returns 503.
|
||||
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
assert_eq!(body["error"], "Not configured");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_agent_binary_no_auth_required() {
|
||||
// Same reasoning as test_agent_info_no_auth_required: the binary
|
||||
// download endpoint must be publicly accessible (no RequireAuth).
|
||||
// When no bootstrap_token is configured, any caller can reach the
|
||||
// handler. We still get 503 because the agent feature itself is
|
||||
// not configured in the test environment.
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/api/v1/agent/binary", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// Must NOT be 401 Unauthorized — the endpoint has no auth middleware.
|
||||
assert_ne!(
|
||||
response.status(),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"agent/binary should not require authentication when no bootstrap_token is configured"
|
||||
);
|
||||
// Should be 503 because agent config is absent.
|
||||
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_agent_binary_invalid_arch() {
|
||||
// Architecture validation (`validate_arch`) rejects unsupported values
|
||||
// with 400 Bad Request. However, in the handler the execution order is:
|
||||
// 1. validate_token (passes — no bootstrap_token configured)
|
||||
// 2. check agent config (fails with 503 — not configured)
|
||||
// 3. validate_arch (never reached)
|
||||
//
|
||||
// So even with an invalid arch like "mips", we get 503 from the config
|
||||
// check before the arch is ever validated. The arch validation is covered
|
||||
// by unit tests in routes/agent.rs instead.
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/api/v1/agent/binary?arch=mips", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// 503 from the agent-config-not-set check, NOT 400 from arch validation.
|
||||
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
@@ -305,6 +305,126 @@ async fn test_login_nonexistent_user() {
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// ── LDAP auth tests ──────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_ldap_login_returns_501_when_not_configured() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/ldap/login",
|
||||
json!({
|
||||
"login": "jdoe",
|
||||
"password": "secret"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// LDAP is not configured in config.test.yaml, so the endpoint
|
||||
// should return 501 Not Implemented.
|
||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_ldap_login_validates_empty_login() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/ldap/login",
|
||||
json!({
|
||||
"login": "",
|
||||
"password": "secret"
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// Validation should fail before we even check LDAP config
|
||||
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_ldap_login_validates_empty_password() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/ldap/login",
|
||||
json!({
|
||||
"login": "jdoe",
|
||||
"password": ""
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_ldap_login_validates_missing_fields() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.post("/auth/ldap/login", json!({}), None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
// Missing required fields should return 422
|
||||
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
// ── auth/settings LDAP field tests ──────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_auth_settings_includes_ldap_fields_disabled() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let response = ctx
|
||||
.get("/auth/settings", None)
|
||||
.await
|
||||
.expect("Failed to make request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
|
||||
|
||||
// LDAP is not configured in config.test.yaml, so these should all
|
||||
// reflect the disabled state.
|
||||
assert_eq!(body["data"]["ldap_enabled"], false);
|
||||
assert_eq!(body["data"]["ldap_visible_by_default"], false);
|
||||
assert!(body["data"]["ldap_provider_name"].is_null());
|
||||
assert!(body["data"]["ldap_provider_label"].is_null());
|
||||
assert!(body["data"]["ldap_provider_icon_url"].is_null());
|
||||
|
||||
// Existing fields should still be present
|
||||
assert!(body["data"]["authentication_enabled"].is_boolean());
|
||||
assert!(body["data"]["local_password_enabled"].is_boolean());
|
||||
assert!(body["data"]["oidc_enabled"].is_boolean());
|
||||
assert!(body["data"]["self_registration_enabled"].is_boolean());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_get_current_user() {
|
||||
|
||||
@@ -241,6 +241,7 @@ impl TestContext {
|
||||
}
|
||||
|
||||
/// Create and authenticate a test user
|
||||
#[allow(dead_code)]
|
||||
pub async fn with_auth(mut self) -> Result<Self> {
|
||||
// Generate unique username to avoid conflicts in parallel tests
|
||||
let unique_id = uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string();
|
||||
@@ -251,6 +252,7 @@ impl TestContext {
|
||||
}
|
||||
|
||||
/// Create and authenticate a test user with identity + permission admin grants.
|
||||
#[allow(dead_code)]
|
||||
pub async fn with_admin_auth(mut self) -> Result<Self> {
|
||||
let unique_id = uuid::Uuid::new_v4().to_string().replace("-", "")[..8].to_string();
|
||||
let login = format!("adminuser_{}", unique_id);
|
||||
@@ -393,6 +395,7 @@ impl TestContext {
|
||||
}
|
||||
|
||||
/// Get authenticated token
|
||||
#[allow(dead_code)]
|
||||
pub fn token(&self) -> Option<&str> {
|
||||
self.token.as_deref()
|
||||
}
|
||||
@@ -494,7 +497,7 @@ pub async fn create_test_action(pool: &PgPool, pack_id: i64, ref_name: &str) ->
|
||||
pack: pack_id,
|
||||
pack_ref: format!("pack_{}", pack_id),
|
||||
label: format!("Test Action {}", ref_name),
|
||||
description: format!("Test action for {}", ref_name),
|
||||
description: Some(format!("Test action for {}", ref_name)),
|
||||
entrypoint: "main.py".to_string(),
|
||||
runtime: None,
|
||||
runtime_version_constraint: None,
|
||||
@@ -551,7 +554,6 @@ pub async fn create_test_workflow(
|
||||
]
|
||||
}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
Ok(WorkflowDefinitionRepository::create(pool, input).await?)
|
||||
|
||||
@@ -22,7 +22,6 @@ ref: {}.example_workflow
|
||||
label: Example Workflow
|
||||
description: A test workflow for integration testing
|
||||
version: "1.0.0"
|
||||
enabled: true
|
||||
parameters:
|
||||
message:
|
||||
type: string
|
||||
@@ -46,7 +45,6 @@ ref: {}.another_workflow
|
||||
label: Another Workflow
|
||||
description: Second test workflow
|
||||
version: "1.0.0"
|
||||
enabled: false
|
||||
tasks:
|
||||
- name: task1
|
||||
action: core.noop
|
||||
|
||||
276
crates/api/tests/rbac_scoped_resources_api_tests.rs
Normal file
276
crates/api/tests/rbac_scoped_resources_api_tests.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
use axum::http::StatusCode;
|
||||
use helpers::*;
|
||||
use serde_json::json;
|
||||
|
||||
use attune_common::{
|
||||
models::enums::{ArtifactType, ArtifactVisibility, OwnerType, RetentionPolicyType},
|
||||
repositories::{
|
||||
artifact::{ArtifactRepository, CreateArtifactInput},
|
||||
identity::{
|
||||
CreatePermissionAssignmentInput, CreatePermissionSetInput, IdentityRepository,
|
||||
PermissionAssignmentRepository, PermissionSetRepository,
|
||||
},
|
||||
key::{CreateKeyInput, KeyRepository},
|
||||
Create,
|
||||
},
|
||||
};
|
||||
|
||||
mod helpers;
|
||||
|
||||
async fn register_scoped_user(
|
||||
ctx: &TestContext,
|
||||
login: &str,
|
||||
grants: serde_json::Value,
|
||||
) -> Result<String> {
|
||||
let response = ctx
|
||||
.post(
|
||||
"/auth/register",
|
||||
json!({
|
||||
"login": login,
|
||||
"password": "TestPassword123!",
|
||||
"display_name": format!("Scoped User {}", login),
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::CREATED);
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let token = body["data"]["access_token"]
|
||||
.as_str()
|
||||
.expect("missing access token")
|
||||
.to_string();
|
||||
|
||||
let identity = IdentityRepository::find_by_login(&ctx.pool, login)
|
||||
.await?
|
||||
.expect("registered identity should exist");
|
||||
|
||||
let permset = PermissionSetRepository::create(
|
||||
&ctx.pool,
|
||||
CreatePermissionSetInput {
|
||||
r#ref: format!("test.scoped_{}", uuid::Uuid::new_v4().simple()),
|
||||
pack: None,
|
||||
pack_ref: None,
|
||||
label: Some("Scoped Test Permission Set".to_string()),
|
||||
description: Some("Scoped test grants".to_string()),
|
||||
grants,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
PermissionAssignmentRepository::create(
|
||||
&ctx.pool,
|
||||
CreatePermissionAssignmentInput {
|
||||
identity: identity.id,
|
||||
permset: permset.id,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_pack_scoped_key_permissions_enforce_owner_refs() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let token = register_scoped_user(
|
||||
&ctx,
|
||||
&format!("scoped_keys_{}", uuid::Uuid::new_v4().simple()),
|
||||
json!([
|
||||
{
|
||||
"resource": "keys",
|
||||
"actions": ["read"],
|
||||
"constraints": {
|
||||
"owner_types": ["pack"],
|
||||
"owner_refs": ["python_example"]
|
||||
}
|
||||
}
|
||||
]),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to register scoped user");
|
||||
|
||||
KeyRepository::create(
|
||||
&ctx.pool,
|
||||
CreateKeyInput {
|
||||
r#ref: format!("python_example_key_{}", uuid::Uuid::new_v4().simple()),
|
||||
owner_type: OwnerType::Pack,
|
||||
owner: Some("python_example".to_string()),
|
||||
owner_identity: None,
|
||||
owner_pack: None,
|
||||
owner_pack_ref: Some("python_example".to_string()),
|
||||
owner_action: None,
|
||||
owner_action_ref: None,
|
||||
owner_sensor: None,
|
||||
owner_sensor_ref: None,
|
||||
name: "Python Example Key".to_string(),
|
||||
encrypted: false,
|
||||
encryption_key_hash: None,
|
||||
value: json!("allowed"),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create scoped key");
|
||||
|
||||
let blocked_key = KeyRepository::create(
|
||||
&ctx.pool,
|
||||
CreateKeyInput {
|
||||
r#ref: format!("other_pack_key_{}", uuid::Uuid::new_v4().simple()),
|
||||
owner_type: OwnerType::Pack,
|
||||
owner: Some("other_pack".to_string()),
|
||||
owner_identity: None,
|
||||
owner_pack: None,
|
||||
owner_pack_ref: Some("other_pack".to_string()),
|
||||
owner_action: None,
|
||||
owner_action_ref: None,
|
||||
owner_sensor: None,
|
||||
owner_sensor_ref: None,
|
||||
name: "Other Pack Key".to_string(),
|
||||
encrypted: false,
|
||||
encryption_key_hash: None,
|
||||
value: json!("blocked"),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create blocked key");
|
||||
|
||||
let allowed_list = ctx
|
||||
.get("/api/v1/keys", Some(&token))
|
||||
.await
|
||||
.expect("Failed to list keys");
|
||||
assert_eq!(allowed_list.status(), StatusCode::OK);
|
||||
let allowed_body: serde_json::Value = allowed_list.json().await.expect("Invalid key list");
|
||||
assert_eq!(
|
||||
allowed_body["data"]
|
||||
.as_array()
|
||||
.expect("expected list")
|
||||
.len(),
|
||||
1
|
||||
);
|
||||
assert_eq!(allowed_body["data"][0]["owner"], "python_example");
|
||||
|
||||
let blocked_get = ctx
|
||||
.get(&format!("/api/v1/keys/{}", blocked_key.r#ref), Some(&token))
|
||||
.await
|
||||
.expect("Failed to fetch blocked key");
|
||||
assert_eq!(blocked_get.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "integration test — requires database"]
|
||||
async fn test_pack_scoped_artifact_permissions_enforce_owner_refs() {
|
||||
let ctx = TestContext::new()
|
||||
.await
|
||||
.expect("Failed to create test context");
|
||||
|
||||
let token = register_scoped_user(
|
||||
&ctx,
|
||||
&format!("scoped_artifacts_{}", uuid::Uuid::new_v4().simple()),
|
||||
json!([
|
||||
{
|
||||
"resource": "artifacts",
|
||||
"actions": ["read", "create"],
|
||||
"constraints": {
|
||||
"owner_types": ["pack"],
|
||||
"owner_refs": ["python_example"]
|
||||
}
|
||||
}
|
||||
]),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to register scoped user");
|
||||
|
||||
let allowed_artifact = ArtifactRepository::create(
|
||||
&ctx.pool,
|
||||
CreateArtifactInput {
|
||||
r#ref: format!("python_example.allowed_{}", uuid::Uuid::new_v4().simple()),
|
||||
scope: OwnerType::Pack,
|
||||
owner: "python_example".to_string(),
|
||||
r#type: ArtifactType::FileText,
|
||||
visibility: ArtifactVisibility::Private,
|
||||
retention_policy: RetentionPolicyType::Versions,
|
||||
retention_limit: 5,
|
||||
name: Some("Allowed Artifact".to_string()),
|
||||
description: None,
|
||||
content_type: Some("text/plain".to_string()),
|
||||
execution: None,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create allowed artifact");
|
||||
|
||||
let blocked_artifact = ArtifactRepository::create(
|
||||
&ctx.pool,
|
||||
CreateArtifactInput {
|
||||
r#ref: format!("other_pack.blocked_{}", uuid::Uuid::new_v4().simple()),
|
||||
scope: OwnerType::Pack,
|
||||
owner: "other_pack".to_string(),
|
||||
r#type: ArtifactType::FileText,
|
||||
visibility: ArtifactVisibility::Private,
|
||||
retention_policy: RetentionPolicyType::Versions,
|
||||
retention_limit: 5,
|
||||
name: Some("Blocked Artifact".to_string()),
|
||||
description: None,
|
||||
content_type: Some("text/plain".to_string()),
|
||||
execution: None,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create blocked artifact");
|
||||
|
||||
let allowed_get = ctx
|
||||
.get(
|
||||
&format!("/api/v1/artifacts/{}", allowed_artifact.id),
|
||||
Some(&token),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to fetch allowed artifact");
|
||||
assert_eq!(allowed_get.status(), StatusCode::OK);
|
||||
|
||||
let blocked_get = ctx
|
||||
.get(
|
||||
&format!("/api/v1/artifacts/{}", blocked_artifact.id),
|
||||
Some(&token),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to fetch blocked artifact");
|
||||
assert_eq!(blocked_get.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
let create_allowed = ctx
|
||||
.post(
|
||||
"/api/v1/artifacts",
|
||||
json!({
|
||||
"ref": format!("python_example.created_{}", uuid::Uuid::new_v4().simple()),
|
||||
"scope": "pack",
|
||||
"owner": "python_example",
|
||||
"type": "file_text",
|
||||
"name": "Created Artifact"
|
||||
}),
|
||||
Some(&token),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create allowed artifact");
|
||||
assert_eq!(create_allowed.status(), StatusCode::CREATED);
|
||||
|
||||
let create_blocked = ctx
|
||||
.post(
|
||||
"/api/v1/artifacts",
|
||||
json!({
|
||||
"ref": format!("other_pack.created_{}", uuid::Uuid::new_v4().simple()),
|
||||
"scope": "pack",
|
||||
"owner": "other_pack",
|
||||
"type": "file_text",
|
||||
"name": "Blocked Artifact"
|
||||
}),
|
||||
Some(&token),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create blocked artifact");
|
||||
assert_eq!(create_blocked.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
@@ -52,7 +52,7 @@ async fn setup_test_pack_and_action(pool: &PgPool) -> Result<(Pack, Action)> {
|
||||
pack: pack.id,
|
||||
pack_ref: pack.r#ref.clone(),
|
||||
label: "Test Action".to_string(),
|
||||
description: "Test action for SSE tests".to_string(),
|
||||
description: Some("Test action for SSE tests".to_string()),
|
||||
entrypoint: "test.sh".to_string(),
|
||||
runtime: None,
|
||||
runtime_version_constraint: None,
|
||||
|
||||
@@ -46,8 +46,7 @@ async fn test_create_workflow_success() {
|
||||
}
|
||||
]
|
||||
},
|
||||
"tags": ["test", "automation"],
|
||||
"enabled": true
|
||||
"tags": ["test", "automation"]
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
@@ -60,7 +59,6 @@ async fn test_create_workflow_success() {
|
||||
assert_eq!(body["data"]["ref"], "test-pack.test_workflow");
|
||||
assert_eq!(body["data"]["label"], "Test Workflow");
|
||||
assert_eq!(body["data"]["version"], "1.0.0");
|
||||
assert_eq!(body["data"]["enabled"], true);
|
||||
assert!(body["data"]["tags"].as_array().unwrap().len() == 2);
|
||||
}
|
||||
|
||||
@@ -85,7 +83,6 @@ async fn test_create_workflow_duplicate_ref() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
@@ -152,7 +149,6 @@ async fn test_get_workflow_by_ref() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": [{"name": "task1"}]}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
@@ -206,7 +202,6 @@ async fn test_list_workflows() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: i % 2 == 1, // Odd ones enabled
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
@@ -256,7 +251,6 @@ async fn test_list_workflows_by_pack() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
@@ -275,7 +269,6 @@ async fn test_list_workflows_by_pack() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
@@ -308,14 +301,14 @@ async fn test_list_workflows_with_filters() {
|
||||
let pack_name = unique_pack_name();
|
||||
let pack = create_test_pack(&ctx.pool, &pack_name).await.unwrap();
|
||||
|
||||
// Create workflows with different tags and enabled status
|
||||
// Create workflows with different tags
|
||||
let workflows = vec![
|
||||
("workflow1", vec!["incident", "approval"], true),
|
||||
("workflow2", vec!["incident"], false),
|
||||
("workflow3", vec!["automation"], true),
|
||||
("workflow1", vec!["incident", "approval"]),
|
||||
("workflow2", vec!["incident"]),
|
||||
("workflow3", vec!["automation"]),
|
||||
];
|
||||
|
||||
for (ref_name, tags, enabled) in workflows {
|
||||
for (ref_name, tags) in workflows {
|
||||
let input = CreateWorkflowDefinitionInput {
|
||||
r#ref: format!("test-pack.{}", ref_name),
|
||||
pack: pack.id,
|
||||
@@ -327,24 +320,12 @@ async fn test_list_workflows_with_filters() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: tags.iter().map(|s| s.to_string()).collect(),
|
||||
enabled,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Filter by enabled (and pack_ref for isolation)
|
||||
let response = ctx
|
||||
.get(
|
||||
&format!("/api/v1/workflows?enabled=true&pack_ref={}", pack_name),
|
||||
ctx.token(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body: Value = response.json().await.unwrap();
|
||||
assert_eq!(body["data"].as_array().unwrap().len(), 2);
|
||||
|
||||
// Filter by tag (and pack_ref for isolation)
|
||||
let response = ctx
|
||||
.get(
|
||||
@@ -387,7 +368,6 @@ async fn test_update_workflow() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec!["test".to_string()],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
@@ -400,8 +380,7 @@ async fn test_update_workflow() {
|
||||
json!({
|
||||
"label": "Updated Label",
|
||||
"description": "Updated description",
|
||||
"version": "1.1.0",
|
||||
"enabled": false
|
||||
"version": "1.1.0"
|
||||
}),
|
||||
ctx.token(),
|
||||
)
|
||||
@@ -414,7 +393,6 @@ async fn test_update_workflow() {
|
||||
assert_eq!(body["data"]["label"], "Updated Label");
|
||||
assert_eq!(body["data"]["description"], "Updated description");
|
||||
assert_eq!(body["data"]["version"], "1.1.0");
|
||||
assert_eq!(body["data"]["enabled"], false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -455,7 +433,6 @@ async fn test_delete_workflow() {
|
||||
out_schema: None,
|
||||
definition: json!({"tasks": []}),
|
||||
tags: vec![],
|
||||
enabled: true,
|
||||
};
|
||||
WorkflowDefinitionRepository::create(&ctx.pool, input)
|
||||
.await
|
||||
|
||||
@@ -23,6 +23,7 @@ clap = { workspace = true, features = ["derive", "env", "string"] }
|
||||
|
||||
# HTTP client
|
||||
reqwest = { workspace = true, features = ["multipart", "stream"] }
|
||||
reqwest-eventsource = { workspace = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
@@ -69,7 +70,7 @@ tracing-subscriber = { workspace = true }
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
wiremock = "0.6"
|
||||
assert_cmd = "2.1"
|
||||
assert_cmd = "2.2"
|
||||
predicates = "3.1"
|
||||
mockito = "1.7"
|
||||
tokio-test = "0.4"
|
||||
|
||||
@@ -21,6 +21,11 @@ pub struct ApiResponse<T> {
|
||||
pub data: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct PaginatedResponse<T> {
|
||||
data: Vec<T>,
|
||||
}
|
||||
|
||||
/// API error response
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct ApiError {
|
||||
@@ -55,6 +60,10 @@ impl ApiClient {
|
||||
&self.base_url
|
||||
}
|
||||
|
||||
pub fn auth_token(&self) -> Option<&str> {
|
||||
self.auth_token.as_deref()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn new(base_url: String, auth_token: Option<String>) -> Self {
|
||||
let client = HttpClient::builder()
|
||||
@@ -255,6 +264,31 @@ impl ApiClient {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_paginated_response<T: DeserializeOwned>(
|
||||
&self,
|
||||
response: reqwest::Response,
|
||||
) -> Result<Vec<T>> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let paginated: PaginatedResponse<T> = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse paginated API response")?;
|
||||
Ok(paginated.data)
|
||||
} else {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_text) {
|
||||
anyhow::bail!("API error ({}): {}", status, api_error.error);
|
||||
} else {
|
||||
anyhow::bail!("API error ({}): {}", status, error_text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a response where we only care about success/failure, not a body.
|
||||
async fn handle_empty_response(&self, response: reqwest::Response) -> Result<()> {
|
||||
let status = response.status();
|
||||
@@ -281,6 +315,25 @@ impl ApiClient {
|
||||
self.execute_json::<T, ()>(Method::GET, path, None).await
|
||||
}
|
||||
|
||||
pub async fn get_paginated<T: DeserializeOwned>(&mut self, path: &str) -> Result<Vec<T>> {
|
||||
let req = self.build_request(Method::GET, path);
|
||||
let response = req.send().await.context("Failed to send request to API")?;
|
||||
|
||||
if response.status() == StatusCode::UNAUTHORIZED
|
||||
&& self.refresh_token.is_some()
|
||||
&& self.refresh_auth_token().await?
|
||||
{
|
||||
let req = self.build_request(Method::GET, path);
|
||||
let response = req
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to API (retry)")?;
|
||||
return self.handle_paginated_response(response).await;
|
||||
}
|
||||
|
||||
self.handle_paginated_response(response).await
|
||||
}
|
||||
|
||||
/// GET request with query parameters (query string must be in path)
|
||||
///
|
||||
/// Part of REST client API - reserved for future advanced filtering/search features.
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::collections::HashMap;
|
||||
use crate::client::ApiClient;
|
||||
use crate::config::CliConfig;
|
||||
use crate::output::{self, OutputFormat};
|
||||
use crate::wait::{wait_for_execution, WaitOptions};
|
||||
use crate::wait::{extract_stdout, spawn_execution_output_watch, wait_for_execution, WaitOptions};
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum ActionCommands {
|
||||
@@ -90,7 +90,7 @@ struct Action {
|
||||
action_ref: String,
|
||||
pack_ref: String,
|
||||
label: String,
|
||||
description: String,
|
||||
description: Option<String>,
|
||||
entrypoint: String,
|
||||
runtime: Option<i64>,
|
||||
created: String,
|
||||
@@ -105,7 +105,7 @@ struct ActionDetail {
|
||||
pack: i64,
|
||||
pack_ref: String,
|
||||
label: String,
|
||||
description: String,
|
||||
description: Option<String>,
|
||||
entrypoint: String,
|
||||
runtime: Option<i64>,
|
||||
param_schema: Option<serde_json::Value>,
|
||||
@@ -253,7 +253,7 @@ async fn handle_list(
|
||||
.runtime
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_else(|| "none".to_string()),
|
||||
output::truncate(&action.description, 40),
|
||||
output::truncate(&action.description.unwrap_or_default(), 40),
|
||||
]);
|
||||
}
|
||||
|
||||
@@ -288,7 +288,10 @@ async fn handle_show(
|
||||
("Reference", action.action_ref.clone()),
|
||||
("Pack", action.pack_ref.clone()),
|
||||
("Label", action.label.clone()),
|
||||
("Description", action.description.clone()),
|
||||
(
|
||||
"Description",
|
||||
action.description.unwrap_or_else(|| "None".to_string()),
|
||||
),
|
||||
("Entry Point", action.entrypoint.clone()),
|
||||
(
|
||||
"Runtime",
|
||||
@@ -356,7 +359,10 @@ async fn handle_update(
|
||||
("Ref", action.action_ref.clone()),
|
||||
("Pack", action.pack_ref.clone()),
|
||||
("Label", action.label.clone()),
|
||||
("Description", action.description.clone()),
|
||||
(
|
||||
"Description",
|
||||
action.description.unwrap_or_else(|| "None".to_string()),
|
||||
),
|
||||
("Entrypoint", action.entrypoint.clone()),
|
||||
(
|
||||
"Runtime",
|
||||
@@ -487,6 +493,15 @@ async fn handle_execute(
|
||||
}
|
||||
|
||||
let verbose = matches!(output_format, OutputFormat::Table);
|
||||
let watch_task = if verbose {
|
||||
Some(spawn_execution_output_watch(
|
||||
ApiClient::from_config(&config, api_url),
|
||||
execution.id,
|
||||
verbose,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let summary = wait_for_execution(WaitOptions {
|
||||
execution_id: execution.id,
|
||||
timeout_secs: timeout,
|
||||
@@ -495,6 +510,13 @@ async fn handle_execute(
|
||||
verbose,
|
||||
})
|
||||
.await?;
|
||||
let suppress_final_stdout = watch_task
|
||||
.as_ref()
|
||||
.is_some_and(|task| task.delivered_output() && task.root_stdout_completed());
|
||||
|
||||
if let Some(task) = watch_task {
|
||||
let _ = tokio::time::timeout(tokio::time::Duration::from_secs(2), task.handle).await;
|
||||
}
|
||||
|
||||
match output_format {
|
||||
OutputFormat::Json | OutputFormat::Yaml => {
|
||||
@@ -511,7 +533,20 @@ async fn handle_execute(
|
||||
("Updated", output::format_timestamp(&summary.updated)),
|
||||
]);
|
||||
|
||||
if let Some(result) = summary.result {
|
||||
let stdout = extract_stdout(&summary.result);
|
||||
if !suppress_final_stdout {
|
||||
if let Some(stdout) = &stdout {
|
||||
output::print_section("Stdout");
|
||||
println!("{}", stdout);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(mut result) = summary.result {
|
||||
if stdout.is_some() {
|
||||
if let Some(obj) = result.as_object_mut() {
|
||||
obj.remove("stdout");
|
||||
}
|
||||
}
|
||||
if !result.is_null() {
|
||||
output::print_section("Result");
|
||||
println!("{}", serde_json::to_string_pretty(&result)?);
|
||||
|
||||
@@ -803,6 +803,7 @@ async fn handle_upload(
|
||||
api_url: &Option<String>,
|
||||
output_format: OutputFormat,
|
||||
) -> Result<()> {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- CLI users explicitly choose a local file to upload; this is not a server-side path sink.
|
||||
let file_path = Path::new(&file);
|
||||
if !file_path.exists() {
|
||||
anyhow::bail!("File not found: {}", file);
|
||||
@@ -811,6 +812,7 @@ async fn handle_upload(
|
||||
anyhow::bail!("Not a file: {}", file);
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The validated CLI-selected upload path is intentionally read and sent to the API.
|
||||
let file_bytes = tokio::fs::read(file_path).await?;
|
||||
let file_name = file_path
|
||||
.file_name()
|
||||
|
||||
@@ -840,6 +840,7 @@ async fn handle_upload(
|
||||
api_url: &Option<String>,
|
||||
output_format: OutputFormat,
|
||||
) -> Result<()> {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- CLI pack commands intentionally operate on operator-supplied local paths.
|
||||
let pack_dir = Path::new(&path);
|
||||
|
||||
// Validate the directory exists and contains pack.yaml
|
||||
@@ -855,6 +856,7 @@ async fn handle_upload(
|
||||
}
|
||||
|
||||
// Read pack ref from pack.yaml so we can display it
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Reading local pack metadata from the user-selected pack directory is expected CLI behavior.
|
||||
let pack_yaml_content =
|
||||
std::fs::read_to_string(&pack_yaml_path).context("Failed to read pack.yaml")?;
|
||||
let pack_yaml: serde_yaml_ng::Value =
|
||||
@@ -957,6 +959,7 @@ fn append_dir_to_tar<W: std::io::Write>(
|
||||
base: &Path,
|
||||
dir: &Path,
|
||||
) -> Result<()> {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The archiver walks a validated local directory selected by the CLI operator.
|
||||
for entry in std::fs::read_dir(dir).context("Failed to read directory")? {
|
||||
let entry = entry.context("Failed to read directory entry")?;
|
||||
let entry_path = entry.path();
|
||||
@@ -1061,6 +1064,7 @@ async fn handle_test(
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
// Determine if pack is a path or a pack name
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Pack test targets are local CLI inputs, not remote request paths.
|
||||
let pack_path = Path::new(&pack);
|
||||
let (pack_dir, pack_ref, pack_version) = if pack_path.exists() && pack_path.is_dir() {
|
||||
// Local pack directory
|
||||
@@ -1072,6 +1076,7 @@ async fn handle_test(
|
||||
anyhow::bail!("pack.yaml not found in directory: {}", pack);
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- This reads pack.yaml from a local directory explicitly selected by the CLI operator.
|
||||
let pack_yaml_content = std::fs::read_to_string(&pack_yaml_path)?;
|
||||
let pack_yaml: serde_yaml_ng::Value = serde_yaml_ng::from_str(&pack_yaml_content)?;
|
||||
|
||||
@@ -1107,6 +1112,7 @@ async fn handle_test(
|
||||
anyhow::bail!("pack.yaml not found for pack: {}", pack);
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Installed pack tests intentionally read local metadata from the workspace packs directory.
|
||||
let pack_yaml_content = std::fs::read_to_string(&pack_yaml_path)?;
|
||||
let pack_yaml: serde_yaml_ng::Value = serde_yaml_ng::from_str(&pack_yaml_content)?;
|
||||
|
||||
@@ -1120,6 +1126,7 @@ async fn handle_test(
|
||||
|
||||
// Load pack.yaml and extract test configuration
|
||||
let pack_yaml_path = pack_dir.join("pack.yaml");
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Test configuration is loaded from the validated local pack directory.
|
||||
let pack_yaml_content = std::fs::read_to_string(&pack_yaml_path)?;
|
||||
let pack_yaml: serde_yaml_ng::Value = serde_yaml_ng::from_str(&pack_yaml_content)?;
|
||||
|
||||
@@ -1484,6 +1491,7 @@ fn detect_source_type(source: &str, ref_spec: Option<&str>, no_registry: bool) -
|
||||
async fn handle_checksum(path: String, json: bool, output_format: OutputFormat) -> Result<()> {
|
||||
use attune_common::pack_registry::{calculate_directory_checksum, calculate_file_checksum};
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Checksum generation intentionally accepts arbitrary local paths from the CLI operator.
|
||||
let path_obj = Path::new(&path);
|
||||
|
||||
if !path_obj.exists() {
|
||||
@@ -1581,6 +1589,7 @@ async fn handle_index_entry(
|
||||
) -> Result<()> {
|
||||
use attune_common::pack_registry::calculate_directory_checksum;
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Index-entry generation intentionally inspects a local pack directory chosen by the CLI operator.
|
||||
let path_obj = Path::new(&path);
|
||||
|
||||
if !path_obj.exists() {
|
||||
@@ -1606,6 +1615,7 @@ async fn handle_index_entry(
|
||||
}
|
||||
|
||||
// Read and parse pack.yaml
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Reading local pack metadata for index generation is expected CLI behavior.
|
||||
let pack_yaml_content = std::fs::read_to_string(&pack_yaml_path)?;
|
||||
let pack_yaml: serde_yaml_ng::Value = serde_yaml_ng::from_str(&pack_yaml_content)?;
|
||||
|
||||
@@ -1775,19 +1785,25 @@ async fn handle_update(
|
||||
anyhow::bail!("At least one field must be provided to update");
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
enum PackDescriptionPatch {
|
||||
Set(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UpdatePackRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
label: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
description: Option<PackDescriptionPatch>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
version: Option<String>,
|
||||
}
|
||||
|
||||
let request = UpdatePackRequest {
|
||||
label,
|
||||
description,
|
||||
description: description.map(PackDescriptionPatch::Set),
|
||||
version,
|
||||
};
|
||||
|
||||
|
||||
@@ -19,11 +19,13 @@ pub async fn handle_index_update(
|
||||
output_format: OutputFormat,
|
||||
) -> Result<()> {
|
||||
// Load existing index
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Registry index maintenance is a local CLI/admin operation over operator-supplied files.
|
||||
let index_file_path = Path::new(&index_path);
|
||||
if !index_file_path.exists() {
|
||||
return Err(anyhow::anyhow!("Index file not found: {}", index_path));
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The CLI intentionally reads the local index file selected by the operator.
|
||||
let index_content = fs::read_to_string(index_file_path)?;
|
||||
let mut index: JsonValue = serde_json::from_str(&index_content)?;
|
||||
|
||||
@@ -34,6 +36,7 @@ pub async fn handle_index_update(
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid index format: missing 'packs' array"))?;
|
||||
|
||||
// Load pack.yaml from the pack directory
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Local pack directories are explicit CLI inputs, not remote taint.
|
||||
let pack_dir = Path::new(&pack_path);
|
||||
if !pack_dir.exists() || !pack_dir.is_dir() {
|
||||
return Err(anyhow::anyhow!("Pack directory not found: {}", pack_path));
|
||||
@@ -47,6 +50,7 @@ pub async fn handle_index_update(
|
||||
));
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Reading pack.yaml from a local operator-selected pack directory is expected CLI behavior.
|
||||
let pack_yaml_content = fs::read_to_string(&pack_yaml_path)?;
|
||||
let pack_yaml: serde_yaml_ng::Value = serde_yaml_ng::from_str(&pack_yaml_content)?;
|
||||
|
||||
@@ -250,6 +254,7 @@ pub async fn handle_index_merge(
|
||||
output_format: OutputFormat,
|
||||
) -> Result<()> {
|
||||
// Check if output file exists
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Index merge output is a local CLI path controlled by the operator.
|
||||
let output_file_path = Path::new(&output_path);
|
||||
if output_file_path.exists() && !force {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -265,6 +270,7 @@ pub async fn handle_index_merge(
|
||||
|
||||
// Load and merge all input files
|
||||
for input_path in &input_paths {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Index merge inputs are local operator-selected files.
|
||||
let input_file_path = Path::new(input_path);
|
||||
if !input_file_path.exists() {
|
||||
if output_format == OutputFormat::Table {
|
||||
@@ -277,6 +283,7 @@ pub async fn handle_index_merge(
|
||||
output::print_info(&format!("Loading: {}", input_path));
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The CLI intentionally reads each local input index file during merge.
|
||||
let index_content = fs::read_to_string(input_file_path)?;
|
||||
let index: JsonValue = serde_json::from_str(&index_content)?;
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ struct Rule {
|
||||
pack: Option<i64>,
|
||||
pack_ref: String,
|
||||
label: String,
|
||||
description: String,
|
||||
description: Option<String>,
|
||||
#[serde(default)]
|
||||
trigger: Option<i64>,
|
||||
trigger_ref: String,
|
||||
@@ -133,7 +133,7 @@ struct RuleDetail {
|
||||
pack: Option<i64>,
|
||||
pack_ref: String,
|
||||
label: String,
|
||||
description: String,
|
||||
description: Option<String>,
|
||||
#[serde(default)]
|
||||
trigger: Option<i64>,
|
||||
trigger_ref: String,
|
||||
@@ -321,7 +321,10 @@ async fn handle_show(
|
||||
("Ref", rule.rule_ref.clone()),
|
||||
("Pack", rule.pack_ref.clone()),
|
||||
("Label", rule.label.clone()),
|
||||
("Description", rule.description.clone()),
|
||||
(
|
||||
"Description",
|
||||
rule.description.unwrap_or_else(|| "None".to_string()),
|
||||
),
|
||||
("Trigger", rule.trigger_ref.clone()),
|
||||
("Action", rule.action_ref.clone()),
|
||||
("Enabled", output::format_bool(rule.enabled)),
|
||||
@@ -440,7 +443,10 @@ async fn handle_update(
|
||||
("Ref", rule.rule_ref.clone()),
|
||||
("Pack", rule.pack_ref.clone()),
|
||||
("Label", rule.label.clone()),
|
||||
("Description", rule.description.clone()),
|
||||
(
|
||||
"Description",
|
||||
rule.description.unwrap_or_else(|| "None".to_string()),
|
||||
),
|
||||
("Trigger", rule.trigger_ref.clone()),
|
||||
("Action", rule.action_ref.clone()),
|
||||
("Enabled", output::format_bool(rule.enabled)),
|
||||
|
||||
@@ -254,19 +254,25 @@ async fn handle_update(
|
||||
anyhow::bail!("At least one field must be provided to update");
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(tag = "op", content = "value", rename_all = "snake_case")]
|
||||
enum TriggerDescriptionPatch {
|
||||
Set(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UpdateTriggerRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
label: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
description: Option<TriggerDescriptionPatch>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
enabled: Option<bool>,
|
||||
}
|
||||
|
||||
let request = UpdateTriggerRequest {
|
||||
label,
|
||||
description,
|
||||
description: description.map(TriggerDescriptionPatch::Set),
|
||||
enabled,
|
||||
};
|
||||
|
||||
|
||||
@@ -85,10 +85,6 @@ struct ActionYaml {
|
||||
/// Tags
|
||||
#[serde(default)]
|
||||
tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the action is enabled
|
||||
#[serde(default)]
|
||||
enabled: Option<bool>,
|
||||
}
|
||||
|
||||
// ── API DTOs ────────────────────────────────────────────────────────────
|
||||
@@ -109,8 +105,6 @@ struct SaveWorkflowFileRequest {
|
||||
out_schema: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tags: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
enabled: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -127,7 +121,6 @@ struct WorkflowResponse {
|
||||
out_schema: Option<serde_json::Value>,
|
||||
definition: serde_json::Value,
|
||||
tags: Vec<String>,
|
||||
enabled: bool,
|
||||
created: String,
|
||||
updated: String,
|
||||
}
|
||||
@@ -142,7 +135,6 @@ struct WorkflowSummary {
|
||||
description: Option<String>,
|
||||
version: String,
|
||||
tags: Vec<String>,
|
||||
enabled: bool,
|
||||
created: String,
|
||||
updated: String,
|
||||
}
|
||||
@@ -180,6 +172,7 @@ async fn handle_upload(
|
||||
api_url: &Option<String>,
|
||||
output_format: OutputFormat,
|
||||
) -> Result<()> {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Workflow upload reads local files chosen by the CLI operator; it is not a server-side path sink.
|
||||
let action_path = Path::new(&action_file);
|
||||
|
||||
// ── 1. Validate & read the action YAML ──────────────────────────────
|
||||
@@ -190,6 +183,7 @@ async fn handle_upload(
|
||||
anyhow::bail!("Path is not a file: {}", action_file);
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The action YAML is intentionally read from the validated local CLI path.
|
||||
let action_yaml_content =
|
||||
std::fs::read_to_string(action_path).context("Failed to read action YAML file")?;
|
||||
|
||||
@@ -224,6 +218,7 @@ async fn handle_upload(
|
||||
}
|
||||
|
||||
// ── 4. Read and parse the workflow YAML ─────────────────────────────
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The workflow file path is confined to the pack directory before this local read occurs.
|
||||
let workflow_yaml_content =
|
||||
std::fs::read_to_string(&workflow_path).context("Failed to read workflow YAML file")?;
|
||||
|
||||
@@ -281,7 +276,6 @@ async fn handle_upload(
|
||||
param_schema: action.parameters.clone(),
|
||||
out_schema: action.output.clone(),
|
||||
tags: action.tags.clone(),
|
||||
enabled: action.enabled,
|
||||
};
|
||||
|
||||
// ── 6. Print progress ───────────────────────────────────────────────
|
||||
@@ -357,7 +351,6 @@ async fn handle_upload(
|
||||
response.tags.join(", ")
|
||||
},
|
||||
),
|
||||
("Enabled", output::format_bool(response.enabled)),
|
||||
]);
|
||||
}
|
||||
}
|
||||
@@ -408,15 +401,7 @@ async fn handle_list(
|
||||
let mut table = output::create_table();
|
||||
output::add_header(
|
||||
&mut table,
|
||||
vec![
|
||||
"ID",
|
||||
"Reference",
|
||||
"Pack",
|
||||
"Label",
|
||||
"Version",
|
||||
"Enabled",
|
||||
"Tags",
|
||||
],
|
||||
vec!["ID", "Reference", "Pack", "Label", "Version", "Tags"],
|
||||
);
|
||||
|
||||
for wf in &workflows {
|
||||
@@ -426,7 +411,6 @@ async fn handle_list(
|
||||
wf.pack_ref.clone(),
|
||||
output::truncate(&wf.label, 30),
|
||||
wf.version.clone(),
|
||||
output::format_bool(wf.enabled),
|
||||
if wf.tags.is_empty() {
|
||||
"-".to_string()
|
||||
} else {
|
||||
@@ -478,7 +462,6 @@ async fn handle_show(
|
||||
.unwrap_or_else(|| "-".to_string()),
|
||||
),
|
||||
("Version", workflow.version.clone()),
|
||||
("Enabled", output::format_bool(workflow.enabled)),
|
||||
(
|
||||
"Tags",
|
||||
if workflow.tags.is_empty() {
|
||||
@@ -636,12 +619,41 @@ fn split_action_ref(action_ref: &str) -> Result<(String, String)> {
|
||||
/// resolved relative to the action YAML's parent directory.
|
||||
fn resolve_workflow_path(action_yaml_path: &Path, workflow_file: &str) -> Result<PathBuf> {
|
||||
let action_dir = action_yaml_path.parent().unwrap_or(Path::new("."));
|
||||
let pack_root = action_dir
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow::anyhow!("Action YAML must live inside a pack actions/ directory"))?;
|
||||
let canonical_pack_root = pack_root
|
||||
.canonicalize()
|
||||
.context("Failed to resolve pack root for workflow file")?;
|
||||
let canonical_action_dir = action_dir
|
||||
.canonicalize()
|
||||
.context("Failed to resolve action directory for workflow file")?;
|
||||
let canonical_workflow_path = normalize_path_from_base(&canonical_action_dir, workflow_file);
|
||||
|
||||
let resolved = action_dir.join(workflow_file);
|
||||
if !canonical_workflow_path.starts_with(&canonical_pack_root) {
|
||||
anyhow::bail!(
|
||||
"Workflow file resolves outside the pack directory: {}",
|
||||
workflow_file
|
||||
);
|
||||
}
|
||||
|
||||
// Canonicalize if possible (for better error messages), but don't fail
|
||||
// if the file doesn't exist yet — we'll check existence later.
|
||||
Ok(resolved)
|
||||
Ok(canonical_workflow_path)
|
||||
}
|
||||
|
||||
fn normalize_path_from_base(base: &Path, relative_path: &str) -> PathBuf {
|
||||
let mut normalized = PathBuf::new();
|
||||
for component in base.join(relative_path).components() {
|
||||
match component {
|
||||
std::path::Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
|
||||
std::path::Component::RootDir => normalized.push(std::path::MAIN_SEPARATOR.to_string()),
|
||||
std::path::Component::CurDir => {}
|
||||
std::path::Component::ParentDir => {
|
||||
normalized.pop();
|
||||
}
|
||||
std::path::Component::Normal(part) => normalized.push(part),
|
||||
}
|
||||
}
|
||||
normalized
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -675,23 +687,62 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_resolve_workflow_path() {
|
||||
let action_path = Path::new("/packs/mypack/actions/deploy.yaml");
|
||||
let temp = tempfile::tempdir().unwrap();
|
||||
let pack_dir = temp.path().join("mypack");
|
||||
let actions_dir = pack_dir.join("actions");
|
||||
let workflow_dir = actions_dir.join("workflows");
|
||||
std::fs::create_dir_all(&workflow_dir).unwrap();
|
||||
|
||||
let action_path = actions_dir.join("deploy.yaml");
|
||||
let workflow_path = workflow_dir.join("deploy.workflow.yaml");
|
||||
std::fs::write(
|
||||
&action_path,
|
||||
"ref: mypack.deploy\nworkflow_file: workflows/deploy.workflow.yaml\n",
|
||||
)
|
||||
.unwrap();
|
||||
std::fs::write(&workflow_path, "version: 1.0.0\n").unwrap();
|
||||
|
||||
let resolved =
|
||||
resolve_workflow_path(action_path, "workflows/deploy.workflow.yaml").unwrap();
|
||||
assert_eq!(
|
||||
resolved,
|
||||
PathBuf::from("/packs/mypack/actions/workflows/deploy.workflow.yaml")
|
||||
);
|
||||
resolve_workflow_path(&action_path, "workflows/deploy.workflow.yaml").unwrap();
|
||||
assert_eq!(resolved, workflow_path.canonicalize().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_workflow_path_relative() {
|
||||
let action_path = Path::new("actions/deploy.yaml");
|
||||
let temp = tempfile::tempdir().unwrap();
|
||||
let pack_dir = temp.path().join("mypack");
|
||||
let actions_dir = pack_dir.join("actions");
|
||||
let workflows_dir = pack_dir.join("workflows");
|
||||
std::fs::create_dir_all(&actions_dir).unwrap();
|
||||
std::fs::create_dir_all(&workflows_dir).unwrap();
|
||||
|
||||
let action_path = actions_dir.join("deploy.yaml");
|
||||
let workflow_path = workflows_dir.join("deploy.workflow.yaml");
|
||||
std::fs::write(
|
||||
&action_path,
|
||||
"ref: mypack.deploy\nworkflow_file: ../workflows/deploy.workflow.yaml\n",
|
||||
)
|
||||
.unwrap();
|
||||
std::fs::write(&workflow_path, "version: 1.0.0\n").unwrap();
|
||||
|
||||
let resolved =
|
||||
resolve_workflow_path(action_path, "workflows/deploy.workflow.yaml").unwrap();
|
||||
assert_eq!(
|
||||
resolved,
|
||||
PathBuf::from("actions/workflows/deploy.workflow.yaml")
|
||||
);
|
||||
resolve_workflow_path(&action_path, "../workflows/deploy.workflow.yaml").unwrap();
|
||||
assert_eq!(resolved, workflow_path.canonicalize().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_workflow_path_rejects_traversal_outside_pack() {
|
||||
let temp = tempfile::tempdir().unwrap();
|
||||
let pack_dir = temp.path().join("mypack");
|
||||
let actions_dir = pack_dir.join("actions");
|
||||
std::fs::create_dir_all(&actions_dir).unwrap();
|
||||
|
||||
let action_path = actions_dir.join("deploy.yaml");
|
||||
let outside = temp.path().join("outside.yaml");
|
||||
std::fs::write(&action_path, "ref: mypack.deploy\n").unwrap();
|
||||
std::fs::write(&outside, "version: 1.0.0\n").unwrap();
|
||||
|
||||
let err = resolve_workflow_path(&action_path, "../../outside.yaml").unwrap_err();
|
||||
assert!(err.to_string().contains("outside the pack directory"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,8 +401,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_effective_format_defaults_to_config() {
|
||||
let mut config = CliConfig::default();
|
||||
config.format = "json".to_string();
|
||||
let config = CliConfig {
|
||||
format: "json".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// No CLI override → uses config
|
||||
assert_eq!(config.effective_format(None), OutputFormat::Json);
|
||||
@@ -410,8 +412,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_effective_format_cli_overrides_config() {
|
||||
let mut config = CliConfig::default();
|
||||
config.format = "json".to_string();
|
||||
let config = CliConfig {
|
||||
format: "json".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// CLI override wins
|
||||
assert_eq!(
|
||||
|
||||
@@ -11,7 +11,13 @@
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use reqwest_eventsource::{Event as SseEvent, EventSource};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
@@ -54,6 +60,22 @@ pub struct WaitOptions<'a> {
|
||||
pub verbose: bool,
|
||||
}
|
||||
|
||||
pub struct OutputWatchTask {
|
||||
pub handle: tokio::task::JoinHandle<()>,
|
||||
delivered_output: Arc<AtomicBool>,
|
||||
root_stdout_completed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl OutputWatchTask {
|
||||
pub fn delivered_output(&self) -> bool {
|
||||
self.delivered_output.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn root_stdout_completed(&self) -> bool {
|
||||
self.root_stdout_completed.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
// ── notifier WebSocket messages (mirrors websocket_server.rs) ────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -102,6 +124,58 @@ struct RestExecution {
|
||||
updated: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct WorkflowTaskMetadata {
|
||||
task_name: String,
|
||||
#[serde(default)]
|
||||
task_index: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct ExecutionListItem {
|
||||
id: i64,
|
||||
action_ref: String,
|
||||
status: String,
|
||||
#[serde(default)]
|
||||
workflow_task: Option<WorkflowTaskMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ChildWatchState {
|
||||
label: String,
|
||||
status: String,
|
||||
announced_terminal: bool,
|
||||
stream_handles: Vec<StreamWatchHandle>,
|
||||
}
|
||||
|
||||
struct RootWatchState {
|
||||
stream_handles: Vec<StreamWatchHandle>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StreamWatchHandle {
|
||||
stream_name: &'static str,
|
||||
offset: Arc<AtomicU64>,
|
||||
handle: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamWatchConfig {
|
||||
base_url: String,
|
||||
token: String,
|
||||
execution_id: i64,
|
||||
prefix: Option<String>,
|
||||
verbose: bool,
|
||||
delivered_output: Arc<AtomicBool>,
|
||||
root_stdout_completed: Option<Arc<AtomicBool>>,
|
||||
}
|
||||
|
||||
struct StreamLogTask {
|
||||
stream_name: &'static str,
|
||||
offset: Arc<AtomicU64>,
|
||||
config: StreamWatchConfig,
|
||||
}
|
||||
|
||||
impl From<RestExecution> for ExecutionSummary {
|
||||
fn from(e: RestExecution) -> Self {
|
||||
Self {
|
||||
@@ -177,6 +251,260 @@ pub async fn wait_for_execution(opts: WaitOptions<'_>) -> Result<ExecutionSummar
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn spawn_execution_output_watch(
|
||||
mut client: ApiClient,
|
||||
execution_id: i64,
|
||||
verbose: bool,
|
||||
) -> OutputWatchTask {
|
||||
let delivered_output = Arc::new(AtomicBool::new(false));
|
||||
let root_stdout_completed = Arc::new(AtomicBool::new(false));
|
||||
let delivered_output_for_task = delivered_output.clone();
|
||||
let root_stdout_completed_for_task = root_stdout_completed.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
if let Err(err) = watch_execution_output(
|
||||
&mut client,
|
||||
execution_id,
|
||||
verbose,
|
||||
delivered_output_for_task,
|
||||
root_stdout_completed_for_task,
|
||||
)
|
||||
.await
|
||||
{
|
||||
if verbose {
|
||||
eprintln!(" [watch] {}", err);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
OutputWatchTask {
|
||||
handle,
|
||||
delivered_output,
|
||||
root_stdout_completed,
|
||||
}
|
||||
}
|
||||
|
||||
async fn watch_execution_output(
|
||||
client: &mut ApiClient,
|
||||
execution_id: i64,
|
||||
verbose: bool,
|
||||
delivered_output: Arc<AtomicBool>,
|
||||
root_stdout_completed: Arc<AtomicBool>,
|
||||
) -> Result<()> {
|
||||
let base_url = client.base_url().to_string();
|
||||
let mut root_watch: Option<RootWatchState> = None;
|
||||
let mut children: HashMap<i64, ChildWatchState> = HashMap::new();
|
||||
|
||||
loop {
|
||||
let execution: RestExecution = client.get(&format!("/executions/{}", execution_id)).await?;
|
||||
|
||||
if root_watch
|
||||
.as_ref()
|
||||
.is_none_or(|state| streams_need_restart(&state.stream_handles))
|
||||
{
|
||||
if let Some(token) = client.auth_token().map(str::to_string) {
|
||||
match root_watch.as_mut() {
|
||||
Some(state) => restart_finished_streams(
|
||||
&mut state.stream_handles,
|
||||
&StreamWatchConfig {
|
||||
base_url: base_url.clone(),
|
||||
token,
|
||||
execution_id,
|
||||
prefix: None,
|
||||
verbose,
|
||||
delivered_output: delivered_output.clone(),
|
||||
root_stdout_completed: Some(root_stdout_completed.clone()),
|
||||
},
|
||||
),
|
||||
None => {
|
||||
root_watch = Some(RootWatchState {
|
||||
stream_handles: spawn_execution_log_streams(StreamWatchConfig {
|
||||
base_url: base_url.clone(),
|
||||
token,
|
||||
execution_id,
|
||||
verbose,
|
||||
prefix: None,
|
||||
delivered_output: delivered_output.clone(),
|
||||
root_stdout_completed: Some(root_stdout_completed.clone()),
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let child_items = list_child_executions(client, execution_id)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
for child in child_items {
|
||||
let label = format_task_label(&child.workflow_task, &child.action_ref, child.id);
|
||||
let entry = children.entry(child.id).or_insert_with(|| {
|
||||
if verbose {
|
||||
eprintln!(" [{}] started ({})", label, child.action_ref);
|
||||
}
|
||||
let stream_handles = client
|
||||
.auth_token()
|
||||
.map(str::to_string)
|
||||
.map(|token| {
|
||||
spawn_execution_log_streams(StreamWatchConfig {
|
||||
base_url: base_url.clone(),
|
||||
token,
|
||||
execution_id: child.id,
|
||||
prefix: Some(label.clone()),
|
||||
verbose,
|
||||
delivered_output: delivered_output.clone(),
|
||||
root_stdout_completed: None,
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
ChildWatchState {
|
||||
label,
|
||||
status: child.status.clone(),
|
||||
announced_terminal: false,
|
||||
stream_handles,
|
||||
}
|
||||
});
|
||||
|
||||
if entry.status != child.status {
|
||||
entry.status = child.status.clone();
|
||||
}
|
||||
|
||||
let child_is_terminal = is_terminal(&entry.status);
|
||||
if !child_is_terminal && streams_need_restart(&entry.stream_handles) {
|
||||
if let Some(token) = client.auth_token().map(str::to_string) {
|
||||
restart_finished_streams(
|
||||
&mut entry.stream_handles,
|
||||
&StreamWatchConfig {
|
||||
base_url: base_url.clone(),
|
||||
token,
|
||||
execution_id: child.id,
|
||||
prefix: Some(entry.label.clone()),
|
||||
verbose,
|
||||
delivered_output: delivered_output.clone(),
|
||||
root_stdout_completed: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !entry.announced_terminal && is_terminal(&child.status) {
|
||||
entry.announced_terminal = true;
|
||||
if verbose {
|
||||
eprintln!(" [{}] {}", entry.label, child.status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_terminal(&execution.status) {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
|
||||
if let Some(root_watch) = root_watch {
|
||||
wait_for_stream_handles(root_watch.stream_handles).await;
|
||||
}
|
||||
|
||||
for child in children.into_values() {
|
||||
wait_for_stream_handles(child.stream_handles).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_execution_log_streams(config: StreamWatchConfig) -> Vec<StreamWatchHandle> {
|
||||
["stdout", "stderr"]
|
||||
.into_iter()
|
||||
.map(|stream_name| {
|
||||
let offset = Arc::new(AtomicU64::new(0));
|
||||
let completion_flag = if stream_name == "stdout" {
|
||||
config.root_stdout_completed.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
StreamWatchHandle {
|
||||
stream_name,
|
||||
handle: tokio::spawn(stream_execution_log(StreamLogTask {
|
||||
stream_name,
|
||||
offset: offset.clone(),
|
||||
config: StreamWatchConfig {
|
||||
base_url: config.base_url.clone(),
|
||||
token: config.token.clone(),
|
||||
execution_id: config.execution_id,
|
||||
prefix: config.prefix.clone(),
|
||||
verbose: config.verbose,
|
||||
delivered_output: config.delivered_output.clone(),
|
||||
root_stdout_completed: completion_flag,
|
||||
},
|
||||
})),
|
||||
offset,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn streams_need_restart(handles: &[StreamWatchHandle]) -> bool {
|
||||
handles.is_empty() || handles.iter().any(|handle| handle.handle.is_finished())
|
||||
}
|
||||
|
||||
fn restart_finished_streams(handles: &mut [StreamWatchHandle], config: &StreamWatchConfig) {
|
||||
for stream in handles.iter_mut() {
|
||||
if stream.handle.is_finished() {
|
||||
let offset = stream.offset.clone();
|
||||
let completion_flag = if stream.stream_name == "stdout" {
|
||||
config.root_stdout_completed.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
stream.handle = tokio::spawn(stream_execution_log(StreamLogTask {
|
||||
stream_name: stream.stream_name,
|
||||
offset,
|
||||
config: StreamWatchConfig {
|
||||
base_url: config.base_url.clone(),
|
||||
token: config.token.clone(),
|
||||
execution_id: config.execution_id,
|
||||
prefix: config.prefix.clone(),
|
||||
verbose: config.verbose,
|
||||
delivered_output: config.delivered_output.clone(),
|
||||
root_stdout_completed: completion_flag,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_stream_handles(handles: Vec<StreamWatchHandle>) {
|
||||
for handle in handles {
|
||||
let _ = handle.handle.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_child_executions(
|
||||
client: &mut ApiClient,
|
||||
execution_id: i64,
|
||||
) -> Result<Vec<ExecutionListItem>> {
|
||||
const PER_PAGE: u32 = 100;
|
||||
|
||||
let mut page = 1;
|
||||
let mut all_children = Vec::new();
|
||||
|
||||
loop {
|
||||
let path = format!("/executions?parent={execution_id}&page={page}&per_page={PER_PAGE}");
|
||||
let mut page_items: Vec<ExecutionListItem> = client.get_paginated(&path).await?;
|
||||
let page_len = page_items.len();
|
||||
all_children.append(&mut page_items);
|
||||
|
||||
if page_len < PER_PAGE as usize {
|
||||
break;
|
||||
}
|
||||
|
||||
page += 1;
|
||||
}
|
||||
|
||||
Ok(all_children)
|
||||
}
|
||||
|
||||
// ── WebSocket path ────────────────────────────────────────────────────────────
|
||||
|
||||
async fn wait_via_websocket(
|
||||
@@ -482,6 +810,7 @@ fn resolve_ws_url(opts: &WaitOptions<'_>) -> Option<String> {
|
||||
/// - `https://api.example.com` → `wss://api.example.com:8081`
|
||||
/// - `http://api.example.com:9000` → `ws://api.example.com:8081`
|
||||
fn derive_notifier_url(api_url: &str) -> Option<String> {
|
||||
// nosemgrep: javascript.lang.security.detect-insecure-websocket.detect-insecure-websocket -- The function upgrades https->wss and only returns ws for explicit http base URLs or test examples.
|
||||
let url = url::Url::parse(api_url).ok()?;
|
||||
let ws_scheme = match url.scheme() {
|
||||
"https" => "wss",
|
||||
@@ -491,6 +820,148 @@ fn derive_notifier_url(api_url: &str) -> Option<String> {
|
||||
Some(format!("{}://{}:8081", ws_scheme, host))
|
||||
}
|
||||
|
||||
pub fn extract_stdout(result: &Option<serde_json::Value>) -> Option<String> {
|
||||
result
|
||||
.as_ref()
|
||||
.and_then(|value| value.get("stdout"))
|
||||
.and_then(|stdout| stdout.as_str())
|
||||
.filter(|stdout| !stdout.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
fn format_task_label(
|
||||
workflow_task: &Option<WorkflowTaskMetadata>,
|
||||
action_ref: &str,
|
||||
execution_id: i64,
|
||||
) -> String {
|
||||
if let Some(workflow_task) = workflow_task {
|
||||
if let Some(index) = workflow_task.task_index {
|
||||
format!("{}[{}]", workflow_task.task_name, index)
|
||||
} else {
|
||||
workflow_task.task_name.clone()
|
||||
}
|
||||
} else {
|
||||
format!("{}#{}", action_ref, execution_id)
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_execution_log(task: StreamLogTask) {
|
||||
let StreamLogTask {
|
||||
stream_name,
|
||||
offset,
|
||||
config:
|
||||
StreamWatchConfig {
|
||||
base_url,
|
||||
token,
|
||||
execution_id,
|
||||
prefix,
|
||||
verbose,
|
||||
delivered_output,
|
||||
root_stdout_completed,
|
||||
},
|
||||
} = task;
|
||||
|
||||
let mut stream_url = match url::Url::parse(&format!(
|
||||
"{}/api/v1/executions/{}/logs/{}/stream",
|
||||
base_url.trim_end_matches('/'),
|
||||
execution_id,
|
||||
stream_name
|
||||
)) {
|
||||
Ok(url) => url,
|
||||
Err(err) => {
|
||||
if verbose {
|
||||
eprintln!(" [watch] failed to build stream URL: {}", err);
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
let current_offset = offset.load(Ordering::Relaxed).to_string();
|
||||
stream_url
|
||||
.query_pairs_mut()
|
||||
.append_pair("token", &token)
|
||||
.append_pair("offset", ¤t_offset);
|
||||
|
||||
let mut event_source = EventSource::get(stream_url);
|
||||
let mut carry = String::new();
|
||||
|
||||
while let Some(event) = event_source.next().await {
|
||||
match event {
|
||||
Ok(SseEvent::Open) => {}
|
||||
Ok(SseEvent::Message(message)) => match message.event.as_str() {
|
||||
"content" | "append" => {
|
||||
if let Ok(server_offset) = message.id.parse::<u64>() {
|
||||
offset.store(server_offset, Ordering::Relaxed);
|
||||
}
|
||||
if !message.data.is_empty() {
|
||||
delivered_output.store(true, Ordering::Relaxed);
|
||||
}
|
||||
print_stream_chunk(prefix.as_deref(), &message.data, &mut carry);
|
||||
}
|
||||
"done" => {
|
||||
if let Some(flag) = &root_stdout_completed {
|
||||
flag.store(true, Ordering::Relaxed);
|
||||
}
|
||||
flush_stream_chunk(prefix.as_deref(), &mut carry);
|
||||
break;
|
||||
}
|
||||
"error" => {
|
||||
if verbose && !message.data.is_empty() {
|
||||
eprintln!(" [watch] {}", message.data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
flush_stream_chunk(prefix.as_deref(), &mut carry);
|
||||
if verbose {
|
||||
eprintln!(
|
||||
" [watch] stream error for execution {}: {}",
|
||||
execution_id, err
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flush_stream_chunk(prefix.as_deref(), &mut carry);
|
||||
event_source.close();
|
||||
}
|
||||
|
||||
fn print_stream_chunk(prefix: Option<&str>, chunk: &str, carry: &mut String) {
|
||||
carry.push_str(chunk);
|
||||
|
||||
while let Some(idx) = carry.find('\n') {
|
||||
let mut line = carry.drain(..=idx).collect::<String>();
|
||||
if line.ends_with('\n') {
|
||||
line.pop();
|
||||
}
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
|
||||
if let Some(prefix) = prefix {
|
||||
eprintln!("[{}] {}", prefix, line);
|
||||
} else {
|
||||
eprintln!("{}", line);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_stream_chunk(prefix: Option<&str>, carry: &mut String) {
|
||||
if carry.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(prefix) = prefix {
|
||||
eprintln!("[{}] {}", prefix, carry);
|
||||
} else {
|
||||
eprintln!("{}", carry);
|
||||
}
|
||||
carry.clear();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -553,4 +1024,26 @@ mod tests {
|
||||
assert_eq!(summary.status, "failed");
|
||||
assert_eq!(summary.action_ref, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_stdout() {
|
||||
let result = Some(serde_json::json!({
|
||||
"stdout": "hello world",
|
||||
"stderr_log": "/tmp/stderr.log"
|
||||
}));
|
||||
assert_eq!(extract_stdout(&result).as_deref(), Some("hello world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_label() {
|
||||
let workflow_task = Some(WorkflowTaskMetadata {
|
||||
task_name: "build".to_string(),
|
||||
task_index: Some(2),
|
||||
});
|
||||
assert_eq!(
|
||||
format_task_label(&workflow_task, "core.echo", 42),
|
||||
"build[2]"
|
||||
);
|
||||
assert_eq!(format_task_label(&None, "core.echo", 42), "core.echo#42");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,6 +73,7 @@ regex = { workspace = true }
|
||||
|
||||
# Version matching
|
||||
semver = { workspace = true }
|
||||
url = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mockall = { workspace = true }
|
||||
|
||||
107
crates/common/src/agent_bootstrap.rs
Normal file
107
crates/common/src/agent_bootstrap.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! Shared bootstrap helpers for injected agent binaries.
|
||||
|
||||
use crate::agent_runtime_detection::{
|
||||
detect_runtimes, format_as_env_value, print_detection_report_for_env, DetectedRuntime,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RuntimeBootstrapResult {
|
||||
pub runtimes_override: Option<String>,
|
||||
pub detected_runtimes: Option<Vec<DetectedRuntime>>,
|
||||
}
|
||||
|
||||
/// Detect runtimes and populate the agent runtime environment variable when needed.
|
||||
///
|
||||
/// This must run before the Tokio runtime starts because it may mutate process
|
||||
/// environment variables.
|
||||
pub fn bootstrap_runtime_env(env_var_name: &str) -> RuntimeBootstrapResult {
|
||||
let runtimes_override = std::env::var(env_var_name).ok();
|
||||
let mut detected_runtimes = None;
|
||||
|
||||
if let Some(ref override_value) = runtimes_override {
|
||||
info!(
|
||||
"{} already set (override): {}",
|
||||
env_var_name, override_value
|
||||
);
|
||||
info!("Running auto-detection for override-specified runtimes...");
|
||||
|
||||
let detected = detect_runtimes();
|
||||
let override_names: Vec<&str> = override_value.split(',').map(|s| s.trim()).collect();
|
||||
|
||||
let filtered: Vec<_> = detected
|
||||
.into_iter()
|
||||
.filter(|rt| {
|
||||
let lower_name = rt.name.to_ascii_lowercase();
|
||||
override_names
|
||||
.iter()
|
||||
.any(|ov| ov.to_ascii_lowercase() == lower_name)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if filtered.is_empty() {
|
||||
warn!(
|
||||
"None of the override runtimes ({}) were found on this system",
|
||||
override_value
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"Matched {} override runtime(s) to detected interpreters:",
|
||||
filtered.len()
|
||||
);
|
||||
for rt in &filtered {
|
||||
match &rt.version {
|
||||
Some(ver) => info!(" ✓ {} — {} ({})", rt.name, rt.path, ver),
|
||||
None => info!(" ✓ {} — {}", rt.name, rt.path),
|
||||
}
|
||||
}
|
||||
detected_runtimes = Some(filtered);
|
||||
}
|
||||
} else {
|
||||
info!("No {} override — running auto-detection...", env_var_name);
|
||||
|
||||
let detected = detect_runtimes();
|
||||
|
||||
if detected.is_empty() {
|
||||
warn!("No runtimes detected! The agent may not be able to execute any work.");
|
||||
} else {
|
||||
info!("Detected {} runtime(s):", detected.len());
|
||||
for rt in &detected {
|
||||
match &rt.version {
|
||||
Some(ver) => info!(" ✓ {} — {} ({})", rt.name, rt.path, ver),
|
||||
None => info!(" ✓ {} — {}", rt.name, rt.path),
|
||||
}
|
||||
}
|
||||
|
||||
let runtime_csv = format_as_env_value(&detected);
|
||||
info!("Setting {}={}", env_var_name, runtime_csv);
|
||||
std::env::set_var(env_var_name, &runtime_csv);
|
||||
detected_runtimes = Some(detected);
|
||||
}
|
||||
}
|
||||
|
||||
RuntimeBootstrapResult {
|
||||
runtimes_override,
|
||||
detected_runtimes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_detect_only_report(env_var_name: &str, result: &RuntimeBootstrapResult) {
|
||||
if result.runtimes_override.is_some() {
|
||||
info!("--detect-only: re-running detection to show what is available on this system...");
|
||||
println!(
|
||||
"NOTE: {} is set — auto-detection was skipped during normal startup.",
|
||||
env_var_name
|
||||
);
|
||||
println!(" Showing what auto-detection would find on this system:");
|
||||
println!();
|
||||
|
||||
let detected = detect_runtimes();
|
||||
print_detection_report_for_env(env_var_name, &detected);
|
||||
} else if let Some(ref detected) = result.detected_runtimes {
|
||||
print_detection_report_for_env(env_var_name, detected);
|
||||
} else {
|
||||
let detected = detect_runtimes();
|
||||
print_detection_report_for_env(env_var_name, &detected);
|
||||
}
|
||||
}
|
||||
306
crates/common/src/agent_runtime_detection.rs
Normal file
306
crates/common/src/agent_runtime_detection.rs
Normal file
@@ -0,0 +1,306 @@
|
||||
//! Runtime auto-detection for injected Attune agent binaries.
|
||||
//!
|
||||
//! This module probes the local system directly for well-known interpreters,
|
||||
//! without requiring database access.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::process::Command;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// A runtime interpreter discovered on the local system.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DetectedRuntime {
|
||||
/// Canonical runtime name (for example, "python" or "node").
|
||||
pub name: String,
|
||||
|
||||
/// Absolute path to the interpreter binary.
|
||||
pub path: String,
|
||||
|
||||
/// Version string if the version command succeeded.
|
||||
pub version: Option<String>,
|
||||
}
|
||||
|
||||
impl fmt::Display for DetectedRuntime {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.version {
|
||||
Some(v) => write!(f, "{} ({}, v{})", self.name, self.path, v),
|
||||
None => write!(f, "{} ({})", self.name, self.path),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RuntimeCandidate {
|
||||
name: &'static str,
|
||||
binaries: &'static [&'static str],
|
||||
version_args: &'static [&'static str],
|
||||
version_parser: VersionParser,
|
||||
}
|
||||
|
||||
enum VersionParser {
|
||||
SemverLike,
|
||||
JavaStyle,
|
||||
}
|
||||
|
||||
fn candidates() -> Vec<RuntimeCandidate> {
|
||||
vec![
|
||||
RuntimeCandidate {
|
||||
name: "shell",
|
||||
binaries: &["bash", "sh"],
|
||||
version_args: &["--version"],
|
||||
version_parser: VersionParser::SemverLike,
|
||||
},
|
||||
RuntimeCandidate {
|
||||
name: "python",
|
||||
binaries: &["python3", "python"],
|
||||
version_args: &["--version"],
|
||||
version_parser: VersionParser::SemverLike,
|
||||
},
|
||||
RuntimeCandidate {
|
||||
name: "node",
|
||||
binaries: &["node", "nodejs"],
|
||||
version_args: &["--version"],
|
||||
version_parser: VersionParser::SemverLike,
|
||||
},
|
||||
RuntimeCandidate {
|
||||
name: "ruby",
|
||||
binaries: &["ruby"],
|
||||
version_args: &["--version"],
|
||||
version_parser: VersionParser::SemverLike,
|
||||
},
|
||||
RuntimeCandidate {
|
||||
name: "go",
|
||||
binaries: &["go"],
|
||||
version_args: &["version"],
|
||||
version_parser: VersionParser::SemverLike,
|
||||
},
|
||||
RuntimeCandidate {
|
||||
name: "java",
|
||||
binaries: &["java"],
|
||||
version_args: &["-version"],
|
||||
version_parser: VersionParser::JavaStyle,
|
||||
},
|
||||
RuntimeCandidate {
|
||||
name: "r",
|
||||
binaries: &["Rscript"],
|
||||
version_args: &["--version"],
|
||||
version_parser: VersionParser::SemverLike,
|
||||
},
|
||||
RuntimeCandidate {
|
||||
name: "perl",
|
||||
binaries: &["perl"],
|
||||
version_args: &["--version"],
|
||||
version_parser: VersionParser::SemverLike,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Detect available runtimes by probing the local system.
|
||||
pub fn detect_runtimes() -> Vec<DetectedRuntime> {
|
||||
info!("Starting runtime auto-detection...");
|
||||
|
||||
let mut detected = Vec::new();
|
||||
|
||||
for candidate in candidates() {
|
||||
match detect_single_runtime(&candidate) {
|
||||
Some(runtime) => {
|
||||
info!(" ✓ Detected: {}", runtime);
|
||||
detected.push(runtime);
|
||||
}
|
||||
None => {
|
||||
debug!(" ✗ Not found: {}", candidate.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Runtime auto-detection complete: found {} runtime(s): [{}]",
|
||||
detected.len(),
|
||||
detected
|
||||
.iter()
|
||||
.map(|r| r.name.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
);
|
||||
|
||||
detected
|
||||
}
|
||||
|
||||
fn detect_single_runtime(candidate: &RuntimeCandidate) -> Option<DetectedRuntime> {
|
||||
for binary in candidate.binaries {
|
||||
if let Some(path) = which_binary(binary) {
|
||||
let version = get_version(&path, candidate.version_args, &candidate.version_parser);
|
||||
|
||||
return Some(DetectedRuntime {
|
||||
name: candidate.name.to_string(),
|
||||
path,
|
||||
version,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn which_binary(binary: &str) -> Option<String> {
|
||||
if binary == "bash" || binary == "sh" {
|
||||
let absolute_path = format!("/bin/{}", binary);
|
||||
if std::path::Path::new(&absolute_path).exists() {
|
||||
return Some(absolute_path);
|
||||
}
|
||||
}
|
||||
|
||||
match Command::new("which").arg(binary).output() {
|
||||
Ok(output) if output.status.success() => {
|
||||
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if path.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(path)
|
||||
}
|
||||
}
|
||||
Ok(_) => None,
|
||||
Err(e) => {
|
||||
debug!("'which' command failed ({}), trying 'command -v'", e);
|
||||
match Command::new("sh")
|
||||
.args(["-c", &format!("command -v {}", binary)])
|
||||
.output()
|
||||
{
|
||||
Ok(output) if output.status.success() => {
|
||||
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if path.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(path)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_version(binary_path: &str, version_args: &[&str], parser: &VersionParser) -> Option<String> {
|
||||
let output = match Command::new(binary_path).args(version_args).output() {
|
||||
Ok(output) => output,
|
||||
Err(e) => {
|
||||
debug!("Failed to run version command for {}: {}", binary_path, e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let combined = format!("{}{}", stdout, stderr);
|
||||
|
||||
match parser {
|
||||
VersionParser::SemverLike => parse_semver_like(&combined),
|
||||
VersionParser::JavaStyle => parse_java_version(&combined),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_semver_like(output: &str) -> Option<String> {
|
||||
let re = regex::Regex::new(r"(?:v|go)?(\d+\.\d+(?:\.\d+)?)").ok()?;
|
||||
re.captures(output)
|
||||
.and_then(|captures| captures.get(1).map(|m| m.as_str().to_string()))
|
||||
}
|
||||
|
||||
fn parse_java_version(output: &str) -> Option<String> {
|
||||
let quoted_re = regex::Regex::new(r#"version\s+"([^"]+)""#).ok()?;
|
||||
if let Some(captures) = quoted_re.captures(output) {
|
||||
return captures.get(1).map(|m| m.as_str().to_string());
|
||||
}
|
||||
|
||||
parse_semver_like(output)
|
||||
}
|
||||
|
||||
pub fn format_as_env_value(runtimes: &[DetectedRuntime]) -> String {
|
||||
runtimes
|
||||
.iter()
|
||||
.map(|r| r.name.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
}
|
||||
|
||||
pub fn print_detection_report_for_env(env_var_name: &str, runtimes: &[DetectedRuntime]) {
|
||||
println!("=== Attune Agent Runtime Detection Report ===");
|
||||
println!();
|
||||
|
||||
if runtimes.is_empty() {
|
||||
println!("No runtimes detected!");
|
||||
println!();
|
||||
println!("The agent could not find any supported interpreter binaries.");
|
||||
println!("Ensure at least one of the following is installed and on PATH:");
|
||||
println!(" - bash / sh (shell scripts)");
|
||||
println!(" - python3 / python (Python scripts)");
|
||||
println!(" - node / nodejs (Node.js scripts)");
|
||||
println!(" - ruby (Ruby scripts)");
|
||||
println!(" - go (Go programs)");
|
||||
println!(" - java (Java programs)");
|
||||
println!(" - Rscript (R scripts)");
|
||||
println!(" - perl (Perl scripts)");
|
||||
} else {
|
||||
println!("Detected {} runtime(s):", runtimes.len());
|
||||
println!();
|
||||
for rt in runtimes {
|
||||
let version_str = rt.version.as_deref().unwrap_or("unknown version");
|
||||
println!(" ✓ {:<10} {} ({})", rt.name, rt.path, version_str);
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
println!("{}={}", env_var_name, format_as_env_value(runtimes));
|
||||
println!();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_semver_like_python() {
|
||||
assert_eq!(
|
||||
parse_semver_like("Python 3.12.1"),
|
||||
Some("3.12.1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_semver_like_node() {
|
||||
assert_eq!(parse_semver_like("v20.11.0"), Some("20.11.0".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_semver_like_go() {
|
||||
assert_eq!(
|
||||
parse_semver_like("go version go1.22.0 linux/amd64"),
|
||||
Some("1.22.0".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_java_version_openjdk() {
|
||||
assert_eq!(
|
||||
parse_java_version(r#"openjdk version "21.0.1" 2023-10-17"#),
|
||||
Some("21.0.1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_as_env_value_multiple() {
|
||||
let runtimes = vec![
|
||||
DetectedRuntime {
|
||||
name: "shell".to_string(),
|
||||
path: "/bin/bash".to_string(),
|
||||
version: Some("5.2.15".to_string()),
|
||||
},
|
||||
DetectedRuntime {
|
||||
name: "python".to_string(),
|
||||
path: "/usr/bin/python3".to_string(),
|
||||
version: Some("3.12.1".to_string()),
|
||||
},
|
||||
];
|
||||
|
||||
assert_eq!(format_as_env_value(&runtimes), "shell,python");
|
||||
}
|
||||
}
|
||||
@@ -299,6 +299,18 @@ pub struct SecurityConfig {
|
||||
/// Allow unauthenticated self-service user registration
|
||||
#[serde(default)]
|
||||
pub allow_self_registration: bool,
|
||||
|
||||
/// Login page visibility defaults for the web UI.
|
||||
#[serde(default)]
|
||||
pub login_page: LoginPageConfig,
|
||||
|
||||
/// Optional OpenID Connect configuration for browser login.
|
||||
#[serde(default)]
|
||||
pub oidc: Option<OidcConfig>,
|
||||
|
||||
/// Optional LDAP configuration for username/password login against a directory.
|
||||
#[serde(default)]
|
||||
pub ldap: Option<LdapConfig>,
|
||||
}
|
||||
|
||||
fn default_jwt_access_expiration() -> u64 {
|
||||
@@ -309,6 +321,170 @@ fn default_jwt_refresh_expiration() -> u64 {
|
||||
604800 // 7 days
|
||||
}
|
||||
|
||||
/// Web login page configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LoginPageConfig {
|
||||
/// Show the local username/password form by default.
|
||||
#[serde(default = "default_true")]
|
||||
pub show_local_login: bool,
|
||||
|
||||
/// Show the OIDC/SSO option by default when configured.
|
||||
#[serde(default = "default_true")]
|
||||
pub show_oidc_login: bool,
|
||||
|
||||
/// Show the LDAP option by default when configured.
|
||||
#[serde(default = "default_true")]
|
||||
pub show_ldap_login: bool,
|
||||
}
|
||||
|
||||
impl Default for LoginPageConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
show_local_login: true,
|
||||
show_oidc_login: true,
|
||||
show_ldap_login: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenID Connect configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OidcConfig {
|
||||
/// Enable OpenID Connect login flow.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// OpenID Provider discovery document URL.
|
||||
/// Required when `enabled` is true; ignored otherwise.
|
||||
#[serde(default)]
|
||||
pub discovery_url: Option<String>,
|
||||
|
||||
/// Confidential client ID.
|
||||
/// Required when `enabled` is true; ignored otherwise.
|
||||
#[serde(default)]
|
||||
pub client_id: Option<String>,
|
||||
|
||||
/// Provider name used in login-page overrides such as `?auth=<provider_name>`.
|
||||
#[serde(default = "default_oidc_provider_name")]
|
||||
pub provider_name: String,
|
||||
|
||||
/// User-facing provider label shown on the login page.
|
||||
pub provider_label: Option<String>,
|
||||
|
||||
/// Optional icon URL shown beside the provider label on the login page.
|
||||
pub provider_icon_url: Option<String>,
|
||||
|
||||
/// Confidential client secret.
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
/// Redirect URI registered with the provider.
|
||||
/// Required when `enabled` is true; ignored otherwise.
|
||||
#[serde(default)]
|
||||
pub redirect_uri: Option<String>,
|
||||
|
||||
/// Optional post-logout redirect URI.
|
||||
pub post_logout_redirect_uri: Option<String>,
|
||||
|
||||
/// Optional requested scopes in addition to `openid email profile`.
|
||||
#[serde(default)]
|
||||
pub scopes: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_oidc_provider_name() -> String {
|
||||
"oidc".to_string()
|
||||
}
|
||||
|
||||
/// LDAP authentication configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LdapConfig {
|
||||
/// Enable LDAP login flow.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// LDAP server URL (e.g., "ldap://ldap.example.com:389" or "ldaps://ldap.example.com:636").
|
||||
/// Required when `enabled` is true; ignored otherwise.
|
||||
#[serde(default)]
|
||||
pub url: Option<String>,
|
||||
|
||||
/// Bind DN template. Use `{login}` as placeholder for the user-supplied login.
|
||||
/// Example: "uid={login},ou=users,dc=example,dc=com"
|
||||
/// If not set, an anonymous bind is attempted first to search for the user.
|
||||
pub bind_dn_template: Option<String>,
|
||||
|
||||
/// Base DN for user searches when bind_dn_template is not set.
|
||||
/// Example: "ou=users,dc=example,dc=com"
|
||||
pub user_search_base: Option<String>,
|
||||
|
||||
/// LDAP search filter template. Use `{login}` as placeholder.
|
||||
/// Default: "(uid={login})"
|
||||
#[serde(default = "default_ldap_user_filter")]
|
||||
pub user_filter: String,
|
||||
|
||||
/// DN of a service account used to search for users (required when using search-based auth).
|
||||
pub search_bind_dn: Option<String>,
|
||||
|
||||
/// Password for the search service account.
|
||||
pub search_bind_password: Option<String>,
|
||||
|
||||
/// LDAP attribute to use as the login name. Default: "uid"
|
||||
#[serde(default = "default_ldap_login_attr")]
|
||||
pub login_attr: String,
|
||||
|
||||
/// LDAP attribute to use as the email. Default: "mail"
|
||||
#[serde(default = "default_ldap_email_attr")]
|
||||
pub email_attr: String,
|
||||
|
||||
/// LDAP attribute to use as the display name. Default: "cn"
|
||||
#[serde(default = "default_ldap_display_name_attr")]
|
||||
pub display_name_attr: String,
|
||||
|
||||
/// LDAP attribute that contains group membership. Default: "memberOf"
|
||||
#[serde(default = "default_ldap_group_attr")]
|
||||
pub group_attr: String,
|
||||
|
||||
/// Whether to use STARTTLS. Default: false
|
||||
#[serde(default)]
|
||||
pub starttls: bool,
|
||||
|
||||
/// Whether to skip TLS certificate verification (insecure!). Default: false
|
||||
#[serde(default)]
|
||||
pub danger_skip_tls_verify: bool,
|
||||
|
||||
/// Provider name used in login-page overrides such as `?auth=<provider_name>`.
|
||||
#[serde(default = "default_ldap_provider_name")]
|
||||
pub provider_name: String,
|
||||
|
||||
/// User-facing provider label shown on the login page.
|
||||
pub provider_label: Option<String>,
|
||||
|
||||
/// Optional icon URL shown beside the provider label on the login page.
|
||||
pub provider_icon_url: Option<String>,
|
||||
}
|
||||
|
||||
fn default_ldap_provider_name() -> String {
|
||||
"ldap".to_string()
|
||||
}
|
||||
|
||||
fn default_ldap_user_filter() -> String {
|
||||
"(uid={login})".to_string()
|
||||
}
|
||||
|
||||
fn default_ldap_login_attr() -> String {
|
||||
"uid".to_string()
|
||||
}
|
||||
|
||||
fn default_ldap_email_attr() -> String {
|
||||
"mail".to_string()
|
||||
}
|
||||
|
||||
fn default_ldap_display_name_attr() -> String {
|
||||
"cn".to_string()
|
||||
}
|
||||
|
||||
fn default_ldap_group_attr() -> String {
|
||||
"memberOf".to_string()
|
||||
}
|
||||
|
||||
/// Worker configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkerConfig {
|
||||
@@ -482,6 +658,11 @@ pub struct PackRegistryConfig {
|
||||
#[serde(default = "default_true")]
|
||||
pub verify_checksums: bool,
|
||||
|
||||
/// Additional remote hosts allowed for pack archive/git downloads.
|
||||
/// Hosts from enabled registry indices are implicitly allowed.
|
||||
#[serde(default)]
|
||||
pub allowed_source_hosts: Vec<String>,
|
||||
|
||||
/// Allow HTTP (non-HTTPS) registries
|
||||
#[serde(default)]
|
||||
pub allow_http: bool,
|
||||
@@ -504,11 +685,21 @@ impl Default for PackRegistryConfig {
|
||||
cache_enabled: true,
|
||||
timeout: default_registry_timeout(),
|
||||
verify_checksums: true,
|
||||
allowed_source_hosts: Vec::new(),
|
||||
allow_http: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent binary distribution configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentConfig {
|
||||
/// Directory containing agent binary files
|
||||
pub binary_dir: String,
|
||||
/// Optional bootstrap token for authenticating agent binary downloads
|
||||
pub bootstrap_token: Option<String>,
|
||||
}
|
||||
|
||||
/// Executor service configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutorConfig {
|
||||
@@ -602,6 +793,9 @@ pub struct Config {
|
||||
|
||||
/// Executor configuration (optional, for executor service)
|
||||
pub executor: Option<ExecutorConfig>,
|
||||
|
||||
/// Agent configuration (optional, for agent binary distribution)
|
||||
pub agent: Option<AgentConfig>,
|
||||
}
|
||||
|
||||
fn default_service_name() -> String {
|
||||
@@ -681,6 +875,9 @@ impl Default for SecurityConfig {
|
||||
encryption_key: None,
|
||||
enable_auth: true,
|
||||
allow_self_registration: false,
|
||||
login_page: LoginPageConfig::default(),
|
||||
oidc: None,
|
||||
ldap: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -800,6 +997,51 @@ impl Config {
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(oidc) = &self.security.oidc {
|
||||
if oidc.enabled {
|
||||
if oidc
|
||||
.discovery_url
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.is_empty()
|
||||
{
|
||||
return Err(crate::Error::validation(
|
||||
"OIDC discovery URL is required when OIDC is enabled",
|
||||
));
|
||||
}
|
||||
if oidc.client_id.as_deref().unwrap_or("").trim().is_empty() {
|
||||
return Err(crate::Error::validation(
|
||||
"OIDC client ID is required when OIDC is enabled",
|
||||
));
|
||||
}
|
||||
if oidc
|
||||
.client_secret
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.is_empty()
|
||||
{
|
||||
return Err(crate::Error::validation(
|
||||
"OIDC client secret is required when OIDC is enabled",
|
||||
));
|
||||
}
|
||||
if oidc.redirect_uri.as_deref().unwrap_or("").trim().is_empty() {
|
||||
return Err(crate::Error::validation(
|
||||
"OIDC redirect URI is required when OIDC is enabled",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ldap) = &self.security.ldap {
|
||||
if ldap.enabled && ldap.url.as_deref().unwrap_or("").trim().is_empty() {
|
||||
return Err(crate::Error::validation(
|
||||
"LDAP server URL is required when LDAP is enabled",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Validate encryption key if provided
|
||||
if let Some(ref key) = self.security.encryption_key {
|
||||
if key.len() < 32 {
|
||||
@@ -864,6 +1106,7 @@ mod tests {
|
||||
notifier: None,
|
||||
pack_registry: PackRegistryConfig::default(),
|
||||
executor: None,
|
||||
agent: None,
|
||||
};
|
||||
|
||||
assert_eq!(config.service_name, "attune");
|
||||
@@ -930,6 +1173,9 @@ mod tests {
|
||||
encryption_key: Some("a".repeat(32)),
|
||||
enable_auth: true,
|
||||
allow_self_registration: false,
|
||||
login_page: LoginPageConfig::default(),
|
||||
oidc: None,
|
||||
ldap: None,
|
||||
},
|
||||
worker: None,
|
||||
sensor: None,
|
||||
@@ -939,6 +1185,7 @@ mod tests {
|
||||
notifier: None,
|
||||
pack_registry: PackRegistryConfig::default(),
|
||||
executor: None,
|
||||
agent: None,
|
||||
};
|
||||
|
||||
assert!(config.validate().is_ok());
|
||||
@@ -952,4 +1199,127 @@ mod tests {
|
||||
config.security.jwt_secret = None;
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oidc_config_disabled_no_urls_required() {
|
||||
let yaml = r#"
|
||||
enabled: false
|
||||
"#;
|
||||
let cfg: OidcConfig = serde_yaml_ng::from_str(yaml).unwrap();
|
||||
assert!(!cfg.enabled);
|
||||
assert!(cfg.discovery_url.is_none());
|
||||
assert!(cfg.client_id.is_none());
|
||||
assert!(cfg.redirect_uri.is_none());
|
||||
assert!(cfg.client_secret.is_none());
|
||||
assert_eq!(cfg.provider_name, "oidc");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ldap_config_disabled_no_url_required() {
|
||||
let yaml = r#"
|
||||
enabled: false
|
||||
"#;
|
||||
let cfg: LdapConfig = serde_yaml_ng::from_str(yaml).unwrap();
|
||||
assert!(!cfg.enabled);
|
||||
assert!(cfg.url.is_none());
|
||||
assert_eq!(cfg.provider_name, "ldap");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ldap_config_defaults() {
|
||||
let yaml = r#"
|
||||
enabled: true
|
||||
url: "ldap://localhost:389"
|
||||
client_id: "test"
|
||||
"#;
|
||||
let cfg: LdapConfig = serde_yaml_ng::from_str(yaml).unwrap();
|
||||
|
||||
assert!(cfg.enabled);
|
||||
assert_eq!(cfg.url.as_deref(), Some("ldap://localhost:389"));
|
||||
assert_eq!(cfg.user_filter, "(uid={login})");
|
||||
assert_eq!(cfg.login_attr, "uid");
|
||||
assert_eq!(cfg.email_attr, "mail");
|
||||
assert_eq!(cfg.display_name_attr, "cn");
|
||||
assert_eq!(cfg.group_attr, "memberOf");
|
||||
assert_eq!(cfg.provider_name, "ldap");
|
||||
assert!(!cfg.starttls);
|
||||
assert!(!cfg.danger_skip_tls_verify);
|
||||
assert!(cfg.bind_dn_template.is_none());
|
||||
assert!(cfg.user_search_base.is_none());
|
||||
assert!(cfg.search_bind_dn.is_none());
|
||||
assert!(cfg.search_bind_password.is_none());
|
||||
assert!(cfg.provider_label.is_none());
|
||||
assert!(cfg.provider_icon_url.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ldap_config_full_deserialization() {
|
||||
let yaml = r#"
|
||||
enabled: true
|
||||
url: "ldaps://ldap.corp.com:636"
|
||||
bind_dn_template: "uid={login},ou=people,dc=corp,dc=com"
|
||||
user_search_base: "ou=people,dc=corp,dc=com"
|
||||
user_filter: "(sAMAccountName={login})"
|
||||
search_bind_dn: "cn=svc,dc=corp,dc=com"
|
||||
search_bind_password: "secret"
|
||||
login_attr: "sAMAccountName"
|
||||
email_attr: "userPrincipalName"
|
||||
display_name_attr: "displayName"
|
||||
group_attr: "memberOf"
|
||||
starttls: true
|
||||
danger_skip_tls_verify: true
|
||||
provider_name: "corpldap"
|
||||
provider_label: "Corporate Directory"
|
||||
provider_icon_url: "https://corp.com/icon.svg"
|
||||
"#;
|
||||
let cfg: LdapConfig = serde_yaml_ng::from_str(yaml).unwrap();
|
||||
|
||||
assert!(cfg.enabled);
|
||||
assert_eq!(cfg.url.as_deref(), Some("ldaps://ldap.corp.com:636"));
|
||||
assert_eq!(
|
||||
cfg.bind_dn_template.as_deref(),
|
||||
Some("uid={login},ou=people,dc=corp,dc=com")
|
||||
);
|
||||
assert_eq!(
|
||||
cfg.user_search_base.as_deref(),
|
||||
Some("ou=people,dc=corp,dc=com")
|
||||
);
|
||||
assert_eq!(cfg.user_filter, "(sAMAccountName={login})");
|
||||
assert_eq!(cfg.search_bind_dn.as_deref(), Some("cn=svc,dc=corp,dc=com"));
|
||||
assert_eq!(cfg.search_bind_password.as_deref(), Some("secret"));
|
||||
assert_eq!(cfg.login_attr, "sAMAccountName");
|
||||
assert_eq!(cfg.email_attr, "userPrincipalName");
|
||||
assert_eq!(cfg.display_name_attr, "displayName");
|
||||
assert_eq!(cfg.group_attr, "memberOf");
|
||||
assert!(cfg.starttls);
|
||||
assert!(cfg.danger_skip_tls_verify);
|
||||
assert_eq!(cfg.provider_name, "corpldap");
|
||||
assert_eq!(cfg.provider_label.as_deref(), Some("Corporate Directory"));
|
||||
assert_eq!(
|
||||
cfg.provider_icon_url.as_deref(),
|
||||
Some("https://corp.com/icon.svg")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_security_config_ldap_none_by_default() {
|
||||
let yaml = r#"jwt_secret: "s""#;
|
||||
let cfg: SecurityConfig = serde_yaml_ng::from_str(yaml).unwrap();
|
||||
|
||||
assert!(cfg.ldap.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_login_page_show_ldap_default_true() {
|
||||
let cfg: LoginPageConfig = serde_yaml_ng::from_str("{}").unwrap();
|
||||
|
||||
assert!(cfg.show_ldap_login);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_login_page_show_ldap_explicit_false() {
|
||||
let cfg: LoginPageConfig = serde_yaml_ng::from_str("show_ldap_login: false").unwrap();
|
||||
|
||||
assert!(!cfg.show_ldap_login);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
//! - Configuration
|
||||
//! - Utilities
|
||||
|
||||
pub mod agent_bootstrap;
|
||||
pub mod agent_runtime_detection;
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod crypto;
|
||||
|
||||
@@ -444,13 +444,55 @@ pub mod runtime {
|
||||
|
||||
/// Optional environment variables to set during action execution.
|
||||
///
|
||||
/// Values support the same template variables as other fields:
|
||||
/// Entries support the same template variables as other fields:
|
||||
/// `{pack_dir}`, `{env_dir}`, `{interpreter}`, `{manifest_path}`.
|
||||
///
|
||||
/// Example: `{"NODE_PATH": "{env_dir}/node_modules"}` ensures Node.js
|
||||
/// can find packages installed in the isolated runtime environment.
|
||||
/// The shorthand string form replaces the variable entirely:
|
||||
/// `{"NODE_PATH": "{env_dir}/node_modules"}`
|
||||
///
|
||||
/// The object form supports declarative merge semantics:
|
||||
/// `{"PYTHONPATH": {"value": "{pack_dir}/lib", "operation": "prepend"}}`
|
||||
#[serde(default)]
|
||||
pub env_vars: HashMap<String, String>,
|
||||
pub env_vars: HashMap<String, RuntimeEnvVarConfig>,
|
||||
}
|
||||
|
||||
/// Declarative configuration for a single runtime environment variable.
|
||||
///
|
||||
/// The string form is shorthand for `{ "value": "...", "operation": "set" }`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(untagged)]
|
||||
pub enum RuntimeEnvVarConfig {
|
||||
Value(String),
|
||||
Spec(RuntimeEnvVarSpec),
|
||||
}
|
||||
|
||||
/// Full configuration for a runtime environment variable.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct RuntimeEnvVarSpec {
|
||||
/// Template value to resolve for this variable.
|
||||
pub value: String,
|
||||
|
||||
/// How the resolved value should be merged with any existing value.
|
||||
#[serde(default)]
|
||||
pub operation: RuntimeEnvVarOperation,
|
||||
|
||||
/// Separator used for prepend/append operations.
|
||||
#[serde(default = "default_env_var_separator")]
|
||||
pub separator: String,
|
||||
}
|
||||
|
||||
/// Merge behavior for runtime-provided environment variables.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RuntimeEnvVarOperation {
|
||||
#[default]
|
||||
Set,
|
||||
Prepend,
|
||||
Append,
|
||||
}
|
||||
|
||||
fn default_env_var_separator() -> String {
|
||||
":".to_string()
|
||||
}
|
||||
|
||||
/// Controls how inline code is materialized before execution.
|
||||
@@ -768,6 +810,43 @@ pub mod runtime {
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeEnvVarConfig {
|
||||
/// Resolve this environment variable against the current template
|
||||
/// variables and any existing value already present in the process env.
|
||||
pub fn resolve(
|
||||
&self,
|
||||
vars: &HashMap<&str, String>,
|
||||
existing_value: Option<&str>,
|
||||
) -> String {
|
||||
match self {
|
||||
Self::Value(value) => RuntimeExecutionConfig::resolve_template(value, vars),
|
||||
Self::Spec(spec) => {
|
||||
let resolved = RuntimeExecutionConfig::resolve_template(&spec.value, vars);
|
||||
match spec.operation {
|
||||
RuntimeEnvVarOperation::Set => resolved,
|
||||
RuntimeEnvVarOperation::Prepend => {
|
||||
join_env_var_values(&resolved, existing_value, &spec.separator)
|
||||
}
|
||||
RuntimeEnvVarOperation::Append => join_env_var_values(
|
||||
existing_value.unwrap_or_default(),
|
||||
Some(&resolved),
|
||||
&spec.separator,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn join_env_var_values(left: &str, right: Option<&str>, separator: &str) -> String {
|
||||
match (left.is_empty(), right.unwrap_or_default().is_empty()) {
|
||||
(true, true) => String::new(),
|
||||
(false, true) => left.to_string(),
|
||||
(true, false) => right.unwrap_or_default().to_string(),
|
||||
(false, false) => format!("{}{}{}", left, separator, right.unwrap_or_default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct Runtime {
|
||||
pub id: Id,
|
||||
@@ -776,10 +855,13 @@ pub mod runtime {
|
||||
pub pack_ref: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
pub aliases: Vec<String>,
|
||||
pub distributions: JsonDict,
|
||||
pub installation: Option<JsonDict>,
|
||||
pub installers: JsonDict,
|
||||
pub execution_config: JsonDict,
|
||||
pub auto_detected: bool,
|
||||
pub detection_config: JsonDict,
|
||||
pub created: DateTime<Utc>,
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
@@ -884,7 +966,7 @@ pub mod trigger {
|
||||
pub pack: Option<Id>,
|
||||
pub pack_ref: Option<String>,
|
||||
pub label: String,
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
pub entrypoint: String,
|
||||
pub runtime: Id,
|
||||
pub runtime_ref: String,
|
||||
@@ -912,7 +994,7 @@ pub mod action {
|
||||
pub pack: Id,
|
||||
pub pack_ref: String,
|
||||
pub label: String,
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
pub entrypoint: String,
|
||||
pub runtime: Option<Id>,
|
||||
/// Optional semver version constraint for the runtime
|
||||
@@ -962,7 +1044,7 @@ pub mod rule {
|
||||
pub pack: Id,
|
||||
pub pack_ref: String,
|
||||
pub label: String,
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
pub action: Option<Id>,
|
||||
pub action_ref: String,
|
||||
pub trigger: Option<Id>,
|
||||
@@ -1218,6 +1300,7 @@ pub mod identity {
|
||||
pub display_name: Option<String>,
|
||||
pub password_hash: Option<String>,
|
||||
pub attributes: JsonDict,
|
||||
pub frozen: bool,
|
||||
pub created: DateTime<Utc>,
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
@@ -1242,6 +1325,25 @@ pub mod identity {
|
||||
pub permset: Id,
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct IdentityRoleAssignment {
|
||||
pub id: Id,
|
||||
pub identity: Id,
|
||||
pub role: String,
|
||||
pub source: String,
|
||||
pub managed: bool,
|
||||
pub created: DateTime<Utc>,
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PermissionSetRoleAssignment {
|
||||
pub id: Id,
|
||||
pub permset: Id,
|
||||
pub role: String,
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
}
|
||||
|
||||
/// Key/Value storage
|
||||
@@ -1310,7 +1412,7 @@ pub mod artifact {
|
||||
pub content_type: Option<String>,
|
||||
/// Size of the latest version's content in bytes
|
||||
pub size_bytes: Option<i64>,
|
||||
/// Execution that produced this artifact (no FK — execution is a hypertable)
|
||||
/// Execution that produced this artifact (no FK by design)
|
||||
pub execution: Option<Id>,
|
||||
/// Structured JSONB data for progress artifacts or metadata
|
||||
pub data: Option<serde_json::Value>,
|
||||
@@ -1385,7 +1487,6 @@ pub mod workflow {
|
||||
pub out_schema: Option<JsonSchema>,
|
||||
pub definition: JsonDict,
|
||||
pub tags: Vec<String>,
|
||||
pub enabled: bool,
|
||||
pub created: DateTime<Utc>,
|
||||
pub updated: DateTime<Utc>,
|
||||
}
|
||||
@@ -1618,3 +1719,68 @@ pub mod entity_history {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::runtime::{
|
||||
RuntimeEnvVarConfig, RuntimeEnvVarOperation, RuntimeEnvVarSpec, RuntimeExecutionConfig,
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn runtime_execution_config_env_vars_accept_string_and_object_forms() {
|
||||
let config: RuntimeExecutionConfig = serde_json::from_value(json!({
|
||||
"env_vars": {
|
||||
"NODE_PATH": "{env_dir}/node_modules",
|
||||
"PYTHONPATH": {
|
||||
"value": "{pack_dir}/lib",
|
||||
"operation": "prepend",
|
||||
"separator": ":"
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("runtime execution config should deserialize");
|
||||
|
||||
assert!(matches!(
|
||||
config.env_vars.get("NODE_PATH"),
|
||||
Some(RuntimeEnvVarConfig::Value(value)) if value == "{env_dir}/node_modules"
|
||||
));
|
||||
|
||||
assert!(matches!(
|
||||
config.env_vars.get("PYTHONPATH"),
|
||||
Some(RuntimeEnvVarConfig::Spec(RuntimeEnvVarSpec {
|
||||
value,
|
||||
operation: RuntimeEnvVarOperation::Prepend,
|
||||
separator,
|
||||
})) if value == "{pack_dir}/lib" && separator == ":"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_env_var_config_resolves_prepend_and_append_against_existing_values() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("pack_dir", "/packs/example".to_string());
|
||||
vars.insert("env_dir", "/runtime_envs/example/python".to_string());
|
||||
|
||||
let prepend = RuntimeEnvVarConfig::Spec(RuntimeEnvVarSpec {
|
||||
value: "{pack_dir}/lib".to_string(),
|
||||
operation: RuntimeEnvVarOperation::Prepend,
|
||||
separator: ":".to_string(),
|
||||
});
|
||||
assert_eq!(
|
||||
prepend.resolve(&vars, Some("/already/set")),
|
||||
"/packs/example/lib:/already/set"
|
||||
);
|
||||
|
||||
let append = RuntimeEnvVarConfig::Spec(RuntimeEnvVarSpec {
|
||||
value: "{env_dir}/node_modules".to_string(),
|
||||
operation: RuntimeEnvVarOperation::Append,
|
||||
separator: ":".to_string(),
|
||||
});
|
||||
assert_eq!(
|
||||
append.resolve(&vars, Some("/base/modules")),
|
||||
"/base/modules:/runtime_envs/example/python/node_modules"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +102,12 @@ impl MqError {
|
||||
pub fn is_retriable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
MqError::Connection(_) | MqError::Channel(_) | MqError::Timeout(_) | MqError::Pool(_)
|
||||
MqError::Connection(_)
|
||||
| MqError::Channel(_)
|
||||
| MqError::Publish(_)
|
||||
| MqError::Timeout(_)
|
||||
| MqError::Pool(_)
|
||||
| MqError::Lapin(_)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,8 +10,9 @@ use crate::config::Config;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::models::Runtime;
|
||||
use crate::repositories::action::ActionRepository;
|
||||
use crate::repositories::runtime::RuntimeRepository;
|
||||
use crate::repositories::runtime::{self, RuntimeRepository};
|
||||
use crate::repositories::FindById as _;
|
||||
use regex::Regex;
|
||||
use serde_json::Value as JsonValue;
|
||||
use sqlx::{PgPool, Row};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@@ -94,10 +95,7 @@ pub struct PackEnvironmentManager {
|
||||
impl PackEnvironmentManager {
|
||||
/// Create a new pack environment manager
|
||||
pub fn new(pool: PgPool, config: &Config) -> Self {
|
||||
let base_path = PathBuf::from(&config.packs_base_dir)
|
||||
.parent()
|
||||
.map(|p| p.join("packenvs"))
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/attune/packenvs"));
|
||||
let base_path = PathBuf::from(&config.runtime_envs_dir);
|
||||
|
||||
Self { pool, base_path }
|
||||
}
|
||||
@@ -370,19 +368,15 @@ impl PackEnvironmentManager {
|
||||
// ========================================================================
|
||||
|
||||
async fn get_runtime(&self, runtime_id: i64) -> Result<Runtime> {
|
||||
sqlx::query_as::<_, Runtime>(
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, description, name,
|
||||
distributions, installation, installers, execution_config,
|
||||
created, updated
|
||||
FROM runtime
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(runtime_id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to fetch runtime: {}", e)))
|
||||
let query = format!(
|
||||
"SELECT {} FROM runtime WHERE id = $1",
|
||||
runtime::SELECT_COLUMNS
|
||||
);
|
||||
sqlx::query_as::<_, Runtime>(&query)
|
||||
.bind(runtime_id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Internal(format!("Failed to fetch runtime: {}", e)))
|
||||
}
|
||||
|
||||
fn runtime_requires_environment(&self, runtime: &Runtime) -> Result<bool> {
|
||||
@@ -403,19 +397,19 @@ impl PackEnvironmentManager {
|
||||
}
|
||||
|
||||
fn calculate_env_path(&self, pack_ref: &str, runtime: &Runtime) -> Result<PathBuf> {
|
||||
let runtime_name_lower = runtime.name.to_lowercase();
|
||||
let template = runtime
|
||||
.installers
|
||||
.get("base_path_template")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("/opt/attune/packenvs/{pack_ref}/{runtime_name_lower}");
|
||||
.unwrap_or("{pack_ref}/{runtime_name_lower}");
|
||||
|
||||
let runtime_name_lower = runtime.name.to_lowercase();
|
||||
let path_str = template
|
||||
.replace("{pack_ref}", pack_ref)
|
||||
.replace("{runtime_ref}", &runtime.r#ref)
|
||||
.replace("{runtime_name_lower}", &runtime_name_lower);
|
||||
|
||||
Ok(PathBuf::from(path_str))
|
||||
resolve_env_path(&self.base_path, &path_str)
|
||||
}
|
||||
|
||||
async fn upsert_environment_record(
|
||||
@@ -532,6 +526,7 @@ impl PackEnvironmentManager {
|
||||
let mut install_log = String::new();
|
||||
|
||||
// Create environment directory
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- env_path comes from validated runtime-env path construction under runtime_envs_dir.
|
||||
let env_path = PathBuf::from(&pack_env.env_path);
|
||||
if env_path.exists() {
|
||||
warn!(
|
||||
@@ -663,6 +658,8 @@ impl PackEnvironmentManager {
|
||||
env_path,
|
||||
&pack_path_str,
|
||||
)?;
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The candidate command path is validated and confined before any execution is attempted.
|
||||
let command = validate_installer_command(&command, pack_path, Path::new(env_path))?;
|
||||
|
||||
let args_template = installer
|
||||
.get("args")
|
||||
@@ -684,12 +681,17 @@ impl PackEnvironmentManager {
|
||||
|
||||
let cwd_template = installer.get("cwd").and_then(|v| v.as_str());
|
||||
let cwd = if let Some(cwd_t) = cwd_template {
|
||||
Some(self.resolve_template(
|
||||
cwd_t,
|
||||
pack_ref,
|
||||
runtime_ref,
|
||||
env_path,
|
||||
&pack_path_str,
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Installer cwd values are validated to stay under the pack root or environment directory.
|
||||
Some(validate_installer_path(
|
||||
&self.resolve_template(
|
||||
cwd_t,
|
||||
pack_ref,
|
||||
runtime_ref,
|
||||
env_path,
|
||||
&pack_path_str,
|
||||
)?,
|
||||
pack_path,
|
||||
Path::new(env_path),
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
@@ -767,6 +769,7 @@ impl PackEnvironmentManager {
|
||||
async fn execute_installer_action(&self, action: &InstallerAction) -> Result<String> {
|
||||
debug!("Executing: {} {:?}", action.command, action.args);
|
||||
|
||||
// nosemgrep: rust.actix.command-injection.rust-actix-command-injection.rust-actix-command-injection -- action.command is accepted only after strict validation of executable shape and allowed path roots.
|
||||
let mut cmd = Command::new(&action.command);
|
||||
cmd.args(&action.args);
|
||||
|
||||
@@ -804,7 +807,9 @@ impl PackEnvironmentManager {
|
||||
// Check file_exists condition
|
||||
if let Some(file_path_template) = condition.get("file_exists").and_then(|v| v.as_str()) {
|
||||
let file_path = file_path_template.replace("{pack_path}", &pack_path.to_string_lossy());
|
||||
return Ok(PathBuf::from(file_path).exists());
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Conditional file checks are validated to stay under trusted pack/environment roots before filesystem access.
|
||||
let validated = validate_installer_path(&file_path, pack_path, &self.base_path)?;
|
||||
return Ok(PathBuf::from(validated).exists());
|
||||
}
|
||||
|
||||
// Default: condition is true
|
||||
@@ -820,6 +825,93 @@ impl PackEnvironmentManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_env_path(base_path: &Path, path_str: &str) -> Result<PathBuf> {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- This helper normalizes env paths and preserves legacy absolute templates while still rejecting parent traversal.
|
||||
let raw_path = Path::new(path_str);
|
||||
if raw_path.is_absolute() {
|
||||
return normalize_relative_or_absolute_path(raw_path);
|
||||
}
|
||||
|
||||
let joined = base_path.join(raw_path);
|
||||
normalize_relative_or_absolute_path(&joined)
|
||||
}
|
||||
|
||||
fn normalize_relative_or_absolute_path(path: &Path) -> Result<PathBuf> {
|
||||
let mut normalized = PathBuf::new();
|
||||
for component in path.components() {
|
||||
match component {
|
||||
std::path::Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
|
||||
std::path::Component::RootDir => normalized.push(std::path::MAIN_SEPARATOR.to_string()),
|
||||
std::path::Component::CurDir => {}
|
||||
std::path::Component::ParentDir => {
|
||||
return Err(Error::validation(format!(
|
||||
"Parent-directory traversal is not allowed in installer paths: {}",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
std::path::Component::Normal(part) => normalized.push(part),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(normalized)
|
||||
}
|
||||
|
||||
fn validate_installer_command(command: &str, pack_path: &Path, env_path: &Path) -> Result<String> {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Command validation inspects the path form before enforcing allowed executable rules.
|
||||
let command_path = Path::new(command);
|
||||
if command_path.is_absolute() {
|
||||
return validate_installer_path(command, pack_path, env_path);
|
||||
}
|
||||
|
||||
if command.contains(std::path::MAIN_SEPARATOR) {
|
||||
return Err(Error::validation(format!(
|
||||
"Installer command must be a bare executable name or an allowed absolute path: {}",
|
||||
command
|
||||
)));
|
||||
}
|
||||
|
||||
let command_name_re = Regex::new(r"^[A-Za-z0-9._+-]+$").expect("valid installer regex");
|
||||
if !command_name_re.is_match(command) {
|
||||
return Err(Error::validation(format!(
|
||||
"Installer command contains invalid characters: {}",
|
||||
command
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(command.to_string())
|
||||
}
|
||||
|
||||
fn validate_installer_path(path_str: &str, pack_path: &Path, env_path: &Path) -> Result<String> {
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Path validation normalizes candidate installer paths before enforcing root confinement.
|
||||
let path = normalize_path(Path::new(path_str));
|
||||
let normalized_pack_path = normalize_path(pack_path);
|
||||
let normalized_env_path = normalize_path(env_path);
|
||||
if path.starts_with(&normalized_pack_path) || path.starts_with(&normalized_env_path) {
|
||||
Ok(path.to_string_lossy().to_string())
|
||||
} else {
|
||||
Err(Error::validation(format!(
|
||||
"Installer path must remain under the pack or environment directory: {}",
|
||||
path_str
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_path(path: &Path) -> PathBuf {
|
||||
let mut normalized = PathBuf::new();
|
||||
for component in path.components() {
|
||||
match component {
|
||||
std::path::Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
|
||||
std::path::Component::RootDir => normalized.push(std::path::MAIN_SEPARATOR.to_string()),
|
||||
std::path::Component::CurDir => {}
|
||||
std::path::Component::ParentDir => {
|
||||
normalized.pop();
|
||||
}
|
||||
std::path::Component::Normal(part) => normalized.push(part),
|
||||
}
|
||||
}
|
||||
normalized
|
||||
}
|
||||
|
||||
/// Collect the lowercase runtime names that require environment setup for a pack.
|
||||
///
|
||||
/// This queries the pack's actions, resolves their runtimes, and returns the names
|
||||
|
||||
@@ -349,6 +349,7 @@ mod tests {
|
||||
cache_enabled: true,
|
||||
timeout: 120,
|
||||
verify_checksums: true,
|
||||
allowed_source_hosts: Vec::new(),
|
||||
allow_http: false,
|
||||
};
|
||||
|
||||
|
||||
@@ -11,10 +11,14 @@
|
||||
use super::{Checksum, InstallSource, PackIndexEntry, RegistryClient};
|
||||
use crate::config::PackRegistryConfig;
|
||||
use crate::error::{Error, Result};
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv6Addr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::fs;
|
||||
use tokio::net::lookup_host;
|
||||
use tokio::process::Command;
|
||||
use url::Url;
|
||||
|
||||
/// Progress callback type
|
||||
pub type ProgressCallback = Arc<dyn Fn(ProgressEvent) + Send + Sync>;
|
||||
@@ -53,6 +57,12 @@ pub struct PackInstaller {
|
||||
/// Whether to verify checksums
|
||||
verify_checksums: bool,
|
||||
|
||||
/// Whether HTTP remote sources are allowed
|
||||
allow_http: bool,
|
||||
|
||||
/// Remote hosts allowed for archive/git installs
|
||||
allowed_remote_hosts: Option<HashSet<String>>,
|
||||
|
||||
/// Progress callback (optional)
|
||||
progress_callback: Option<ProgressCallback>,
|
||||
}
|
||||
@@ -106,17 +116,32 @@ impl PackInstaller {
|
||||
.await
|
||||
.map_err(|e| Error::internal(format!("Failed to create temp directory: {}", e)))?;
|
||||
|
||||
let (registry_client, verify_checksums) = if let Some(config) = registry_config {
|
||||
let verify_checksums = config.verify_checksums;
|
||||
(Some(RegistryClient::new(config)?), verify_checksums)
|
||||
} else {
|
||||
(None, false)
|
||||
};
|
||||
let (registry_client, verify_checksums, allow_http, allowed_remote_hosts) =
|
||||
if let Some(config) = registry_config {
|
||||
let verify_checksums = config.verify_checksums;
|
||||
let allow_http = config.allow_http;
|
||||
let allowed_remote_hosts = collect_allowed_remote_hosts(&config)?;
|
||||
let allowed_remote_hosts = if allowed_remote_hosts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(allowed_remote_hosts)
|
||||
};
|
||||
(
|
||||
Some(RegistryClient::new(config)?),
|
||||
verify_checksums,
|
||||
allow_http,
|
||||
allowed_remote_hosts,
|
||||
)
|
||||
} else {
|
||||
(None, false, false, None)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
temp_dir,
|
||||
registry_client,
|
||||
verify_checksums,
|
||||
allow_http,
|
||||
allowed_remote_hosts,
|
||||
progress_callback: None,
|
||||
})
|
||||
}
|
||||
@@ -152,6 +177,7 @@ impl PackInstaller {
|
||||
|
||||
/// Install from git repository
|
||||
async fn install_from_git(&self, url: &str, git_ref: Option<&str>) -> Result<InstalledPack> {
|
||||
self.validate_git_source(url).await?;
|
||||
tracing::info!("Installing pack from git: {} (ref: {:?})", url, git_ref);
|
||||
|
||||
self.report_progress(ProgressEvent::StepStarted {
|
||||
@@ -405,10 +431,12 @@ impl PackInstaller {
|
||||
|
||||
/// Download an archive from a URL
|
||||
async fn download_archive(&self, url: &str) -> Result<PathBuf> {
|
||||
let parsed_url = self.validate_remote_url(url).await?;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// nosemgrep: rust.actix.ssrf.reqwest-taint.reqwest-taint -- Remote source URLs are restricted to configured allowlisted hosts, HTTPS, and public IPs before request execution.
|
||||
let response = client
|
||||
.get(url)
|
||||
.get(parsed_url.clone())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::internal(format!("Failed to download archive: {}", e)))?;
|
||||
@@ -421,11 +449,7 @@ impl PackInstaller {
|
||||
}
|
||||
|
||||
// Determine filename from URL
|
||||
let filename = url
|
||||
.split('/')
|
||||
.next_back()
|
||||
.unwrap_or("archive.zip")
|
||||
.to_string();
|
||||
let filename = archive_filename_from_url(&parsed_url);
|
||||
|
||||
let archive_path = self.temp_dir.join(&filename);
|
||||
|
||||
@@ -442,6 +466,116 @@ impl PackInstaller {
|
||||
Ok(archive_path)
|
||||
}
|
||||
|
||||
async fn validate_remote_url(&self, raw_url: &str) -> Result<Url> {
|
||||
let parsed = Url::parse(raw_url)
|
||||
.map_err(|e| Error::validation(format!("Invalid remote URL '{}': {}", raw_url, e)))?;
|
||||
|
||||
if parsed.scheme() != "https" && !(self.allow_http && parsed.scheme() == "http") {
|
||||
return Err(Error::validation(format!(
|
||||
"Remote URL must use https{}: {}",
|
||||
if self.allow_http {
|
||||
" or http when pack_registry.allow_http is enabled"
|
||||
} else {
|
||||
""
|
||||
},
|
||||
raw_url
|
||||
)));
|
||||
}
|
||||
|
||||
if !parsed.username().is_empty() || parsed.password().is_some() {
|
||||
return Err(Error::validation(
|
||||
"Remote URLs with embedded credentials are not allowed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let host = parsed.host_str().ok_or_else(|| {
|
||||
Error::validation(format!("Remote URL is missing a host: {}", raw_url))
|
||||
})?;
|
||||
let normalized_host = host.to_ascii_lowercase();
|
||||
|
||||
if normalized_host == "localhost" {
|
||||
return Err(Error::validation(format!(
|
||||
"Remote URL host is not allowed: {}",
|
||||
host
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(allowed_remote_hosts) = &self.allowed_remote_hosts {
|
||||
if !allowed_remote_hosts.contains(&normalized_host) {
|
||||
return Err(Error::validation(format!(
|
||||
"Remote URL host '{}' is not in the configured allowlist. Add it to pack_registry.allowed_source_hosts.",
|
||||
host
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ip) = parsed.host().and_then(|host| match host {
|
||||
url::Host::Ipv4(ip) => Some(IpAddr::V4(ip)),
|
||||
url::Host::Ipv6(ip) => Some(IpAddr::V6(ip)),
|
||||
url::Host::Domain(_) => None,
|
||||
}) {
|
||||
ensure_public_ip(ip)?;
|
||||
}
|
||||
|
||||
let port = parsed.port_or_known_default().ok_or_else(|| {
|
||||
Error::validation(format!("Remote URL is missing a usable port: {}", raw_url))
|
||||
})?;
|
||||
|
||||
let resolved = lookup_host((host, port))
|
||||
.await
|
||||
.map_err(|e| Error::validation(format!("Failed to resolve host '{}': {}", host, e)))?;
|
||||
|
||||
let mut saw_address = false;
|
||||
for addr in resolved {
|
||||
saw_address = true;
|
||||
ensure_public_ip(addr.ip())?;
|
||||
}
|
||||
|
||||
if !saw_address {
|
||||
return Err(Error::validation(format!(
|
||||
"Remote URL host did not resolve to any addresses: {}",
|
||||
host
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
|
||||
async fn validate_git_source(&self, raw_url: &str) -> Result<()> {
|
||||
if raw_url.starts_with("http://") || raw_url.starts_with("https://") {
|
||||
self.validate_remote_url(raw_url).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(host) = extract_git_host(raw_url) {
|
||||
self.validate_remote_host(&host)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_remote_host(&self, host: &str) -> Result<()> {
|
||||
let normalized_host = host.to_ascii_lowercase();
|
||||
|
||||
if normalized_host == "localhost" {
|
||||
return Err(Error::validation(format!(
|
||||
"Remote host is not allowed: {}",
|
||||
host
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(allowed_remote_hosts) = &self.allowed_remote_hosts {
|
||||
if !allowed_remote_hosts.contains(&normalized_host) {
|
||||
return Err(Error::validation(format!(
|
||||
"Remote host '{}' is not in the configured allowlist. Add it to pack_registry.allowed_source_hosts.",
|
||||
host
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract an archive (zip or tar.gz)
|
||||
async fn extract_archive(&self, archive_path: &Path) -> Result<PathBuf> {
|
||||
let extract_dir = self.create_temp_dir().await?;
|
||||
@@ -583,6 +717,7 @@ impl PackInstaller {
|
||||
}
|
||||
|
||||
// Check in first subdirectory (common for GitHub archives)
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Archive inspection is limited to the temporary extraction directory created by this installer.
|
||||
let mut entries = fs::read_dir(base_dir)
|
||||
.await
|
||||
.map_err(|e| Error::internal(format!("Failed to read directory: {}", e)))?;
|
||||
@@ -618,6 +753,7 @@ impl PackInstaller {
|
||||
})?;
|
||||
|
||||
// Read source directory
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Directory copy operates on installer-managed local paths, not request-derived paths.
|
||||
let mut entries = fs::read_dir(src)
|
||||
.await
|
||||
.map_err(|e| Error::internal(format!("Failed to read source directory: {}", e)))?;
|
||||
@@ -674,6 +810,111 @@ impl PackInstaller {
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_allowed_remote_hosts(config: &PackRegistryConfig) -> Result<HashSet<String>> {
|
||||
let mut hosts = HashSet::new();
|
||||
|
||||
for index in &config.indices {
|
||||
if !index.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parsed = Url::parse(&index.url).map_err(|e| {
|
||||
Error::validation(format!("Invalid registry index URL '{}': {}", index.url, e))
|
||||
})?;
|
||||
|
||||
let host = parsed.host_str().ok_or_else(|| {
|
||||
Error::validation(format!(
|
||||
"Registry index URL '{}' is missing a host",
|
||||
index.url
|
||||
))
|
||||
})?;
|
||||
|
||||
hosts.insert(host.to_ascii_lowercase());
|
||||
}
|
||||
|
||||
for host in &config.allowed_source_hosts {
|
||||
let normalized = host.trim().to_ascii_lowercase();
|
||||
if !normalized.is_empty() {
|
||||
hosts.insert(normalized);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hosts)
|
||||
}
|
||||
|
||||
fn extract_git_host(raw_url: &str) -> Option<String> {
|
||||
if let Ok(parsed) = Url::parse(raw_url) {
|
||||
return parsed.host_str().map(|host| host.to_ascii_lowercase());
|
||||
}
|
||||
|
||||
raw_url.split_once('@').and_then(|(_, rest)| {
|
||||
rest.split_once(':')
|
||||
.map(|(host, _)| host.to_ascii_lowercase())
|
||||
})
|
||||
}
|
||||
|
||||
fn archive_filename_from_url(url: &Url) -> String {
|
||||
let raw_name = url
|
||||
.path_segments()
|
||||
.and_then(|mut segments| segments.rfind(|segment| !segment.is_empty()))
|
||||
.unwrap_or("archive.bin");
|
||||
|
||||
let sanitized: String = raw_name
|
||||
.chars()
|
||||
.map(|ch| match ch {
|
||||
'a'..='z' | 'A'..='Z' | '0'..='9' | '.' | '-' | '_' => ch,
|
||||
_ => '_',
|
||||
})
|
||||
.collect();
|
||||
|
||||
let filename = sanitized.trim_matches('.');
|
||||
if filename.is_empty() {
|
||||
"archive.bin".to_string()
|
||||
} else {
|
||||
filename.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_public_ip(ip: IpAddr) -> Result<()> {
|
||||
let is_blocked = match ip {
|
||||
IpAddr::V4(ip) => {
|
||||
let octets = ip.octets();
|
||||
let is_documentation_range = matches!(
|
||||
octets,
|
||||
[192, 0, 2, _] | [198, 51, 100, _] | [203, 0, 113, _]
|
||||
);
|
||||
ip.is_private()
|
||||
|| ip.is_loopback()
|
||||
|| ip.is_link_local()
|
||||
|| ip.is_multicast()
|
||||
|| ip.is_broadcast()
|
||||
|| is_documentation_range
|
||||
|| ip.is_unspecified()
|
||||
|| octets[0] == 0
|
||||
}
|
||||
IpAddr::V6(ip) => {
|
||||
let segments = ip.segments();
|
||||
let is_documentation_range = segments[0] == 0x2001 && segments[1] == 0x0db8;
|
||||
ip.is_loopback()
|
||||
|| ip.is_unspecified()
|
||||
|| ip.is_multicast()
|
||||
|| ip.is_unique_local()
|
||||
|| ip.is_unicast_link_local()
|
||||
|| is_documentation_range
|
||||
|| ip == Ipv6Addr::LOCALHOST
|
||||
}
|
||||
};
|
||||
|
||||
if is_blocked {
|
||||
return Err(Error::validation(format!(
|
||||
"Remote URL resolved to a non-public address: {}",
|
||||
ip
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -721,4 +962,52 @@ mod tests {
|
||||
|
||||
assert!(matches!(source, InstallSource::Git { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_archive_filename_from_url_sanitizes_path_segments() {
|
||||
let url = Url::parse("https://example.com/releases/../../pack.zip?token=x").unwrap();
|
||||
assert_eq!(archive_filename_from_url(&url), "pack.zip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ensure_public_ip_rejects_private_ipv4() {
|
||||
let err = ensure_public_ip(IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))).unwrap_err();
|
||||
assert!(err.to_string().contains("non-public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collect_allowed_remote_hosts_includes_indices_and_overrides() {
|
||||
let config = PackRegistryConfig {
|
||||
indices: vec![crate::config::RegistryIndexConfig {
|
||||
url: "https://registry.example.com/index.json".to_string(),
|
||||
priority: 1,
|
||||
enabled: true,
|
||||
name: None,
|
||||
headers: std::collections::HashMap::new(),
|
||||
}],
|
||||
allowed_source_hosts: vec!["github.com".to_string(), "cdn.example.com".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let hosts = collect_allowed_remote_hosts(&config).unwrap();
|
||||
assert!(hosts.contains("registry.example.com"));
|
||||
assert!(hosts.contains("github.com"));
|
||||
assert!(hosts.contains("cdn.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_git_host_from_scp_style_source() {
|
||||
assert_eq!(
|
||||
extract_git_host("git@github.com:org/repo.git"),
|
||||
Some("github.com".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_git_host_from_git_scheme_source() {
|
||||
assert_eq!(
|
||||
extract_git_host("git://github.com/org/repo.git"),
|
||||
Some("github.com".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
//! can reference the same workflow file with different configurations.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use sqlx::PgPool;
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -42,7 +42,6 @@ use crate::repositories::action::{ActionRepository, UpdateActionInput};
|
||||
use crate::repositories::identity::{
|
||||
CreatePermissionSetInput, PermissionSetRepository, UpdatePermissionSetInput,
|
||||
};
|
||||
use crate::repositories::runtime::{CreateRuntimeInput, RuntimeRepository, UpdateRuntimeInput};
|
||||
use crate::repositories::runtime_version::{
|
||||
CreateRuntimeVersionInput, RuntimeVersionRepository, UpdateRuntimeVersionInput,
|
||||
};
|
||||
@@ -53,7 +52,10 @@ use crate::repositories::trigger::{
|
||||
use crate::repositories::workflow::{
|
||||
CreateWorkflowDefinitionInput, UpdateWorkflowDefinitionInput, WorkflowDefinitionRepository,
|
||||
};
|
||||
use crate::repositories::{Create, Delete, FindById, FindByRef, Update};
|
||||
use crate::repositories::{
|
||||
runtime::{CreateRuntimeInput, RuntimeRepository, UpdateRuntimeInput},
|
||||
Create, Delete, FindById, FindByRef, Patch, Update,
|
||||
};
|
||||
use crate::version_matching::extract_version_components;
|
||||
use crate::workflow::parser::parse_workflow_yaml;
|
||||
|
||||
@@ -402,14 +404,32 @@ impl<'a> PackComponentLoader<'a> {
|
||||
.and_then(|v| serde_json::to_value(v).ok())
|
||||
.unwrap_or_else(|| serde_json::json!({}));
|
||||
|
||||
let aliases: Vec<String> = data
|
||||
.get("aliases")
|
||||
.and_then(|v| v.as_sequence())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_ascii_lowercase()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Check if runtime already exists — update in place if so
|
||||
if let Some(existing) = RuntimeRepository::find_by_ref(self.pool, &runtime_ref).await? {
|
||||
let update_input = UpdateRuntimeInput {
|
||||
description,
|
||||
description: Some(match description {
|
||||
Some(description) => Patch::Set(description),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
name: Some(name),
|
||||
distributions: Some(distributions),
|
||||
installation,
|
||||
installation: Some(match installation {
|
||||
Some(installation) => Patch::Set(installation),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
execution_config: Some(execution_config),
|
||||
aliases: Some(aliases),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
match RuntimeRepository::update(self.pool, existing.id, update_input).await {
|
||||
@@ -440,6 +460,9 @@ impl<'a> PackComponentLoader<'a> {
|
||||
distributions,
|
||||
installation,
|
||||
execution_config,
|
||||
aliases,
|
||||
auto_detected: false,
|
||||
detection_config: serde_json::json!({}),
|
||||
};
|
||||
|
||||
match RuntimeRepository::create(self.pool, input).await {
|
||||
@@ -547,9 +570,18 @@ impl<'a> PackComponentLoader<'a> {
|
||||
{
|
||||
let update_input = UpdateRuntimeVersionInput {
|
||||
version: None, // version string doesn't change
|
||||
version_major: Some(version_major),
|
||||
version_minor: Some(version_minor),
|
||||
version_patch: Some(version_patch),
|
||||
version_major: Some(match version_major {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
version_minor: Some(match version_minor {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
version_patch: Some(match version_patch {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
execution_config: Some(execution_config),
|
||||
distributions: Some(distributions),
|
||||
is_default: Some(is_default),
|
||||
@@ -693,8 +725,7 @@ impl<'a> PackComponentLoader<'a> {
|
||||
let description = data
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let enabled = data
|
||||
.get("enabled")
|
||||
@@ -713,10 +744,19 @@ impl<'a> PackComponentLoader<'a> {
|
||||
if let Some(existing) = TriggerRepository::find_by_ref(self.pool, &trigger_ref).await? {
|
||||
let update_input = UpdateTriggerInput {
|
||||
label: Some(label),
|
||||
description: Some(description),
|
||||
description: Some(match description {
|
||||
Some(description) => Patch::Set(description),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
enabled: Some(enabled),
|
||||
param_schema,
|
||||
out_schema,
|
||||
param_schema: Some(match param_schema {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
out_schema: Some(match out_schema {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
};
|
||||
|
||||
match TriggerRepository::update(self.pool, existing.id, update_input).await {
|
||||
@@ -740,7 +780,7 @@ impl<'a> PackComponentLoader<'a> {
|
||||
pack: Some(self.pack_id),
|
||||
pack_ref: Some(self.pack_ref.clone()),
|
||||
label,
|
||||
description: Some(description),
|
||||
description,
|
||||
enabled,
|
||||
param_schema,
|
||||
out_schema,
|
||||
@@ -820,8 +860,7 @@ impl<'a> PackComponentLoader<'a> {
|
||||
let description = data
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// ── Workflow file handling ──────────────────────────────────
|
||||
// If the action declares `workflow_file`, load the referenced
|
||||
@@ -838,7 +877,7 @@ impl<'a> PackComponentLoader<'a> {
|
||||
wf_path,
|
||||
&action_ref,
|
||||
&label,
|
||||
&description,
|
||||
description.as_deref().unwrap_or(""),
|
||||
&data,
|
||||
)
|
||||
.await
|
||||
@@ -918,10 +957,16 @@ impl<'a> PackComponentLoader<'a> {
|
||||
if let Some(existing) = ActionRepository::find_by_ref(self.pool, &action_ref).await? {
|
||||
let update_input = UpdateActionInput {
|
||||
label: Some(label),
|
||||
description: Some(description),
|
||||
description: Some(match description {
|
||||
Some(description) => Patch::Set(description),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
entrypoint: Some(entrypoint),
|
||||
runtime: runtime_id,
|
||||
runtime_version_constraint: Some(runtime_version_constraint),
|
||||
runtime_version_constraint: Some(match runtime_version_constraint {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
param_schema,
|
||||
out_schema,
|
||||
parameter_delivery: Some(parameter_delivery),
|
||||
@@ -1046,7 +1091,10 @@ impl<'a> PackComponentLoader<'a> {
|
||||
action_description: &str,
|
||||
action_data: &serde_yaml_ng::Value,
|
||||
) -> Result<Id> {
|
||||
let full_path = actions_dir.join(workflow_file_path);
|
||||
let pack_root = actions_dir.parent().ok_or_else(|| {
|
||||
Error::validation("Actions directory must live inside a pack directory".to_string())
|
||||
})?;
|
||||
let full_path = resolve_pack_relative_path(pack_root, actions_dir, workflow_file_path)?;
|
||||
if !full_path.exists() {
|
||||
return Err(Error::validation(format!(
|
||||
"Workflow file '{}' not found at '{}'",
|
||||
@@ -1055,6 +1103,7 @@ impl<'a> PackComponentLoader<'a> {
|
||||
)));
|
||||
}
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- The workflow path is normalized and confined to the pack root before this local read.
|
||||
let content = std::fs::read_to_string(&full_path).map_err(|e| {
|
||||
Error::io(format!(
|
||||
"Failed to read workflow file '{}': {}",
|
||||
@@ -1131,7 +1180,6 @@ impl<'a> PackComponentLoader<'a> {
|
||||
out_schema,
|
||||
definition: Some(definition_json),
|
||||
tags: Some(tags),
|
||||
enabled: Some(true),
|
||||
};
|
||||
|
||||
WorkflowDefinitionRepository::update(self.pool, existing.id, update_input).await?;
|
||||
@@ -1159,7 +1207,6 @@ impl<'a> PackComponentLoader<'a> {
|
||||
out_schema,
|
||||
definition: definition_json,
|
||||
tags,
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
let created = WorkflowDefinitionRepository::create(self.pool, create_input).await?;
|
||||
@@ -1271,8 +1318,7 @@ impl<'a> PackComponentLoader<'a> {
|
||||
let description = data
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let enabled = data
|
||||
.get("enabled")
|
||||
@@ -1308,15 +1354,24 @@ impl<'a> PackComponentLoader<'a> {
|
||||
if let Some(existing) = SensorRepository::find_by_ref(self.pool, &sensor_ref).await? {
|
||||
let update_input = UpdateSensorInput {
|
||||
label: Some(label),
|
||||
description: Some(description),
|
||||
description: Some(match description {
|
||||
Some(description) => Patch::Set(description),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
entrypoint: Some(entrypoint),
|
||||
runtime: Some(sensor_runtime_id),
|
||||
runtime_ref: Some(sensor_runtime_ref.clone()),
|
||||
runtime_version_constraint: Some(runtime_version_constraint.clone()),
|
||||
runtime_version_constraint: Some(match runtime_version_constraint.clone() {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
trigger: Some(trigger_id.unwrap_or(existing.trigger)),
|
||||
trigger_ref: Some(trigger_ref.unwrap_or(existing.trigger_ref.clone())),
|
||||
enabled: Some(enabled),
|
||||
param_schema,
|
||||
param_schema: Some(match param_schema {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Clear,
|
||||
}),
|
||||
config: Some(config),
|
||||
};
|
||||
|
||||
@@ -1598,11 +1653,60 @@ impl<'a> PackComponentLoader<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_pack_relative_path(
|
||||
pack_root: &Path,
|
||||
base_dir: &Path,
|
||||
relative_path: &str,
|
||||
) -> Result<PathBuf> {
|
||||
let canonical_pack_root = pack_root.canonicalize().map_err(|e| {
|
||||
Error::io(format!(
|
||||
"Failed to resolve pack root '{}': {}",
|
||||
pack_root.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
let canonical_base_dir = base_dir.canonicalize().map_err(|e| {
|
||||
Error::io(format!(
|
||||
"Failed to resolve base directory '{}': {}",
|
||||
base_dir.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
let canonical_candidate = normalize_path_from_base(&canonical_base_dir, relative_path);
|
||||
|
||||
if !canonical_candidate.starts_with(&canonical_pack_root) {
|
||||
return Err(Error::validation(format!(
|
||||
"Resolved path '{}' escapes pack root '{}'",
|
||||
canonical_candidate.display(),
|
||||
canonical_pack_root.display()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(canonical_candidate)
|
||||
}
|
||||
|
||||
fn normalize_path_from_base(base: &Path, relative_path: &str) -> PathBuf {
|
||||
let mut normalized = PathBuf::new();
|
||||
for component in base.join(relative_path).components() {
|
||||
match component {
|
||||
std::path::Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
|
||||
std::path::Component::RootDir => normalized.push(std::path::MAIN_SEPARATOR.to_string()),
|
||||
std::path::Component::CurDir => {}
|
||||
std::path::Component::ParentDir => {
|
||||
normalized.pop();
|
||||
}
|
||||
std::path::Component::Normal(part) => normalized.push(part),
|
||||
}
|
||||
}
|
||||
normalized
|
||||
}
|
||||
|
||||
/// Read all YAML files from a directory, returning `(filename, content)` pairs
|
||||
/// sorted by filename for deterministic ordering.
|
||||
fn read_yaml_files(dir: &Path) -> Result<Vec<(String, String)>> {
|
||||
let mut files = Vec::new();
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Pack loader scans pack-owned directories on disk after selecting the pack root.
|
||||
let entries = std::fs::read_dir(dir)
|
||||
.map_err(|e| Error::io(format!("Failed to read directory {}: {}", dir.display(), e)))?;
|
||||
|
||||
@@ -1625,6 +1729,7 @@ fn read_yaml_files(dir: &Path) -> Result<Vec<(String, String)>> {
|
||||
let path = entry.path();
|
||||
let filename = entry.file_name().to_string_lossy().to_string();
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- YAML files are read only after being discovered under the selected pack directory.
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.map_err(|e| Error::io(format!("Failed to read file {}: {}", path.display(), e)))?;
|
||||
|
||||
|
||||
@@ -292,6 +292,7 @@ fn copy_dir_all(src: &Path, dst: &Path) -> Result<()> {
|
||||
))
|
||||
})?;
|
||||
|
||||
// nosemgrep: rust.actix.path-traversal.tainted-path.tainted-path -- Pack storage copy recursively processes validated local directories under the configured pack store.
|
||||
for entry in fs::read_dir(src).map_err(|e| {
|
||||
Error::io(format!(
|
||||
"Failed to read source directory {}: {}",
|
||||
|
||||
@@ -21,10 +21,6 @@ pub enum Resource {
|
||||
Inquiries,
|
||||
Keys,
|
||||
Artifacts,
|
||||
Workflows,
|
||||
Webhooks,
|
||||
Analytics,
|
||||
History,
|
||||
Identities,
|
||||
Permissions,
|
||||
}
|
||||
@@ -40,6 +36,7 @@ pub enum Action {
|
||||
Cancel,
|
||||
Respond,
|
||||
Manage,
|
||||
Decrypt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
@@ -69,6 +66,8 @@ pub struct GrantConstraints {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub owner_types: Option<Vec<OwnerType>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub owner_refs: Option<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub visibility: Option<Vec<ArtifactVisibility>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub execution_scope: Option<ExecutionScopeConstraint>,
|
||||
@@ -99,6 +98,7 @@ pub struct AuthorizationContext {
|
||||
pub pack_ref: Option<String>,
|
||||
pub owner_identity_id: Option<Id>,
|
||||
pub owner_type: Option<OwnerType>,
|
||||
pub owner_ref: Option<String>,
|
||||
pub visibility: Option<ArtifactVisibility>,
|
||||
pub encrypted: Option<bool>,
|
||||
pub execution_owner_identity_id: Option<Id>,
|
||||
@@ -115,6 +115,7 @@ impl AuthorizationContext {
|
||||
pack_ref: None,
|
||||
owner_identity_id: None,
|
||||
owner_type: None,
|
||||
owner_ref: None,
|
||||
visibility: None,
|
||||
encrypted: None,
|
||||
execution_owner_identity_id: None,
|
||||
@@ -162,6 +163,15 @@ impl Grant {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(owner_refs) = &constraints.owner_refs {
|
||||
let Some(owner_ref) = &ctx.owner_ref else {
|
||||
return false;
|
||||
};
|
||||
if !owner_refs.contains(owner_ref) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(visibility) = &constraints.visibility {
|
||||
let Some(target_visibility) = ctx.visibility else {
|
||||
return false;
|
||||
@@ -289,4 +299,28 @@ mod tests {
|
||||
.insert("team".to_string(), json!("infra"));
|
||||
assert!(!grant.allows(Resource::Packs, Action::Read, &ctx));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn owner_ref_constraint_requires_exact_value_match() {
|
||||
let grant = Grant {
|
||||
resource: Resource::Artifacts,
|
||||
actions: vec![Action::Read],
|
||||
constraints: Some(GrantConstraints {
|
||||
owner_types: Some(vec![OwnerType::Pack]),
|
||||
owner_refs: Some(vec!["python_example".to_string()]),
|
||||
..Default::default()
|
||||
}),
|
||||
};
|
||||
|
||||
let mut ctx = AuthorizationContext::new(1);
|
||||
ctx.owner_type = Some(OwnerType::Pack);
|
||||
ctx.owner_ref = Some("python_example".to_string());
|
||||
assert!(grant.allows(Resource::Artifacts, Action::Read, &ctx));
|
||||
|
||||
ctx.owner_ref = Some("other_pack".to_string());
|
||||
assert!(!grant.allows(Resource::Artifacts, Action::Read, &ctx));
|
||||
|
||||
ctx.owner_ref = None;
|
||||
assert!(!grant.allows(Resource::Artifacts, Action::Read, &ctx));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::models::{action::*, enums::PolicyMethod, Id, JsonSchema};
|
||||
use crate::{Error, Result};
|
||||
use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Patch, Repository, Update};
|
||||
|
||||
/// Columns selected in all Action queries. Must match the `Action` model's `FromRow` fields.
|
||||
pub const ACTION_COLUMNS: &str = "id, ref, pack, pack_ref, label, description, entrypoint, \
|
||||
@@ -51,7 +51,7 @@ pub struct CreateActionInput {
|
||||
pub pack: Id,
|
||||
pub pack_ref: String,
|
||||
pub label: String,
|
||||
pub description: String,
|
||||
pub description: Option<String>,
|
||||
pub entrypoint: String,
|
||||
pub runtime: Option<Id>,
|
||||
pub runtime_version_constraint: Option<String>,
|
||||
@@ -64,10 +64,10 @@ pub struct CreateActionInput {
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct UpdateActionInput {
|
||||
pub label: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub description: Option<Patch<String>>,
|
||||
pub entrypoint: Option<String>,
|
||||
pub runtime: Option<Id>,
|
||||
pub runtime_version_constraint: Option<Option<String>>,
|
||||
pub runtime_version_constraint: Option<Patch<String>>,
|
||||
pub param_schema: Option<JsonSchema>,
|
||||
pub out_schema: Option<JsonSchema>,
|
||||
pub parameter_delivery: Option<String>,
|
||||
@@ -210,7 +210,10 @@ impl Update for ActionRepository {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("description = ");
|
||||
query.push_bind(description);
|
||||
match description {
|
||||
Patch::Set(value) => query.push_bind(value),
|
||||
Patch::Clear => query.push_bind(Option::<String>::None),
|
||||
};
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
@@ -237,7 +240,10 @@ impl Update for ActionRepository {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("runtime_version_constraint = ");
|
||||
query.push_bind(runtime_version_constraint);
|
||||
match runtime_version_constraint {
|
||||
Patch::Set(value) => query.push_bind(value),
|
||||
Patch::Clear => query.push_bind(Option::<String>::None),
|
||||
};
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
@@ -565,7 +571,7 @@ impl Repository for PolicyRepository {
|
||||
type Entity = Policy;
|
||||
|
||||
fn table_name() -> &'static str {
|
||||
"policies"
|
||||
"policy"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -606,7 +612,7 @@ impl FindById for PolicyRepository {
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policies
|
||||
FROM policy
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
@@ -628,7 +634,7 @@ impl FindByRef for PolicyRepository {
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policies
|
||||
FROM policy
|
||||
WHERE ref = $1
|
||||
"#,
|
||||
)
|
||||
@@ -650,7 +656,7 @@ impl List for PolicyRepository {
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policies
|
||||
FROM policy
|
||||
ORDER BY ref ASC
|
||||
"#,
|
||||
)
|
||||
@@ -672,7 +678,7 @@ impl Create for PolicyRepository {
|
||||
// Try to insert - database will enforce uniqueness constraint
|
||||
let policy = sqlx::query_as::<_, Policy>(
|
||||
r#"
|
||||
INSERT INTO policies (ref, pack, pack_ref, action, action_ref, parameters,
|
||||
INSERT INTO policy (ref, pack, pack_ref, action, action_ref, parameters,
|
||||
method, threshold, name, description, tags)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
@@ -714,7 +720,7 @@ impl Update for PolicyRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let mut query = QueryBuilder::new("UPDATE policies SET ");
|
||||
let mut query = QueryBuilder::new("UPDATE policy SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(parameters) = &input.parameters {
|
||||
@@ -792,7 +798,7 @@ impl Delete for PolicyRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM policies WHERE id = $1")
|
||||
let result = sqlx::query("DELETE FROM policy WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
@@ -811,7 +817,7 @@ impl PolicyRepository {
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policies
|
||||
FROM policy
|
||||
WHERE action = $1
|
||||
ORDER BY ref ASC
|
||||
"#,
|
||||
@@ -832,7 +838,7 @@ impl PolicyRepository {
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policies
|
||||
FROM policy
|
||||
WHERE $1 = ANY(tags)
|
||||
ORDER BY ref ASC
|
||||
"#,
|
||||
@@ -843,4 +849,69 @@ impl PolicyRepository {
|
||||
|
||||
Ok(policies)
|
||||
}
|
||||
|
||||
/// Find the most recent action-specific policy.
|
||||
pub async fn find_latest_by_action<'e, E>(executor: E, action_id: Id) -> Result<Option<Policy>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let policy = sqlx::query_as::<_, Policy>(
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policy
|
||||
WHERE action = $1
|
||||
ORDER BY created DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(action_id)
|
||||
.fetch_optional(executor)
|
||||
.await?;
|
||||
|
||||
Ok(policy)
|
||||
}
|
||||
|
||||
/// Find the most recent pack-specific policy.
|
||||
pub async fn find_latest_by_pack<'e, E>(executor: E, pack_id: Id) -> Result<Option<Policy>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let policy = sqlx::query_as::<_, Policy>(
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policy
|
||||
WHERE pack = $1 AND action IS NULL
|
||||
ORDER BY created DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(pack_id)
|
||||
.fetch_optional(executor)
|
||||
.await?;
|
||||
|
||||
Ok(policy)
|
||||
}
|
||||
|
||||
/// Find the most recent global policy.
|
||||
pub async fn find_latest_global<'e, E>(executor: E) -> Result<Option<Policy>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let policy = sqlx::query_as::<_, Policy>(
|
||||
r#"
|
||||
SELECT id, ref, pack, pack_ref, action, action_ref, parameters, method,
|
||||
threshold, name, description, tags, created, updated
|
||||
FROM policy
|
||||
WHERE pack IS NULL AND action IS NULL
|
||||
ORDER BY created DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.fetch_optional(executor)
|
||||
.await?;
|
||||
|
||||
Ok(policy)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ pub struct EnforcementVolumeBucket {
|
||||
pub enforcement_count: i64,
|
||||
}
|
||||
|
||||
/// A single hourly bucket of execution volume (from execution hypertable directly).
|
||||
/// A single hourly bucket of execution volume (from the execution table directly).
|
||||
#[derive(Debug, Clone, Serialize, FromRow)]
|
||||
pub struct ExecutionVolumeBucket {
|
||||
/// Start of the 1-hour bucket
|
||||
@@ -468,7 +468,7 @@ impl AnalyticsRepository {
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// Execution volume (from execution hypertable directly)
|
||||
// Execution volume (from the execution table directly)
|
||||
// =======================================================================
|
||||
|
||||
/// Query the `execution_volume_hourly` continuous aggregate for execution
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::models::{
|
||||
use crate::Result;
|
||||
use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Repository, Update};
|
||||
use super::{Create, Delete, FindById, FindByRef, List, Patch, Repository, Update};
|
||||
|
||||
// ============================================================================
|
||||
// ArtifactRepository
|
||||
@@ -48,12 +48,12 @@ pub struct UpdateArtifactInput {
|
||||
pub visibility: Option<ArtifactVisibility>,
|
||||
pub retention_policy: Option<RetentionPolicyType>,
|
||||
pub retention_limit: Option<i32>,
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub name: Option<Patch<String>>,
|
||||
pub description: Option<Patch<String>>,
|
||||
pub content_type: Option<Patch<String>>,
|
||||
pub size_bytes: Option<i64>,
|
||||
pub execution: Option<Option<i64>>,
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub execution: Option<Patch<i64>>,
|
||||
pub data: Option<Patch<serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Filters for searching artifacts
|
||||
@@ -186,20 +186,62 @@ impl Update for ArtifactRepository {
|
||||
push_field!(input.visibility, "visibility");
|
||||
push_field!(input.retention_policy, "retention_policy");
|
||||
push_field!(input.retention_limit, "retention_limit");
|
||||
push_field!(&input.name, "name");
|
||||
push_field!(&input.description, "description");
|
||||
push_field!(&input.content_type, "content_type");
|
||||
if let Some(name) = &input.name {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("name = ");
|
||||
match name {
|
||||
Patch::Set(value) => query.push_bind(value),
|
||||
Patch::Clear => query.push_bind(Option::<String>::None),
|
||||
};
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(description) = &input.description {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("description = ");
|
||||
match description {
|
||||
Patch::Set(value) => query.push_bind(value),
|
||||
Patch::Clear => query.push_bind(Option::<String>::None),
|
||||
};
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(content_type) = &input.content_type {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("content_type = ");
|
||||
match content_type {
|
||||
Patch::Set(value) => query.push_bind(value),
|
||||
Patch::Clear => query.push_bind(Option::<String>::None),
|
||||
};
|
||||
has_updates = true;
|
||||
}
|
||||
push_field!(input.size_bytes, "size_bytes");
|
||||
// execution is Option<Option<i64>> — outer Option = "was field provided?",
|
||||
// inner Option = nullable column value
|
||||
if let Some(exec_val) = input.execution {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("execution = ").push_bind(exec_val);
|
||||
query.push("execution = ");
|
||||
match exec_val {
|
||||
Patch::Set(value) => query.push_bind(value),
|
||||
Patch::Clear => query.push_bind(Option::<i64>::None),
|
||||
};
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(data) = &input.data {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("data = ");
|
||||
match data {
|
||||
Patch::Set(value) => query.push_bind(value),
|
||||
Patch::Clear => query.push_bind(Option::<serde_json::Value>::None),
|
||||
};
|
||||
has_updates = true;
|
||||
}
|
||||
push_field!(&input.data, "data");
|
||||
|
||||
if !has_updates {
|
||||
return Self::get_by_id(executor, id).await;
|
||||
@@ -535,6 +577,14 @@ pub struct CreateArtifactVersionInput {
|
||||
}
|
||||
|
||||
impl ArtifactVersionRepository {
|
||||
fn select_columns_with_alias(alias: &str) -> String {
|
||||
format!(
|
||||
"{alias}.id, {alias}.artifact, {alias}.version, {alias}.content_type, \
|
||||
{alias}.size_bytes, NULL::bytea AS content, {alias}.content_json, \
|
||||
{alias}.file_path, {alias}.meta, {alias}.created_by, {alias}.created"
|
||||
)
|
||||
}
|
||||
|
||||
/// Find a version by ID (without binary content for performance)
|
||||
pub async fn find_by_id<'e, E>(executor: E, id: i64) -> Result<Option<ArtifactVersion>>
|
||||
where
|
||||
@@ -770,14 +820,11 @@ impl ArtifactVersionRepository {
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT av.{} \
|
||||
"SELECT {} \
|
||||
FROM artifact_version av \
|
||||
JOIN artifact a ON av.artifact = a.id \
|
||||
WHERE a.execution = $1 AND av.file_path IS NOT NULL",
|
||||
artifact_version::SELECT_COLUMNS
|
||||
.split(", ")
|
||||
.collect::<Vec<_>>()
|
||||
.join(", av.")
|
||||
Self::select_columns_with_alias("av")
|
||||
);
|
||||
sqlx::query_as::<_, ArtifactVersion>(&query)
|
||||
.bind(execution_id)
|
||||
@@ -805,3 +852,18 @@ impl ArtifactVersionRepository {
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ArtifactVersionRepository;
|
||||
|
||||
#[test]
|
||||
fn aliased_select_columns_keep_null_content_expression_unqualified() {
|
||||
let columns = ArtifactVersionRepository::select_columns_with_alias("av");
|
||||
|
||||
assert!(columns.contains("av.id"));
|
||||
assert!(columns.contains("av.file_path"));
|
||||
assert!(columns.contains("NULL::bytea AS content"));
|
||||
assert!(!columns.contains("av.NULL::bytea AS content"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,12 @@ pub struct EnforcementSearchResult {
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EnforcementCreateOrGetResult {
|
||||
pub enforcement: Enforcement,
|
||||
pub created: bool,
|
||||
}
|
||||
|
||||
/// Repository for Event operations
|
||||
pub struct EventRepository;
|
||||
|
||||
@@ -416,7 +422,115 @@ impl Update for EnforcementRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
// Build update query
|
||||
if input.status.is_none() && input.payload.is_none() && input.resolved_at.is_none() {
|
||||
return Self::get_by_id(executor, id).await;
|
||||
}
|
||||
|
||||
Self::update_with_locator(executor, input, |query| {
|
||||
query.push(" WHERE id = ");
|
||||
query.push_bind(id);
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Delete for EnforcementRepository {
|
||||
async fn delete<'e, E>(executor: E, id: i64) -> Result<bool>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM enforcement WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl EnforcementRepository {
|
||||
async fn update_with_locator<'e, E, F>(
|
||||
executor: E,
|
||||
input: UpdateEnforcementInput,
|
||||
where_clause: F,
|
||||
) -> Result<Enforcement>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
F: FnOnce(&mut QueryBuilder<'_, Postgres>),
|
||||
{
|
||||
let mut query = QueryBuilder::new("UPDATE enforcement SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(status) = input.status {
|
||||
query.push("status = ");
|
||||
query.push_bind(status);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(payload) = &input.payload {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("payload = ");
|
||||
query.push_bind(payload);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(resolved_at) = input.resolved_at {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("resolved_at = ");
|
||||
query.push_bind(resolved_at);
|
||||
}
|
||||
|
||||
where_clause(&mut query);
|
||||
query.push(
|
||||
" RETURNING id, rule, rule_ref, trigger_ref, config, event, status, payload, \
|
||||
condition, conditions, created, resolved_at",
|
||||
);
|
||||
|
||||
let enforcement = query
|
||||
.build_query_as::<Enforcement>()
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
|
||||
Ok(enforcement)
|
||||
}
|
||||
|
||||
/// Update an enforcement using the loaded row's primary key.
|
||||
pub async fn update_loaded<'e, E>(
|
||||
executor: E,
|
||||
enforcement: &Enforcement,
|
||||
input: UpdateEnforcementInput,
|
||||
) -> Result<Enforcement>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
if input.status.is_none() && input.payload.is_none() && input.resolved_at.is_none() {
|
||||
return Ok(enforcement.clone());
|
||||
}
|
||||
|
||||
Self::update_with_locator(executor, input, |query| {
|
||||
query.push(" WHERE id = ");
|
||||
query.push_bind(enforcement.id);
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn update_loaded_if_status<'e, E>(
|
||||
executor: E,
|
||||
enforcement: &Enforcement,
|
||||
expected_status: EnforcementStatus,
|
||||
input: UpdateEnforcementInput,
|
||||
) -> Result<Option<Enforcement>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
if input.status.is_none() && input.payload.is_none() && input.resolved_at.is_none() {
|
||||
return Ok(Some(enforcement.clone()));
|
||||
}
|
||||
|
||||
let mut query = QueryBuilder::new("UPDATE enforcement SET ");
|
||||
let mut has_updates = false;
|
||||
@@ -446,39 +560,25 @@ impl Update for EnforcementRepository {
|
||||
}
|
||||
|
||||
if !has_updates {
|
||||
// No updates requested, fetch and return existing entity
|
||||
return Self::get_by_id(executor, id).await;
|
||||
return Ok(Some(enforcement.clone()));
|
||||
}
|
||||
|
||||
query.push(" WHERE id = ");
|
||||
query.push_bind(id);
|
||||
query.push(" RETURNING id, rule, rule_ref, trigger_ref, config, event, status, payload, condition, conditions, created, resolved_at");
|
||||
query.push_bind(enforcement.id);
|
||||
query.push(" AND status = ");
|
||||
query.push_bind(expected_status);
|
||||
query.push(
|
||||
" RETURNING id, rule, rule_ref, trigger_ref, config, event, status, payload, \
|
||||
condition, conditions, created, resolved_at",
|
||||
);
|
||||
|
||||
let enforcement = query
|
||||
query
|
||||
.build_query_as::<Enforcement>()
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
|
||||
Ok(enforcement)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Delete for EnforcementRepository {
|
||||
async fn delete<'e, E>(executor: E, id: i64) -> Result<bool>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM enforcement WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl EnforcementRepository {
|
||||
/// Find enforcements by rule ID
|
||||
pub async fn find_by_rule<'e, E>(executor: E, rule_id: Id) -> Result<Vec<Enforcement>>
|
||||
where
|
||||
@@ -545,6 +645,90 @@ impl EnforcementRepository {
|
||||
Ok(enforcements)
|
||||
}
|
||||
|
||||
pub async fn find_by_rule_and_event<'e, E>(
|
||||
executor: E,
|
||||
rule_id: Id,
|
||||
event_id: Id,
|
||||
) -> Result<Option<Enforcement>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Enforcement>(
|
||||
r#"
|
||||
SELECT id, rule, rule_ref, trigger_ref, config, event, status, payload,
|
||||
condition, conditions, created, resolved_at
|
||||
FROM enforcement
|
||||
WHERE rule = $1 AND event = $2
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(rule_id)
|
||||
.bind(event_id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn create_or_get_by_rule_event<'e, E>(
|
||||
executor: E,
|
||||
input: CreateEnforcementInput,
|
||||
) -> Result<EnforcementCreateOrGetResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let (Some(rule_id), Some(event_id)) = (input.rule, input.event) else {
|
||||
let enforcement = Self::create(executor, input).await?;
|
||||
return Ok(EnforcementCreateOrGetResult {
|
||||
enforcement,
|
||||
created: true,
|
||||
});
|
||||
};
|
||||
|
||||
let inserted = sqlx::query_as::<_, Enforcement>(
|
||||
r#"
|
||||
INSERT INTO enforcement (rule, rule_ref, trigger_ref, config, event, status,
|
||||
payload, condition, conditions)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (rule, event) WHERE rule IS NOT NULL AND event IS NOT NULL DO NOTHING
|
||||
RETURNING id, rule, rule_ref, trigger_ref, config, event, status, payload,
|
||||
condition, conditions, created, resolved_at
|
||||
"#,
|
||||
)
|
||||
.bind(input.rule)
|
||||
.bind(&input.rule_ref)
|
||||
.bind(&input.trigger_ref)
|
||||
.bind(&input.config)
|
||||
.bind(input.event)
|
||||
.bind(input.status)
|
||||
.bind(&input.payload)
|
||||
.bind(input.condition)
|
||||
.bind(&input.conditions)
|
||||
.fetch_optional(executor)
|
||||
.await?;
|
||||
|
||||
if let Some(enforcement) = inserted {
|
||||
return Ok(EnforcementCreateOrGetResult {
|
||||
enforcement,
|
||||
created: true,
|
||||
});
|
||||
}
|
||||
|
||||
let enforcement = Self::find_by_rule_and_event(executor, rule_id, event_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"enforcement for rule {} and event {} disappeared after dedupe conflict",
|
||||
rule_id,
|
||||
event_id
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(EnforcementCreateOrGetResult {
|
||||
enforcement,
|
||||
created: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Search enforcements with all filters pushed into SQL.
|
||||
///
|
||||
/// All filter fields are combinable (AND). Pagination is server-side.
|
||||
|
||||
@@ -4,7 +4,8 @@ use chrono::{DateTime, Utc};
|
||||
|
||||
use crate::models::{enums::ExecutionStatus, execution::*, Id, JsonDict};
|
||||
use crate::Result;
|
||||
use sqlx::{Executor, Postgres, QueryBuilder};
|
||||
use sqlx::{Executor, PgConnection, PgPool, Postgres, QueryBuilder};
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
use super::{Create, Delete, FindById, List, Repository, Update};
|
||||
|
||||
@@ -41,6 +42,18 @@ pub struct ExecutionSearchResult {
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkflowTaskExecutionCreateOrGetResult {
|
||||
pub execution: Execution,
|
||||
pub created: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EnforcementExecutionCreateOrGetResult {
|
||||
pub execution: Execution,
|
||||
pub created: bool,
|
||||
}
|
||||
|
||||
/// An execution row with optional `rule_ref` / `trigger_ref` populated from
|
||||
/// the joined `enforcement` table. This avoids a separate in-memory lookup.
|
||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||
@@ -191,7 +204,577 @@ impl Update for ExecutionRepository {
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
// Build update query
|
||||
if input.status.is_none()
|
||||
&& input.result.is_none()
|
||||
&& input.executor.is_none()
|
||||
&& input.worker.is_none()
|
||||
&& input.started_at.is_none()
|
||||
&& input.workflow_task.is_none()
|
||||
{
|
||||
return Self::get_by_id(executor, id).await;
|
||||
}
|
||||
|
||||
Self::update_with_locator(executor, input, |query| {
|
||||
query.push(" WHERE id = ").push_bind(id);
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionRepository {
|
||||
pub async fn find_top_level_by_enforcement<'e, E>(
|
||||
executor: E,
|
||||
enforcement_id: Id,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let sql = format!(
|
||||
"SELECT {SELECT_COLUMNS} \
|
||||
FROM execution \
|
||||
WHERE enforcement = $1
|
||||
AND parent IS NULL
|
||||
AND (config IS NULL OR NOT (config ? 'retry_of')) \
|
||||
ORDER BY created ASC \
|
||||
LIMIT 1"
|
||||
);
|
||||
|
||||
sqlx::query_as::<_, Execution>(&sql)
|
||||
.bind(enforcement_id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn create_top_level_for_enforcement_if_absent<'e, E>(
|
||||
executor: E,
|
||||
input: CreateExecutionInput,
|
||||
enforcement_id: Id,
|
||||
) -> Result<EnforcementExecutionCreateOrGetResult>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
let inserted = sqlx::query_as::<_, Execution>(&format!(
|
||||
"INSERT INTO execution \
|
||||
(action, action_ref, config, env_vars, parent, enforcement, executor, worker, status, result, workflow_task) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) \
|
||||
ON CONFLICT (enforcement)
|
||||
WHERE enforcement IS NOT NULL
|
||||
AND parent IS NULL
|
||||
AND (config IS NULL OR NOT (config ? 'retry_of'))
|
||||
DO NOTHING \
|
||||
RETURNING {SELECT_COLUMNS}"
|
||||
))
|
||||
.bind(input.action)
|
||||
.bind(&input.action_ref)
|
||||
.bind(&input.config)
|
||||
.bind(&input.env_vars)
|
||||
.bind(input.parent)
|
||||
.bind(input.enforcement)
|
||||
.bind(input.executor)
|
||||
.bind(input.worker)
|
||||
.bind(input.status)
|
||||
.bind(&input.result)
|
||||
.bind(sqlx::types::Json(&input.workflow_task))
|
||||
.fetch_optional(executor)
|
||||
.await?;
|
||||
|
||||
if let Some(execution) = inserted {
|
||||
return Ok(EnforcementExecutionCreateOrGetResult {
|
||||
execution,
|
||||
created: true,
|
||||
});
|
||||
}
|
||||
|
||||
let execution = Self::find_top_level_by_enforcement(executor, enforcement_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"top-level execution for enforcement {} disappeared after dedupe conflict",
|
||||
enforcement_id
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(EnforcementExecutionCreateOrGetResult {
|
||||
execution,
|
||||
created: false,
|
||||
})
|
||||
}
|
||||
|
||||
async fn claim_workflow_task_dispatch<'e, E>(
|
||||
executor: E,
|
||||
workflow_execution_id: Id,
|
||||
task_name: &str,
|
||||
task_index: Option<i32>,
|
||||
) -> Result<bool>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let inserted: Option<(i64,)> = sqlx::query_as(
|
||||
"INSERT INTO workflow_task_dispatch (workflow_execution, task_name, task_index)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (workflow_execution, task_name, COALESCE(task_index, -1)) DO NOTHING
|
||||
RETURNING id",
|
||||
)
|
||||
.bind(workflow_execution_id)
|
||||
.bind(task_name)
|
||||
.bind(task_index)
|
||||
.fetch_optional(executor)
|
||||
.await?;
|
||||
|
||||
Ok(inserted.is_some())
|
||||
}
|
||||
|
||||
async fn assign_workflow_task_dispatch_execution<'e, E>(
|
||||
executor: E,
|
||||
workflow_execution_id: Id,
|
||||
task_name: &str,
|
||||
task_index: Option<i32>,
|
||||
execution_id: Id,
|
||||
) -> Result<()>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query(
|
||||
"UPDATE workflow_task_dispatch
|
||||
SET execution_id = COALESCE(execution_id, $4)
|
||||
WHERE workflow_execution = $1
|
||||
AND task_name = $2
|
||||
AND task_index IS NOT DISTINCT FROM $3",
|
||||
)
|
||||
.bind(workflow_execution_id)
|
||||
.bind(task_name)
|
||||
.bind(task_index)
|
||||
.bind(execution_id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn lock_workflow_task_dispatch<'e, E>(
|
||||
executor: E,
|
||||
workflow_execution_id: Id,
|
||||
task_name: &str,
|
||||
task_index: Option<i32>,
|
||||
) -> Result<Option<Option<Id>>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let row: Option<(Option<i64>,)> = sqlx::query_as(
|
||||
"SELECT execution_id
|
||||
FROM workflow_task_dispatch
|
||||
WHERE workflow_execution = $1
|
||||
AND task_name = $2
|
||||
AND task_index IS NOT DISTINCT FROM $3
|
||||
FOR UPDATE",
|
||||
)
|
||||
.bind(workflow_execution_id)
|
||||
.bind(task_name)
|
||||
.bind(task_index)
|
||||
.fetch_optional(executor)
|
||||
.await?;
|
||||
|
||||
// Map the outer Option to distinguish three cases:
|
||||
// - None → no row exists
|
||||
// - Some(None) → row exists but execution_id is still NULL (mid-creation)
|
||||
// - Some(Some(id)) → row exists with a completed execution_id
|
||||
Ok(row.map(|(execution_id,)| execution_id))
|
||||
}
|
||||
|
||||
async fn create_workflow_task_if_absent_in_conn(
|
||||
conn: &mut PgConnection,
|
||||
input: CreateExecutionInput,
|
||||
workflow_execution_id: Id,
|
||||
task_name: &str,
|
||||
task_index: Option<i32>,
|
||||
) -> Result<WorkflowTaskExecutionCreateOrGetResult> {
|
||||
let claimed = Self::claim_workflow_task_dispatch(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if claimed {
|
||||
let execution = Self::create(&mut *conn, input).await?;
|
||||
Self::assign_workflow_task_dispatch_execution(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
execution.id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok(WorkflowTaskExecutionCreateOrGetResult {
|
||||
execution,
|
||||
created: true,
|
||||
});
|
||||
}
|
||||
|
||||
let dispatch_state = Self::lock_workflow_task_dispatch(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
)
|
||||
.await?;
|
||||
|
||||
match dispatch_state {
|
||||
Some(Some(existing_execution_id)) => {
|
||||
// Row exists with execution_id — return the existing execution.
|
||||
let execution = Self::find_by_id(&mut *conn, existing_execution_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"workflow child execution {} missing for workflow_execution {} task '{}' index {:?}",
|
||||
existing_execution_id,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(WorkflowTaskExecutionCreateOrGetResult {
|
||||
execution,
|
||||
created: false,
|
||||
})
|
||||
}
|
||||
|
||||
Some(None) => {
|
||||
// Row exists but execution_id is still NULL: another transaction is
|
||||
// mid-creation (between claim and assign). Retry until it's filled in.
|
||||
// If the original creator's transaction rolled back, the row also
|
||||
// disappears — handled by the `None` branch inside the loop.
|
||||
'wait: {
|
||||
for _ in 0..20_u32 {
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
match Self::lock_workflow_task_dispatch(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(Some(execution_id)) => {
|
||||
let execution =
|
||||
Self::find_by_id(&mut *conn, execution_id).await?.ok_or_else(
|
||||
|| {
|
||||
anyhow::anyhow!(
|
||||
"workflow child execution {} missing for workflow_execution {} task '{}' index {:?}",
|
||||
execution_id,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index
|
||||
)
|
||||
},
|
||||
)?;
|
||||
return Ok(WorkflowTaskExecutionCreateOrGetResult {
|
||||
execution,
|
||||
created: false,
|
||||
});
|
||||
}
|
||||
Some(None) => {} // still NULL, keep waiting
|
||||
None => break 'wait, // row rolled back; fall through to re-claim
|
||||
}
|
||||
}
|
||||
// Exhausted all retries without the execution_id being set.
|
||||
return Err(anyhow::anyhow!(
|
||||
"Timed out waiting for workflow task dispatch execution_id to be set \
|
||||
for workflow_execution {} task '{}' index {:?}",
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Row disappeared (original creator rolled back) — re-claim and create.
|
||||
let re_claimed = Self::claim_workflow_task_dispatch(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
)
|
||||
.await?;
|
||||
if !re_claimed {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Workflow task dispatch for workflow_execution {} task '{}' index {:?} \
|
||||
was reclaimed by another executor after rollback",
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let execution = Self::create(&mut *conn, input).await?;
|
||||
Self::assign_workflow_task_dispatch_execution(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
execution.id,
|
||||
)
|
||||
.await?;
|
||||
Ok(WorkflowTaskExecutionCreateOrGetResult {
|
||||
execution,
|
||||
created: true,
|
||||
})
|
||||
}
|
||||
|
||||
None => {
|
||||
// No row at all — the original INSERT was rolled back before we arrived.
|
||||
// Attempt to re-claim and create as if this were a fresh dispatch.
|
||||
let re_claimed = Self::claim_workflow_task_dispatch(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
)
|
||||
.await?;
|
||||
if !re_claimed {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Workflow task dispatch for workflow_execution {} task '{}' index {:?} \
|
||||
was claimed by another executor",
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let execution = Self::create(&mut *conn, input).await?;
|
||||
Self::assign_workflow_task_dispatch_execution(
|
||||
&mut *conn,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
execution.id,
|
||||
)
|
||||
.await?;
|
||||
Ok(WorkflowTaskExecutionCreateOrGetResult {
|
||||
execution,
|
||||
created: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_workflow_task_if_absent(
|
||||
pool: &PgPool,
|
||||
input: CreateExecutionInput,
|
||||
workflow_execution_id: Id,
|
||||
task_name: &str,
|
||||
task_index: Option<i32>,
|
||||
) -> Result<WorkflowTaskExecutionCreateOrGetResult> {
|
||||
let mut conn = pool.acquire().await?;
|
||||
sqlx::query("BEGIN").execute(&mut *conn).await?;
|
||||
|
||||
let result = Self::create_workflow_task_if_absent_in_conn(
|
||||
&mut conn,
|
||||
input,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(result) => {
|
||||
sqlx::query("COMMIT").execute(&mut *conn).await?;
|
||||
Ok(result)
|
||||
}
|
||||
Err(err) => {
|
||||
sqlx::query("ROLLBACK").execute(&mut *conn).await?;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_workflow_task_if_absent_with_conn(
|
||||
conn: &mut PgConnection,
|
||||
input: CreateExecutionInput,
|
||||
workflow_execution_id: Id,
|
||||
task_name: &str,
|
||||
task_index: Option<i32>,
|
||||
) -> Result<WorkflowTaskExecutionCreateOrGetResult> {
|
||||
Self::create_workflow_task_if_absent_in_conn(
|
||||
conn,
|
||||
input,
|
||||
workflow_execution_id,
|
||||
task_name,
|
||||
task_index,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn claim_for_scheduling<'e, E>(
|
||||
executor: E,
|
||||
id: Id,
|
||||
claiming_executor: Option<Id>,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let sql = format!(
|
||||
"UPDATE execution \
|
||||
SET status = $2, executor = COALESCE($3, executor), updated = NOW() \
|
||||
WHERE id = $1 AND status = $4 \
|
||||
RETURNING {SELECT_COLUMNS}"
|
||||
);
|
||||
|
||||
sqlx::query_as::<_, Execution>(&sql)
|
||||
.bind(id)
|
||||
.bind(ExecutionStatus::Scheduling)
|
||||
.bind(claiming_executor)
|
||||
.bind(ExecutionStatus::Requested)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn reclaim_stale_scheduling<'e, E>(
|
||||
executor: E,
|
||||
id: Id,
|
||||
claiming_executor: Option<Id>,
|
||||
stale_before: DateTime<Utc>,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let sql = format!(
|
||||
"UPDATE execution \
|
||||
SET executor = COALESCE($2, executor), updated = NOW() \
|
||||
WHERE id = $1 AND status = $3 AND updated <= $4 \
|
||||
RETURNING {SELECT_COLUMNS}"
|
||||
);
|
||||
|
||||
sqlx::query_as::<_, Execution>(&sql)
|
||||
.bind(id)
|
||||
.bind(claiming_executor)
|
||||
.bind(ExecutionStatus::Scheduling)
|
||||
.bind(stale_before)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn update_if_status<'e, E>(
|
||||
executor: E,
|
||||
id: Id,
|
||||
expected_status: ExecutionStatus,
|
||||
input: UpdateExecutionInput,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
if input.status.is_none()
|
||||
&& input.result.is_none()
|
||||
&& input.executor.is_none()
|
||||
&& input.worker.is_none()
|
||||
&& input.started_at.is_none()
|
||||
&& input.workflow_task.is_none()
|
||||
{
|
||||
return Self::find_by_id(executor, id).await;
|
||||
}
|
||||
|
||||
Self::update_with_locator_optional(executor, input, |query| {
|
||||
query.push(" WHERE id = ").push_bind(id);
|
||||
query.push(" AND status = ").push_bind(expected_status);
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn update_if_status_and_updated_before<'e, E>(
|
||||
executor: E,
|
||||
id: Id,
|
||||
expected_status: ExecutionStatus,
|
||||
stale_before: DateTime<Utc>,
|
||||
input: UpdateExecutionInput,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
if input.status.is_none()
|
||||
&& input.result.is_none()
|
||||
&& input.executor.is_none()
|
||||
&& input.worker.is_none()
|
||||
&& input.started_at.is_none()
|
||||
&& input.workflow_task.is_none()
|
||||
{
|
||||
return Self::find_by_id(executor, id).await;
|
||||
}
|
||||
|
||||
Self::update_with_locator_optional(executor, input, |query| {
|
||||
query.push(" WHERE id = ").push_bind(id);
|
||||
query.push(" AND status = ").push_bind(expected_status);
|
||||
query.push(" AND updated < ").push_bind(stale_before);
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn update_if_status_and_updated_at<'e, E>(
|
||||
executor: E,
|
||||
id: Id,
|
||||
expected_status: ExecutionStatus,
|
||||
expected_updated: DateTime<Utc>,
|
||||
input: UpdateExecutionInput,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
if input.status.is_none()
|
||||
&& input.result.is_none()
|
||||
&& input.executor.is_none()
|
||||
&& input.worker.is_none()
|
||||
&& input.started_at.is_none()
|
||||
&& input.workflow_task.is_none()
|
||||
{
|
||||
return Self::find_by_id(executor, id).await;
|
||||
}
|
||||
|
||||
Self::update_with_locator_optional(executor, input, |query| {
|
||||
query.push(" WHERE id = ").push_bind(id);
|
||||
query.push(" AND status = ").push_bind(expected_status);
|
||||
query.push(" AND updated = ").push_bind(expected_updated);
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn revert_scheduled_to_requested<'e, E>(
|
||||
executor: E,
|
||||
id: Id,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let sql = format!(
|
||||
"UPDATE execution \
|
||||
SET status = $2, worker = NULL, executor = NULL, updated = NOW() \
|
||||
WHERE id = $1 AND status = $3 \
|
||||
RETURNING {SELECT_COLUMNS}"
|
||||
);
|
||||
|
||||
sqlx::query_as::<_, Execution>(&sql)
|
||||
.bind(id)
|
||||
.bind(ExecutionStatus::Requested)
|
||||
.bind(ExecutionStatus::Scheduled)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn update_with_locator<'e, E, F>(
|
||||
executor: E,
|
||||
input: UpdateExecutionInput,
|
||||
where_clause: F,
|
||||
) -> Result<Execution>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
F: FnOnce(&mut QueryBuilder<'_, Postgres>),
|
||||
{
|
||||
let mut query = QueryBuilder::new("UPDATE execution SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
@@ -234,15 +817,10 @@ impl Update for ExecutionRepository {
|
||||
query
|
||||
.push("workflow_task = ")
|
||||
.push_bind(sqlx::types::Json(workflow_task));
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if !has_updates {
|
||||
// No updates requested, fetch and return existing entity
|
||||
return Self::get_by_id(executor, id).await;
|
||||
}
|
||||
|
||||
query.push(", updated = NOW() WHERE id = ").push_bind(id);
|
||||
query.push(", updated = NOW()");
|
||||
where_clause(&mut query);
|
||||
query.push(" RETURNING ");
|
||||
query.push(SELECT_COLUMNS);
|
||||
|
||||
@@ -252,6 +830,96 @@ impl Update for ExecutionRepository {
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn update_with_locator_optional<'e, E, F>(
|
||||
executor: E,
|
||||
input: UpdateExecutionInput,
|
||||
where_clause: F,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
F: FnOnce(&mut QueryBuilder<'_, Postgres>),
|
||||
{
|
||||
let mut query = QueryBuilder::new("UPDATE execution SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(status) = input.status {
|
||||
query.push("status = ").push_bind(status);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(result) = &input.result {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("result = ").push_bind(result);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(executor_id) = input.executor {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("executor = ").push_bind(executor_id);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(worker_id) = input.worker {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("worker = ").push_bind(worker_id);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(started_at) = input.started_at {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("started_at = ").push_bind(started_at);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(workflow_task) = &input.workflow_task {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query
|
||||
.push("workflow_task = ")
|
||||
.push_bind(sqlx::types::Json(workflow_task));
|
||||
}
|
||||
|
||||
query.push(", updated = NOW()");
|
||||
where_clause(&mut query);
|
||||
query.push(" RETURNING ");
|
||||
query.push(SELECT_COLUMNS);
|
||||
|
||||
query
|
||||
.build_query_as::<Execution>()
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Update an execution using the loaded row's primary key.
|
||||
pub async fn update_loaded<'e, E>(
|
||||
executor: E,
|
||||
execution: &Execution,
|
||||
input: UpdateExecutionInput,
|
||||
) -> Result<Execution>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
if input.status.is_none()
|
||||
&& input.result.is_none()
|
||||
&& input.executor.is_none()
|
||||
&& input.worker.is_none()
|
||||
&& input.started_at.is_none()
|
||||
&& input.workflow_task.is_none()
|
||||
{
|
||||
return Ok(execution.clone());
|
||||
}
|
||||
|
||||
Self::update_with_locator(executor, input, |query| {
|
||||
query.push(" WHERE id = ").push_bind(execution.id);
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -303,6 +971,34 @@ impl ExecutionRepository {
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn find_by_workflow_task<'e, E>(
|
||||
executor: E,
|
||||
workflow_execution_id: Id,
|
||||
task_name: &str,
|
||||
task_index: Option<i32>,
|
||||
) -> Result<Option<Execution>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let sql = format!(
|
||||
"SELECT {SELECT_COLUMNS} \
|
||||
FROM execution \
|
||||
WHERE workflow_task->>'workflow_execution' = $1::text \
|
||||
AND workflow_task->>'task_name' = $2 \
|
||||
AND (workflow_task->>'task_index')::int IS NOT DISTINCT FROM $3 \
|
||||
ORDER BY created ASC \
|
||||
LIMIT 1"
|
||||
);
|
||||
|
||||
sqlx::query_as::<_, Execution>(&sql)
|
||||
.bind(workflow_execution_id.to_string())
|
||||
.bind(task_name)
|
||||
.bind(task_index)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Find all child executions for a given parent execution ID.
|
||||
///
|
||||
/// Returns child executions ordered by creation time (ascending),
|
||||
|
||||
909
crates/common/src/repositories/execution_admission.rs
Normal file
909
crates/common/src/repositories/execution_admission.rs
Normal file
@@ -0,0 +1,909 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::{PgPool, Postgres, Row, Transaction};
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::models::Id;
|
||||
use crate::repositories::queue_stats::{QueueStatsRepository, UpsertQueueStatsInput};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdmissionSlotAcquireOutcome {
|
||||
pub acquired: bool,
|
||||
pub current_count: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AdmissionEnqueueOutcome {
|
||||
Acquired,
|
||||
Enqueued,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdmissionSlotReleaseOutcome {
|
||||
pub action_id: Id,
|
||||
pub group_key: Option<String>,
|
||||
pub next_execution_id: Option<Id>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdmissionQueuedRemovalOutcome {
|
||||
pub action_id: Id,
|
||||
pub group_key: Option<String>,
|
||||
pub next_execution_id: Option<Id>,
|
||||
pub execution_id: Id,
|
||||
pub queue_order: i64,
|
||||
pub enqueued_at: DateTime<Utc>,
|
||||
pub removed_index: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdmissionQueueStats {
|
||||
pub action_id: Id,
|
||||
pub queue_length: usize,
|
||||
pub active_count: u32,
|
||||
pub max_concurrent: u32,
|
||||
pub oldest_enqueued_at: Option<DateTime<Utc>>,
|
||||
pub total_enqueued: u64,
|
||||
pub total_completed: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct AdmissionState {
|
||||
id: Id,
|
||||
action_id: Id,
|
||||
group_key: Option<String>,
|
||||
max_concurrent: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ExecutionEntry {
|
||||
state_id: Id,
|
||||
action_id: Id,
|
||||
group_key: Option<String>,
|
||||
status: String,
|
||||
queue_order: i64,
|
||||
enqueued_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub struct ExecutionAdmissionRepository;
|
||||
|
||||
impl ExecutionAdmissionRepository {
|
||||
pub async fn enqueue(
|
||||
pool: &PgPool,
|
||||
max_queue_length: usize,
|
||||
action_id: Id,
|
||||
execution_id: Id,
|
||||
max_concurrent: u32,
|
||||
group_key: Option<String>,
|
||||
) -> Result<AdmissionEnqueueOutcome> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let state = Self::lock_state(&mut tx, action_id, group_key, max_concurrent).await?;
|
||||
let outcome =
|
||||
Self::enqueue_in_state(&mut tx, &state, max_queue_length, execution_id, true).await?;
|
||||
Self::refresh_queue_stats(&mut tx, action_id).await?;
|
||||
tx.commit().await?;
|
||||
Ok(outcome)
|
||||
}
|
||||
|
||||
pub async fn wait_status(pool: &PgPool, execution_id: Id) -> Result<Option<bool>> {
|
||||
let row = sqlx::query_scalar::<Postgres, bool>(
|
||||
r#"
|
||||
SELECT status = 'active'
|
||||
FROM execution_admission_entry
|
||||
WHERE execution_id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(execution_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(row)
|
||||
}
|
||||
|
||||
pub async fn try_acquire(
|
||||
pool: &PgPool,
|
||||
action_id: Id,
|
||||
execution_id: Id,
|
||||
max_concurrent: u32,
|
||||
group_key: Option<String>,
|
||||
) -> Result<AdmissionSlotAcquireOutcome> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let state = Self::lock_state(&mut tx, action_id, group_key, max_concurrent).await?;
|
||||
let active_count = Self::active_count(&mut tx, state.id).await? as u32;
|
||||
|
||||
let outcome = match Self::find_execution_entry(&mut tx, execution_id).await? {
|
||||
Some(entry) if entry.status == "active" => AdmissionSlotAcquireOutcome {
|
||||
acquired: true,
|
||||
current_count: active_count,
|
||||
},
|
||||
Some(entry) if entry.status == "queued" && entry.state_id == state.id => {
|
||||
let promoted =
|
||||
Self::maybe_promote_existing_queued(&mut tx, &state, execution_id).await?;
|
||||
AdmissionSlotAcquireOutcome {
|
||||
acquired: promoted,
|
||||
current_count: active_count,
|
||||
}
|
||||
}
|
||||
Some(_) => AdmissionSlotAcquireOutcome {
|
||||
acquired: false,
|
||||
current_count: active_count,
|
||||
},
|
||||
None => {
|
||||
if active_count < max_concurrent
|
||||
&& Self::queued_count(&mut tx, state.id).await? == 0
|
||||
{
|
||||
let queue_order = Self::allocate_queue_order(&mut tx, state.id).await?;
|
||||
Self::insert_entry(
|
||||
&mut tx,
|
||||
state.id,
|
||||
execution_id,
|
||||
"active",
|
||||
queue_order,
|
||||
Utc::now(),
|
||||
)
|
||||
.await?;
|
||||
Self::increment_total_enqueued(&mut tx, state.id).await?;
|
||||
Self::refresh_queue_stats(&mut tx, action_id).await?;
|
||||
AdmissionSlotAcquireOutcome {
|
||||
acquired: true,
|
||||
current_count: active_count,
|
||||
}
|
||||
} else {
|
||||
AdmissionSlotAcquireOutcome {
|
||||
acquired: false,
|
||||
current_count: active_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(outcome)
|
||||
}
|
||||
|
||||
pub async fn release_active_slot(
|
||||
pool: &PgPool,
|
||||
execution_id: Id,
|
||||
) -> Result<Option<AdmissionSlotReleaseOutcome>> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let Some(entry) = Self::find_execution_entry_for_update(&mut tx, execution_id).await?
|
||||
else {
|
||||
tx.commit().await?;
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
if entry.status != "active" {
|
||||
tx.commit().await?;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let state = Self::lock_existing_state(&mut tx, entry.action_id, entry.group_key.clone())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
crate::Error::internal("missing execution_admission_state for active execution")
|
||||
})?;
|
||||
|
||||
sqlx::query("DELETE FROM execution_admission_entry WHERE execution_id = $1")
|
||||
.bind(execution_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
Self::increment_total_completed(&mut tx, state.id).await?;
|
||||
|
||||
let next_execution_id = Self::promote_next_queued(&mut tx, &state).await?;
|
||||
Self::refresh_queue_stats(&mut tx, state.action_id).await?;
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(Some(AdmissionSlotReleaseOutcome {
|
||||
action_id: state.action_id,
|
||||
group_key: state.group_key,
|
||||
next_execution_id,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn restore_active_slot(
|
||||
pool: &PgPool,
|
||||
execution_id: Id,
|
||||
outcome: &AdmissionSlotReleaseOutcome,
|
||||
) -> Result<()> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let state =
|
||||
Self::lock_existing_state(&mut tx, outcome.action_id, outcome.group_key.clone())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
crate::Error::internal("missing execution_admission_state on restore")
|
||||
})?;
|
||||
|
||||
if let Some(next_execution_id) = outcome.next_execution_id {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE execution_admission_entry
|
||||
SET status = 'queued', activated_at = NULL
|
||||
WHERE execution_id = $1
|
||||
AND state_id = $2
|
||||
AND status = 'active'
|
||||
"#,
|
||||
)
|
||||
.bind(next_execution_id)
|
||||
.bind(state.id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO execution_admission_entry (
|
||||
state_id, execution_id, status, queue_order, enqueued_at, activated_at
|
||||
) VALUES ($1, $2, 'active', $3, NOW(), NOW())
|
||||
ON CONFLICT (execution_id) DO UPDATE
|
||||
SET state_id = EXCLUDED.state_id,
|
||||
status = 'active',
|
||||
activated_at = EXCLUDED.activated_at
|
||||
"#,
|
||||
)
|
||||
.bind(state.id)
|
||||
.bind(execution_id)
|
||||
.bind(Self::allocate_queue_order(&mut tx, state.id).await?)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE execution_admission_state
|
||||
SET total_completed = GREATEST(total_completed - 1, 0)
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(state.id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
Self::refresh_queue_stats(&mut tx, state.action_id).await?;
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_queued_execution(
|
||||
pool: &PgPool,
|
||||
execution_id: Id,
|
||||
) -> Result<Option<AdmissionQueuedRemovalOutcome>> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let Some(entry) = Self::find_execution_entry_for_update(&mut tx, execution_id).await?
|
||||
else {
|
||||
tx.commit().await?;
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
if entry.status != "queued" {
|
||||
tx.commit().await?;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let state = Self::lock_existing_state(&mut tx, entry.action_id, entry.group_key.clone())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
crate::Error::internal("missing execution_admission_state for queued execution")
|
||||
})?;
|
||||
|
||||
let removed_index = sqlx::query_scalar::<Postgres, i64>(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM execution_admission_entry
|
||||
WHERE state_id = $1
|
||||
AND status = 'queued'
|
||||
AND (enqueued_at, id) < (
|
||||
SELECT enqueued_at, id
|
||||
FROM execution_admission_entry
|
||||
WHERE execution_id = $2
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(state.id)
|
||||
.bind(execution_id)
|
||||
.fetch_one(&mut *tx)
|
||||
.await? as usize;
|
||||
|
||||
sqlx::query("DELETE FROM execution_admission_entry WHERE execution_id = $1")
|
||||
.bind(execution_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let next_execution_id =
|
||||
if Self::active_count(&mut tx, state.id).await? < state.max_concurrent as i64 {
|
||||
Self::promote_next_queued(&mut tx, &state).await?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self::refresh_queue_stats(&mut tx, state.action_id).await?;
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(Some(AdmissionQueuedRemovalOutcome {
|
||||
action_id: state.action_id,
|
||||
group_key: state.group_key,
|
||||
next_execution_id,
|
||||
execution_id,
|
||||
queue_order: entry.queue_order,
|
||||
enqueued_at: entry.enqueued_at,
|
||||
removed_index,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn restore_queued_execution(
|
||||
pool: &PgPool,
|
||||
outcome: &AdmissionQueuedRemovalOutcome,
|
||||
) -> Result<()> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let state =
|
||||
Self::lock_existing_state(&mut tx, outcome.action_id, outcome.group_key.clone())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
crate::Error::internal("missing execution_admission_state on queued restore")
|
||||
})?;
|
||||
|
||||
if let Some(next_execution_id) = outcome.next_execution_id {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE execution_admission_entry
|
||||
SET status = 'queued', activated_at = NULL
|
||||
WHERE execution_id = $1
|
||||
AND state_id = $2
|
||||
AND status = 'active'
|
||||
"#,
|
||||
)
|
||||
.bind(next_execution_id)
|
||||
.bind(state.id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO execution_admission_entry (
|
||||
state_id, execution_id, status, queue_order, enqueued_at, activated_at
|
||||
) VALUES ($1, $2, 'queued', $3, $4, NULL)
|
||||
ON CONFLICT (execution_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(state.id)
|
||||
.bind(outcome.execution_id)
|
||||
.bind(outcome.queue_order)
|
||||
.bind(outcome.enqueued_at)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
Self::refresh_queue_stats(&mut tx, state.action_id).await?;
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_queue_stats(
|
||||
pool: &PgPool,
|
||||
action_id: Id,
|
||||
) -> Result<Option<AdmissionQueueStats>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
WITH state_rows AS (
|
||||
SELECT
|
||||
COUNT(*) AS state_count,
|
||||
COALESCE(SUM(max_concurrent), 0) AS max_concurrent,
|
||||
COALESCE(SUM(total_enqueued), 0) AS total_enqueued,
|
||||
COALESCE(SUM(total_completed), 0) AS total_completed
|
||||
FROM execution_admission_state
|
||||
WHERE action_id = $1
|
||||
),
|
||||
entry_rows AS (
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE e.status = 'queued') AS queue_length,
|
||||
COUNT(*) FILTER (WHERE e.status = 'active') AS active_count,
|
||||
MIN(e.enqueued_at) FILTER (WHERE e.status = 'queued') AS oldest_enqueued_at
|
||||
FROM execution_admission_state s
|
||||
LEFT JOIN execution_admission_entry e ON e.state_id = s.id
|
||||
WHERE s.action_id = $1
|
||||
)
|
||||
SELECT
|
||||
sr.state_count,
|
||||
er.queue_length,
|
||||
er.active_count,
|
||||
sr.max_concurrent,
|
||||
er.oldest_enqueued_at,
|
||||
sr.total_enqueued,
|
||||
sr.total_completed
|
||||
FROM state_rows sr
|
||||
CROSS JOIN entry_rows er
|
||||
"#,
|
||||
)
|
||||
.bind(action_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let state_count: i64 = row.try_get("state_count")?;
|
||||
if state_count == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(AdmissionQueueStats {
|
||||
action_id,
|
||||
queue_length: row.try_get::<i64, _>("queue_length")? as usize,
|
||||
active_count: row.try_get::<i64, _>("active_count")? as u32,
|
||||
max_concurrent: row.try_get::<i64, _>("max_concurrent")? as u32,
|
||||
oldest_enqueued_at: row.try_get("oldest_enqueued_at")?,
|
||||
total_enqueued: row.try_get::<i64, _>("total_enqueued")? as u64,
|
||||
total_completed: row.try_get::<i64, _>("total_completed")? as u64,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn enqueue_in_state(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
state: &AdmissionState,
|
||||
max_queue_length: usize,
|
||||
execution_id: Id,
|
||||
allow_queue: bool,
|
||||
) -> Result<AdmissionEnqueueOutcome> {
|
||||
if let Some(entry) = Self::find_execution_entry(tx, execution_id).await? {
|
||||
if entry.status == "active" {
|
||||
return Ok(AdmissionEnqueueOutcome::Acquired);
|
||||
}
|
||||
|
||||
if entry.status == "queued" && entry.state_id == state.id {
|
||||
if Self::maybe_promote_existing_queued(tx, state, execution_id).await? {
|
||||
return Ok(AdmissionEnqueueOutcome::Acquired);
|
||||
}
|
||||
return Ok(AdmissionEnqueueOutcome::Enqueued);
|
||||
}
|
||||
|
||||
return Ok(AdmissionEnqueueOutcome::Enqueued);
|
||||
}
|
||||
|
||||
let active_count = Self::active_count(tx, state.id).await?;
|
||||
let queued_count = Self::queued_count(tx, state.id).await?;
|
||||
|
||||
if active_count < state.max_concurrent as i64 && queued_count == 0 {
|
||||
let queue_order = Self::allocate_queue_order(tx, state.id).await?;
|
||||
Self::insert_entry(
|
||||
tx,
|
||||
state.id,
|
||||
execution_id,
|
||||
"active",
|
||||
queue_order,
|
||||
Utc::now(),
|
||||
)
|
||||
.await?;
|
||||
Self::increment_total_enqueued(tx, state.id).await?;
|
||||
return Ok(AdmissionEnqueueOutcome::Acquired);
|
||||
}
|
||||
|
||||
if !allow_queue {
|
||||
return Ok(AdmissionEnqueueOutcome::Enqueued);
|
||||
}
|
||||
|
||||
if queued_count >= max_queue_length as i64 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Queue full for action {}: maximum {} entries",
|
||||
state.action_id,
|
||||
max_queue_length
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let queue_order = Self::allocate_queue_order(tx, state.id).await?;
|
||||
Self::insert_entry(
|
||||
tx,
|
||||
state.id,
|
||||
execution_id,
|
||||
"queued",
|
||||
queue_order,
|
||||
Utc::now(),
|
||||
)
|
||||
.await?;
|
||||
Self::increment_total_enqueued(tx, state.id).await?;
|
||||
Ok(AdmissionEnqueueOutcome::Enqueued)
|
||||
}
|
||||
|
||||
async fn maybe_promote_existing_queued(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
state: &AdmissionState,
|
||||
execution_id: Id,
|
||||
) -> Result<bool> {
|
||||
let active_count = Self::active_count(tx, state.id).await?;
|
||||
if active_count >= state.max_concurrent as i64 {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let front_execution_id = sqlx::query_scalar::<Postgres, Id>(
|
||||
r#"
|
||||
SELECT execution_id
|
||||
FROM execution_admission_entry
|
||||
WHERE state_id = $1
|
||||
AND status = 'queued'
|
||||
ORDER BY queue_order ASC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(state.id)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await?;
|
||||
|
||||
if front_execution_id != Some(execution_id) {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE execution_admission_entry
|
||||
SET status = 'active',
|
||||
activated_at = NOW()
|
||||
WHERE execution_id = $1
|
||||
AND state_id = $2
|
||||
AND status = 'queued'
|
||||
"#,
|
||||
)
|
||||
.bind(execution_id)
|
||||
.bind(state.id)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn promote_next_queued(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
state: &AdmissionState,
|
||||
) -> Result<Option<Id>> {
|
||||
let next_execution_id = sqlx::query_scalar::<Postgres, Id>(
|
||||
r#"
|
||||
SELECT execution_id
|
||||
FROM execution_admission_entry
|
||||
WHERE state_id = $1
|
||||
AND status = 'queued'
|
||||
ORDER BY queue_order ASC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(state.id)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await?;
|
||||
|
||||
if let Some(next_execution_id) = next_execution_id {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE execution_admission_entry
|
||||
SET status = 'active',
|
||||
activated_at = NOW()
|
||||
WHERE execution_id = $1
|
||||
AND state_id = $2
|
||||
AND status = 'queued'
|
||||
"#,
|
||||
)
|
||||
.bind(next_execution_id)
|
||||
.bind(state.id)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(next_execution_id)
|
||||
}
|
||||
|
||||
async fn lock_state(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
action_id: Id,
|
||||
group_key: Option<String>,
|
||||
max_concurrent: u32,
|
||||
) -> Result<AdmissionState> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO execution_admission_state (action_id, group_key, max_concurrent)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (action_id, group_key_normalized)
|
||||
DO UPDATE SET max_concurrent = EXCLUDED.max_concurrent
|
||||
"#,
|
||||
)
|
||||
.bind(action_id)
|
||||
.bind(group_key.clone())
|
||||
.bind(max_concurrent as i32)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
|
||||
let state = sqlx::query(
|
||||
r#"
|
||||
SELECT id, action_id, group_key, max_concurrent
|
||||
FROM execution_admission_state
|
||||
WHERE action_id = $1
|
||||
AND group_key_normalized = COALESCE($2, '')
|
||||
FOR UPDATE
|
||||
"#,
|
||||
)
|
||||
.bind(action_id)
|
||||
.bind(group_key)
|
||||
.fetch_one(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(AdmissionState {
|
||||
id: state.try_get("id")?,
|
||||
action_id: state.try_get("action_id")?,
|
||||
group_key: state.try_get("group_key")?,
|
||||
max_concurrent: state.try_get("max_concurrent")?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn lock_existing_state(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
action_id: Id,
|
||||
group_key: Option<String>,
|
||||
) -> Result<Option<AdmissionState>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, action_id, group_key, max_concurrent
|
||||
FROM execution_admission_state
|
||||
WHERE action_id = $1
|
||||
AND group_key_normalized = COALESCE($2, '')
|
||||
FOR UPDATE
|
||||
"#,
|
||||
)
|
||||
.bind(action_id)
|
||||
.bind(group_key)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|state| AdmissionState {
|
||||
id: state.try_get("id").expect("state.id"),
|
||||
action_id: state.try_get("action_id").expect("state.action_id"),
|
||||
group_key: state.try_get("group_key").expect("state.group_key"),
|
||||
max_concurrent: state
|
||||
.try_get("max_concurrent")
|
||||
.expect("state.max_concurrent"),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn find_execution_entry(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
execution_id: Id,
|
||||
) -> Result<Option<ExecutionEntry>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
e.state_id,
|
||||
s.action_id,
|
||||
s.group_key,
|
||||
e.execution_id,
|
||||
e.status,
|
||||
e.queue_order,
|
||||
e.enqueued_at
|
||||
FROM execution_admission_entry e
|
||||
JOIN execution_admission_state s ON s.id = e.state_id
|
||||
WHERE e.execution_id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(execution_id)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|entry| ExecutionEntry {
|
||||
state_id: entry.try_get("state_id").expect("entry.state_id"),
|
||||
action_id: entry.try_get("action_id").expect("entry.action_id"),
|
||||
group_key: entry.try_get("group_key").expect("entry.group_key"),
|
||||
status: entry.try_get("status").expect("entry.status"),
|
||||
queue_order: entry.try_get("queue_order").expect("entry.queue_order"),
|
||||
enqueued_at: entry.try_get("enqueued_at").expect("entry.enqueued_at"),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn find_execution_entry_for_update(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
execution_id: Id,
|
||||
) -> Result<Option<ExecutionEntry>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
e.state_id,
|
||||
s.action_id,
|
||||
s.group_key,
|
||||
e.execution_id,
|
||||
e.status,
|
||||
e.queue_order,
|
||||
e.enqueued_at
|
||||
FROM execution_admission_entry e
|
||||
JOIN execution_admission_state s ON s.id = e.state_id
|
||||
WHERE e.execution_id = $1
|
||||
FOR UPDATE OF e, s
|
||||
"#,
|
||||
)
|
||||
.bind(execution_id)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|entry| ExecutionEntry {
|
||||
state_id: entry.try_get("state_id").expect("entry.state_id"),
|
||||
action_id: entry.try_get("action_id").expect("entry.action_id"),
|
||||
group_key: entry.try_get("group_key").expect("entry.group_key"),
|
||||
status: entry.try_get("status").expect("entry.status"),
|
||||
queue_order: entry.try_get("queue_order").expect("entry.queue_order"),
|
||||
enqueued_at: entry.try_get("enqueued_at").expect("entry.enqueued_at"),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn active_count(tx: &mut Transaction<'_, Postgres>, state_id: Id) -> Result<i64> {
|
||||
Ok(sqlx::query_scalar::<Postgres, i64>(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM execution_admission_entry
|
||||
WHERE state_id = $1
|
||||
AND status = 'active'
|
||||
"#,
|
||||
)
|
||||
.bind(state_id)
|
||||
.fetch_one(&mut **tx)
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn queued_count(tx: &mut Transaction<'_, Postgres>, state_id: Id) -> Result<i64> {
|
||||
Ok(sqlx::query_scalar::<Postgres, i64>(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM execution_admission_entry
|
||||
WHERE state_id = $1
|
||||
AND status = 'queued'
|
||||
"#,
|
||||
)
|
||||
.bind(state_id)
|
||||
.fetch_one(&mut **tx)
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn insert_entry(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
state_id: Id,
|
||||
execution_id: Id,
|
||||
status: &str,
|
||||
queue_order: i64,
|
||||
enqueued_at: DateTime<Utc>,
|
||||
) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO execution_admission_entry (
|
||||
state_id, execution_id, status, queue_order, enqueued_at, activated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
CASE WHEN $3 = 'active' THEN NOW() ELSE NULL END
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(state_id)
|
||||
.bind(execution_id)
|
||||
.bind(status)
|
||||
.bind(queue_order)
|
||||
.bind(enqueued_at)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn allocate_queue_order(tx: &mut Transaction<'_, Postgres>, state_id: Id) -> Result<i64> {
|
||||
let queue_order = sqlx::query_scalar::<Postgres, i64>(
|
||||
r#"
|
||||
UPDATE execution_admission_state
|
||||
SET next_queue_order = next_queue_order + 1
|
||||
WHERE id = $1
|
||||
RETURNING next_queue_order - 1
|
||||
"#,
|
||||
)
|
||||
.bind(state_id)
|
||||
.fetch_one(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(queue_order)
|
||||
}
|
||||
|
||||
async fn increment_total_enqueued(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
state_id: Id,
|
||||
) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE execution_admission_state
|
||||
SET total_enqueued = total_enqueued + 1
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(state_id)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn increment_total_completed(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
state_id: Id,
|
||||
) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE execution_admission_state
|
||||
SET total_completed = total_completed + 1
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(state_id)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn refresh_queue_stats(tx: &mut Transaction<'_, Postgres>, action_id: Id) -> Result<()> {
|
||||
let Some(stats) = Self::get_queue_stats_from_tx(tx, action_id).await? else {
|
||||
QueueStatsRepository::delete(&mut **tx, action_id).await?;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
QueueStatsRepository::upsert(
|
||||
&mut **tx,
|
||||
UpsertQueueStatsInput {
|
||||
action_id,
|
||||
queue_length: stats.queue_length as i32,
|
||||
active_count: stats.active_count as i32,
|
||||
max_concurrent: stats.max_concurrent as i32,
|
||||
oldest_enqueued_at: stats.oldest_enqueued_at,
|
||||
total_enqueued: stats.total_enqueued as i64,
|
||||
total_completed: stats.total_completed as i64,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_queue_stats_from_tx(
|
||||
tx: &mut Transaction<'_, Postgres>,
|
||||
action_id: Id,
|
||||
) -> Result<Option<AdmissionQueueStats>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
WITH state_rows AS (
|
||||
SELECT
|
||||
COUNT(*) AS state_count,
|
||||
COALESCE(SUM(max_concurrent), 0) AS max_concurrent,
|
||||
COALESCE(SUM(total_enqueued), 0) AS total_enqueued,
|
||||
COALESCE(SUM(total_completed), 0) AS total_completed
|
||||
FROM execution_admission_state
|
||||
WHERE action_id = $1
|
||||
),
|
||||
entry_rows AS (
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE e.status = 'queued') AS queue_length,
|
||||
COUNT(*) FILTER (WHERE e.status = 'active') AS active_count,
|
||||
MIN(e.enqueued_at) FILTER (WHERE e.status = 'queued') AS oldest_enqueued_at
|
||||
FROM execution_admission_state s
|
||||
LEFT JOIN execution_admission_entry e ON e.state_id = s.id
|
||||
WHERE s.action_id = $1
|
||||
)
|
||||
SELECT
|
||||
sr.state_count,
|
||||
er.queue_length,
|
||||
er.active_count,
|
||||
sr.max_concurrent,
|
||||
er.oldest_enqueued_at,
|
||||
sr.total_enqueued,
|
||||
sr.total_completed
|
||||
FROM state_rows sr
|
||||
CROSS JOIN entry_rows er
|
||||
"#,
|
||||
)
|
||||
.bind(action_id)
|
||||
.fetch_one(&mut **tx)
|
||||
.await?;
|
||||
|
||||
let state_count: i64 = row.try_get("state_count")?;
|
||||
if state_count == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(AdmissionQueueStats {
|
||||
action_id,
|
||||
queue_length: row.try_get::<i64, _>("queue_length")? as usize,
|
||||
active_count: row.try_get::<i64, _>("active_count")? as u32,
|
||||
max_concurrent: row.try_get::<i64, _>("max_concurrent")? as u32,
|
||||
oldest_enqueued_at: row.try_get("oldest_enqueued_at")?,
|
||||
total_enqueued: row.try_get::<i64, _>("total_enqueued")? as u64,
|
||||
total_completed: row.try_get::<i64, _>("total_completed")? as u64,
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@ pub struct UpdateIdentityInput {
|
||||
pub display_name: Option<String>,
|
||||
pub password_hash: Option<String>,
|
||||
pub attributes: Option<JsonDict>,
|
||||
pub frozen: Option<bool>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -37,7 +38,7 @@ impl FindById for IdentityRepository {
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Identity>(
|
||||
"SELECT id, login, display_name, password_hash, attributes, created, updated FROM identity WHERE id = $1"
|
||||
"SELECT id, login, display_name, password_hash, attributes, frozen, created, updated FROM identity WHERE id = $1"
|
||||
).bind(id).fetch_optional(executor).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
@@ -49,7 +50,7 @@ impl List for IdentityRepository {
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Identity>(
|
||||
"SELECT id, login, display_name, password_hash, attributes, created, updated FROM identity ORDER BY login ASC"
|
||||
"SELECT id, login, display_name, password_hash, attributes, frozen, created, updated FROM identity ORDER BY login ASC"
|
||||
).fetch_all(executor).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
@@ -62,7 +63,7 @@ impl Create for IdentityRepository {
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Identity>(
|
||||
"INSERT INTO identity (login, display_name, password_hash, attributes) VALUES ($1, $2, $3, $4) RETURNING id, login, display_name, password_hash, attributes, created, updated"
|
||||
"INSERT INTO identity (login, display_name, password_hash, attributes) VALUES ($1, $2, $3, $4) RETURNING id, login, display_name, password_hash, attributes, frozen, created, updated"
|
||||
)
|
||||
.bind(&input.login)
|
||||
.bind(&input.display_name)
|
||||
@@ -111,6 +112,13 @@ impl Update for IdentityRepository {
|
||||
query.push("attributes = ").push_bind(attributes);
|
||||
has_updates = true;
|
||||
}
|
||||
if let Some(frozen) = input.frozen {
|
||||
if has_updates {
|
||||
query.push(", ");
|
||||
}
|
||||
query.push("frozen = ").push_bind(frozen);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if !has_updates {
|
||||
// No updates requested, fetch and return existing entity
|
||||
@@ -119,7 +127,7 @@ impl Update for IdentityRepository {
|
||||
|
||||
query.push(", updated = NOW() WHERE id = ").push_bind(id);
|
||||
query.push(
|
||||
" RETURNING id, login, display_name, password_hash, attributes, created, updated",
|
||||
" RETURNING id, login, display_name, password_hash, attributes, frozen, created, updated",
|
||||
);
|
||||
|
||||
query
|
||||
@@ -156,9 +164,51 @@ impl IdentityRepository {
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Identity>(
|
||||
"SELECT id, login, display_name, password_hash, attributes, created, updated FROM identity WHERE login = $1"
|
||||
"SELECT id, login, display_name, password_hash, attributes, frozen, created, updated FROM identity WHERE login = $1"
|
||||
).bind(login).fetch_optional(executor).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn find_by_oidc_subject<'e, E>(
|
||||
executor: E,
|
||||
issuer: &str,
|
||||
subject: &str,
|
||||
) -> Result<Option<Identity>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Identity>(
|
||||
"SELECT id, login, display_name, password_hash, attributes, frozen, created, updated
|
||||
FROM identity
|
||||
WHERE attributes->'oidc'->>'issuer' = $1
|
||||
AND attributes->'oidc'->>'sub' = $2",
|
||||
)
|
||||
.bind(issuer)
|
||||
.bind(subject)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn find_by_ldap_dn<'e, E>(
|
||||
executor: E,
|
||||
server_url: &str,
|
||||
dn: &str,
|
||||
) -> Result<Option<Identity>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, Identity>(
|
||||
"SELECT id, login, display_name, password_hash, attributes, frozen, created, updated
|
||||
FROM identity
|
||||
WHERE attributes->'ldap'->>'server_url' = $1
|
||||
AND attributes->'ldap'->>'dn' = $2",
|
||||
)
|
||||
.bind(server_url)
|
||||
.bind(dn)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
// Permission Set Repository
|
||||
@@ -321,6 +371,27 @@ impl PermissionSetRepository {
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn find_by_roles<'e, E>(executor: E, roles: &[String]) -> Result<Vec<PermissionSet>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
if roles.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
sqlx::query_as::<_, PermissionSet>(
|
||||
"SELECT DISTINCT ps.id, ps.ref, ps.pack, ps.pack_ref, ps.label, ps.description, ps.grants, ps.created, ps.updated
|
||||
FROM permission_set ps
|
||||
INNER JOIN permission_set_role_assignment psra ON psra.permset = ps.id
|
||||
WHERE psra.role = ANY($1)
|
||||
ORDER BY ps.ref ASC",
|
||||
)
|
||||
.bind(roles)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Delete permission sets belonging to a pack whose refs are NOT in the given set.
|
||||
///
|
||||
/// Used during pack reinstallation to clean up permission sets that were
|
||||
@@ -439,3 +510,231 @@ impl PermissionAssignmentRepository {
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IdentityRoleAssignmentRepository;
|
||||
|
||||
impl Repository for IdentityRoleAssignmentRepository {
|
||||
type Entity = IdentityRoleAssignment;
|
||||
fn table_name() -> &'static str {
|
||||
"identity_role_assignment"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreateIdentityRoleAssignmentInput {
|
||||
pub identity: Id,
|
||||
pub role: String,
|
||||
pub source: String,
|
||||
pub managed: bool,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl FindById for IdentityRoleAssignmentRepository {
|
||||
async fn find_by_id<'e, E>(executor: E, id: i64) -> Result<Option<Self::Entity>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, IdentityRoleAssignment>(
|
||||
"SELECT id, identity, role, source, managed, created, updated FROM identity_role_assignment WHERE id = $1"
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Create for IdentityRoleAssignmentRepository {
|
||||
type CreateInput = CreateIdentityRoleAssignmentInput;
|
||||
async fn create<'e, E>(executor: E, input: Self::CreateInput) -> Result<Self::Entity>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, IdentityRoleAssignment>(
|
||||
"INSERT INTO identity_role_assignment (identity, role, source, managed)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, identity, role, source, managed, created, updated",
|
||||
)
|
||||
.bind(input.identity)
|
||||
.bind(&input.role)
|
||||
.bind(&input.source)
|
||||
.bind(input.managed)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Delete for IdentityRoleAssignmentRepository {
|
||||
async fn delete<'e, E>(executor: E, id: i64) -> Result<bool>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM identity_role_assignment WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl IdentityRoleAssignmentRepository {
|
||||
pub async fn find_by_identity<'e, E>(
|
||||
executor: E,
|
||||
identity_id: Id,
|
||||
) -> Result<Vec<IdentityRoleAssignment>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, IdentityRoleAssignment>(
|
||||
"SELECT id, identity, role, source, managed, created, updated
|
||||
FROM identity_role_assignment
|
||||
WHERE identity = $1
|
||||
ORDER BY role ASC",
|
||||
)
|
||||
.bind(identity_id)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn find_role_names_by_identity<'e, E>(
|
||||
executor: E,
|
||||
identity_id: Id,
|
||||
) -> Result<Vec<String>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_scalar::<_, String>(
|
||||
"SELECT role FROM identity_role_assignment WHERE identity = $1 ORDER BY role ASC",
|
||||
)
|
||||
.bind(identity_id)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub async fn replace_managed_roles<'e, E>(
|
||||
executor: E,
|
||||
identity_id: Id,
|
||||
source: &str,
|
||||
roles: &[String],
|
||||
) -> Result<()>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + Copy + 'e,
|
||||
{
|
||||
sqlx::query(
|
||||
"DELETE FROM identity_role_assignment WHERE identity = $1 AND source = $2 AND managed = true",
|
||||
)
|
||||
.bind(identity_id)
|
||||
.bind(source)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
for role in roles {
|
||||
sqlx::query(
|
||||
"INSERT INTO identity_role_assignment (identity, role, source, managed)
|
||||
VALUES ($1, $2, $3, true)
|
||||
ON CONFLICT (identity, role) DO UPDATE
|
||||
SET source = EXCLUDED.source,
|
||||
managed = EXCLUDED.managed,
|
||||
updated = NOW()",
|
||||
)
|
||||
.bind(identity_id)
|
||||
.bind(role)
|
||||
.bind(source)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PermissionSetRoleAssignmentRepository;
|
||||
|
||||
impl Repository for PermissionSetRoleAssignmentRepository {
|
||||
type Entity = PermissionSetRoleAssignment;
|
||||
fn table_name() -> &'static str {
|
||||
"permission_set_role_assignment"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreatePermissionSetRoleAssignmentInput {
|
||||
pub permset: Id,
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl FindById for PermissionSetRoleAssignmentRepository {
|
||||
async fn find_by_id<'e, E>(executor: E, id: i64) -> Result<Option<Self::Entity>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, PermissionSetRoleAssignment>(
|
||||
"SELECT id, permset, role, created FROM permission_set_role_assignment WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Create for PermissionSetRoleAssignmentRepository {
|
||||
type CreateInput = CreatePermissionSetRoleAssignmentInput;
|
||||
async fn create<'e, E>(executor: E, input: Self::CreateInput) -> Result<Self::Entity>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, PermissionSetRoleAssignment>(
|
||||
"INSERT INTO permission_set_role_assignment (permset, role)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, permset, role, created",
|
||||
)
|
||||
.bind(input.permset)
|
||||
.bind(&input.role)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Delete for PermissionSetRoleAssignmentRepository {
|
||||
async fn delete<'e, E>(executor: E, id: i64) -> Result<bool>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM permission_set_role_assignment WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl PermissionSetRoleAssignmentRepository {
|
||||
pub async fn find_by_permission_set<'e, E>(
|
||||
executor: E,
|
||||
permset_id: Id,
|
||||
) -> Result<Vec<PermissionSetRoleAssignment>>
|
||||
where
|
||||
E: Executor<'e, Database = Postgres> + 'e,
|
||||
{
|
||||
sqlx::query_as::<_, PermissionSetRoleAssignment>(
|
||||
"SELECT id, permset, role, created
|
||||
FROM permission_set_role_assignment
|
||||
WHERE permset = $1
|
||||
ORDER BY role ASC",
|
||||
)
|
||||
.bind(permset_id)
|
||||
.fetch_all(executor)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user