re-uploading work
This commit is contained in:
62
crates/notifier/Cargo.toml
Normal file
62
crates/notifier/Cargo.toml
Normal file
@@ -0,0 +1,62 @@
|
||||
[package]
|
||||
name = "attune-notifier"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "attune-notifier"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
attune-common = { path = "../common" }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# Database
|
||||
sqlx = { workspace = true }
|
||||
|
||||
# Web framework & WebSocket
|
||||
axum = { workspace = true, features = ["ws"] }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Logging and tracing
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
# Error handling
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
# Configuration
|
||||
config = { workspace = true }
|
||||
|
||||
# Date/Time
|
||||
chrono = { workspace = true }
|
||||
|
||||
# CLI
|
||||
clap = { workspace = true }
|
||||
|
||||
# Redis (optional, for distributed notifications)
|
||||
redis = { workspace = true }
|
||||
|
||||
# UUID
|
||||
uuid = { workspace = true }
|
||||
|
||||
# Concurrent data structures
|
||||
dashmap = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mockall = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
128
crates/notifier/src/main.rs
Normal file
128
crates/notifier/src/main.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
//! Attune Notifier Service - Real-time notification delivery
|
||||
|
||||
use anyhow::Result;
|
||||
use attune_common::config::Config;
|
||||
use clap::Parser;
|
||||
use tracing::{error, info};
|
||||
|
||||
mod postgres_listener;
|
||||
mod service;
|
||||
mod subscriber_manager;
|
||||
mod websocket_server;
|
||||
|
||||
use service::NotifierService;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "attune-notifier")]
|
||||
#[command(about = "Attune Notifier Service - Real-time notifications", long_about = None)]
|
||||
struct Args {
|
||||
/// Path to configuration file
|
||||
#[arg(short, long)]
|
||||
config: Option<String>,
|
||||
|
||||
/// Log level (trace, debug, info, warn, error)
|
||||
#[arg(short, long, default_value = "info")]
|
||||
log_level: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Initialize tracing with specified log level
|
||||
let log_level = args
|
||||
.log_level
|
||||
.parse::<tracing::Level>()
|
||||
.unwrap_or(tracing::Level::INFO);
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(log_level)
|
||||
.with_target(false)
|
||||
.with_thread_ids(true)
|
||||
.init();
|
||||
|
||||
info!("Starting Attune Notifier Service");
|
||||
|
||||
// Load configuration
|
||||
if let Some(config_path) = args.config {
|
||||
std::env::set_var("ATTUNE_CONFIG", config_path);
|
||||
}
|
||||
|
||||
let config = Config::load()?;
|
||||
config.validate()?;
|
||||
|
||||
info!("Configuration loaded successfully");
|
||||
info!("Environment: {}", config.environment);
|
||||
info!("Database: {}", mask_password(&config.database.url));
|
||||
|
||||
let notifier_config = config
|
||||
.notifier
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Notifier configuration not found in config file"))?;
|
||||
|
||||
info!(
|
||||
"Listening on: {}:{}",
|
||||
notifier_config.host, notifier_config.port
|
||||
);
|
||||
|
||||
// Create and start the notifier service
|
||||
let service = NotifierService::new(config).await?;
|
||||
|
||||
info!("Notifier Service initialized successfully");
|
||||
|
||||
// Set up graceful shutdown handler
|
||||
let service_clone = std::sync::Arc::new(service);
|
||||
let service_for_shutdown = service_clone.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to listen for Ctrl+C");
|
||||
info!("Received shutdown signal");
|
||||
|
||||
if let Err(e) = service_for_shutdown.shutdown().await {
|
||||
error!("Error during shutdown: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Start the service (blocks until shutdown)
|
||||
if let Err(e) = service_clone.start().await {
|
||||
error!("Notifier service error: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
info!("Attune Notifier Service stopped");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mask password in database URL for logging
|
||||
fn mask_password(url: &str) -> String {
|
||||
if let Some(at_pos) = url.rfind('@') {
|
||||
if let Some(colon_pos) = url[..at_pos].rfind(':') {
|
||||
let mut masked = url.to_string();
|
||||
masked.replace_range(colon_pos + 1..at_pos, "****");
|
||||
return masked;
|
||||
}
|
||||
}
|
||||
url.to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mask_password() {
|
||||
let url = "postgresql://user:password@localhost:5432/db";
|
||||
let masked = mask_password(url);
|
||||
assert_eq!(masked, "postgresql://user:****@localhost:5432/db");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_password_no_password() {
|
||||
let url = "postgresql://localhost:5432/db";
|
||||
let masked = mask_password(url);
|
||||
assert_eq!(masked, "postgresql://localhost:5432/db");
|
||||
}
|
||||
}
|
||||
232
crates/notifier/src/postgres_listener.rs
Normal file
232
crates/notifier/src/postgres_listener.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
//! PostgreSQL LISTEN/NOTIFY integration for real-time notifications
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use sqlx::postgres::PgListener;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::service::Notification;
|
||||
|
||||
/// Channels to listen on for PostgreSQL notifications
|
||||
const NOTIFICATION_CHANNELS: &[&str] = &[
|
||||
"attune_notifications",
|
||||
"execution_status_changed",
|
||||
"execution_created",
|
||||
"inquiry_created",
|
||||
"inquiry_responded",
|
||||
"enforcement_created",
|
||||
"event_created",
|
||||
"workflow_execution_status_changed",
|
||||
];
|
||||
|
||||
/// PostgreSQL listener that receives NOTIFY events and broadcasts them
|
||||
pub struct PostgresListener {
|
||||
database_url: String,
|
||||
notification_tx: broadcast::Sender<Notification>,
|
||||
}
|
||||
|
||||
impl PostgresListener {
|
||||
/// Create a new PostgreSQL listener
|
||||
pub async fn new(
|
||||
database_url: String,
|
||||
notification_tx: broadcast::Sender<Notification>,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
database_url,
|
||||
notification_tx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Start listening for PostgreSQL notifications
|
||||
pub async fn listen(&self) -> Result<()> {
|
||||
info!(
|
||||
"Starting PostgreSQL LISTEN on channels: {:?}",
|
||||
NOTIFICATION_CHANNELS
|
||||
);
|
||||
|
||||
// Create a dedicated listener connection
|
||||
let mut listener = PgListener::connect(&self.database_url)
|
||||
.await
|
||||
.context("Failed to connect PostgreSQL listener")?;
|
||||
|
||||
// Listen on all notification channels
|
||||
for channel in NOTIFICATION_CHANNELS {
|
||||
listener
|
||||
.listen(channel)
|
||||
.await
|
||||
.context(format!("Failed to LISTEN on channel '{}'", channel))?;
|
||||
info!("Listening on PostgreSQL channel: {}", channel);
|
||||
}
|
||||
|
||||
// Process notifications in a loop
|
||||
loop {
|
||||
match listener.recv().await {
|
||||
Ok(pg_notification) => {
|
||||
debug!(
|
||||
"Received PostgreSQL notification: channel={}, payload={}",
|
||||
pg_notification.channel(),
|
||||
pg_notification.payload()
|
||||
);
|
||||
|
||||
// Parse and broadcast notification
|
||||
if let Err(e) = self
|
||||
.process_notification(pg_notification.channel(), pg_notification.payload())
|
||||
{
|
||||
error!(
|
||||
"Failed to process notification from channel '{}': {}",
|
||||
pg_notification.channel(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error receiving PostgreSQL notification: {}", e);
|
||||
|
||||
// Sleep briefly before retrying to avoid tight loop on persistent errors
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
|
||||
// Try to reconnect
|
||||
warn!("Attempting to reconnect PostgreSQL listener...");
|
||||
match PgListener::connect(&self.database_url).await {
|
||||
Ok(new_listener) => {
|
||||
listener = new_listener;
|
||||
// Re-subscribe to all channels
|
||||
for channel in NOTIFICATION_CHANNELS {
|
||||
if let Err(e) = listener.listen(channel).await {
|
||||
error!(
|
||||
"Failed to re-subscribe to channel '{}': {}",
|
||||
channel, e
|
||||
);
|
||||
}
|
||||
}
|
||||
info!("PostgreSQL listener reconnected successfully");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to reconnect PostgreSQL listener: {}", e);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a PostgreSQL notification and broadcast it to WebSocket clients
|
||||
fn process_notification(&self, channel: &str, payload: &str) -> Result<()> {
|
||||
// Parse the JSON payload
|
||||
let payload_json: serde_json::Value = serde_json::from_str(payload)
|
||||
.context("Failed to parse notification payload as JSON")?;
|
||||
|
||||
// Extract common fields
|
||||
let entity_type = payload_json
|
||||
.get("entity_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.context("Missing 'entity_type' in notification payload")?
|
||||
.to_string();
|
||||
|
||||
let entity_id = payload_json
|
||||
.get("entity_id")
|
||||
.and_then(|v| v.as_i64())
|
||||
.context("Missing 'entity_id' in notification payload")?;
|
||||
|
||||
let user_id = payload_json.get("user_id").and_then(|v| v.as_i64());
|
||||
|
||||
// Create notification
|
||||
let notification = Notification {
|
||||
notification_type: channel.to_string(),
|
||||
entity_type,
|
||||
entity_id,
|
||||
user_id,
|
||||
payload: payload_json,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
// Broadcast to all subscribers (ignore errors if no receivers)
|
||||
match self.notification_tx.send(notification) {
|
||||
Ok(receiver_count) => {
|
||||
debug!(
|
||||
"Broadcast notification to {} receivers: type={}, entity_id={}",
|
||||
receiver_count, channel, entity_id
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
// No active receivers, this is fine
|
||||
debug!("No active receivers for notification: type={}", channel);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_notification_channels_defined() {
|
||||
assert!(!NOTIFICATION_CHANNELS.is_empty());
|
||||
assert!(NOTIFICATION_CHANNELS.contains(&"execution_status_changed"));
|
||||
assert!(NOTIFICATION_CHANNELS.contains(&"inquiry_created"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_notification_valid_payload() {
|
||||
let (tx, mut rx) = broadcast::channel(10);
|
||||
let listener = PostgresListener {
|
||||
database_url: "postgresql://test".to_string(),
|
||||
notification_tx: tx,
|
||||
};
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"entity_type": "execution",
|
||||
"entity_id": 123,
|
||||
"user_id": 456,
|
||||
"status": "succeeded"
|
||||
});
|
||||
|
||||
let result =
|
||||
listener.process_notification("execution_status_changed", &payload.to_string());
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Should receive the notification
|
||||
let notification = rx.try_recv().unwrap();
|
||||
assert_eq!(notification.notification_type, "execution_status_changed");
|
||||
assert_eq!(notification.entity_type, "execution");
|
||||
assert_eq!(notification.entity_id, 123);
|
||||
assert_eq!(notification.user_id, Some(456));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_notification_missing_fields() {
|
||||
let (tx, _rx) = broadcast::channel(10);
|
||||
let listener = PostgresListener {
|
||||
database_url: "postgresql://test".to_string(),
|
||||
notification_tx: tx,
|
||||
};
|
||||
|
||||
// Missing entity_id
|
||||
let payload = serde_json::json!({
|
||||
"entity_type": "execution"
|
||||
});
|
||||
|
||||
let result =
|
||||
listener.process_notification("execution_status_changed", &payload.to_string());
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_notification_invalid_json() {
|
||||
let (tx, _rx) = broadcast::channel(10);
|
||||
let listener = PostgresListener {
|
||||
database_url: "postgresql://test".to_string(),
|
||||
notification_tx: tx,
|
||||
};
|
||||
|
||||
let result = listener.process_notification("execution_status_changed", "not valid json");
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
204
crates/notifier/src/service.rs
Normal file
204
crates/notifier/src/service.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
//! Notifier Service - Real-time notification orchestration
|
||||
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{error, info};
|
||||
|
||||
use attune_common::config::Config;
|
||||
|
||||
use crate::postgres_listener::PostgresListener;
|
||||
use crate::subscriber_manager::SubscriberManager;
|
||||
use crate::websocket_server::WebSocketServer;
|
||||
|
||||
/// Notification message that can be broadcast to subscribers
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Notification {
|
||||
/// Type of notification (e.g., "execution_status_changed", "inquiry_created")
|
||||
pub notification_type: String,
|
||||
|
||||
/// Entity type (e.g., "execution", "inquiry", "enforcement")
|
||||
pub entity_type: String,
|
||||
|
||||
/// Entity ID
|
||||
pub entity_id: i64,
|
||||
|
||||
/// Optional user/identity ID that should receive this notification
|
||||
pub user_id: Option<i64>,
|
||||
|
||||
/// Notification payload (varies by type)
|
||||
pub payload: serde_json::Value,
|
||||
|
||||
/// Timestamp when notification was created
|
||||
pub timestamp: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
/// Main notifier service that coordinates all components
|
||||
pub struct NotifierService {
|
||||
config: Config,
|
||||
postgres_listener: Arc<PostgresListener>,
|
||||
subscriber_manager: Arc<SubscriberManager>,
|
||||
websocket_server: WebSocketServer,
|
||||
shutdown_tx: broadcast::Sender<()>,
|
||||
}
|
||||
|
||||
impl NotifierService {
|
||||
/// Create a new notifier service
|
||||
pub async fn new(config: Config) -> Result<Self> {
|
||||
info!("Initializing Notifier Service");
|
||||
|
||||
// Create shutdown broadcast channel
|
||||
let (shutdown_tx, _) = broadcast::channel(16);
|
||||
|
||||
// Create notification broadcast channel
|
||||
let (notification_tx, _) = broadcast::channel(1000);
|
||||
|
||||
// Create subscriber manager
|
||||
let subscriber_manager = Arc::new(SubscriberManager::new());
|
||||
|
||||
// Create PostgreSQL listener
|
||||
let postgres_listener = Arc::new(
|
||||
PostgresListener::new(config.database.url.clone(), notification_tx.clone()).await?,
|
||||
);
|
||||
|
||||
// Create WebSocket server
|
||||
let websocket_server = WebSocketServer::new(
|
||||
config.clone(),
|
||||
notification_tx.clone(),
|
||||
subscriber_manager.clone(),
|
||||
shutdown_tx.clone(),
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
postgres_listener,
|
||||
subscriber_manager,
|
||||
websocket_server,
|
||||
shutdown_tx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Start the notifier service
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
info!("Starting Notifier Service components");
|
||||
|
||||
// Start PostgreSQL listener
|
||||
let listener_handle = {
|
||||
let listener = self.postgres_listener.clone();
|
||||
let mut shutdown_rx = self.shutdown_tx.subscribe();
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
result = listener.listen() => {
|
||||
if let Err(e) = result {
|
||||
error!("PostgreSQL listener error: {}", e);
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("PostgreSQL listener shutting down");
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Start notification broadcaster (forwards notifications to WebSocket clients)
|
||||
let broadcast_handle = {
|
||||
let subscriber_manager = self.subscriber_manager.clone();
|
||||
let mut notification_rx = self.websocket_server.notification_tx.subscribe();
|
||||
let mut shutdown_rx = self.shutdown_tx.subscribe();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(notification) = notification_rx.recv() => {
|
||||
subscriber_manager.broadcast(notification);
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("Notification broadcaster shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Start WebSocket server
|
||||
let server_handle = {
|
||||
let server = self.websocket_server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = server.start().await {
|
||||
error!("WebSocket server error: {}", e);
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let notifier_config = self
|
||||
.config
|
||||
.notifier
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Notifier configuration not found in config"))?;
|
||||
|
||||
info!(
|
||||
"Notifier Service started on {}:{}",
|
||||
notifier_config.host, notifier_config.port
|
||||
);
|
||||
|
||||
// Wait for any task to complete (they shouldn't unless there's an error)
|
||||
tokio::select! {
|
||||
_ = listener_handle => {
|
||||
error!("PostgreSQL listener stopped unexpectedly");
|
||||
}
|
||||
_ = broadcast_handle => {
|
||||
error!("Notification broadcaster stopped unexpectedly");
|
||||
}
|
||||
_ = server_handle => {
|
||||
error!("WebSocket server stopped unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the notifier service gracefully
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down Notifier Service");
|
||||
|
||||
// Send shutdown signal to all components
|
||||
let _ = self.shutdown_tx.send(());
|
||||
|
||||
// Disconnect all WebSocket clients
|
||||
self.subscriber_manager.disconnect_all().await;
|
||||
|
||||
info!("Notifier Service shutdown complete");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_notification_serialization() {
|
||||
let notification = Notification {
|
||||
notification_type: "execution_status_changed".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: Some(456),
|
||||
payload: serde_json::json!({
|
||||
"status": "succeeded",
|
||||
"action": "core.echo"
|
||||
}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
let deserialized: Notification = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
notification.notification_type,
|
||||
deserialized.notification_type
|
||||
);
|
||||
assert_eq!(notification.entity_type, deserialized.entity_type);
|
||||
assert_eq!(notification.entity_id, deserialized.entity_id);
|
||||
}
|
||||
}
|
||||
466
crates/notifier/src/subscriber_manager.rs
Normal file
466
crates/notifier/src/subscriber_manager.rs
Normal file
@@ -0,0 +1,466 @@
|
||||
//! Subscriber management for WebSocket clients
|
||||
|
||||
use dashmap::DashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::service::Notification;
|
||||
|
||||
/// Unique identifier for a WebSocket client connection
|
||||
pub type ClientId = String;
|
||||
|
||||
/// Subscription filter for notifications
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum SubscriptionFilter {
|
||||
/// Subscribe to all notifications
|
||||
All,
|
||||
|
||||
/// Subscribe to notifications for a specific entity type
|
||||
EntityType(String),
|
||||
|
||||
/// Subscribe to notifications for a specific entity
|
||||
Entity { entity_type: String, entity_id: i64 },
|
||||
|
||||
/// Subscribe to notifications for a specific user
|
||||
User(i64),
|
||||
|
||||
/// Subscribe to a specific notification type
|
||||
NotificationType(String),
|
||||
}
|
||||
|
||||
impl SubscriptionFilter {
|
||||
/// Check if this filter matches a notification
|
||||
pub fn matches(&self, notification: &Notification) -> bool {
|
||||
match self {
|
||||
SubscriptionFilter::All => true,
|
||||
SubscriptionFilter::EntityType(entity_type) => ¬ification.entity_type == entity_type,
|
||||
SubscriptionFilter::Entity {
|
||||
entity_type,
|
||||
entity_id,
|
||||
} => ¬ification.entity_type == entity_type && notification.entity_id == *entity_id,
|
||||
SubscriptionFilter::User(user_id) => notification.user_id == Some(*user_id),
|
||||
SubscriptionFilter::NotificationType(notification_type) => {
|
||||
¬ification.notification_type == notification_type
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A WebSocket client subscriber
|
||||
pub struct Subscriber {
|
||||
/// Unique client identifier
|
||||
#[allow(dead_code)]
|
||||
pub client_id: ClientId,
|
||||
|
||||
/// Optional user ID associated with this client
|
||||
#[allow(dead_code)]
|
||||
pub user_id: Option<i64>,
|
||||
|
||||
/// Channel to send notifications to this client
|
||||
pub tx: mpsc::UnboundedSender<Notification>,
|
||||
|
||||
/// Filters that determine which notifications this client receives
|
||||
pub filters: Vec<SubscriptionFilter>,
|
||||
}
|
||||
|
||||
impl Subscriber {
|
||||
/// Check if this subscriber should receive a notification
|
||||
pub fn should_receive(&self, notification: &Notification) -> bool {
|
||||
// If no filters, don't receive anything (must explicitly subscribe)
|
||||
if self.filters.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if any filter matches
|
||||
self.filters
|
||||
.iter()
|
||||
.any(|filter| filter.matches(notification))
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages all WebSocket subscribers
|
||||
pub struct SubscriberManager {
|
||||
/// Map of client ID to subscriber
|
||||
subscribers: Arc<DashMap<ClientId, Subscriber>>,
|
||||
|
||||
/// Counter for generating unique client IDs
|
||||
next_id: AtomicUsize,
|
||||
}
|
||||
|
||||
impl SubscriberManager {
|
||||
/// Create a new subscriber manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
subscribers: Arc::new(DashMap::new()),
|
||||
next_id: AtomicUsize::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a unique client ID
|
||||
pub fn generate_client_id(&self) -> ClientId {
|
||||
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
|
||||
format!("client_{}", id)
|
||||
}
|
||||
|
||||
/// Register a new subscriber
|
||||
pub fn register(
|
||||
&self,
|
||||
client_id: ClientId,
|
||||
user_id: Option<i64>,
|
||||
tx: mpsc::UnboundedSender<Notification>,
|
||||
) {
|
||||
let subscriber = Subscriber {
|
||||
client_id: client_id.clone(),
|
||||
user_id,
|
||||
tx,
|
||||
filters: vec![],
|
||||
};
|
||||
|
||||
self.subscribers.insert(client_id.clone(), subscriber);
|
||||
info!("Registered new subscriber: {}", client_id);
|
||||
}
|
||||
|
||||
/// Unregister a subscriber
|
||||
pub fn unregister(&self, client_id: &ClientId) {
|
||||
if self.subscribers.remove(client_id).is_some() {
|
||||
info!("Unregistered subscriber: {}", client_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a subscription filter for a client
|
||||
pub fn subscribe(&self, client_id: &ClientId, filter: SubscriptionFilter) -> bool {
|
||||
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
|
||||
if !subscriber.filters.contains(&filter) {
|
||||
subscriber.filters.push(filter.clone());
|
||||
debug!("Client {} subscribed to {:?}", client_id, filter);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Remove a subscription filter for a client
|
||||
pub fn unsubscribe(&self, client_id: &ClientId, filter: &SubscriptionFilter) -> bool {
|
||||
if let Some(mut subscriber) = self.subscribers.get_mut(client_id) {
|
||||
if let Some(pos) = subscriber.filters.iter().position(|f| f == filter) {
|
||||
subscriber.filters.remove(pos);
|
||||
debug!("Client {} unsubscribed from {:?}", client_id, filter);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Broadcast a notification to all matching subscribers
|
||||
pub fn broadcast(&self, notification: Notification) {
|
||||
let mut sent_count = 0;
|
||||
let mut failed_count = 0;
|
||||
|
||||
// Collect client IDs to remove (if send fails)
|
||||
let mut to_remove = Vec::new();
|
||||
|
||||
for entry in self.subscribers.iter() {
|
||||
let client_id = entry.key();
|
||||
let subscriber = entry.value();
|
||||
|
||||
// Check if this subscriber should receive the notification
|
||||
if !subscriber.should_receive(¬ification) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to send the notification
|
||||
match subscriber.tx.send(notification.clone()) {
|
||||
Ok(_) => {
|
||||
sent_count += 1;
|
||||
debug!("Sent notification to client: {}", client_id);
|
||||
}
|
||||
Err(_) => {
|
||||
// Channel closed, client disconnected
|
||||
failed_count += 1;
|
||||
to_remove.push(client_id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove disconnected clients
|
||||
for client_id in to_remove {
|
||||
self.unregister(&client_id);
|
||||
}
|
||||
|
||||
if sent_count > 0 {
|
||||
debug!(
|
||||
"Broadcast notification: sent={}, failed={}, type={}",
|
||||
sent_count, failed_count, notification.notification_type
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of connected clients
|
||||
pub fn client_count(&self) -> usize {
|
||||
self.subscribers.len()
|
||||
}
|
||||
|
||||
/// Get the total number of subscriptions across all clients
|
||||
pub fn subscription_count(&self) -> usize {
|
||||
self.subscribers
|
||||
.iter()
|
||||
.map(|entry| entry.value().filters.len())
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Disconnect all subscribers
|
||||
pub async fn disconnect_all(&self) {
|
||||
let client_ids: Vec<ClientId> = self
|
||||
.subscribers
|
||||
.iter()
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
for client_id in client_ids {
|
||||
self.unregister(&client_id);
|
||||
}
|
||||
|
||||
info!("Disconnected all subscribers");
|
||||
}
|
||||
|
||||
/// Get subscriber information for a client
|
||||
#[allow(dead_code)]
|
||||
pub fn get_subscriber_info(&self, client_id: &ClientId) -> Option<SubscriberInfo> {
|
||||
self.subscribers
|
||||
.get(client_id)
|
||||
.map(|subscriber| SubscriberInfo {
|
||||
client_id: subscriber.client_id.clone(),
|
||||
user_id: subscriber.user_id,
|
||||
filter_count: subscriber.filters.len(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SubscriberManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a subscriber (for status/debugging)
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SubscriberInfo {
|
||||
pub client_id: ClientId,
|
||||
pub user_id: Option<i64>,
|
||||
pub filter_count: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_all_matches_everything() {
|
||||
let filter = SubscriptionFilter::All;
|
||||
let notification = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: Some(456),
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_entity_type() {
|
||||
let filter = SubscriptionFilter::EntityType("execution".to_string());
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "inquiry".to_string(),
|
||||
entity_id: 456,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification1));
|
||||
assert!(!filter.matches(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_specific_entity() {
|
||||
let filter = SubscriptionFilter::Entity {
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
};
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 456,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification1));
|
||||
assert!(!filter.matches(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscription_filter_user() {
|
||||
let filter = SubscriptionFilter::User(456);
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: Some(456),
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: Some(789),
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(filter.matches(¬ification1));
|
||||
assert!(!filter.matches(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscriber_manager_register_unregister() {
|
||||
let manager = SubscriberManager::new();
|
||||
let client_id = manager.generate_client_id();
|
||||
|
||||
assert_eq!(manager.client_count(), 0);
|
||||
|
||||
let (tx, _rx) = mpsc::unbounded_channel();
|
||||
manager.register(client_id.clone(), Some(123), tx);
|
||||
|
||||
assert_eq!(manager.client_count(), 1);
|
||||
|
||||
manager.unregister(&client_id);
|
||||
|
||||
assert_eq!(manager.client_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscriber_manager_subscribe() {
|
||||
let manager = SubscriberManager::new();
|
||||
let client_id = manager.generate_client_id();
|
||||
|
||||
let (tx, _rx) = mpsc::unbounded_channel();
|
||||
manager.register(client_id.clone(), None, tx);
|
||||
|
||||
// Subscribe to all notifications
|
||||
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
|
||||
assert!(result);
|
||||
|
||||
assert_eq!(manager.subscription_count(), 1);
|
||||
|
||||
// Subscribing to the same filter again should not increase count
|
||||
let result = manager.subscribe(&client_id, SubscriptionFilter::All);
|
||||
assert!(!result);
|
||||
|
||||
assert_eq!(manager.subscription_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscriber_should_receive() {
|
||||
let (tx, _rx) = mpsc::unbounded_channel();
|
||||
let subscriber = Subscriber {
|
||||
client_id: "test".to_string(),
|
||||
user_id: Some(456),
|
||||
tx,
|
||||
filters: vec![SubscriptionFilter::EntityType("execution".to_string())],
|
||||
};
|
||||
|
||||
let notification1 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let notification2 = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "inquiry".to_string(),
|
||||
entity_id: 456,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert!(subscriber.should_receive(¬ification1));
|
||||
assert!(!subscriber.should_receive(¬ification2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_to_matching_subscribers() {
|
||||
let manager = SubscriberManager::new();
|
||||
|
||||
let client1_id = manager.generate_client_id();
|
||||
let (tx1, mut rx1) = mpsc::unbounded_channel();
|
||||
manager.register(client1_id.clone(), None, tx1);
|
||||
manager.subscribe(
|
||||
&client1_id,
|
||||
SubscriptionFilter::EntityType("execution".to_string()),
|
||||
);
|
||||
|
||||
let client2_id = manager.generate_client_id();
|
||||
let (tx2, mut rx2) = mpsc::unbounded_channel();
|
||||
manager.register(client2_id.clone(), None, tx2);
|
||||
manager.subscribe(
|
||||
&client2_id,
|
||||
SubscriptionFilter::EntityType("inquiry".to_string()),
|
||||
);
|
||||
|
||||
let notification = Notification {
|
||||
notification_type: "test".to_string(),
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 123,
|
||||
user_id: None,
|
||||
payload: serde_json::json!({}),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
manager.broadcast(notification.clone());
|
||||
|
||||
// Client 1 should receive the notification
|
||||
let received1 = rx1.try_recv();
|
||||
assert!(received1.is_ok());
|
||||
assert_eq!(received1.unwrap().entity_id, 123);
|
||||
|
||||
// Client 2 should not receive the notification
|
||||
let received2 = rx2.try_recv();
|
||||
assert!(received2.is_err());
|
||||
}
|
||||
}
|
||||
367
crates/notifier/src/websocket_server.rs
Normal file
367
crates/notifier/src/websocket_server.rs
Normal file
@@ -0,0 +1,367 @@
|
||||
//! WebSocket server for real-time notifications
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
State,
|
||||
},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use attune_common::config::Config;
|
||||
|
||||
use crate::service::Notification;
|
||||
use crate::subscriber_manager::{ClientId, SubscriberManager, SubscriptionFilter};
|
||||
|
||||
/// WebSocket server for handling client connections
|
||||
pub struct WebSocketServer {
|
||||
config: Config,
|
||||
pub notification_tx: broadcast::Sender<Notification>,
|
||||
subscriber_manager: Arc<SubscriberManager>,
|
||||
shutdown_tx: broadcast::Sender<()>,
|
||||
}
|
||||
|
||||
impl WebSocketServer {
|
||||
/// Create a new WebSocket server
|
||||
pub fn new(
|
||||
config: Config,
|
||||
notification_tx: broadcast::Sender<Notification>,
|
||||
subscriber_manager: Arc<SubscriberManager>,
|
||||
shutdown_tx: broadcast::Sender<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
notification_tx,
|
||||
subscriber_manager,
|
||||
shutdown_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Clone method for spawning tasks
|
||||
pub fn clone(&self) -> Self {
|
||||
Self {
|
||||
config: self.config.clone(),
|
||||
notification_tx: self.notification_tx.clone(),
|
||||
subscriber_manager: self.subscriber_manager.clone(),
|
||||
shutdown_tx: self.shutdown_tx.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the WebSocket server
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
let app_state = Arc::new(AppState {
|
||||
notification_tx: self.notification_tx.clone(),
|
||||
subscriber_manager: self.subscriber_manager.clone(),
|
||||
});
|
||||
|
||||
// Build router with WebSocket endpoint
|
||||
let app = Router::new()
|
||||
.route("/ws", get(websocket_handler))
|
||||
.route("/health", get(health_handler))
|
||||
.route("/stats", get(stats_handler))
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any),
|
||||
)
|
||||
.with_state(app_state);
|
||||
|
||||
let notifier_config = self
|
||||
.config
|
||||
.notifier
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Notifier configuration not found in config"))?;
|
||||
|
||||
let addr = format!("{}:{}", notifier_config.host, notifier_config.port);
|
||||
let listener = tokio::net::TcpListener::bind(&addr)
|
||||
.await
|
||||
.context(format!("Failed to bind to {}", addr))?;
|
||||
|
||||
info!("WebSocket server listening on {}", addr);
|
||||
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.context("WebSocket server error")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared application state
|
||||
struct AppState {
|
||||
#[allow(dead_code)]
|
||||
notification_tx: broadcast::Sender<Notification>,
|
||||
subscriber_manager: Arc<SubscriberManager>,
|
||||
}
|
||||
|
||||
/// Health check endpoint
|
||||
async fn health_handler() -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
/// Stats endpoint
|
||||
async fn stats_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
let stats = serde_json::json!({
|
||||
"connected_clients": state.subscriber_manager.client_count(),
|
||||
"total_subscriptions": state.subscriber_manager.subscription_count(),
|
||||
});
|
||||
(StatusCode::OK, Json(stats))
|
||||
}
|
||||
|
||||
/// WebSocket handler - upgrades HTTP connection to WebSocket
|
||||
async fn websocket_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| handle_websocket(socket, state))
|
||||
}
|
||||
|
||||
/// Handle individual WebSocket connection
|
||||
async fn handle_websocket(socket: WebSocket, state: Arc<AppState>) {
|
||||
let client_id = state.subscriber_manager.generate_client_id();
|
||||
info!("New WebSocket connection: {}", client_id);
|
||||
|
||||
// Split the socket into sender and receiver
|
||||
let (mut ws_sender, mut ws_receiver) = socket.split();
|
||||
|
||||
// Create channel for sending notifications to this client
|
||||
let (tx, mut rx) = mpsc::unbounded_channel::<Notification>();
|
||||
|
||||
// Register the subscriber
|
||||
state
|
||||
.subscriber_manager
|
||||
.register(client_id.clone(), None, tx);
|
||||
|
||||
// Send welcome message
|
||||
let welcome = ClientMessage::Welcome {
|
||||
client_id: client_id.clone(),
|
||||
message: "Connected to Attune Notifier".to_string(),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&welcome) {
|
||||
let _ = ws_sender.send(Message::Text(json.into())).await;
|
||||
}
|
||||
|
||||
// Spawn task to handle outgoing notifications
|
||||
let client_id_clone = client_id.clone();
|
||||
let subscriber_manager_clone = state.subscriber_manager.clone();
|
||||
let outgoing_task = tokio::spawn(async move {
|
||||
while let Some(notification) = rx.recv().await {
|
||||
// Serialize notification to JSON
|
||||
match serde_json::to_string(¬ification) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = ws_sender.send(Message::Text(json.into())).await {
|
||||
error!("Failed to send notification to {}: {}", client_id_clone, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to serialize notification: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("Outgoing task stopped for client: {}", client_id_clone);
|
||||
subscriber_manager_clone.unregister(&client_id_clone);
|
||||
});
|
||||
|
||||
// Handle incoming messages from client (subscriptions, etc.)
|
||||
let subscriber_manager_clone = state.subscriber_manager.clone();
|
||||
let client_id_clone = client_id.clone();
|
||||
while let Some(msg) = ws_receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(text)) => {
|
||||
if let Err(e) =
|
||||
handle_client_message(&client_id_clone, &text, &subscriber_manager_clone).await
|
||||
{
|
||||
error!("Error handling client message: {}", e);
|
||||
}
|
||||
}
|
||||
Ok(Message::Binary(_)) => {
|
||||
warn!("Received binary message from {}, ignoring", client_id_clone);
|
||||
}
|
||||
Ok(Message::Close(_)) => {
|
||||
info!("Client {} closed connection", client_id_clone);
|
||||
break;
|
||||
}
|
||||
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
|
||||
// Handled automatically by axum
|
||||
}
|
||||
Err(e) => {
|
||||
error!("WebSocket error for {}: {}", client_id_clone, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
subscriber_manager_clone.unregister(&client_id);
|
||||
outgoing_task.abort();
|
||||
info!("WebSocket connection closed: {}", client_id);
|
||||
}
|
||||
|
||||
/// Handle incoming message from client
|
||||
async fn handle_client_message(
|
||||
client_id: &ClientId,
|
||||
message: &str,
|
||||
subscriber_manager: &SubscriberManager,
|
||||
) -> Result<()> {
|
||||
let msg: ServerMessage =
|
||||
serde_json::from_str(message).context("Failed to parse client message")?;
|
||||
|
||||
match msg {
|
||||
ServerMessage::Subscribe { filter } => {
|
||||
let subscription_filter = parse_subscription_filter(&filter)?;
|
||||
subscriber_manager.subscribe(client_id, subscription_filter);
|
||||
info!("Client {} subscribed to: {:?}", client_id, filter);
|
||||
}
|
||||
ServerMessage::Unsubscribe { filter } => {
|
||||
let subscription_filter = parse_subscription_filter(&filter)?;
|
||||
subscriber_manager.unsubscribe(client_id, &subscription_filter);
|
||||
info!("Client {} unsubscribed from: {:?}", client_id, filter);
|
||||
}
|
||||
ServerMessage::Ping => {
|
||||
debug!("Received ping from {}", client_id);
|
||||
// Pong is handled automatically
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse subscription filter from string
|
||||
fn parse_subscription_filter(filter_str: &str) -> Result<SubscriptionFilter> {
|
||||
// Format: "type:value" or "all"
|
||||
if filter_str == "all" {
|
||||
return Ok(SubscriptionFilter::All);
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = filter_str.split(':').collect();
|
||||
if parts.len() < 2 {
|
||||
anyhow::bail!("Invalid filter format: {}", filter_str);
|
||||
}
|
||||
|
||||
match parts[0] {
|
||||
"entity_type" => Ok(SubscriptionFilter::EntityType(parts[1].to_string())),
|
||||
"notification_type" => Ok(SubscriptionFilter::NotificationType(parts[1].to_string())),
|
||||
"user" => {
|
||||
let user_id: i64 = parts[1].parse().context("Invalid user ID")?;
|
||||
Ok(SubscriptionFilter::User(user_id))
|
||||
}
|
||||
"entity" => {
|
||||
if parts.len() < 3 {
|
||||
anyhow::bail!("Entity filter requires type and id: entity:type:id");
|
||||
}
|
||||
let entity_id: i64 = parts[2].parse().context("Invalid entity ID")?;
|
||||
Ok(SubscriptionFilter::Entity {
|
||||
entity_type: parts[1].to_string(),
|
||||
entity_id,
|
||||
})
|
||||
}
|
||||
_ => anyhow::bail!("Unknown filter type: {}", parts[0]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Messages sent from server to client
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[allow(dead_code)]
|
||||
enum ClientMessage {
|
||||
#[serde(rename = "welcome")]
|
||||
Welcome { client_id: String, message: String },
|
||||
|
||||
#[serde(rename = "notification")]
|
||||
Notification(Notification),
|
||||
|
||||
#[serde(rename = "error")]
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// Messages sent from client to server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum ServerMessage {
|
||||
#[serde(rename = "subscribe")]
|
||||
Subscribe { filter: String },
|
||||
|
||||
#[serde(rename = "unsubscribe")]
|
||||
Unsubscribe { filter: String },
|
||||
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_all() {
|
||||
let filter = parse_subscription_filter("all").unwrap();
|
||||
assert_eq!(filter, SubscriptionFilter::All);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_entity_type() {
|
||||
let filter = parse_subscription_filter("entity_type:execution").unwrap();
|
||||
assert_eq!(
|
||||
filter,
|
||||
SubscriptionFilter::EntityType("execution".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_notification_type() {
|
||||
let filter =
|
||||
parse_subscription_filter("notification_type:execution_status_changed").unwrap();
|
||||
assert_eq!(
|
||||
filter,
|
||||
SubscriptionFilter::NotificationType("execution_status_changed".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_user() {
|
||||
let filter = parse_subscription_filter("user:123").unwrap();
|
||||
assert_eq!(filter, SubscriptionFilter::User(123));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_entity() {
|
||||
let filter = parse_subscription_filter("entity:execution:456").unwrap();
|
||||
assert_eq!(
|
||||
filter,
|
||||
SubscriptionFilter::Entity {
|
||||
entity_type: "execution".to_string(),
|
||||
entity_id: 456
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_invalid() {
|
||||
let result = parse_subscription_filter("invalid");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_invalid_user_id() {
|
||||
let result = parse_subscription_filter("user:not_a_number");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscription_filter_entity_missing_id() {
|
||||
let result = parse_subscription_filter("entity:execution");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user