re-uploading work
This commit is contained in:
61
crates/api/src/middleware/cors.rs
Normal file
61
crates/api/src/middleware/cors.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
//! CORS middleware configuration
|
||||
|
||||
use axum::http::{header, HeaderValue, Method};
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
|
||||
/// Create CORS layer configured from allowed origins
|
||||
///
|
||||
/// If no origins are provided, defaults to common development origins.
|
||||
/// Cannot use `allow_origin(Any)` with credentials enabled.
|
||||
pub fn create_cors_layer(allowed_origins: Vec<String>) -> CorsLayer {
|
||||
// Get the list of allowed origins
|
||||
let origins = if allowed_origins.is_empty() {
|
||||
// Default development origins
|
||||
vec![
|
||||
"http://localhost:3000".to_string(),
|
||||
"http://localhost:5173".to_string(),
|
||||
"http://localhost:8080".to_string(),
|
||||
"http://127.0.0.1:3000".to_string(),
|
||||
"http://127.0.0.1:5173".to_string(),
|
||||
"http://127.0.0.1:8080".to_string(),
|
||||
]
|
||||
} else {
|
||||
allowed_origins
|
||||
};
|
||||
|
||||
// Convert origins to HeaderValues for matching
|
||||
let allowed_origin_values: Arc<Vec<HeaderValue>> = Arc::new(
|
||||
origins
|
||||
.iter()
|
||||
.filter_map(|o| o.parse::<HeaderValue>().ok())
|
||||
.collect(),
|
||||
);
|
||||
|
||||
CorsLayer::new()
|
||||
// Allow common HTTP methods
|
||||
.allow_methods([
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::PATCH,
|
||||
Method::OPTIONS,
|
||||
])
|
||||
// Allow specific headers (required when using credentials)
|
||||
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
|
||||
// Expose headers to the frontend
|
||||
.expose_headers([
|
||||
header::AUTHORIZATION,
|
||||
header::CONTENT_TYPE,
|
||||
header::CONTENT_LENGTH,
|
||||
header::ACCEPT,
|
||||
])
|
||||
// Allow credentials (cookies, authorization headers)
|
||||
.allow_credentials(true)
|
||||
// Use predicate to match against allowed origins
|
||||
// Arc allows the closure to be called multiple times (preflight + actual request)
|
||||
.allow_origin(AllowOrigin::predicate(move |origin: &HeaderValue, _| {
|
||||
allowed_origin_values.contains(origin)
|
||||
}))
|
||||
}
|
||||
251
crates/api/src/middleware/error.rs
Normal file
251
crates/api/src/middleware/error.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
//! Error handling middleware and response types
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Standard API error response
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
/// Error message
|
||||
pub error: String,
|
||||
/// Optional error code
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub code: Option<String>,
|
||||
/// Optional additional details
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
/// Create a new error response
|
||||
pub fn new(error: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error: error.into(),
|
||||
code: None,
|
||||
details: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set error code
|
||||
pub fn with_code(mut self, code: impl Into<String>) -> Self {
|
||||
self.code = Some(code.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set error details
|
||||
pub fn with_details(mut self, details: serde_json::Value) -> Self {
|
||||
self.details = Some(details);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// API error type that can be converted to HTTP responses
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
/// Bad request (400)
|
||||
BadRequest(String),
|
||||
/// Unauthorized (401)
|
||||
Unauthorized(String),
|
||||
/// Forbidden (403)
|
||||
Forbidden(String),
|
||||
/// Not found (404)
|
||||
NotFound(String),
|
||||
/// Conflict (409)
|
||||
Conflict(String),
|
||||
/// Unprocessable entity (422)
|
||||
UnprocessableEntity(String),
|
||||
/// Too many requests (429)
|
||||
TooManyRequests(String),
|
||||
/// Internal server error (500)
|
||||
InternalServerError(String),
|
||||
/// Not implemented (501)
|
||||
NotImplemented(String),
|
||||
/// Database error
|
||||
DatabaseError(String),
|
||||
/// Validation error
|
||||
ValidationError(String),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Get the HTTP status code for this error
|
||||
pub fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
ApiError::BadRequest(_) => StatusCode::BAD_REQUEST,
|
||||
ApiError::Unauthorized(_) => StatusCode::UNAUTHORIZED,
|
||||
ApiError::Forbidden(_) => StatusCode::FORBIDDEN,
|
||||
ApiError::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
ApiError::Conflict(_) => StatusCode::CONFLICT,
|
||||
ApiError::UnprocessableEntity(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
ApiError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
ApiError::TooManyRequests(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||
ApiError::NotImplemented(_) => StatusCode::NOT_IMPLEMENTED,
|
||||
ApiError::InternalServerError(_) | ApiError::DatabaseError(_) => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the error message
|
||||
pub fn message(&self) -> &str {
|
||||
match self {
|
||||
ApiError::BadRequest(msg)
|
||||
| ApiError::Unauthorized(msg)
|
||||
| ApiError::Forbidden(msg)
|
||||
| ApiError::NotFound(msg)
|
||||
| ApiError::Conflict(msg)
|
||||
| ApiError::UnprocessableEntity(msg)
|
||||
| ApiError::TooManyRequests(msg)
|
||||
| ApiError::NotImplemented(msg)
|
||||
| ApiError::InternalServerError(msg)
|
||||
| ApiError::DatabaseError(msg)
|
||||
| ApiError::ValidationError(msg) => msg,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the error code
|
||||
pub fn code(&self) -> &str {
|
||||
match self {
|
||||
ApiError::BadRequest(_) => "BAD_REQUEST",
|
||||
ApiError::Unauthorized(_) => "UNAUTHORIZED",
|
||||
ApiError::Forbidden(_) => "FORBIDDEN",
|
||||
ApiError::NotFound(_) => "NOT_FOUND",
|
||||
ApiError::Conflict(_) => "CONFLICT",
|
||||
ApiError::UnprocessableEntity(_) => "UNPROCESSABLE_ENTITY",
|
||||
ApiError::TooManyRequests(_) => "TOO_MANY_REQUESTS",
|
||||
ApiError::NotImplemented(_) => "NOT_IMPLEMENTED",
|
||||
ApiError::ValidationError(_) => "VALIDATION_ERROR",
|
||||
ApiError::DatabaseError(_) => "DATABASE_ERROR",
|
||||
ApiError::InternalServerError(_) => "INTERNAL_SERVER_ERROR",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ApiError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.message())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ApiError {}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status_code();
|
||||
let error_response = ErrorResponse::new(self.message()).with_code(self.code());
|
||||
|
||||
(status, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from common error types
|
||||
impl From<sqlx::Error> for ApiError {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
match err {
|
||||
sqlx::Error::RowNotFound => ApiError::NotFound("Resource not found".to_string()),
|
||||
sqlx::Error::Database(db_err) => {
|
||||
// Check for unique constraint violations
|
||||
if let Some(constraint) = db_err.constraint() {
|
||||
ApiError::Conflict(format!("Constraint violation: {}", constraint))
|
||||
} else {
|
||||
ApiError::DatabaseError(format!("Database error: {}", db_err))
|
||||
}
|
||||
}
|
||||
_ => ApiError::DatabaseError(format!("Database error: {}", err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<attune_common::error::Error> for ApiError {
|
||||
fn from(err: attune_common::error::Error) -> Self {
|
||||
match err {
|
||||
attune_common::error::Error::NotFound {
|
||||
entity,
|
||||
field,
|
||||
value,
|
||||
} => ApiError::NotFound(format!("{} with {}={} not found", entity, field, value)),
|
||||
attune_common::error::Error::AlreadyExists {
|
||||
entity,
|
||||
field,
|
||||
value,
|
||||
} => ApiError::Conflict(format!(
|
||||
"{} with {}={} already exists",
|
||||
entity, field, value
|
||||
)),
|
||||
attune_common::error::Error::Validation(msg) => ApiError::BadRequest(msg),
|
||||
attune_common::error::Error::SchemaValidation(msg) => ApiError::BadRequest(msg),
|
||||
attune_common::error::Error::Database(err) => ApiError::from(err),
|
||||
attune_common::error::Error::InvalidState(msg) => ApiError::BadRequest(msg),
|
||||
attune_common::error::Error::PermissionDenied(msg) => ApiError::Forbidden(msg),
|
||||
attune_common::error::Error::AuthenticationFailed(msg) => ApiError::Unauthorized(msg),
|
||||
attune_common::error::Error::Configuration(msg) => ApiError::InternalServerError(msg),
|
||||
attune_common::error::Error::Serialization(err) => {
|
||||
ApiError::InternalServerError(format!("{}", err))
|
||||
}
|
||||
attune_common::error::Error::Io(msg)
|
||||
| attune_common::error::Error::Encryption(msg)
|
||||
| attune_common::error::Error::Timeout(msg)
|
||||
| attune_common::error::Error::ExternalService(msg)
|
||||
| attune_common::error::Error::Worker(msg)
|
||||
| attune_common::error::Error::Execution(msg)
|
||||
| attune_common::error::Error::Internal(msg) => ApiError::InternalServerError(msg),
|
||||
attune_common::error::Error::Other(err) => {
|
||||
ApiError::InternalServerError(format!("{}", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<validator::ValidationErrors> for ApiError {
|
||||
fn from(err: validator::ValidationErrors) -> Self {
|
||||
ApiError::ValidationError(format!("Validation failed: {}", err))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::auth::jwt::JwtError> for ApiError {
|
||||
fn from(err: crate::auth::jwt::JwtError) -> Self {
|
||||
match err {
|
||||
crate::auth::jwt::JwtError::Expired => {
|
||||
ApiError::Unauthorized("Token has expired".to_string())
|
||||
}
|
||||
crate::auth::jwt::JwtError::Invalid => {
|
||||
ApiError::Unauthorized("Invalid token".to_string())
|
||||
}
|
||||
crate::auth::jwt::JwtError::EncodeError(msg) => {
|
||||
ApiError::InternalServerError(format!("Failed to encode token: {}", msg))
|
||||
}
|
||||
crate::auth::jwt::JwtError::DecodeError(msg) => {
|
||||
ApiError::Unauthorized(format!("Failed to decode token: {}", msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::auth::password::PasswordError> for ApiError {
|
||||
fn from(err: crate::auth::password::PasswordError) -> Self {
|
||||
match err {
|
||||
crate::auth::password::PasswordError::HashError(msg) => {
|
||||
ApiError::InternalServerError(format!("Failed to hash password: {}", msg))
|
||||
}
|
||||
crate::auth::password::PasswordError::VerifyError(msg) => {
|
||||
ApiError::InternalServerError(format!("Failed to verify password: {}", msg))
|
||||
}
|
||||
crate::auth::password::PasswordError::InvalidHash => {
|
||||
ApiError::InternalServerError("Invalid password hash format".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::num::ParseIntError> for ApiError {
|
||||
fn from(err: std::num::ParseIntError) -> Self {
|
||||
ApiError::BadRequest(format!("Invalid number format: {}", err))
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type alias for API handlers
|
||||
pub type ApiResult<T> = Result<T, ApiError>;
|
||||
54
crates/api/src/middleware/logging.rs
Normal file
54
crates/api/src/middleware/logging.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
//! Request/Response logging middleware
|
||||
|
||||
use axum::{extract::Request, middleware::Next, response::Response};
|
||||
use std::time::Instant;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Middleware for logging HTTP requests and responses
|
||||
pub async fn log_request(req: Request, next: Next) -> Response {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let version = req.version();
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
info!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
version = ?version,
|
||||
"request started"
|
||||
);
|
||||
|
||||
let response = next.run(req).await;
|
||||
|
||||
let duration = start.elapsed();
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
info!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"request completed"
|
||||
);
|
||||
} else if status.is_client_error() {
|
||||
warn!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"request failed (client error)"
|
||||
);
|
||||
} else if status.is_server_error() {
|
||||
warn!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"request failed (server error)"
|
||||
);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
9
crates/api/src/middleware/mod.rs
Normal file
9
crates/api/src/middleware/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! Middleware modules for the API service
|
||||
|
||||
pub mod cors;
|
||||
pub mod error;
|
||||
pub mod logging;
|
||||
|
||||
pub use cors::create_cors_layer;
|
||||
pub use error::{ApiError, ApiResult};
|
||||
pub use logging::log_request;
|
||||
Reference in New Issue
Block a user