proper sql filtering

This commit is contained in:
2026-03-01 20:43:48 -06:00
parent 6b9d7d6cf2
commit bbe94d75f8
54 changed files with 6692 additions and 928 deletions

View File

@@ -18,9 +18,10 @@ use attune_common::models::enums::ExecutionStatus;
use attune_common::mq::{ExecutionRequestedPayload, MessageEnvelope, MessageType};
use attune_common::repositories::{
action::ActionRepository,
execution::{CreateExecutionInput, ExecutionRepository},
Create, EnforcementRepository, FindById, FindByRef, List,
execution::{CreateExecutionInput, ExecutionRepository, ExecutionSearchFilters},
Create, FindById, FindByRef,
};
use sqlx::Row;
use crate::{
auth::middleware::RequireAuth,
@@ -125,117 +126,37 @@ pub async fn list_executions(
RequireAuth(_user): RequireAuth,
Query(query): Query<ExecutionQueryParams>,
) -> ApiResult<impl IntoResponse> {
// Get executions based on filters
let executions = if let Some(status) = query.status {
// Filter by status
ExecutionRepository::find_by_status(&state.db, status).await?
} else if let Some(enforcement_id) = query.enforcement {
// Filter by enforcement
ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?
} else {
// Get all executions
ExecutionRepository::list(&state.db).await?
// All filtering, pagination, and the enforcement JOIN happen in a single
// SQL query — no in-memory filtering or post-fetch lookups.
let filters = ExecutionSearchFilters {
status: query.status,
action_ref: query.action_ref.clone(),
pack_name: query.pack_name.clone(),
rule_ref: query.rule_ref.clone(),
trigger_ref: query.trigger_ref.clone(),
executor: query.executor,
result_contains: query.result_contains.clone(),
enforcement: query.enforcement,
parent: query.parent,
top_level_only: query.top_level_only == Some(true),
limit: query.limit(),
offset: query.offset(),
};
// Apply additional filters in memory (could be optimized with database queries)
let mut filtered_executions = executions;
let result = ExecutionRepository::search(&state.db, &filters).await?;
if let Some(action_ref) = &query.action_ref {
filtered_executions.retain(|e| e.action_ref == *action_ref);
}
if let Some(pack_name) = &query.pack_name {
filtered_executions.retain(|e| {
// action_ref format is "pack.action"
e.action_ref.starts_with(&format!("{}.", pack_name))
});
}
if let Some(result_search) = &query.result_contains {
let search_lower = result_search.to_lowercase();
filtered_executions.retain(|e| {
if let Some(result) = &e.result {
// Convert result to JSON string and search case-insensitively
let result_str = serde_json::to_string(result).unwrap_or_default();
result_str.to_lowercase().contains(&search_lower)
} else {
false
}
});
}
if let Some(parent_id) = query.parent {
filtered_executions.retain(|e| e.parent == Some(parent_id));
}
if query.top_level_only == Some(true) {
filtered_executions.retain(|e| e.parent.is_none());
}
if let Some(executor_id) = query.executor {
filtered_executions.retain(|e| e.executor == Some(executor_id));
}
// Fetch enforcements for all executions to populate rule_ref and trigger_ref
let enforcement_ids: Vec<i64> = filtered_executions
.iter()
.filter_map(|e| e.enforcement)
let paginated_executions: Vec<ExecutionSummary> = result
.rows
.into_iter()
.map(ExecutionSummary::from)
.collect();
let enforcement_map: std::collections::HashMap<i64, _> = if !enforcement_ids.is_empty() {
let enforcements = EnforcementRepository::list(&state.db).await?;
enforcements.into_iter().map(|enf| (enf.id, enf)).collect()
} else {
std::collections::HashMap::new()
};
// Filter by rule_ref if specified
if let Some(rule_ref) = &query.rule_ref {
filtered_executions.retain(|e| {
e.enforcement
.and_then(|enf_id| enforcement_map.get(&enf_id))
.map(|enf| enf.rule_ref == *rule_ref)
.unwrap_or(false)
});
}
// Filter by trigger_ref if specified
if let Some(trigger_ref) = &query.trigger_ref {
filtered_executions.retain(|e| {
e.enforcement
.and_then(|enf_id| enforcement_map.get(&enf_id))
.map(|enf| enf.trigger_ref == *trigger_ref)
.unwrap_or(false)
});
}
// Calculate pagination
let total = filtered_executions.len() as u64;
let start = query.offset() as usize;
let end = (start + query.limit() as usize).min(filtered_executions.len());
// Get paginated slice and populate rule_ref/trigger_ref from enforcements
let paginated_executions: Vec<ExecutionSummary> = filtered_executions[start..end]
.iter()
.map(|e| {
let mut summary = ExecutionSummary::from(e.clone());
if let Some(enf_id) = e.enforcement {
if let Some(enforcement) = enforcement_map.get(&enf_id) {
summary.rule_ref = Some(enforcement.rule_ref.clone());
summary.trigger_ref = Some(enforcement.trigger_ref.clone());
}
}
summary
})
.collect();
// Convert query params to pagination params for response
let pagination_params = PaginationParams {
page: query.page,
page_size: query.per_page,
};
let response = PaginatedResponse::new(paginated_executions, &pagination_params, total);
let response = PaginatedResponse::new(paginated_executions, &pagination_params, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -310,21 +231,23 @@ pub async fn list_executions_by_status(
}
};
// Get executions by status
let executions = ExecutionRepository::find_by_status(&state.db, status).await?;
// Use the search method for SQL-side filtering + pagination.
let filters = ExecutionSearchFilters {
status: Some(status),
limit: pagination.limit(),
offset: pagination.offset(),
..Default::default()
};
// Calculate pagination
let total = executions.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(executions.len());
let result = ExecutionRepository::search(&state.db, &filters).await?;
// Get paginated slice
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
.iter()
.map(|e| ExecutionSummary::from(e.clone()))
let paginated_executions: Vec<ExecutionSummary> = result
.rows
.into_iter()
.map(ExecutionSummary::from)
.collect();
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -350,21 +273,23 @@ pub async fn list_executions_by_enforcement(
Path(enforcement_id): Path<i64>,
Query(pagination): Query<PaginationParams>,
) -> ApiResult<impl IntoResponse> {
// Get executions by enforcement
let executions = ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?;
// Use the search method for SQL-side filtering + pagination.
let filters = ExecutionSearchFilters {
enforcement: Some(enforcement_id),
limit: pagination.limit(),
offset: pagination.offset(),
..Default::default()
};
// Calculate pagination
let total = executions.len() as u64;
let start = ((pagination.page - 1) * pagination.limit()) as usize;
let end = (start + pagination.limit() as usize).min(executions.len());
let result = ExecutionRepository::search(&state.db, &filters).await?;
// Get paginated slice
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
.iter()
.map(|e| ExecutionSummary::from(e.clone()))
let paginated_executions: Vec<ExecutionSummary> = result
.rows
.into_iter()
.map(ExecutionSummary::from)
.collect();
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
Ok((StatusCode::OK, Json(response)))
}
@@ -384,34 +309,37 @@ pub async fn get_execution_stats(
State(state): State<Arc<AppState>>,
RequireAuth(_user): RequireAuth,
) -> ApiResult<impl IntoResponse> {
// Get all executions (limited by repository to 1000)
let executions = ExecutionRepository::list(&state.db).await?;
// Use a single SQL query with COUNT + GROUP BY instead of fetching all rows.
let rows = sqlx::query(
"SELECT status::text AS status, COUNT(*) AS cnt FROM execution GROUP BY status",
)
.fetch_all(&state.db)
.await?;
// Calculate statistics
let total = executions.len();
let completed = executions
.iter()
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Completed)
.count();
let failed = executions
.iter()
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Failed)
.count();
let running = executions
.iter()
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Running)
.count();
let pending = executions
.iter()
.filter(|e| {
matches!(
e.status,
attune_common::models::enums::ExecutionStatus::Requested
| attune_common::models::enums::ExecutionStatus::Scheduling
| attune_common::models::enums::ExecutionStatus::Scheduled
)
})
.count();
let mut completed: i64 = 0;
let mut failed: i64 = 0;
let mut running: i64 = 0;
let mut pending: i64 = 0;
let mut cancelled: i64 = 0;
let mut timeout: i64 = 0;
let mut abandoned: i64 = 0;
let mut total: i64 = 0;
for row in &rows {
let status: &str = row.get("status");
let cnt: i64 = row.get("cnt");
total += cnt;
match status {
"completed" => completed = cnt,
"failed" => failed = cnt,
"running" => running = cnt,
"requested" | "scheduling" | "scheduled" => pending += cnt,
"cancelled" | "canceling" => cancelled += cnt,
"timeout" => timeout = cnt,
"abandoned" => abandoned = cnt,
_ => {}
}
}
let stats = serde_json::json!({
"total": total,
@@ -419,9 +347,9 @@ pub async fn get_execution_stats(
"failed": failed,
"running": running,
"pending": pending,
"cancelled": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Cancelled).count(),
"timeout": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Timeout).count(),
"abandoned": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Abandoned).count(),
"cancelled": cancelled,
"timeout": timeout,
"abandoned": abandoned,
});
let response = ApiResponse::new(stats);