proper sql filtering
This commit is contained in:
@@ -18,9 +18,10 @@ use attune_common::models::enums::ExecutionStatus;
|
||||
use attune_common::mq::{ExecutionRequestedPayload, MessageEnvelope, MessageType};
|
||||
use attune_common::repositories::{
|
||||
action::ActionRepository,
|
||||
execution::{CreateExecutionInput, ExecutionRepository},
|
||||
Create, EnforcementRepository, FindById, FindByRef, List,
|
||||
execution::{CreateExecutionInput, ExecutionRepository, ExecutionSearchFilters},
|
||||
Create, FindById, FindByRef,
|
||||
};
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::{
|
||||
auth::middleware::RequireAuth,
|
||||
@@ -125,117 +126,37 @@ pub async fn list_executions(
|
||||
RequireAuth(_user): RequireAuth,
|
||||
Query(query): Query<ExecutionQueryParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get executions based on filters
|
||||
let executions = if let Some(status) = query.status {
|
||||
// Filter by status
|
||||
ExecutionRepository::find_by_status(&state.db, status).await?
|
||||
} else if let Some(enforcement_id) = query.enforcement {
|
||||
// Filter by enforcement
|
||||
ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?
|
||||
} else {
|
||||
// Get all executions
|
||||
ExecutionRepository::list(&state.db).await?
|
||||
// All filtering, pagination, and the enforcement JOIN happen in a single
|
||||
// SQL query — no in-memory filtering or post-fetch lookups.
|
||||
let filters = ExecutionSearchFilters {
|
||||
status: query.status,
|
||||
action_ref: query.action_ref.clone(),
|
||||
pack_name: query.pack_name.clone(),
|
||||
rule_ref: query.rule_ref.clone(),
|
||||
trigger_ref: query.trigger_ref.clone(),
|
||||
executor: query.executor,
|
||||
result_contains: query.result_contains.clone(),
|
||||
enforcement: query.enforcement,
|
||||
parent: query.parent,
|
||||
top_level_only: query.top_level_only == Some(true),
|
||||
limit: query.limit(),
|
||||
offset: query.offset(),
|
||||
};
|
||||
|
||||
// Apply additional filters in memory (could be optimized with database queries)
|
||||
let mut filtered_executions = executions;
|
||||
let result = ExecutionRepository::search(&state.db, &filters).await?;
|
||||
|
||||
if let Some(action_ref) = &query.action_ref {
|
||||
filtered_executions.retain(|e| e.action_ref == *action_ref);
|
||||
}
|
||||
|
||||
if let Some(pack_name) = &query.pack_name {
|
||||
filtered_executions.retain(|e| {
|
||||
// action_ref format is "pack.action"
|
||||
e.action_ref.starts_with(&format!("{}.", pack_name))
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(result_search) = &query.result_contains {
|
||||
let search_lower = result_search.to_lowercase();
|
||||
filtered_executions.retain(|e| {
|
||||
if let Some(result) = &e.result {
|
||||
// Convert result to JSON string and search case-insensitively
|
||||
let result_str = serde_json::to_string(result).unwrap_or_default();
|
||||
result_str.to_lowercase().contains(&search_lower)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(parent_id) = query.parent {
|
||||
filtered_executions.retain(|e| e.parent == Some(parent_id));
|
||||
}
|
||||
|
||||
if query.top_level_only == Some(true) {
|
||||
filtered_executions.retain(|e| e.parent.is_none());
|
||||
}
|
||||
|
||||
if let Some(executor_id) = query.executor {
|
||||
filtered_executions.retain(|e| e.executor == Some(executor_id));
|
||||
}
|
||||
|
||||
// Fetch enforcements for all executions to populate rule_ref and trigger_ref
|
||||
let enforcement_ids: Vec<i64> = filtered_executions
|
||||
.iter()
|
||||
.filter_map(|e| e.enforcement)
|
||||
let paginated_executions: Vec<ExecutionSummary> = result
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(ExecutionSummary::from)
|
||||
.collect();
|
||||
|
||||
let enforcement_map: std::collections::HashMap<i64, _> = if !enforcement_ids.is_empty() {
|
||||
let enforcements = EnforcementRepository::list(&state.db).await?;
|
||||
enforcements.into_iter().map(|enf| (enf.id, enf)).collect()
|
||||
} else {
|
||||
std::collections::HashMap::new()
|
||||
};
|
||||
|
||||
// Filter by rule_ref if specified
|
||||
if let Some(rule_ref) = &query.rule_ref {
|
||||
filtered_executions.retain(|e| {
|
||||
e.enforcement
|
||||
.and_then(|enf_id| enforcement_map.get(&enf_id))
|
||||
.map(|enf| enf.rule_ref == *rule_ref)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Filter by trigger_ref if specified
|
||||
if let Some(trigger_ref) = &query.trigger_ref {
|
||||
filtered_executions.retain(|e| {
|
||||
e.enforcement
|
||||
.and_then(|enf_id| enforcement_map.get(&enf_id))
|
||||
.map(|enf| enf.trigger_ref == *trigger_ref)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate pagination
|
||||
let total = filtered_executions.len() as u64;
|
||||
let start = query.offset() as usize;
|
||||
let end = (start + query.limit() as usize).min(filtered_executions.len());
|
||||
|
||||
// Get paginated slice and populate rule_ref/trigger_ref from enforcements
|
||||
let paginated_executions: Vec<ExecutionSummary> = filtered_executions[start..end]
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let mut summary = ExecutionSummary::from(e.clone());
|
||||
if let Some(enf_id) = e.enforcement {
|
||||
if let Some(enforcement) = enforcement_map.get(&enf_id) {
|
||||
summary.rule_ref = Some(enforcement.rule_ref.clone());
|
||||
summary.trigger_ref = Some(enforcement.trigger_ref.clone());
|
||||
}
|
||||
}
|
||||
summary
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Convert query params to pagination params for response
|
||||
let pagination_params = PaginationParams {
|
||||
page: query.page,
|
||||
page_size: query.per_page,
|
||||
};
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination_params, total);
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination_params, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -310,21 +231,23 @@ pub async fn list_executions_by_status(
|
||||
}
|
||||
};
|
||||
|
||||
// Get executions by status
|
||||
let executions = ExecutionRepository::find_by_status(&state.db, status).await?;
|
||||
// Use the search method for SQL-side filtering + pagination.
|
||||
let filters = ExecutionSearchFilters {
|
||||
status: Some(status),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = executions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(executions.len());
|
||||
let result = ExecutionRepository::search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
|
||||
.iter()
|
||||
.map(|e| ExecutionSummary::from(e.clone()))
|
||||
let paginated_executions: Vec<ExecutionSummary> = result
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(ExecutionSummary::from)
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -350,21 +273,23 @@ pub async fn list_executions_by_enforcement(
|
||||
Path(enforcement_id): Path<i64>,
|
||||
Query(pagination): Query<PaginationParams>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get executions by enforcement
|
||||
let executions = ExecutionRepository::find_by_enforcement(&state.db, enforcement_id).await?;
|
||||
// Use the search method for SQL-side filtering + pagination.
|
||||
let filters = ExecutionSearchFilters {
|
||||
enforcement: Some(enforcement_id),
|
||||
limit: pagination.limit(),
|
||||
offset: pagination.offset(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Calculate pagination
|
||||
let total = executions.len() as u64;
|
||||
let start = ((pagination.page - 1) * pagination.limit()) as usize;
|
||||
let end = (start + pagination.limit() as usize).min(executions.len());
|
||||
let result = ExecutionRepository::search(&state.db, &filters).await?;
|
||||
|
||||
// Get paginated slice
|
||||
let paginated_executions: Vec<ExecutionSummary> = executions[start..end]
|
||||
.iter()
|
||||
.map(|e| ExecutionSummary::from(e.clone()))
|
||||
let paginated_executions: Vec<ExecutionSummary> = result
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(ExecutionSummary::from)
|
||||
.collect();
|
||||
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, total);
|
||||
let response = PaginatedResponse::new(paginated_executions, &pagination, result.total);
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
@@ -384,34 +309,37 @@ pub async fn get_execution_stats(
|
||||
State(state): State<Arc<AppState>>,
|
||||
RequireAuth(_user): RequireAuth,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
// Get all executions (limited by repository to 1000)
|
||||
let executions = ExecutionRepository::list(&state.db).await?;
|
||||
// Use a single SQL query with COUNT + GROUP BY instead of fetching all rows.
|
||||
let rows = sqlx::query(
|
||||
"SELECT status::text AS status, COUNT(*) AS cnt FROM execution GROUP BY status",
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
// Calculate statistics
|
||||
let total = executions.len();
|
||||
let completed = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Completed)
|
||||
.count();
|
||||
let failed = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Failed)
|
||||
.count();
|
||||
let running = executions
|
||||
.iter()
|
||||
.filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Running)
|
||||
.count();
|
||||
let pending = executions
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
matches!(
|
||||
e.status,
|
||||
attune_common::models::enums::ExecutionStatus::Requested
|
||||
| attune_common::models::enums::ExecutionStatus::Scheduling
|
||||
| attune_common::models::enums::ExecutionStatus::Scheduled
|
||||
)
|
||||
})
|
||||
.count();
|
||||
let mut completed: i64 = 0;
|
||||
let mut failed: i64 = 0;
|
||||
let mut running: i64 = 0;
|
||||
let mut pending: i64 = 0;
|
||||
let mut cancelled: i64 = 0;
|
||||
let mut timeout: i64 = 0;
|
||||
let mut abandoned: i64 = 0;
|
||||
let mut total: i64 = 0;
|
||||
|
||||
for row in &rows {
|
||||
let status: &str = row.get("status");
|
||||
let cnt: i64 = row.get("cnt");
|
||||
total += cnt;
|
||||
match status {
|
||||
"completed" => completed = cnt,
|
||||
"failed" => failed = cnt,
|
||||
"running" => running = cnt,
|
||||
"requested" | "scheduling" | "scheduled" => pending += cnt,
|
||||
"cancelled" | "canceling" => cancelled += cnt,
|
||||
"timeout" => timeout = cnt,
|
||||
"abandoned" => abandoned = cnt,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let stats = serde_json::json!({
|
||||
"total": total,
|
||||
@@ -419,9 +347,9 @@ pub async fn get_execution_stats(
|
||||
"failed": failed,
|
||||
"running": running,
|
||||
"pending": pending,
|
||||
"cancelled": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Cancelled).count(),
|
||||
"timeout": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Timeout).count(),
|
||||
"abandoned": executions.iter().filter(|e| e.status == attune_common::models::enums::ExecutionStatus::Abandoned).count(),
|
||||
"cancelled": cancelled,
|
||||
"timeout": timeout,
|
||||
"abandoned": abandoned,
|
||||
});
|
||||
|
||||
let response = ApiResponse::new(stats);
|
||||
|
||||
Reference in New Issue
Block a user