re-uploading work

This commit is contained in:
2026-02-04 17:46:30 -06:00
commit 3b14c65998
1388 changed files with 381262 additions and 0 deletions

View File

@@ -0,0 +1,62 @@
[package]
name = "attune-notifier"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
[[bin]]
name = "attune-notifier"
path = "src/main.rs"
[dependencies]
attune-common = { path = "../common" }
# Async runtime
tokio = { workspace = true }
tokio-util = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
# Database
sqlx = { workspace = true }
# Web framework & WebSocket
axum = { workspace = true, features = ["ws"] }
tower = { workspace = true }
tower-http = { workspace = true }
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
# Logging and tracing
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
# Error handling
anyhow = { workspace = true }
thiserror = { workspace = true }
# Configuration
config = { workspace = true }
# Date/Time
chrono = { workspace = true }
# CLI
clap = { workspace = true }
# Redis (optional, for distributed notifications)
redis = { workspace = true }
# UUID
uuid = { workspace = true }
# Concurrent data structures
dashmap = { workspace = true }
[dev-dependencies]
mockall = { workspace = true }
tempfile = { workspace = true }

128
crates/notifier/src/main.rs Normal file
View File

@@ -0,0 +1,128 @@
//! Attune Notifier Service - Real-time notification delivery
use anyhow::Result;
use attune_common::config::Config;
use clap::Parser;
use tracing::{error, info};
mod postgres_listener;
mod service;
mod subscriber_manager;
mod websocket_server;
use service::NotifierService;
#[derive(Parser, Debug)]
#[command(name = "attune-notifier")]
#[command(about = "Attune Notifier Service - Real-time notifications", long_about = None)]
struct Args {
/// Path to configuration file
#[arg(short, long)]
config: Option<String>,
/// Log level (trace, debug, info, warn, error)
#[arg(short, long, default_value = "info")]
log_level: String,
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
// Initialize tracing with specified log level
let log_level = args
.log_level
.parse::<tracing::Level>()
.unwrap_or(tracing::Level::INFO);
tracing_subscriber::fmt()
.with_max_level(log_level)
.with_target(false)
.with_thread_ids(true)
.init();
info!("Starting Attune Notifier Service");
// Load configuration
if let Some(config_path) = args.config {
std::env::set_var("ATTUNE_CONFIG", config_path);
}
let config = Config::load()?;
config.validate()?;
info!("Configuration loaded successfully");
info!("Environment: {}", config.environment);
info!("Database: {}", mask_password(&config.database.url));
let notifier_config = config
.notifier
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Notifier configuration not found in config file"))?;
info!(
"Listening on: {}:{}",
notifier_config.host, notifier_config.port
);
// Create and start the notifier service
let service = NotifierService::new(config).await?;
info!("Notifier Service initialized successfully");
// Set up graceful shutdown handler
let service_clone = std::sync::Arc::new(service);
let service_for_shutdown = service_clone.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen for Ctrl+C");
info!("Received shutdown signal");
if let Err(e) = service_for_shutdown.shutdown().await {
error!("Error during shutdown: {}", e);
}
});
// Start the service (blocks until shutdown)
if let Err(e) = service_clone.start().await {
error!("Notifier service error: {}", e);
return Err(e);
}
info!("Attune Notifier Service stopped");
Ok(())
}
/// Mask password in database URL for logging
fn mask_password(url: &str) -> String {
if let Some(at_pos) = url.rfind('@') {
if let Some(colon_pos) = url[..at_pos].rfind(':') {
let mut masked = url.to_string();
masked.replace_range(colon_pos + 1..at_pos, "****");
return masked;
}
}
url.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mask_password() {
let url = "postgresql://user:password@localhost:5432/db";
let masked = mask_password(url);
assert_eq!(masked, "postgresql://user:****@localhost:5432/db");
}
#[test]
fn test_mask_password_no_password() {
let url = "postgresql://localhost:5432/db";
let masked = mask_password(url);
assert_eq!(masked, "postgresql://localhost:5432/db");
}
}

View File

@@ -0,0 +1,232 @@
//! PostgreSQL LISTEN/NOTIFY integration for real-time notifications
use anyhow::{Context, Result};
use sqlx::postgres::PgListener;
use tokio::sync::broadcast;
use tracing::{debug, error, info, warn};
use crate::service::Notification;
/// Channels to listen on for PostgreSQL notifications
const NOTIFICATION_CHANNELS: &[&str] = &[
"attune_notifications",
"execution_status_changed",
"execution_created",
"inquiry_created",
"inquiry_responded",
"enforcement_created",
"event_created",
"workflow_execution_status_changed",
];
/// PostgreSQL listener that receives NOTIFY events and broadcasts them
pub struct PostgresListener {
database_url: String,
notification_tx: broadcast::Sender<Notification>,
}
impl PostgresListener {
/// Create a new PostgreSQL listener
pub async fn new(
database_url: String,
notification_tx: broadcast::Sender<Notification>,
) -> Result<Self> {
Ok(Self {
database_url,
notification_tx,
})
}
/// Start listening for PostgreSQL notifications
pub async fn listen(&self) -> Result<()> {
info!(
"Starting PostgreSQL LISTEN on channels: {:?}",
NOTIFICATION_CHANNELS
);
// Create a dedicated listener connection
let mut listener = PgListener::connect(&self.database_url)
.await
.context("Failed to connect PostgreSQL listener")?;
// Listen on all notification channels
for channel in NOTIFICATION_CHANNELS {
listener
.listen(channel)
.await
.context(format!("Failed to LISTEN on channel '{}'", channel))?;
info!("Listening on PostgreSQL channel: {}", channel);
}
// Process notifications in a loop
loop {
match listener.recv().await {
Ok(pg_notification) => {
debug!(
"Received PostgreSQL notification: channel={}, payload={}",
pg_notification.channel(),
pg_notification.payload()
);
// Parse and broadcast notification
if let Err(e) = self
.process_notification(pg_notification.channel(), pg_notification.payload())
{
error!(
"Failed to process notification from channel '{}': {}",
pg_notification.channel(),
e
);
}
}
Err(e) => {
error!("Error receiving PostgreSQL notification: {}", e);
// Sleep briefly before retrying to avoid tight loop on persistent errors
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// Try to reconnect
warn!("Attempting to reconnect PostgreSQL listener...");
match PgListener::connect(&self.database_url).await {
Ok(new_listener) => {
listener = new_listener;
// Re-subscribe to all channels
for channel in NOTIFICATION_CHANNELS {
if let Err(e) = listener.listen(channel).await {
error!(
"Failed to re-subscribe to channel '{}': {}",
channel, e
);
}
}
info!("PostgreSQL listener reconnected successfully");
}
Err(e) => {
error!("Failed to reconnect PostgreSQL listener: {}", e);
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
}
}
}
}
}
/// Process a PostgreSQL notification and broadcast it to WebSocket clients
fn process_notification(&self, channel: &str, payload: &str) -> Result<()> {
// Parse the JSON payload
let payload_json: serde_json::Value = serde_json::from_str(payload)
.context("Failed to parse notification payload as JSON")?;
// Extract common fields
let entity_type = payload_json
.get("entity_type")
.and_then(|v| v.as_str())
.context("Missing 'entity_type' in notification payload")?
.to_string();
let entity_id = payload_json
.get("entity_id")
.and_then(|v| v.as_i64())
.context("Missing 'entity_id' in notification payload")?;
let user_id = payload_json.get("user_id").and_then(|v| v.as_i64());
// Create notification
let notification = Notification {
notification_type: channel.to_string(),
entity_type,
entity_id,
user_id,
payload: payload_json,
timestamp: chrono::Utc::now(),
};
// Broadcast to all subscribers (ignore errors if no receivers)
match self.notification_tx.send(notification) {
Ok(receiver_count) => {
debug!(
"Broadcast notification to {} receivers: type={}, entity_id={}",
receiver_count, channel, entity_id
);
}
Err(_) => {
// No active receivers, this is fine
debug!("No active receivers for notification: type={}", channel);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_notification_channels_defined() {
assert!(!NOTIFICATION_CHANNELS.is_empty());
assert!(NOTIFICATION_CHANNELS.contains(&"execution_status_changed"));
assert!(NOTIFICATION_CHANNELS.contains(&"inquiry_created"));
}
#[test]
fn test_process_notification_valid_payload() {
let (tx, mut rx) = broadcast::channel(10);
let listener = PostgresListener {
database_url: "postgresql://test".to_string(),
notification_tx: tx,
};
let payload = serde_json::json!({
"entity_type": "execution",
"entity_id": 123,
"user_id": 456,
"status": "succeeded"
});
let result =
listener.process_notification("execution_status_changed", &payload.to_string());
assert!(result.is_ok());
// Should receive the notification
let notification = rx.try_recv().unwrap();
assert_eq!(notification.notification_type, "execution_status_changed");
assert_eq!(notification.entity_type, "execution");
assert_eq!(notification.entity_id, 123);
assert_eq!(notification.user_id, Some(456));
}
#[test]
fn test_process_notification_missing_fields() {
let (tx, _rx) = broadcast::channel(10);
let listener = PostgresListener {
database_url: "postgresql://test".to_string(),
notification_tx: tx,
};
// Missing entity_id
let payload = serde_json::json!({
"entity_type": "execution"
});
let result =
listener.process_notification("execution_status_changed", &payload.to_string());
assert!(result.is_err());
}
#[test]
fn test_process_notification_invalid_json() {
let (tx, _rx) = broadcast::channel(10);
let listener = PostgresListener {
database_url: "postgresql://test".to_string(),
notification_tx: tx,
};
let result = listener.process_notification("execution_status_changed", "not valid json");
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,204 @@
//! Notifier Service - Real-time notification orchestration
use anyhow::Result;
use std::sync::Arc;
use tokio::sync::broadcast;
use tracing::{error, info};
use attune_common::config::Config;
use crate::postgres_listener::PostgresListener;
use crate::subscriber_manager::SubscriberManager;
use crate::websocket_server::WebSocketServer;
/// Notification message that can be broadcast to subscribers
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Notification {
/// Type of notification (e.g., "execution_status_changed", "inquiry_created")
pub notification_type: String,
/// Entity type (e.g., "execution", "inquiry", "enforcement")
pub entity_type: String,
/// Entity ID
pub entity_id: i64,
/// Optional user/identity ID that should receive this notification
pub user_id: Option<i64>,
/// Notification payload (varies by type)
pub payload: serde_json::Value,
/// Timestamp when notification was created
pub timestamp: chrono::DateTime<chrono::Utc>,
}
/// Main notifier service that coordinates all components
pub struct NotifierService {
config: Config,
postgres_listener: Arc<PostgresListener>,
subscriber_manager: Arc<SubscriberManager>,
websocket_server: WebSocketServer,
shutdown_tx: broadcast::Sender<()>,
}
impl NotifierService {
/// Create a new notifier service
pub async fn new(config: Config) -> Result<Self> {
info!("Initializing Notifier Service");
// Create shutdown broadcast channel
let (shutdown_tx, _) = broadcast::channel(16);
// Create notification broadcast channel
let (notification_tx, _) = broadcast::channel(1000);
// Create subscriber manager
let subscriber_manager = Arc::new(SubscriberManager::new());
// Create PostgreSQL listener
let postgres_listener = Arc::new(
PostgresListener::new(config.database.url.clone(), notification_tx.clone()).await?,
);
// Create WebSocket server
let websocket_server = WebSocketServer::new(
config.clone(),
notification_tx.clone(),
subscriber_manager.clone(),
shutdown_tx.clone(),
);
Ok(Self {
config,
postgres_listener,
subscriber_manager,
websocket_server,
shutdown_tx,
})
}
/// Start the notifier service
pub async fn start(&self) -> Result<()> {
info!("Starting Notifier Service components");
// Start PostgreSQL listener
let listener_handle = {
let listener = self.postgres_listener.clone();
let mut shutdown_rx = self.shutdown_tx.subscribe();
tokio::spawn(async move {
tokio::select! {
result = listener.listen() => {
if let Err(e) = result {
error!("PostgreSQL listener error: {}", e);
}
}
_ = shutdown_rx.recv() => {
info!("PostgreSQL listener shutting down");
}
}
})
};
// Start notification broadcaster (forwards notifications to WebSocket clients)
let broadcast_handle = {
let subscriber_manager = self.subscriber_manager.clone();
let mut notification_rx = self.websocket_server.notification_tx.subscribe();
let mut shutdown_rx = self.shutdown_tx.subscribe();
tokio::spawn(async move {
loop {
tokio::select! {
Ok(notification) = notification_rx.recv() => {
subscriber_manager.broadcast(notification);
}
_ = shutdown_rx.recv() => {
info!("Notification broadcaster shutting down");
break;
}
}
}
})
};
// Start WebSocket server
let server_handle = {
let server = self.websocket_server.clone();
tokio::spawn(async move {
if let Err(e) = server.start().await {
error!("WebSocket server error: {}", e);
}
})
};
let notifier_config = self
.config
.notifier
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Notifier configuration not found in config"))?;
info!(
"Notifier Service started on {}:{}",
notifier_config.host, notifier_config.port
);
// Wait for any task to complete (they shouldn't unless there's an error)
tokio::select! {
_ = listener_handle => {
error!("PostgreSQL listener stopped unexpectedly");
}
_ = broadcast_handle => {
error!("Notification broadcaster stopped unexpectedly");
}
_ = server_handle => {
error!("WebSocket server stopped unexpectedly");
}
}
Ok(())
}
/// Shutdown the notifier service gracefully
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down Notifier Service");
// Send shutdown signal to all components
let _ = self.shutdown_tx.send(());
// Disconnect all WebSocket clients
self.subscriber_manager.disconnect_all().await;
info!("Notifier Service shutdown complete");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_notification_serialization() {
let notification = Notification {
notification_type: "execution_status_changed".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: Some(456),
payload: serde_json::json!({
"status": "succeeded",
"action": "core.echo"
}),
timestamp: chrono::Utc::now(),
};
let json = serde_json::to_string(&notification).unwrap();
let deserialized: Notification = serde_json::from_str(&json).unwrap();
assert_eq!(
notification.notification_type,
deserialized.notification_type
);
assert_eq!(notification.entity_type, deserialized.entity_type);
assert_eq!(notification.entity_id, deserialized.entity_id);
}
}

View File

@@ -0,0 +1,466 @@
//! Subscriber management for WebSocket clients
use dashmap::DashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{debug, info};
use crate::service::Notification;
/// Unique identifier for a WebSocket client connection
pub type ClientId = String;
/// Subscription filter for notifications
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SubscriptionFilter {
/// Subscribe to all notifications
All,
/// Subscribe to notifications for a specific entity type
EntityType(String),
/// Subscribe to notifications for a specific entity
Entity { entity_type: String, entity_id: i64 },
/// Subscribe to notifications for a specific user
User(i64),
/// Subscribe to a specific notification type
NotificationType(String),
}
impl SubscriptionFilter {
/// Check if this filter matches a notification
pub fn matches(&self, notification: &Notification) -> bool {
match self {
SubscriptionFilter::All => true,
SubscriptionFilter::EntityType(entity_type) => &notification.entity_type == entity_type,
SubscriptionFilter::Entity {
entity_type,
entity_id,
} => &notification.entity_type == entity_type && notification.entity_id == *entity_id,
SubscriptionFilter::User(user_id) => notification.user_id == Some(*user_id),
SubscriptionFilter::NotificationType(notification_type) => {
&notification.notification_type == notification_type
}
}
}
}
/// A WebSocket client subscriber
pub struct Subscriber {
/// Unique client identifier
#[allow(dead_code)]
pub client_id: ClientId,
/// Optional user ID associated with this client
#[allow(dead_code)]
pub user_id: Option<i64>,
/// Channel to send notifications to this client
pub tx: mpsc::UnboundedSender<Notification>,
/// Filters that determine which notifications this client receives
pub filters: Vec<SubscriptionFilter>,
}
impl Subscriber {
/// Check if this subscriber should receive a notification
pub fn should_receive(&self, notification: &Notification) -> bool {
// If no filters, don't receive anything (must explicitly subscribe)
if self.filters.is_empty() {
return false;
}
// Check if any filter matches
self.filters
.iter()
.any(|filter| filter.matches(notification))
}
}
/// Manages all WebSocket subscribers
pub struct SubscriberManager {
/// Map of client ID to subscriber
subscribers: Arc<DashMap<ClientId, Subscriber>>,
/// Counter for generating unique client IDs
next_id: AtomicUsize,
}
impl SubscriberManager {
/// Create a new subscriber manager
pub fn new() -> Self {
Self {
subscribers: Arc::new(DashMap::new()),
next_id: AtomicUsize::new(1),
}
}
/// Generate a unique client ID
pub fn generate_client_id(&self) -> ClientId {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
format!("client_{}", id)
}
/// Register a new subscriber
pub fn register(
&self,
client_id: ClientId,
user_id: Option<i64>,
tx: mpsc::UnboundedSender<Notification>,
) {
let subscriber = Subscriber {
client_id: client_id.clone(),
user_id,
tx,
filters: vec![],
};
self.subscribers.insert(client_id.clone(), subscriber);
info!("Registered new subscriber: {}", client_id);
}
/// Unregister a subscriber
pub fn unregister(&self, client_id: &ClientId) {
if self.subscribers.remove(client_id).is_some() {
info!("Unregistered subscriber: {}", client_id);
}
}
/// Add a subscription filter for a client
pub fn subscribe(&self, client_id: &ClientId, filter: SubscriptionFilter) -> bool {
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
if !subscriber.filters.contains(&filter) {
subscriber.filters.push(filter.clone());
debug!("Client {} subscribed to {:?}", client_id, filter);
return true;
}
}
false
}
/// Remove a subscription filter for a client
pub fn unsubscribe(&self, client_id: &ClientId, filter: &SubscriptionFilter) -> bool {
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
if let Some(pos) = subscriber.filters.iter().position(|f| f == filter) {
subscriber.filters.remove(pos);
debug!("Client {} unsubscribed from {:?}", client_id, filter);
return true;
}
}
false
}
/// Broadcast a notification to all matching subscribers
pub fn broadcast(&self, notification: Notification) {
let mut sent_count = 0;
let mut failed_count = 0;
// Collect client IDs to remove (if send fails)
let mut to_remove = Vec::new();
for entry in self.subscribers.iter() {
let client_id = entry.key();
let subscriber = entry.value();
// Check if this subscriber should receive the notification
if !subscriber.should_receive(&notification) {
continue;
}
// Try to send the notification
match subscriber.tx.send(notification.clone()) {
Ok(_) => {
sent_count += 1;
debug!("Sent notification to client: {}", client_id);
}
Err(_) => {
// Channel closed, client disconnected
failed_count += 1;
to_remove.push(client_id.clone());
}
}
}
// Remove disconnected clients
for client_id in to_remove {
self.unregister(&client_id);
}
if sent_count > 0 {
debug!(
"Broadcast notification: sent={}, failed={}, type={}",
sent_count, failed_count, notification.notification_type
);
}
}
/// Get the number of connected clients
pub fn client_count(&self) -> usize {
self.subscribers.len()
}
/// Get the total number of subscriptions across all clients
pub fn subscription_count(&self) -> usize {
self.subscribers
.iter()
.map(|entry| entry.value().filters.len())
.sum()
}
/// Disconnect all subscribers
pub async fn disconnect_all(&self) {
let client_ids: Vec<ClientId> = self
.subscribers
.iter()
.map(|entry| entry.key().clone())
.collect();
for client_id in client_ids {
self.unregister(&client_id);
}
info!("Disconnected all subscribers");
}
/// Get subscriber information for a client
#[allow(dead_code)]
pub fn get_subscriber_info(&self, client_id: &ClientId) -> Option<SubscriberInfo> {
self.subscribers
.get(client_id)
.map(|subscriber| SubscriberInfo {
client_id: subscriber.client_id.clone(),
user_id: subscriber.user_id,
filter_count: subscriber.filters.len(),
})
}
}
impl Default for SubscriberManager {
fn default() -> Self {
Self::new()
}
}
/// Information about a subscriber (for status/debugging)
#[derive(Debug, Clone, serde::Serialize)]
#[allow(dead_code)]
pub struct SubscriberInfo {
pub client_id: ClientId,
pub user_id: Option<i64>,
pub filter_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscription_filter_all_matches_everything() {
let filter = SubscriptionFilter::All;
let notification = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: Some(456),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification));
}
#[test]
fn test_subscription_filter_entity_type() {
let filter = SubscriptionFilter::EntityType("execution".to_string());
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "inquiry".to_string(),
entity_id: 456,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification1));
assert!(!filter.matches(&notification2));
}
#[test]
fn test_subscription_filter_specific_entity() {
let filter = SubscriptionFilter::Entity {
entity_type: "execution".to_string(),
entity_id: 123,
};
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 456,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification1));
assert!(!filter.matches(&notification2));
}
#[test]
fn test_subscription_filter_user() {
let filter = SubscriptionFilter::User(456);
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: Some(456),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: Some(789),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification1));
assert!(!filter.matches(&notification2));
}
#[test]
fn test_subscriber_manager_register_unregister() {
let manager = SubscriberManager::new();
let client_id = manager.generate_client_id();
assert_eq!(manager.client_count(), 0);
let (tx, _rx) = mpsc::unbounded_channel();
manager.register(client_id.clone(), Some(123), tx);
assert_eq!(manager.client_count(), 1);
manager.unregister(&client_id);
assert_eq!(manager.client_count(), 0);
}
#[test]
fn test_subscriber_manager_subscribe() {
let manager = SubscriberManager::new();
let client_id = manager.generate_client_id();
let (tx, _rx) = mpsc::unbounded_channel();
manager.register(client_id.clone(), None, tx);
// Subscribe to all notifications
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
assert!(result);
assert_eq!(manager.subscription_count(), 1);
// Subscribing to the same filter again should not increase count
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
assert!(!result);
assert_eq!(manager.subscription_count(), 1);
}
#[test]
fn test_subscriber_should_receive() {
let (tx, _rx) = mpsc::unbounded_channel();
let subscriber = Subscriber {
client_id: "test".to_string(),
user_id: Some(456),
tx,
filters: vec![SubscriptionFilter::EntityType("execution".to_string())],
};
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "inquiry".to_string(),
entity_id: 456,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(subscriber.should_receive(&notification1));
assert!(!subscriber.should_receive(&notification2));
}
#[test]
fn test_broadcast_to_matching_subscribers() {
let manager = SubscriberManager::new();
let client1_id = manager.generate_client_id();
let (tx1, mut rx1) = mpsc::unbounded_channel();
manager.register(client1_id.clone(), None, tx1);
manager.subscribe(
&client1_id,
SubscriptionFilter::EntityType("execution".to_string()),
);
let client2_id = manager.generate_client_id();
let (tx2, mut rx2) = mpsc::unbounded_channel();
manager.register(client2_id.clone(), None, tx2);
manager.subscribe(
&client2_id,
SubscriptionFilter::EntityType("inquiry".to_string()),
);
let notification = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
manager.broadcast(notification.clone());
// Client 1 should receive the notification
let received1 = rx1.try_recv();
assert!(received1.is_ok());
assert_eq!(received1.unwrap().entity_id, 123);
// Client 2 should not receive the notification
let received2 = rx2.try_recv();
assert!(received2.is_err());
}
}

View File

@@ -0,0 +1,367 @@
//! WebSocket server for real-time notifications
use anyhow::{Context, Result};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
http::StatusCode,
response::IntoResponse,
routing::get,
Json, Router,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc};
use tower_http::cors::{Any, CorsLayer};
use tracing::{debug, error, info, warn};
use attune_common::config::Config;
use crate::service::Notification;
use crate::subscriber_manager::{ClientId, SubscriberManager, SubscriptionFilter};
/// WebSocket server for handling client connections
pub struct WebSocketServer {
config: Config,
pub notification_tx: broadcast::Sender<Notification>,
subscriber_manager: Arc<SubscriberManager>,
shutdown_tx: broadcast::Sender<()>,
}
impl WebSocketServer {
/// Create a new WebSocket server
pub fn new(
config: Config,
notification_tx: broadcast::Sender<Notification>,
subscriber_manager: Arc<SubscriberManager>,
shutdown_tx: broadcast::Sender<()>,
) -> Self {
Self {
config,
notification_tx,
subscriber_manager,
shutdown_tx,
}
}
/// Clone method for spawning tasks
pub fn clone(&self) -> Self {
Self {
config: self.config.clone(),
notification_tx: self.notification_tx.clone(),
subscriber_manager: self.subscriber_manager.clone(),
shutdown_tx: self.shutdown_tx.clone(),
}
}
/// Start the WebSocket server
pub async fn start(&self) -> Result<()> {
let app_state = Arc::new(AppState {
notification_tx: self.notification_tx.clone(),
subscriber_manager: self.subscriber_manager.clone(),
});
// Build router with WebSocket endpoint
let app = Router::new()
.route("/ws", get(websocket_handler))
.route("/health", get(health_handler))
.route("/stats", get(stats_handler))
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
.with_state(app_state);
let notifier_config = self
.config
.notifier
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Notifier configuration not found in config"))?;
let addr = format!("{}:{}", notifier_config.host, notifier_config.port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.context(format!("Failed to bind to {}", addr))?;
info!("WebSocket server listening on {}", addr);
axum::serve(listener, app)
.await
.context("WebSocket server error")?;
Ok(())
}
}
/// Shared application state
struct AppState {
#[allow(dead_code)]
notification_tx: broadcast::Sender<Notification>,
subscriber_manager: Arc<SubscriberManager>,
}
/// Health check endpoint
async fn health_handler() -> impl IntoResponse {
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
}
/// Stats endpoint
async fn stats_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let stats = serde_json::json!({
"connected_clients": state.subscriber_manager.client_count(),
"total_subscriptions": state.subscriber_manager.subscription_count(),
});
(StatusCode::OK, Json(stats))
}
/// WebSocket handler - upgrades HTTP connection to WebSocket
async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_websocket(socket, state))
}
/// Handle individual WebSocket connection
async fn handle_websocket(socket: WebSocket, state: Arc<AppState>) {
let client_id = state.subscriber_manager.generate_client_id();
info!("New WebSocket connection: {}", client_id);
// Split the socket into sender and receiver
let (mut ws_sender, mut ws_receiver) = socket.split();
// Create channel for sending notifications to this client
let (tx, mut rx) = mpsc::unbounded_channel::<Notification>();
// Register the subscriber
state
.subscriber_manager
.register(client_id.clone(), None, tx);
// Send welcome message
let welcome = ClientMessage::Welcome {
client_id: client_id.clone(),
message: "Connected to Attune Notifier".to_string(),
};
if let Ok(json) = serde_json::to_string(&welcome) {
let _ = ws_sender.send(Message::Text(json.into())).await;
}
// Spawn task to handle outgoing notifications
let client_id_clone = client_id.clone();
let subscriber_manager_clone = state.subscriber_manager.clone();
let outgoing_task = tokio::spawn(async move {
while let Some(notification) = rx.recv().await {
// Serialize notification to JSON
match serde_json::to_string(&notification) {
Ok(json) => {
if let Err(e) = ws_sender.send(Message::Text(json.into())).await {
error!("Failed to send notification to {}: {}", client_id_clone, e);
break;
}
}
Err(e) => {
error!("Failed to serialize notification: {}", e);
}
}
}
debug!("Outgoing task stopped for client: {}", client_id_clone);
subscriber_manager_clone.unregister(&client_id_clone);
});
// Handle incoming messages from client (subscriptions, etc.)
let subscriber_manager_clone = state.subscriber_manager.clone();
let client_id_clone = client_id.clone();
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Err(e) =
handle_client_message(&client_id_clone, &text, &subscriber_manager_clone).await
{
error!("Error handling client message: {}", e);
}
}
Ok(Message::Binary(_)) => {
warn!("Received binary message from {}, ignoring", client_id_clone);
}
Ok(Message::Close(_)) => {
info!("Client {} closed connection", client_id_clone);
break;
}
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
// Handled automatically by axum
}
Err(e) => {
error!("WebSocket error for {}: {}", client_id_clone, e);
break;
}
}
}
// Clean up
subscriber_manager_clone.unregister(&client_id);
outgoing_task.abort();
info!("WebSocket connection closed: {}", client_id);
}
/// Handle incoming message from client
async fn handle_client_message(
client_id: &ClientId,
message: &str,
subscriber_manager: &SubscriberManager,
) -> Result<()> {
let msg: ServerMessage =
serde_json::from_str(message).context("Failed to parse client message")?;
match msg {
ServerMessage::Subscribe { filter } => {
let subscription_filter = parse_subscription_filter(&filter)?;
subscriber_manager.subscribe(client_id, subscription_filter);
info!("Client {} subscribed to: {:?}", client_id, filter);
}
ServerMessage::Unsubscribe { filter } => {
let subscription_filter = parse_subscription_filter(&filter)?;
subscriber_manager.unsubscribe(client_id, &subscription_filter);
info!("Client {} unsubscribed from: {:?}", client_id, filter);
}
ServerMessage::Ping => {
debug!("Received ping from {}", client_id);
// Pong is handled automatically
}
}
Ok(())
}
/// Parse subscription filter from string
fn parse_subscription_filter(filter_str: &str) -> Result<SubscriptionFilter> {
// Format: "type:value" or "all"
if filter_str == "all" {
return Ok(SubscriptionFilter::All);
}
let parts: Vec<&str> = filter_str.split(':').collect();
if parts.len() < 2 {
anyhow::bail!("Invalid filter format: {}", filter_str);
}
match parts[0] {
"entity_type" => Ok(SubscriptionFilter::EntityType(parts[1].to_string())),
"notification_type" => Ok(SubscriptionFilter::NotificationType(parts[1].to_string())),
"user" => {
let user_id: i64 = parts[1].parse().context("Invalid user ID")?;
Ok(SubscriptionFilter::User(user_id))
}
"entity" => {
if parts.len() < 3 {
anyhow::bail!("Entity filter requires type and id: entity:type:id");
}
let entity_id: i64 = parts[2].parse().context("Invalid entity ID")?;
Ok(SubscriptionFilter::Entity {
entity_type: parts[1].to_string(),
entity_id,
})
}
_ => anyhow::bail!("Unknown filter type: {}", parts[0]),
}
}
/// Messages sent from server to client
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
#[allow(dead_code)]
enum ClientMessage {
#[serde(rename = "welcome")]
Welcome { client_id: String, message: String },
#[serde(rename = "notification")]
Notification(Notification),
#[serde(rename = "error")]
Error { message: String },
}
/// Messages sent from client to server
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
enum ServerMessage {
#[serde(rename = "subscribe")]
Subscribe { filter: String },
#[serde(rename = "unsubscribe")]
Unsubscribe { filter: String },
#[serde(rename = "ping")]
Ping,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_subscription_filter_all() {
let filter = parse_subscription_filter("all").unwrap();
assert_eq!(filter, SubscriptionFilter::All);
}
#[test]
fn test_parse_subscription_filter_entity_type() {
let filter = parse_subscription_filter("entity_type:execution").unwrap();
assert_eq!(
filter,
SubscriptionFilter::EntityType("execution".to_string())
);
}
#[test]
fn test_parse_subscription_filter_notification_type() {
let filter =
parse_subscription_filter("notification_type:execution_status_changed").unwrap();
assert_eq!(
filter,
SubscriptionFilter::NotificationType("execution_status_changed".to_string())
);
}
#[test]
fn test_parse_subscription_filter_user() {
let filter = parse_subscription_filter("user:123").unwrap();
assert_eq!(filter, SubscriptionFilter::User(123));
}
#[test]
fn test_parse_subscription_filter_entity() {
let filter = parse_subscription_filter("entity:execution:456").unwrap();
assert_eq!(
filter,
SubscriptionFilter::Entity {
entity_type: "execution".to_string(),
entity_id: 456
}
);
}
#[test]
fn test_parse_subscription_filter_invalid() {
let result = parse_subscription_filter("invalid");
assert!(result.is_err());
}
#[test]
fn test_parse_subscription_filter_invalid_user_id() {
let result = parse_subscription_filter("user:not_a_number");
assert!(result.is_err());
}
#[test]
fn test_parse_subscription_filter_entity_missing_id() {
let result = parse_subscription_filter("entity:execution");
assert!(result.is_err());
}
}