Files
attune/crates/notifier/src/subscriber_manager.rs
2026-03-03 13:42:41 -06:00

472 lines
14 KiB
Rust

//! Subscriber management for WebSocket clients
use dashmap::DashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{debug, info};
use crate::service::Notification;
/// Unique identifier for a WebSocket client connection
pub type ClientId = String;
/// Subscription filter for notifications
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SubscriptionFilter {
/// Subscribe to all notifications
All,
/// Subscribe to notifications for a specific entity type
EntityType(String),
/// Subscribe to notifications for a specific entity
Entity { entity_type: String, entity_id: i64 },
/// Subscribe to notifications for a specific user
User(i64),
/// Subscribe to a specific notification type
NotificationType(String),
}
impl SubscriptionFilter {
/// Check if this filter matches a notification
pub fn matches(&self, notification: &Notification) -> bool {
match self {
SubscriptionFilter::All => true,
SubscriptionFilter::EntityType(entity_type) => &notification.entity_type == entity_type,
SubscriptionFilter::Entity {
entity_type,
entity_id,
} => &notification.entity_type == entity_type && notification.entity_id == *entity_id,
SubscriptionFilter::User(user_id) => notification.user_id == Some(*user_id),
SubscriptionFilter::NotificationType(notification_type) => {
&notification.notification_type == notification_type
}
}
}
}
/// A WebSocket client subscriber
pub struct Subscriber {
/// Unique client identifier
#[allow(dead_code)]
pub client_id: ClientId,
/// Optional user ID associated with this client
#[allow(dead_code)]
pub user_id: Option<i64>,
/// Channel to send notifications to this client
pub tx: mpsc::UnboundedSender<Notification>,
/// Filters that determine which notifications this client receives
pub filters: Vec<SubscriptionFilter>,
}
impl Subscriber {
/// Check if this subscriber should receive a notification
pub fn should_receive(&self, notification: &Notification) -> bool {
// If no filters, don't receive anything (must explicitly subscribe)
if self.filters.is_empty() {
return false;
}
// Check if any filter matches
self.filters
.iter()
.any(|filter| filter.matches(notification))
}
}
/// Manages all WebSocket subscribers
pub struct SubscriberManager {
/// Map of client ID to subscriber
subscribers: Arc<DashMap<ClientId, Subscriber>>,
/// Counter for generating unique client IDs
next_id: AtomicUsize,
}
impl SubscriberManager {
/// Create a new subscriber manager
pub fn new() -> Self {
Self {
subscribers: Arc::new(DashMap::new()),
next_id: AtomicUsize::new(1),
}
}
/// Generate a unique client ID
pub fn generate_client_id(&self) -> ClientId {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
format!("client_{}", id)
}
/// Register a new subscriber
pub fn register(
&self,
client_id: ClientId,
user_id: Option<i64>,
tx: mpsc::UnboundedSender<Notification>,
) {
let subscriber = Subscriber {
client_id: client_id.clone(),
user_id,
tx,
filters: vec![],
};
self.subscribers.insert(client_id.clone(), subscriber);
info!("Registered new subscriber: {}", client_id);
}
/// Unregister a subscriber
pub fn unregister(&self, client_id: &ClientId) {
if self.subscribers.remove(client_id).is_some() {
info!("Unregistered subscriber: {}", client_id);
}
}
/// Add a subscription filter for a client
pub fn subscribe(&self, client_id: &ClientId, filter: SubscriptionFilter) -> bool {
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
if !subscriber.filters.contains(&filter) {
subscriber.filters.push(filter.clone());
debug!("Client {} subscribed to {:?}", client_id, filter);
return true;
}
}
false
}
/// Remove a subscription filter for a client
pub fn unsubscribe(&self, client_id: &ClientId, filter: &SubscriptionFilter) -> bool {
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
if let Some(pos) = subscriber.filters.iter().position(|f| f == filter) {
subscriber.filters.remove(pos);
debug!("Client {} unsubscribed from {:?}", client_id, filter);
return true;
}
}
false
}
/// Broadcast a notification to all matching subscribers
pub fn broadcast(&self, notification: Notification) {
let mut sent_count = 0;
let mut failed_count = 0;
// Collect client IDs to remove (if send fails)
let mut to_remove = Vec::new();
for entry in self.subscribers.iter() {
let client_id = entry.key();
let subscriber = entry.value();
// Check if this subscriber should receive the notification
if !subscriber.should_receive(&notification) {
continue;
}
// Try to send the notification
match subscriber.tx.send(notification.clone()) {
Ok(_) => {
sent_count += 1;
debug!("Sent notification to client: {}", client_id);
}
Err(_) => {
// Channel closed, client disconnected
failed_count += 1;
to_remove.push(client_id.clone());
debug!("Client {} disconnected — removing", client_id);
}
}
}
// Remove disconnected clients
for client_id in to_remove {
self.unregister(&client_id);
}
if sent_count > 0 {
debug!(
"Broadcast notification: sent={}, failed={}, type={}, entity_type={}, entity_id={}",
sent_count,
failed_count,
notification.notification_type,
notification.entity_type,
notification.entity_id,
);
}
}
/// Get the number of connected clients
pub fn client_count(&self) -> usize {
self.subscribers.len()
}
/// Get the total number of subscriptions across all clients
pub fn subscription_count(&self) -> usize {
self.subscribers
.iter()
.map(|entry| entry.value().filters.len())
.sum()
}
/// Disconnect all subscribers
pub async fn disconnect_all(&self) {
let client_ids: Vec<ClientId> = self
.subscribers
.iter()
.map(|entry| entry.key().clone())
.collect();
for client_id in client_ids {
self.unregister(&client_id);
}
info!("Disconnected all subscribers");
}
/// Get subscriber information for a client
#[allow(dead_code)]
pub fn get_subscriber_info(&self, client_id: &ClientId) -> Option<SubscriberInfo> {
self.subscribers
.get(client_id)
.map(|subscriber| SubscriberInfo {
client_id: subscriber.client_id.clone(),
user_id: subscriber.user_id,
filter_count: subscriber.filters.len(),
})
}
}
impl Default for SubscriberManager {
fn default() -> Self {
Self::new()
}
}
/// Information about a subscriber (for status/debugging)
#[derive(Debug, Clone, serde::Serialize)]
#[allow(dead_code)]
pub struct SubscriberInfo {
pub client_id: ClientId,
pub user_id: Option<i64>,
pub filter_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscription_filter_all_matches_everything() {
let filter = SubscriptionFilter::All;
let notification = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: Some(456),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification));
}
#[test]
fn test_subscription_filter_entity_type() {
let filter = SubscriptionFilter::EntityType("execution".to_string());
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "inquiry".to_string(),
entity_id: 456,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification1));
assert!(!filter.matches(&notification2));
}
#[test]
fn test_subscription_filter_specific_entity() {
let filter = SubscriptionFilter::Entity {
entity_type: "execution".to_string(),
entity_id: 123,
};
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 456,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification1));
assert!(!filter.matches(&notification2));
}
#[test]
fn test_subscription_filter_user() {
let filter = SubscriptionFilter::User(456);
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: Some(456),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: Some(789),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(filter.matches(&notification1));
assert!(!filter.matches(&notification2));
}
#[test]
fn test_subscriber_manager_register_unregister() {
let manager = SubscriberManager::new();
let client_id = manager.generate_client_id();
assert_eq!(manager.client_count(), 0);
let (tx, _rx) = mpsc::unbounded_channel();
manager.register(client_id.clone(), Some(123), tx);
assert_eq!(manager.client_count(), 1);
manager.unregister(&client_id);
assert_eq!(manager.client_count(), 0);
}
#[test]
fn test_subscriber_manager_subscribe() {
let manager = SubscriberManager::new();
let client_id = manager.generate_client_id();
let (tx, _rx) = mpsc::unbounded_channel();
manager.register(client_id.clone(), None, tx);
// Subscribe to all notifications
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
assert!(result);
assert_eq!(manager.subscription_count(), 1);
// Subscribing to the same filter again should not increase count
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
assert!(!result);
assert_eq!(manager.subscription_count(), 1);
}
#[test]
fn test_subscriber_should_receive() {
let (tx, _rx) = mpsc::unbounded_channel();
let subscriber = Subscriber {
client_id: "test".to_string(),
user_id: Some(456),
tx,
filters: vec![SubscriptionFilter::EntityType("execution".to_string())],
};
let notification1 = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let notification2 = Notification {
notification_type: "test".to_string(),
entity_type: "inquiry".to_string(),
entity_id: 456,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
assert!(subscriber.should_receive(&notification1));
assert!(!subscriber.should_receive(&notification2));
}
#[test]
fn test_broadcast_to_matching_subscribers() {
let manager = SubscriberManager::new();
let client1_id = manager.generate_client_id();
let (tx1, mut rx1) = mpsc::unbounded_channel();
manager.register(client1_id.clone(), None, tx1);
manager.subscribe(
&client1_id,
SubscriptionFilter::EntityType("execution".to_string()),
);
let client2_id = manager.generate_client_id();
let (tx2, mut rx2) = mpsc::unbounded_channel();
manager.register(client2_id.clone(), None, tx2);
manager.subscribe(
&client2_id,
SubscriptionFilter::EntityType("inquiry".to_string()),
);
let notification = Notification {
notification_type: "test".to_string(),
entity_type: "execution".to_string(),
entity_id: 123,
user_id: None,
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
manager.broadcast(notification.clone());
// Client 1 should receive the notification
let received1 = rx1.try_recv();
assert!(received1.is_ok());
assert_eq!(received1.unwrap().entity_id, 123);
// Client 2 should not receive the notification
let received2 = rx2.try_recv();
assert!(received2.is_err());
}
}