167 lines
5.1 KiB
Rust
167 lines
5.1 KiB
Rust
//! Database connection and management
|
|
//!
|
|
//! This module provides database connection pooling and utilities for
|
|
//! interacting with the PostgreSQL database.
|
|
|
|
use sqlx::postgres::{PgPool, PgPoolOptions};
|
|
use std::time::Duration;
|
|
use tracing::{info, warn};
|
|
|
|
use crate::config::DatabaseConfig;
|
|
use crate::error::Result;
|
|
|
|
/// Database connection pool
|
|
#[derive(Debug, Clone)]
|
|
pub struct Database {
|
|
pool: PgPool,
|
|
schema: String,
|
|
}
|
|
|
|
impl Database {
|
|
/// Create a new database connection from configuration
|
|
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
|
|
// Default to "attune" schema for production safety
|
|
let schema = config
|
|
.schema
|
|
.clone()
|
|
.unwrap_or_else(|| "attune".to_string());
|
|
|
|
// Validate schema name (prevent SQL injection)
|
|
Self::validate_schema_name(&schema)?;
|
|
|
|
// Log schema configuration prominently
|
|
info!(
|
|
"Connecting to database with max_connections={}, schema={}",
|
|
config.max_connections, schema
|
|
);
|
|
|
|
// Clone schema for use in closure
|
|
let schema_for_hook = schema.clone();
|
|
|
|
let pool = PgPoolOptions::new()
|
|
.max_connections(config.max_connections)
|
|
.min_connections(config.min_connections)
|
|
.acquire_timeout(Duration::from_secs(config.connect_timeout))
|
|
.idle_timeout(Duration::from_secs(config.idle_timeout))
|
|
.after_connect(move |conn, _meta| {
|
|
let schema = schema_for_hook.clone();
|
|
Box::pin(async move {
|
|
// Set search_path for every connection in the pool
|
|
// Only include 'public' for production schemas (attune), not test schemas
|
|
// This ensures test schemas have isolated migrations tables
|
|
let search_path = if schema.starts_with("test_") {
|
|
format!("SET search_path TO {}", schema)
|
|
} else {
|
|
format!("SET search_path TO {}, public", schema)
|
|
};
|
|
sqlx::query(&search_path).execute(&mut *conn).await?;
|
|
Ok(())
|
|
})
|
|
})
|
|
.connect(&config.url)
|
|
.await?;
|
|
|
|
// Run a test query to verify connection
|
|
sqlx::query("SELECT 1").execute(&pool).await.map_err(|e| {
|
|
warn!("Failed to verify database connection: {}", e);
|
|
e
|
|
})?;
|
|
|
|
info!("Successfully connected to database");
|
|
|
|
Ok(Self { pool, schema })
|
|
}
|
|
|
|
/// Validate schema name to prevent SQL injection
|
|
fn validate_schema_name(schema: &str) -> Result<()> {
|
|
if schema.is_empty() {
|
|
return Err(crate::error::Error::Configuration(
|
|
"Schema name cannot be empty".to_string(),
|
|
));
|
|
}
|
|
|
|
// Only allow alphanumeric and underscores
|
|
if !schema.chars().all(|c| c.is_alphanumeric() || c == '_') {
|
|
return Err(crate::error::Error::Configuration(format!(
|
|
"Invalid schema name '{}': only alphanumeric and underscores allowed",
|
|
schema
|
|
)));
|
|
}
|
|
|
|
// Prevent excessively long names (PostgreSQL limit is 63 chars)
|
|
if schema.len() > 63 {
|
|
return Err(crate::error::Error::Configuration(format!(
|
|
"Schema name '{}' too long (max 63 characters)",
|
|
schema
|
|
)));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Get a reference to the connection pool
|
|
pub fn pool(&self) -> &PgPool {
|
|
&self.pool
|
|
}
|
|
|
|
/// Get the current schema name
|
|
pub fn schema(&self) -> &str {
|
|
&self.schema
|
|
}
|
|
|
|
/// Close the database connection pool
|
|
pub async fn close(&self) {
|
|
self.pool.close().await;
|
|
info!("Database connection pool closed");
|
|
}
|
|
|
|
/// Run database migrations
|
|
/// Note: Migrations should be in the workspace root migrations directory
|
|
pub async fn migrate(&self) -> Result<()> {
|
|
info!("Running database migrations");
|
|
// TODO: Implement migrations when migration files are created
|
|
// sqlx::migrate!("../../migrations")
|
|
// .run(&self.pool)
|
|
// .await?;
|
|
info!("Database migrations will be implemented with migration files");
|
|
Ok(())
|
|
}
|
|
|
|
/// Check if the database connection is healthy
|
|
pub async fn health_check(&self) -> Result<()> {
|
|
sqlx::query("SELECT 1").execute(&self.pool).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Get pool statistics
|
|
pub fn stats(&self) -> PoolStats {
|
|
PoolStats {
|
|
connections: self.pool.size(),
|
|
idle_connections: self.pool.num_idle(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Database pool statistics
|
|
#[derive(Debug, Clone)]
|
|
pub struct PoolStats {
|
|
pub connections: u32,
|
|
pub idle_connections: usize,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_pool_stats() {
|
|
// Test that PoolStats can be created
|
|
let stats = PoolStats {
|
|
connections: 10,
|
|
idle_connections: 5,
|
|
};
|
|
assert_eq!(stats.connections, 10);
|
|
assert_eq!(stats.idle_connections, 5);
|
|
}
|
|
}
|