re-uploading work
This commit is contained in:
466
crates/notifier/src/subscriber_manager.rs
Normal file
466
crates/notifier/src/subscriber_manager.rs
Normal file
@@ -0,0 +1,466 @@
|
||||
//! Subscriber management for WebSocket clients
|
||||
|
||||
use dashmap::DashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::service::Notification;
|
||||
|
||||
/// Unique identifier for a WebSocket client connection
|
||||
pub type ClientId = String;
|
||||
|
||||
/// Subscription filter for notifications
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum SubscriptionFilter {
|
||||
/// Subscribe to all notifications
|
||||
All,
|
||||
|
||||
/// Subscribe to notifications for a specific entity type
|
||||
EntityType(String),
|
||||
|
||||
/// Subscribe to notifications for a specific entity
|
||||
Entity { entity_type: String, entity_id: i64 },
|
||||
|
||||
/// Subscribe to notifications for a specific user
|
||||
User(i64),
|
||||
|
||||
/// Subscribe to a specific notification type
|
||||
NotificationType(String),
|
||||
}
|
||||
|
||||
impl SubscriptionFilter {
|
||||
/// Check if this filter matches a notification
|
||||
pub fn matches(&self, notification: &Notification) -> bool {
|
||||
match self {
|
||||
SubscriptionFilter::All => true,
|
||||
SubscriptionFilter::EntityType(entity_type) => ¬ification.entity_type == entity_type,
|
||||
SubscriptionFilter::Entity {
|
||||
entity_type,
|
||||
entity_id,
|
||||
} => ¬ification.entity_type == entity_type && notification.entity_id == *entity_id,
|
||||
SubscriptionFilter::User(user_id) => notification.user_id == Some(*user_id),
|
||||
SubscriptionFilter::NotificationType(notification_type) => {
|
||||
¬ification.notification_type == notification_type
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A WebSocket client subscriber
|
||||
pub struct Subscriber {
|
||||
/// Unique client identifier
|
||||
#[allow(dead_code)]
|
||||
pub client_id: ClientId,
|
||||
|
||||
/// Optional user ID associated with this client
|
||||
#[allow(dead_code)]
|
||||
pub user_id: Option<i64>,
|
||||
|
||||
/// Channel to send notifications to this client
|
||||
pub tx: mpsc::UnboundedSender<Notification>,
|
||||
|
||||
/// Filters that determine which notifications this client receives
|
||||
pub filters: Vec<SubscriptionFilter>,
|
||||
}
|
||||
|
||||
impl Subscriber {
|
||||
/// Check if this subscriber should receive a notification
|
||||
pub fn should_receive(&self, notification: &Notification) -> bool {
|
||||
// If no filters, don't receive anything (must explicitly subscribe)
|
||||
if self.filters.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if any filter matches
|
||||
self.filters
|
||||
.iter()
|
||||
.any(|filter| filter.matches(notification))
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages all WebSocket subscribers
|
||||
pub struct SubscriberManager {
|
||||
/// Map of client ID to subscriber
|
||||
subscribers: Arc<DashMap<ClientId, Subscriber>>,
|
||||
|
||||
/// Counter for generating unique client IDs
|
||||
next_id: AtomicUsize,
|
||||
}
|
||||
|
||||
impl SubscriberManager {
|
||||
/// Create a new subscriber manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
subscribers: Arc::new(DashMap::new()),
|
||||
next_id: AtomicUsize::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a unique client ID
|
||||
pub fn generate_client_id(&self) -> ClientId {
|
||||
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
|
||||
format!("client_{}", id)
|
||||
}
|
||||
|
||||
/// Register a new subscriber
|
||||
pub fn register(
|
||||
&self,
|
||||
client_id: ClientId,
|
||||
user_id: Option<i64>,
|
||||
tx: mpsc::UnboundedSender<Notification>,
|
||||
) {
|
||||
let subscriber = Subscriber {
|
||||
client_id: client_id.clone(),
|
||||
user_id,
|
||||
tx,
|
||||
filters: vec![],
|
||||
};
|
||||
|
||||
self.subscribers.insert(client_id.clone(), subscriber);
|
||||
info!("Registered new subscriber: {}", client_id);
|
||||
}
|
||||
|
||||
/// Unregister a subscriber
|
||||
pub fn unregister(&self, client_id: &ClientId) {
|
||||
if self.subscribers.remove(client_id).is_some() {
|
||||
info!("Unregistered subscriber: {}", client_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a subscription filter for a client
|
||||
pub fn subscribe(&self, client_id: &ClientId, filter: SubscriptionFilter) -> bool {
|
||||
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
|
||||
if !subscriber.filters.contains(&filter) {
|
||||
subscriber.filters.push(filter.clone());
|
||||
debug!("Client {} subscribed to {:?}", client_id, filter);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Remove a subscription filter for a client
|
||||
pub fn unsubscribe(&self, client_id: &ClientId, filter: &SubscriptionFilter) -> bool {
|
||||
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
|
||||
if let Some(pos) = subscriber.filters.iter().position(|f| f == filter) {
|
||||
subscriber.filters.remove(pos);
|
||||
debug!("Client {} unsubscribed from {:?}", client_id, filter);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Broadcast a notification to all matching subscribers
|
||||
pub fn broadcast(&self, notification: Notification) {
|
||||
let mut sent_count = 0;
|
||||
let mut failed_count = 0;
|
||||
|
||||
// Collect client IDs to remove (if send fails)
|
||||
let mut to_remove = Vec::new();
|
||||
|
||||
for entry in self.subscribers.iter() {
|
||||
let client_id = entry.key();
|
||||
let subscriber = entry.value();
|
||||
|
||||
// Check if this subscriber should receive the notification
|
||||
if !subscriber.should_receive(¬ification) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to send the notification
|
||||
match subscriber.tx.send(notification.clone()) {
|
||||
Ok(_) => {
|
||||
sent_count += 1;
|
||||
debug!("Sent notification to client: {}", client_id);
|
||||
}
|
||||
Err(_) => {
|
||||
// Channel closed, client disconnected
|
||||
failed_count += 1;
|
||||
to_remove.push(client_id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove disconnected clients
|
||||
for client_id in to_remove {
|
||||
self.unregister(&client_id);
|
||||
}
|
||||
|
||||
if sent_count > 0 {
|
||||
debug!(
|
||||
"Broadcast notification: sent={}, failed={}, type={}",
|
||||
sent_count, failed_count, notification.notification_type
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of connected clients
|
||||
pub fn client_count(&self) -> usize {
|
||||
self.subscribers.len()
|
||||
}
|
||||
|
||||
/// Get the total number of subscriptions across all clients
|
||||
pub fn subscription_count(&self) -> usize {
|
||||
self.subscribers
|
||||
.iter()
|
||||
.map(|entry| entry.value().filters.len())
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Disconnect all subscribers
|
||||
pub async fn disconnect_all(&self) {
|
||||
let client_ids: Vec<ClientId> = self
|
||||
.subscribers
|
||||
.iter()
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
for client_id in client_ids {
|
||||
self.unregister(&client_id);
|
||||
}
|
||||
|
||||
info!("Disconnected all subscribers");
|
||||
}
|
||||
|
||||
/// Get subscriber information for a client
|
||||
#[allow(dead_code)]
|
||||
pub fn get_subscriber_info(&self, client_id: &ClientId) -> Option<SubscriberInfo> {
|
||||
self.subscribers
|
||||
.get(client_id)
|
||||
.map(|subscriber| SubscriberInfo {
|
||||
client_id: subscriber.client_id.clone(),
|
||||
user_id: subscriber.user_id,
|
||||
filter_count: subscriber.filters.len(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SubscriberManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a subscriber (for status/debugging)
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SubscriberInfo {
|
||||
pub client_id: ClientId,
|
||||
pub user_id: Option<i64>,
|
||||
pub filter_count: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_all_matches_everything() {
|
||||
let filter = SubscriptionFilter::All;
|
||||
let notification = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: Some(456),
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_entity_type() {
|
||||
let filter = SubscriptionFilter::EntityType("execution".to_string());
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "inquiry".to_string(),
|
||||
entity_id: 456,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification1));
|
||||
assert!(!filter.matches(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_specific_entity() {
|
||||
let filter = SubscriptionFilter::Entity {
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
};
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 456,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification1));
|
||||
assert!(!filter.matches(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_user() {
|
||||
let filter = SubscriptionFilter::User(456);
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: Some(456),
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: Some(789),
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification1));
|
||||
assert!(!filter.matches(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscriber_manager_register_unregister() {
|
||||
let manager = SubscriberManager::new();
|
||||
let client_id = manager.generate_client_id();
|
||||
|
||||
assert_eq!(manager.client_count(), 0);
|
||||
|
||||
let (tx, _rx) = mpsc::unbounded_channel();
|
||||
manager.register(client_id.clone(), Some(123), tx);
|
||||
|
||||
assert_eq!(manager.client_count(), 1);
|
||||
|
||||
manager.unregister(&client_id);
|
||||
|
||||
assert_eq!(manager.client_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscriber_manager_subscribe() {
|
||||
let manager = SubscriberManager::new();
|
||||
let client_id = manager.generate_client_id();
|
||||
|
||||
let (tx, _rx) = mpsc::unbounded_channel();
|
||||
manager.register(client_id.clone(), None, tx);
|
||||
|
||||
// Subscribe to all notifications
|
||||
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
|
||||
assert!(result);
|
||||
|
||||
assert_eq!(manager.subscription_count(), 1);
|
||||
|
||||
// Subscribing to the same filter again should not increase count
|
||||
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
|
||||
assert!(!result);
|
||||
|
||||
assert_eq!(manager.subscription_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscriber_should_receive() {
|
||||
let (tx, _rx) = mpsc::unbounded_channel();
|
||||
let subscriber = Subscriber {
|
||||
client_id: "test".to_string(),
|
||||
user_id: Some(456),
|
||||
tx,
|
||||
filters: vec![SubscriptionFilter::EntityType("execution".to_string())],
|
||||
};
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "inquiry".to_string(),
|
||||
entity_id: 456,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(subscriber.should_receive(¬ification1));
|
||||
assert!(!subscriber.should_receive(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_to_matching_subscribers() {
|
||||
let manager = SubscriberManager::new();
|
||||
|
||||
let client1_id = manager.generate_client_id();
|
||||
let (tx1, mut rx1) = mpsc::unbounded_channel();
|
||||
manager.register(client1_id.clone(), None, tx1);
|
||||
manager.subscribe(
|
||||
&client1_id,
|
||||
SubscriptionFilter::EntityType("execution".to_string()),
|
||||
);
|
||||
|
||||
let client2_id = manager.generate_client_id();
|
||||
let (tx2, mut rx2) = mpsc::unbounded_channel();
|
||||
manager.register(client2_id.clone(), None, tx2);
|
||||
manager.subscribe(
|
||||
&client2_id,
|
||||
SubscriptionFilter::EntityType("inquiry".to_string()),
|
||||
);
|
||||
|
||||
let notification = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
manager.broadcast(notification.clone());
|
||||
|
||||
// Client 1 should receive the notification
|
||||
let received1 = rx1.try_recv();
|
||||
assert!(received1.is_ok());
|
||||
assert_eq!(received1.unwrap().entity_id, 123);
|
||||
|
||||
// Client 2 should not receive the notification
|
||||
let received2 = rx2.try_recv();
|
||||
assert!(received2.is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user