re-uploading work
This commit is contained in:
492
tests/helpers/polling.py
Normal file
492
tests/helpers/polling.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""
|
||||
Polling Utilities for E2E Tests
|
||||
|
||||
Provides helper functions for waiting on asynchronous conditions
|
||||
during end-to-end testing.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from requests.models import HTTPError
|
||||
|
||||
from .client import AttuneClient
|
||||
|
||||
|
||||
def wait_for_condition(
|
||||
condition_fn: Callable[[], bool],
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
error_message: str = "Condition not met within timeout",
|
||||
) -> bool:
|
||||
"""
|
||||
Wait for a condition function to return True
|
||||
|
||||
Args:
|
||||
condition_fn: Function that returns True when condition is met
|
||||
timeout: Maximum time to wait in seconds
|
||||
poll_interval: Time between checks in seconds
|
||||
error_message: Error message if timeout occurs
|
||||
|
||||
Returns:
|
||||
True if condition met
|
||||
|
||||
Raises:
|
||||
TimeoutError: If condition not met within timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
elapsed = 0.0
|
||||
|
||||
while elapsed < timeout:
|
||||
try:
|
||||
if condition_fn():
|
||||
return True
|
||||
except Exception:
|
||||
# Ignore exceptions during polling (e.g., 404 errors)
|
||||
pass
|
||||
|
||||
time.sleep(poll_interval)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
raise TimeoutError(f"{error_message} (waited {elapsed:.1f}s)")
|
||||
|
||||
|
||||
def wait_for_execution_status(
|
||||
client: AttuneClient,
|
||||
execution_id: int,
|
||||
expected_status: str,
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
) -> dict:
|
||||
"""
|
||||
Wait for execution to reach expected status
|
||||
|
||||
Args:
|
||||
client: AttuneClient instance
|
||||
execution_id: Execution ID to monitor
|
||||
expected_status: Expected status (succeeded, failed, canceled, etc)
|
||||
timeout: Maximum time to wait in seconds
|
||||
poll_interval: Time between status checks
|
||||
|
||||
Returns:
|
||||
Final execution object
|
||||
|
||||
Raises:
|
||||
TimeoutError: If status not reached within timeout
|
||||
"""
|
||||
execution = client.get_execution(execution_id)
|
||||
|
||||
def check_status():
|
||||
nonlocal execution
|
||||
execution = client.get_execution(execution_id)
|
||||
return execution["status"] == expected_status
|
||||
|
||||
wait_for_condition(
|
||||
check_status,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
error_message=f"Execution {execution_id} did not reach status '{expected_status}'",
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
def wait_for_execution_completion(
|
||||
client: AttuneClient,
|
||||
execution_id: int,
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
) -> dict:
|
||||
"""
|
||||
Wait for execution to complete (reach terminal status)
|
||||
|
||||
Terminal statuses are: succeeded, failed, canceled, timeout
|
||||
|
||||
Args:
|
||||
client: AttuneClient instance
|
||||
execution_id: Execution ID to monitor
|
||||
timeout: Maximum time to wait in seconds
|
||||
poll_interval: Time between status checks
|
||||
|
||||
Returns:
|
||||
Final execution object
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution doesn't complete within timeout
|
||||
"""
|
||||
execution = client.get_execution(execution_id)
|
||||
|
||||
def check_completion():
|
||||
nonlocal execution
|
||||
execution = client.get_execution(execution_id)
|
||||
terminal_statuses = ["succeeded", "failed", "canceled", "timeout"]
|
||||
return execution["status"] in terminal_statuses
|
||||
|
||||
wait_for_condition(
|
||||
check_completion,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
error_message=f"Execution {execution_id} did not complete",
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
def wait_for_execution_count(
|
||||
client: AttuneClient,
|
||||
expected_count: int,
|
||||
action_ref: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
enforcement_id: Optional[int] = None,
|
||||
rule_id: Optional[int] = None,
|
||||
created_after: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
operator: str = ">=",
|
||||
verbose: bool = False,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Wait for execution count to reach threshold
|
||||
|
||||
Args:
|
||||
client: AttuneClient instance
|
||||
expected_count: Expected number of executions
|
||||
action_ref: Optional filter by action reference
|
||||
status: Optional filter by status
|
||||
enforcement_id: Optional filter by enforcement ID (most precise)
|
||||
rule_id: Optional filter by rule ID (via enforcement)
|
||||
created_after: Optional ISO timestamp to filter executions created after this time
|
||||
timeout: Maximum time to wait
|
||||
poll_interval: Time between checks
|
||||
operator: Comparison operator (>=, ==, <=, >, <)
|
||||
verbose: Print debug information during polling
|
||||
|
||||
Returns:
|
||||
List of executions
|
||||
|
||||
Raises:
|
||||
TimeoutError: If count not reached within timeout
|
||||
"""
|
||||
executions = []
|
||||
|
||||
def check_count():
|
||||
nonlocal executions
|
||||
|
||||
# If rule_id is provided, get executions via enforcements
|
||||
if rule_id is not None:
|
||||
# Get all enforcements for this rule
|
||||
enforcements = client.list_enforcements(rule_id=rule_id, limit=1000)
|
||||
if verbose:
|
||||
print(
|
||||
f" [DEBUG] Found {len(enforcements)} enforcements for rule {rule_id}"
|
||||
)
|
||||
# Get executions for each enforcement
|
||||
all_executions = []
|
||||
for enf in enforcements:
|
||||
enf_executions = client.list_executions(
|
||||
enforcement_id=enf["id"], status=status, limit=1000
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f" [DEBUG] Enforcement {enf['id']}: {len(enf_executions)} executions"
|
||||
)
|
||||
all_executions.extend(enf_executions)
|
||||
executions = all_executions
|
||||
elif enforcement_id is not None:
|
||||
# Filter by specific enforcement
|
||||
executions = client.list_executions(
|
||||
enforcement_id=enforcement_id, status=status, limit=1000
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f" [DEBUG] Found {len(executions)} executions for enforcement {enforcement_id}"
|
||||
)
|
||||
else:
|
||||
# Use action_ref and status filters
|
||||
executions = client.list_executions(
|
||||
action_ref=action_ref, status=status, limit=1000
|
||||
)
|
||||
if verbose:
|
||||
filter_str = f"action_ref={action_ref}" if action_ref else "all"
|
||||
if status:
|
||||
filter_str += f", status={status}"
|
||||
print(f" [DEBUG] Found {len(executions)} executions ({filter_str})")
|
||||
|
||||
# Apply timestamp filter if provided
|
||||
if created_after:
|
||||
from datetime import datetime
|
||||
|
||||
cutoff = datetime.fromisoformat(created_after.replace("Z", "+00:00"))
|
||||
filtered = []
|
||||
for exec in executions:
|
||||
exec_time = datetime.fromisoformat(
|
||||
exec["created"].replace("Z", "+00:00")
|
||||
)
|
||||
if exec_time > cutoff:
|
||||
filtered.append(exec)
|
||||
if verbose:
|
||||
print(
|
||||
f" [DEBUG] After timestamp filter: {len(filtered)} executions (was {len(executions)})"
|
||||
)
|
||||
executions = filtered
|
||||
|
||||
actual_count = len(executions)
|
||||
|
||||
if verbose:
|
||||
print(f" [DEBUG] Checking: {actual_count} {operator} {expected_count}")
|
||||
|
||||
if operator == ">=":
|
||||
return actual_count >= expected_count
|
||||
elif operator == "==":
|
||||
return actual_count == expected_count
|
||||
elif operator == "<=":
|
||||
return actual_count <= expected_count
|
||||
elif operator == ">":
|
||||
return actual_count > expected_count
|
||||
elif operator == "<":
|
||||
return actual_count < expected_count
|
||||
else:
|
||||
raise ValueError(f"Invalid operator: {operator}")
|
||||
|
||||
filter_desc = ""
|
||||
if rule_id:
|
||||
filter_desc += f" for rule {rule_id}"
|
||||
elif enforcement_id:
|
||||
filter_desc += f" for enforcement {enforcement_id}"
|
||||
elif action_ref:
|
||||
filter_desc += f" for action {action_ref}"
|
||||
if status:
|
||||
filter_desc += f" with status {status}"
|
||||
if created_after:
|
||||
filter_desc += f" created after {created_after}"
|
||||
|
||||
wait_for_condition(
|
||||
check_count,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
error_message=f"Execution count did not reach {operator} {expected_count}{filter_desc}",
|
||||
)
|
||||
|
||||
return executions
|
||||
|
||||
|
||||
def wait_for_event_count(
|
||||
client: AttuneClient,
|
||||
expected_count: int,
|
||||
trigger_id: Optional[int] = None,
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
operator: str = ">=",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Wait for event count to reach threshold
|
||||
|
||||
Args:
|
||||
client: AttuneClient instance
|
||||
expected_count: Expected number of events
|
||||
trigger_id: Optional filter by trigger ID
|
||||
timeout: Maximum time to wait
|
||||
poll_interval: Time between checks
|
||||
operator: Comparison operator (>=, ==, <=, >, <)
|
||||
|
||||
Returns:
|
||||
List of events
|
||||
|
||||
Raises:
|
||||
TimeoutError: If count not reached within timeout
|
||||
"""
|
||||
events = []
|
||||
|
||||
def check_count():
|
||||
nonlocal events
|
||||
events = client.list_events(trigger_id=trigger_id, limit=1000)
|
||||
actual_count = len(events)
|
||||
|
||||
if operator == ">=":
|
||||
return actual_count >= expected_count
|
||||
elif operator == "==":
|
||||
return actual_count == expected_count
|
||||
elif operator == "<=":
|
||||
return actual_count <= expected_count
|
||||
elif operator == ">":
|
||||
return actual_count > expected_count
|
||||
elif operator == "<":
|
||||
return actual_count < expected_count
|
||||
else:
|
||||
raise ValueError(f"Invalid operator: {operator}")
|
||||
|
||||
filter_desc = f" for trigger {trigger_id}" if trigger_id else ""
|
||||
|
||||
wait_for_condition(
|
||||
check_count,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
error_message=f"Event count did not reach {operator} {expected_count}{filter_desc}",
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def wait_for_enforcement_count(
|
||||
client: AttuneClient,
|
||||
expected_count: int,
|
||||
rule_id: Optional[int] = None,
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
operator: str = ">=",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Wait for enforcement count to reach threshold
|
||||
|
||||
Args:
|
||||
client: AttuneClient instance
|
||||
expected_count: Expected number of enforcements
|
||||
rule_id: Optional filter by rule ID
|
||||
timeout: Maximum time to wait
|
||||
poll_interval: Time between checks
|
||||
operator: Comparison operator (>=, ==, <=, >, <)
|
||||
|
||||
Returns:
|
||||
List of enforcements
|
||||
|
||||
Raises:
|
||||
TimeoutError: If count not reached within timeout
|
||||
"""
|
||||
enforcements = []
|
||||
|
||||
def check_count():
|
||||
nonlocal enforcements
|
||||
enforcements = client.list_enforcements(rule_id=rule_id, limit=1000)
|
||||
actual_count = len(enforcements)
|
||||
|
||||
if operator == ">=":
|
||||
return actual_count >= expected_count
|
||||
elif operator == "==":
|
||||
return actual_count == expected_count
|
||||
elif operator == "<=":
|
||||
return actual_count <= expected_count
|
||||
elif operator == ">":
|
||||
return actual_count > expected_count
|
||||
elif operator == "<":
|
||||
return actual_count < expected_count
|
||||
else:
|
||||
raise ValueError(f"Invalid operator: {operator}")
|
||||
|
||||
filter_desc = f" for rule {rule_id}" if rule_id else ""
|
||||
|
||||
wait_for_condition(
|
||||
check_count,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
error_message=f"Enforcement count did not reach {operator} {expected_count}{filter_desc}",
|
||||
)
|
||||
|
||||
return enforcements
|
||||
|
||||
|
||||
def wait_for_inquiry_status(
|
||||
client: AttuneClient,
|
||||
inquiry_id: int,
|
||||
expected_status: str,
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
) -> dict:
|
||||
"""
|
||||
Wait for inquiry to reach expected status
|
||||
|
||||
Args:
|
||||
client: AttuneClient instance
|
||||
inquiry_id: Inquiry ID to monitor
|
||||
expected_status: Expected status (pending, responded, expired)
|
||||
timeout: Maximum time to wait
|
||||
poll_interval: Time between checks
|
||||
|
||||
Returns:
|
||||
Final inquiry object
|
||||
|
||||
Raises:
|
||||
TimeoutError: If status not reached within timeout
|
||||
"""
|
||||
inquiry = client.get_inquiry(inquiry_id)
|
||||
|
||||
def check_status():
|
||||
nonlocal inquiry
|
||||
inquiry = client.get_inquiry(inquiry_id)
|
||||
return inquiry["status"] == expected_status
|
||||
|
||||
wait_for_condition(
|
||||
check_status,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
error_message=f"Inquiry {inquiry_id} did not reach status '{expected_status}'",
|
||||
)
|
||||
|
||||
return inquiry
|
||||
|
||||
|
||||
def wait_for_inquiry_count(
|
||||
client: AttuneClient,
|
||||
expected_count: int,
|
||||
status: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
poll_interval: float = 0.5,
|
||||
operator: str = ">=",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Wait for inquiry count to reach expected value
|
||||
|
||||
Args:
|
||||
client: AttuneClient instance
|
||||
expected_count: Expected number of inquiries
|
||||
status: Optional status filter (pending, responded, expired, etc)
|
||||
timeout: Maximum time to wait
|
||||
poll_interval: Time between checks
|
||||
operator: Comparison operator (>=, ==, <=, >, <)
|
||||
|
||||
Returns:
|
||||
List of inquiries matching criteria
|
||||
|
||||
Raises:
|
||||
TimeoutError: If count not reached within timeout
|
||||
"""
|
||||
inquiries = []
|
||||
|
||||
def check_count():
|
||||
nonlocal inquiries
|
||||
try:
|
||||
response = client.get("/inquiries")
|
||||
except HTTPError:
|
||||
return False
|
||||
|
||||
inquiries = response.get("data", [])
|
||||
|
||||
# Filter by status if specified
|
||||
if status:
|
||||
inquiries = [i for i in inquiries if i.get("status") == status]
|
||||
|
||||
actual_count = len(inquiries)
|
||||
|
||||
# Check count based on operator
|
||||
if operator == "==":
|
||||
return actual_count == expected_count
|
||||
elif operator == ">=":
|
||||
return actual_count >= expected_count
|
||||
elif operator == "<=":
|
||||
return actual_count <= expected_count
|
||||
elif operator == ">":
|
||||
return actual_count > expected_count
|
||||
elif operator == "<":
|
||||
return actual_count < expected_count
|
||||
else:
|
||||
raise ValueError(f"Invalid operator: {operator}")
|
||||
|
||||
filter_desc = f" with status {status}" if status else ""
|
||||
|
||||
wait_for_condition(
|
||||
check_count,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
error_message=f"Inquiry count did not reach {operator} {expected_count}{filter_desc}",
|
||||
)
|
||||
|
||||
return inquiries
|
||||
Reference in New Issue
Block a user