chrysopedia/scripts/load_test_chat.py
jlightner 160adc24bf test: Created standalone async load test script that fires concurrent c…
- "scripts/load_test_chat.py"

GSD-Task: S08/T02
2026-04-04 14:33:29 +00:00

366 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""Load test for Chrysopedia chat SSE endpoint.
Fires N concurrent chat requests, parses the SSE stream to measure
time-to-first-token (TTFT) and total response time, and reports
min / p50 / p95 / max latency statistics.
Requirements:
pip install httpx (already a project dependency)
Rate-limit note:
The default anonymous rate limit is 10 requests/hour per IP.
Running 10 concurrent requests from one IP will saturate that quota.
Use --auth-token to authenticate (per-user limit is higher), or
temporarily raise the rate limit in the API config.
Examples:
# Quick smoke test (1 request)
python scripts/load_test_chat.py --concurrency 1
# Full load test with auth token and JSON output
python scripts/load_test_chat.py --concurrency 10 \\
--auth-token eyJ... --output results.json
# Dry-run to verify SSE parsing without a live server
python scripts/load_test_chat.py --dry-run
"""
from __future__ import annotations
import argparse
import asyncio
import json
import statistics
import sys
import time
from dataclasses import asdict, dataclass, field
from typing import Any
@dataclass
class ChatResult:
"""Metrics from a single chat request."""
request_id: int = 0
ttft_ms: float | None = None
total_ms: float = 0.0
token_count: int = 0
error: str | None = None
status_code: int | None = None
fallback_used: bool | None = None
# ---------------------------------------------------------------------------
# SSE parsing
# ---------------------------------------------------------------------------
def parse_sse_lines(raw_lines: list[str]):
"""Yield (event_type, data_str) tuples from raw SSE lines."""
current_event = ""
data_buf: list[str] = []
for line in raw_lines:
if line.startswith("event: "):
current_event = line[7:].strip()
elif line.startswith("data: "):
data_buf.append(line[6:])
elif line.strip() == "" and (current_event or data_buf):
yield current_event, "\n".join(data_buf)
current_event = ""
data_buf = []
# Flush any remaining partial event
if current_event or data_buf:
yield current_event, "\n".join(data_buf)
# ---------------------------------------------------------------------------
# Single request runner
# ---------------------------------------------------------------------------
async def run_single_chat(
client: Any, # httpx.AsyncClient
url: str,
query: str,
request_id: int,
) -> ChatResult:
"""POST to the chat endpoint and parse the SSE stream."""
result = ChatResult(request_id=request_id)
t0 = time.monotonic()
try:
async with client.stream(
"POST",
f"{url}/api/v1/chat",
json={"query": query},
timeout=60.0,
) as resp:
result.status_code = resp.status_code
if resp.status_code != 200:
body = await resp.aread()
result.error = f"HTTP {resp.status_code}: {body.decode(errors='replace')[:200]}"
result.total_ms = (time.monotonic() - t0) * 1000
return result
raw_lines: list[str] = []
async for line in resp.aiter_lines():
raw_lines.append(line)
# Detect first token for TTFT
if result.ttft_ms is None and line.startswith("event: token"):
result.ttft_ms = (time.monotonic() - t0) * 1000
# Parse collected SSE events
for event_type, data_str in parse_sse_lines(raw_lines):
if event_type == "token":
result.token_count += 1
elif event_type == "done":
try:
done = json.loads(data_str)
result.fallback_used = done.get("fallback_used")
except json.JSONDecodeError:
pass
elif event_type == "error":
result.error = data_str
except Exception as exc:
result.error = f"{type(exc).__name__}: {exc}"
result.total_ms = (time.monotonic() - t0) * 1000
return result
# ---------------------------------------------------------------------------
# Dry-run: mock SSE stream for offline testing
# ---------------------------------------------------------------------------
_MOCK_SSE = """\
event: sources
data: [{"title":"Test","url":"/t/test"}]
event: token
data: Hello
event: token
data: world
event: token
data: !
event: done
data: {"cascade_tier":"global","conversation_id":"test-123","fallback_used":false}
"""
async def run_dry_run() -> list[ChatResult]:
"""Parse a canned SSE response to verify the parsing logic works."""
result = ChatResult(request_id=0, status_code=200)
t0 = time.monotonic()
raw_lines = _MOCK_SSE.strip().splitlines()
for line in raw_lines:
if result.ttft_ms is None and line.startswith("event: token"):
result.ttft_ms = (time.monotonic() - t0) * 1000
for event_type, data_str in parse_sse_lines(raw_lines):
if event_type == "token":
result.token_count += 1
elif event_type == "done":
try:
done = json.loads(data_str)
result.fallback_used = done.get("fallback_used")
except json.JSONDecodeError:
pass
elif event_type == "error":
result.error = data_str
result.total_ms = (time.monotonic() - t0) * 1000
return [result]
# ---------------------------------------------------------------------------
# Load test orchestrator
# ---------------------------------------------------------------------------
async def run_load_test(
url: str,
concurrency: int,
query: str,
auth_token: str | None = None,
) -> list[ChatResult]:
"""Fire concurrent chat requests and collect results."""
import httpx
headers: dict[str, str] = {}
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
async with httpx.AsyncClient(headers=headers) as client:
tasks = [
run_single_chat(client, url, query, i)
for i in range(concurrency)
]
results = await asyncio.gather(*tasks)
return list(results)
# ---------------------------------------------------------------------------
# Statistics & reporting
# ---------------------------------------------------------------------------
def percentile(values: list[float], p: float) -> float:
"""Return the p-th percentile of a sorted list (0100 scale)."""
if not values:
return 0.0
k = (len(values) - 1) * (p / 100)
f = int(k)
c = f + 1 if f + 1 < len(values) else f
d = k - f
return values[f] + d * (values[c] - values[f])
def print_stats(results: list[ChatResult]) -> None:
"""Print summary statistics and per-request table."""
successes = [r for r in results if r.error is None]
errors = [r for r in results if r.error is not None]
print(f"\n{'='*60}")
print(f" Chat Load Test Results ({len(results)} requests)")
print(f"{'='*60}")
print(f" Successes: {len(successes)} | Errors: {len(errors)}")
if successes:
totals = sorted(r.total_ms for r in successes)
ttfts = sorted(r.ttft_ms for r in successes if r.ttft_ms is not None)
tokens = [r.token_count for r in successes]
print(f"\n Total Response Time (ms):")
print(f" min={totals[0]:.0f} p50={percentile(totals, 50):.0f}"
f" p95={percentile(totals, 95):.0f} max={totals[-1]:.0f}")
if ttfts:
print(f" Time to First Token (ms):")
print(f" min={ttfts[0]:.0f} p50={percentile(ttfts, 50):.0f}"
f" p95={percentile(ttfts, 95):.0f} max={ttfts[-1]:.0f}")
print(f" Tokens per response:")
print(f" min={min(tokens)} avg={statistics.mean(tokens):.1f}"
f" max={max(tokens)}")
fallback_count = sum(1 for r in successes if r.fallback_used)
if fallback_count:
print(f" Fallback used: {fallback_count}/{len(successes)}")
# Per-request table
print(f"\n {'#':>3} {'Status':>6} {'TTFT':>8} {'Total':>8} {'Tokens':>6} Error")
print(f" {'-'*3} {'-'*6} {'-'*8} {'-'*8} {'-'*6} {'-'*20}")
for r in results:
status = str(r.status_code or "---")
ttft = f"{r.ttft_ms:.0f}ms" if r.ttft_ms is not None else "---"
total = f"{r.total_ms:.0f}ms"
err = (r.error or "")[:40]
print(f" {r.request_id:>3} {status:>6} {ttft:>8} {total:>8} {r.token_count:>6} {err}")
print(f"{'='*60}\n")
def write_json_output(results: list[ChatResult], path: str) -> None:
"""Write results to a JSON file."""
data = {
"results": [asdict(r) for r in results],
"summary": {},
}
successes = [r for r in results if r.error is None]
if successes:
totals = sorted(r.total_ms for r in successes)
ttfts = sorted(r.ttft_ms for r in successes if r.ttft_ms is not None)
data["summary"] = {
"total_requests": len(results),
"successes": len(successes),
"errors": len(results) - len(successes),
"total_ms": {
"min": totals[0],
"p50": percentile(totals, 50),
"p95": percentile(totals, 95),
"max": totals[-1],
},
}
if ttfts:
data["summary"]["ttft_ms"] = {
"min": ttfts[0],
"p50": percentile(ttfts, 50),
"p95": percentile(ttfts, 95),
"max": ttfts[-1],
}
with open(path, "w") as f:
json.dump(data, f, indent=2)
print(f"Results written to {path}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load test the Chrysopedia chat SSE endpoint.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--url",
default="http://localhost:8096",
help="Base URL of the Chrysopedia API (default: http://localhost:8096)",
)
parser.add_argument(
"--concurrency", "-c",
type=int,
default=10,
help="Number of concurrent chat requests (default: 10)",
)
parser.add_argument(
"--query", "-q",
default="What are common compression techniques?",
help="Chat query to send",
)
parser.add_argument(
"--auth-token",
default=None,
help="Bearer token for authenticated requests (avoids IP rate limit)",
)
parser.add_argument(
"--output", "-o",
default=None,
help="Write results as JSON to this file",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Parse a mock SSE response without making network requests",
)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
if args.dry_run:
print("Dry-run mode: parsing mock SSE response...")
results = asyncio.run(run_dry_run())
else:
print(f"Running load test: {args.concurrency} concurrent requests → {args.url}")
results = asyncio.run(
run_load_test(args.url, args.concurrency, args.query, args.auth_token)
)
print_stats(results)
if args.output:
write_json_output(results, args.output)
if __name__ == "__main__":
main()