//! Subscriber management for WebSocket clients use dashmap::DashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; use tracing::{debug, info}; use crate::service::Notification; /// Unique identifier for a WebSocket client connection pub type ClientId = String; /// Subscription filter for notifications #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum SubscriptionFilter { /// Subscribe to all notifications All, /// Subscribe to notifications for a specific entity type EntityType(String), /// Subscribe to notifications for a specific entity Entity { entity_type: String, entity_id: i64 }, /// Subscribe to notifications for a specific user User(i64), /// Subscribe to a specific notification type NotificationType(String), } impl SubscriptionFilter { /// Check if this filter matches a notification pub fn matches(&self, notification: &Notification) -> bool { match self { SubscriptionFilter::All => true, SubscriptionFilter::EntityType(entity_type) => ¬ification.entity_type == entity_type, SubscriptionFilter::Entity { entity_type, entity_id, } => ¬ification.entity_type == entity_type && notification.entity_id == *entity_id, SubscriptionFilter::User(user_id) => notification.user_id == Some(*user_id), SubscriptionFilter::NotificationType(notification_type) => { ¬ification.notification_type == notification_type } } } } /// A WebSocket client subscriber pub struct Subscriber { /// Unique client identifier #[allow(dead_code)] pub client_id: ClientId, /// Optional user ID associated with this client #[allow(dead_code)] pub user_id: Option, /// Channel to send notifications to this client pub tx: mpsc::UnboundedSender, /// Filters that determine which notifications this client receives pub filters: Vec, } impl Subscriber { /// Check if this subscriber should receive a notification pub fn should_receive(&self, notification: &Notification) -> bool { // If no filters, don't receive anything (must explicitly subscribe) if self.filters.is_empty() { return false; } // Check if any filter matches self.filters .iter() .any(|filter| filter.matches(notification)) } } /// Manages all WebSocket subscribers pub struct SubscriberManager { /// Map of client ID to subscriber subscribers: Arc>, /// Counter for generating unique client IDs next_id: AtomicUsize, } impl SubscriberManager { /// Create a new subscriber manager pub fn new() -> Self { Self { subscribers: Arc::new(DashMap::new()), next_id: AtomicUsize::new(1), } } /// Generate a unique client ID pub fn generate_client_id(&self) -> ClientId { let id = self.next_id.fetch_add(1, Ordering::SeqCst); format!("client_{}", id) } /// Register a new subscriber pub fn register( &self, client_id: ClientId, user_id: Option, tx: mpsc::UnboundedSender, ) { let subscriber = Subscriber { client_id: client_id.clone(), user_id, tx, filters: vec![], }; self.subscribers.insert(client_id.clone(), subscriber); info!("Registered new subscriber: {}", client_id); } /// Unregister a subscriber pub fn unregister(&self, client_id: &ClientId) { if self.subscribers.remove(client_id).is_some() { info!("Unregistered subscriber: {}", client_id); } } /// Add a subscription filter for a client pub fn subscribe(&self, client_id: &ClientId, filter: SubscriptionFilter) -> bool { if let Some(mut subscriber) = self.subscribers.get_mut(client_id) { if !subscriber.filters.contains(&filter) { subscriber.filters.push(filter.clone()); debug!("Client {} subscribed to {:?}", client_id, filter); return true; } } false } /// Remove a subscription filter for a client pub fn unsubscribe(&self, client_id: &ClientId, filter: &SubscriptionFilter) -> bool { if let Some(mut subscriber) = self.subscribers.get_mut(client_id) { if let Some(pos) = subscriber.filters.iter().position(|f| f == filter) { subscriber.filters.remove(pos); debug!("Client {} unsubscribed from {:?}", client_id, filter); return true; } } false } /// Broadcast a notification to all matching subscribers pub fn broadcast(&self, notification: Notification) { let mut sent_count = 0; let mut failed_count = 0; // Collect client IDs to remove (if send fails) let mut to_remove = Vec::new(); for entry in self.subscribers.iter() { let client_id = entry.key(); let subscriber = entry.value(); // Check if this subscriber should receive the notification if !subscriber.should_receive(¬ification) { continue; } // Try to send the notification match subscriber.tx.send(notification.clone()) { Ok(_) => { sent_count += 1; debug!("Sent notification to client: {}", client_id); } Err(_) => { // Channel closed, client disconnected failed_count += 1; to_remove.push(client_id.clone()); debug!("Client {} disconnected — removing", client_id); } } } // Remove disconnected clients for client_id in to_remove { self.unregister(&client_id); } if sent_count > 0 { debug!( "Broadcast notification: sent={}, failed={}, type={}, entity_type={}, entity_id={}", sent_count, failed_count, notification.notification_type, notification.entity_type, notification.entity_id, ); } } /// Get the number of connected clients pub fn client_count(&self) -> usize { self.subscribers.len() } /// Get the total number of subscriptions across all clients pub fn subscription_count(&self) -> usize { self.subscribers .iter() .map(|entry| entry.value().filters.len()) .sum() } /// Disconnect all subscribers pub async fn disconnect_all(&self) { let client_ids: Vec = self .subscribers .iter() .map(|entry| entry.key().clone()) .collect(); for client_id in client_ids { self.unregister(&client_id); } info!("Disconnected all subscribers"); } /// Get subscriber information for a client #[allow(dead_code)] pub fn get_subscriber_info(&self, client_id: &ClientId) -> Option { self.subscribers .get(client_id) .map(|subscriber| SubscriberInfo { client_id: subscriber.client_id.clone(), user_id: subscriber.user_id, filter_count: subscriber.filters.len(), }) } } impl Default for SubscriberManager { fn default() -> Self { Self::new() } } /// Information about a subscriber (for status/debugging) #[derive(Debug, Clone, serde::Serialize)] #[allow(dead_code)] pub struct SubscriberInfo { pub client_id: ClientId, pub user_id: Option, pub filter_count: usize, } #[cfg(test)] mod tests { use super::*; #[test] fn test_subscription_filter_all_matches_everything() { let filter = SubscriptionFilter::All; let notification = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 123, user_id: Some(456), payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; assert!(filter.matches(¬ification)); } #[test] fn test_subscription_filter_entity_type() { let filter = SubscriptionFilter::EntityType("execution".to_string()); let notification1 = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 123, user_id: None, payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; let notification2 = Notification { notification_type: "test".to_string(), entity_type: "inquiry".to_string(), entity_id: 456, user_id: None, payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; assert!(filter.matches(¬ification1)); assert!(!filter.matches(¬ification2)); } #[test] fn test_subscription_filter_specific_entity() { let filter = SubscriptionFilter::Entity { entity_type: "execution".to_string(), entity_id: 123, }; let notification1 = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 123, user_id: None, payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; let notification2 = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 456, user_id: None, payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; assert!(filter.matches(¬ification1)); assert!(!filter.matches(¬ification2)); } #[test] fn test_subscription_filter_user() { let filter = SubscriptionFilter::User(456); let notification1 = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 123, user_id: Some(456), payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; let notification2 = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 123, user_id: Some(789), payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; assert!(filter.matches(¬ification1)); assert!(!filter.matches(¬ification2)); } #[test] fn test_subscriber_manager_register_unregister() { let manager = SubscriberManager::new(); let client_id = manager.generate_client_id(); assert_eq!(manager.client_count(), 0); let (tx, _rx) = mpsc::unbounded_channel(); manager.register(client_id.clone(), Some(123), tx); assert_eq!(manager.client_count(), 1); manager.unregister(&client_id); assert_eq!(manager.client_count(), 0); } #[test] fn test_subscriber_manager_subscribe() { let manager = SubscriberManager::new(); let client_id = manager.generate_client_id(); let (tx, _rx) = mpsc::unbounded_channel(); manager.register(client_id.clone(), None, tx); // Subscribe to all notifications let result = manager.subscribe(&client_id, SubscriptionFilter::All); assert!(result); assert_eq!(manager.subscription_count(), 1); // Subscribing to the same filter again should not increase count let result = manager.subscribe(&client_id, SubscriptionFilter::All); assert!(!result); assert_eq!(manager.subscription_count(), 1); } #[test] fn test_subscriber_should_receive() { let (tx, _rx) = mpsc::unbounded_channel(); let subscriber = Subscriber { client_id: "test".to_string(), user_id: Some(456), tx, filters: vec![SubscriptionFilter::EntityType("execution".to_string())], }; let notification1 = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 123, user_id: None, payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; let notification2 = Notification { notification_type: "test".to_string(), entity_type: "inquiry".to_string(), entity_id: 456, user_id: None, payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; assert!(subscriber.should_receive(¬ification1)); assert!(!subscriber.should_receive(¬ification2)); } #[test] fn test_broadcast_to_matching_subscribers() { let manager = SubscriberManager::new(); let client1_id = manager.generate_client_id(); let (tx1, mut rx1) = mpsc::unbounded_channel(); manager.register(client1_id.clone(), None, tx1); manager.subscribe( &client1_id, SubscriptionFilter::EntityType("execution".to_string()), ); let client2_id = manager.generate_client_id(); let (tx2, mut rx2) = mpsc::unbounded_channel(); manager.register(client2_id.clone(), None, tx2); manager.subscribe( &client2_id, SubscriptionFilter::EntityType("inquiry".to_string()), ); let notification = Notification { notification_type: "test".to_string(), entity_type: "execution".to_string(), entity_id: 123, user_id: None, payload: serde_json::json!({}), timestamp: chrono::Utc::now(), }; manager.broadcast(notification.clone()); // Client 1 should receive the notification let received1 = rx1.try_recv(); assert!(received1.is_ok()); assert_eq!(received1.unwrap().entity_id, 123); // Client 2 should not receive the notification let received2 = rx2.try_recv(); assert!(received2.is_err()); } }