366 lines
12 KiB
Python
366 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""Load test for Chrysopedia chat SSE endpoint.
|
||
|
||
Fires N concurrent chat requests, parses the SSE stream to measure
|
||
time-to-first-token (TTFT) and total response time, and reports
|
||
min / p50 / p95 / max latency statistics.
|
||
|
||
Requirements:
|
||
pip install httpx (already a project dependency)
|
||
|
||
Rate-limit note:
|
||
The default anonymous rate limit is 10 requests/hour per IP.
|
||
Running 10 concurrent requests from one IP will saturate that quota.
|
||
Use --auth-token to authenticate (per-user limit is higher), or
|
||
temporarily raise the rate limit in the API config.
|
||
|
||
Examples:
|
||
# Quick smoke test (1 request)
|
||
python scripts/load_test_chat.py --concurrency 1
|
||
|
||
# Full load test with auth token and JSON output
|
||
python scripts/load_test_chat.py --concurrency 10 \\
|
||
--auth-token eyJ... --output results.json
|
||
|
||
# Dry-run to verify SSE parsing without a live server
|
||
python scripts/load_test_chat.py --dry-run
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import asyncio
|
||
import json
|
||
import statistics
|
||
import sys
|
||
import time
|
||
from dataclasses import asdict, dataclass, field
|
||
from typing import Any
|
||
|
||
|
||
@dataclass
|
||
class ChatResult:
|
||
"""Metrics from a single chat request."""
|
||
|
||
request_id: int = 0
|
||
ttft_ms: float | None = None
|
||
total_ms: float = 0.0
|
||
token_count: int = 0
|
||
error: str | None = None
|
||
status_code: int | None = None
|
||
fallback_used: bool | None = None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SSE parsing
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def parse_sse_lines(raw_lines: list[str]):
|
||
"""Yield (event_type, data_str) tuples from raw SSE lines."""
|
||
current_event = ""
|
||
data_buf: list[str] = []
|
||
for line in raw_lines:
|
||
if line.startswith("event: "):
|
||
current_event = line[7:].strip()
|
||
elif line.startswith("data: "):
|
||
data_buf.append(line[6:])
|
||
elif line.strip() == "" and (current_event or data_buf):
|
||
yield current_event, "\n".join(data_buf)
|
||
current_event = ""
|
||
data_buf = []
|
||
# Flush any remaining partial event
|
||
if current_event or data_buf:
|
||
yield current_event, "\n".join(data_buf)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Single request runner
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def run_single_chat(
|
||
client: Any, # httpx.AsyncClient
|
||
url: str,
|
||
query: str,
|
||
request_id: int,
|
||
) -> ChatResult:
|
||
"""POST to the chat endpoint and parse the SSE stream."""
|
||
result = ChatResult(request_id=request_id)
|
||
t0 = time.monotonic()
|
||
|
||
try:
|
||
async with client.stream(
|
||
"POST",
|
||
f"{url}/api/v1/chat",
|
||
json={"query": query},
|
||
timeout=60.0,
|
||
) as resp:
|
||
result.status_code = resp.status_code
|
||
if resp.status_code != 200:
|
||
body = await resp.aread()
|
||
result.error = f"HTTP {resp.status_code}: {body.decode(errors='replace')[:200]}"
|
||
result.total_ms = (time.monotonic() - t0) * 1000
|
||
return result
|
||
|
||
raw_lines: list[str] = []
|
||
async for line in resp.aiter_lines():
|
||
raw_lines.append(line)
|
||
|
||
# Detect first token for TTFT
|
||
if result.ttft_ms is None and line.startswith("event: token"):
|
||
result.ttft_ms = (time.monotonic() - t0) * 1000
|
||
|
||
# Parse collected SSE events
|
||
for event_type, data_str in parse_sse_lines(raw_lines):
|
||
if event_type == "token":
|
||
result.token_count += 1
|
||
elif event_type == "done":
|
||
try:
|
||
done = json.loads(data_str)
|
||
result.fallback_used = done.get("fallback_used")
|
||
except json.JSONDecodeError:
|
||
pass
|
||
elif event_type == "error":
|
||
result.error = data_str
|
||
|
||
except Exception as exc:
|
||
result.error = f"{type(exc).__name__}: {exc}"
|
||
|
||
result.total_ms = (time.monotonic() - t0) * 1000
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dry-run: mock SSE stream for offline testing
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_MOCK_SSE = """\
|
||
event: sources
|
||
data: [{"title":"Test","url":"/t/test"}]
|
||
|
||
event: token
|
||
data: Hello
|
||
|
||
event: token
|
||
data: world
|
||
|
||
event: token
|
||
data: !
|
||
|
||
event: done
|
||
data: {"cascade_tier":"global","conversation_id":"test-123","fallback_used":false}
|
||
|
||
"""
|
||
|
||
|
||
async def run_dry_run() -> list[ChatResult]:
|
||
"""Parse a canned SSE response to verify the parsing logic works."""
|
||
result = ChatResult(request_id=0, status_code=200)
|
||
t0 = time.monotonic()
|
||
|
||
raw_lines = _MOCK_SSE.strip().splitlines()
|
||
|
||
for line in raw_lines:
|
||
if result.ttft_ms is None and line.startswith("event: token"):
|
||
result.ttft_ms = (time.monotonic() - t0) * 1000
|
||
|
||
for event_type, data_str in parse_sse_lines(raw_lines):
|
||
if event_type == "token":
|
||
result.token_count += 1
|
||
elif event_type == "done":
|
||
try:
|
||
done = json.loads(data_str)
|
||
result.fallback_used = done.get("fallback_used")
|
||
except json.JSONDecodeError:
|
||
pass
|
||
elif event_type == "error":
|
||
result.error = data_str
|
||
|
||
result.total_ms = (time.monotonic() - t0) * 1000
|
||
return [result]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Load test orchestrator
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def run_load_test(
|
||
url: str,
|
||
concurrency: int,
|
||
query: str,
|
||
auth_token: str | None = None,
|
||
) -> list[ChatResult]:
|
||
"""Fire concurrent chat requests and collect results."""
|
||
import httpx
|
||
|
||
headers: dict[str, str] = {}
|
||
if auth_token:
|
||
headers["Authorization"] = f"Bearer {auth_token}"
|
||
|
||
async with httpx.AsyncClient(headers=headers) as client:
|
||
tasks = [
|
||
run_single_chat(client, url, query, i)
|
||
for i in range(concurrency)
|
||
]
|
||
results = await asyncio.gather(*tasks)
|
||
|
||
return list(results)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Statistics & reporting
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def percentile(values: list[float], p: float) -> float:
|
||
"""Return the p-th percentile of a sorted list (0–100 scale)."""
|
||
if not values:
|
||
return 0.0
|
||
k = (len(values) - 1) * (p / 100)
|
||
f = int(k)
|
||
c = f + 1 if f + 1 < len(values) else f
|
||
d = k - f
|
||
return values[f] + d * (values[c] - values[f])
|
||
|
||
|
||
def print_stats(results: list[ChatResult]) -> None:
|
||
"""Print summary statistics and per-request table."""
|
||
successes = [r for r in results if r.error is None]
|
||
errors = [r for r in results if r.error is not None]
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f" Chat Load Test Results ({len(results)} requests)")
|
||
print(f"{'='*60}")
|
||
print(f" Successes: {len(successes)} | Errors: {len(errors)}")
|
||
|
||
if successes:
|
||
totals = sorted(r.total_ms for r in successes)
|
||
ttfts = sorted(r.ttft_ms for r in successes if r.ttft_ms is not None)
|
||
tokens = [r.token_count for r in successes]
|
||
|
||
print(f"\n Total Response Time (ms):")
|
||
print(f" min={totals[0]:.0f} p50={percentile(totals, 50):.0f}"
|
||
f" p95={percentile(totals, 95):.0f} max={totals[-1]:.0f}")
|
||
|
||
if ttfts:
|
||
print(f" Time to First Token (ms):")
|
||
print(f" min={ttfts[0]:.0f} p50={percentile(ttfts, 50):.0f}"
|
||
f" p95={percentile(ttfts, 95):.0f} max={ttfts[-1]:.0f}")
|
||
|
||
print(f" Tokens per response:")
|
||
print(f" min={min(tokens)} avg={statistics.mean(tokens):.1f}"
|
||
f" max={max(tokens)}")
|
||
|
||
fallback_count = sum(1 for r in successes if r.fallback_used)
|
||
if fallback_count:
|
||
print(f" Fallback used: {fallback_count}/{len(successes)}")
|
||
|
||
# Per-request table
|
||
print(f"\n {'#':>3} {'Status':>6} {'TTFT':>8} {'Total':>8} {'Tokens':>6} Error")
|
||
print(f" {'-'*3} {'-'*6} {'-'*8} {'-'*8} {'-'*6} {'-'*20}")
|
||
for r in results:
|
||
status = str(r.status_code or "---")
|
||
ttft = f"{r.ttft_ms:.0f}ms" if r.ttft_ms is not None else "---"
|
||
total = f"{r.total_ms:.0f}ms"
|
||
err = (r.error or "")[:40]
|
||
print(f" {r.request_id:>3} {status:>6} {ttft:>8} {total:>8} {r.token_count:>6} {err}")
|
||
|
||
print(f"{'='*60}\n")
|
||
|
||
|
||
def write_json_output(results: list[ChatResult], path: str) -> None:
|
||
"""Write results to a JSON file."""
|
||
data = {
|
||
"results": [asdict(r) for r in results],
|
||
"summary": {},
|
||
}
|
||
successes = [r for r in results if r.error is None]
|
||
if successes:
|
||
totals = sorted(r.total_ms for r in successes)
|
||
ttfts = sorted(r.ttft_ms for r in successes if r.ttft_ms is not None)
|
||
data["summary"] = {
|
||
"total_requests": len(results),
|
||
"successes": len(successes),
|
||
"errors": len(results) - len(successes),
|
||
"total_ms": {
|
||
"min": totals[0],
|
||
"p50": percentile(totals, 50),
|
||
"p95": percentile(totals, 95),
|
||
"max": totals[-1],
|
||
},
|
||
}
|
||
if ttfts:
|
||
data["summary"]["ttft_ms"] = {
|
||
"min": ttfts[0],
|
||
"p50": percentile(ttfts, 50),
|
||
"p95": percentile(ttfts, 95),
|
||
"max": ttfts[-1],
|
||
}
|
||
with open(path, "w") as f:
|
||
json.dump(data, f, indent=2)
|
||
print(f"Results written to {path}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CLI
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def build_parser() -> argparse.ArgumentParser:
|
||
parser = argparse.ArgumentParser(
|
||
description="Load test the Chrysopedia chat SSE endpoint.",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog=__doc__,
|
||
)
|
||
parser.add_argument(
|
||
"--url",
|
||
default="http://localhost:8096",
|
||
help="Base URL of the Chrysopedia API (default: http://localhost:8096)",
|
||
)
|
||
parser.add_argument(
|
||
"--concurrency", "-c",
|
||
type=int,
|
||
default=10,
|
||
help="Number of concurrent chat requests (default: 10)",
|
||
)
|
||
parser.add_argument(
|
||
"--query", "-q",
|
||
default="What are common compression techniques?",
|
||
help="Chat query to send",
|
||
)
|
||
parser.add_argument(
|
||
"--auth-token",
|
||
default=None,
|
||
help="Bearer token for authenticated requests (avoids IP rate limit)",
|
||
)
|
||
parser.add_argument(
|
||
"--output", "-o",
|
||
default=None,
|
||
help="Write results as JSON to this file",
|
||
)
|
||
parser.add_argument(
|
||
"--dry-run",
|
||
action="store_true",
|
||
help="Parse a mock SSE response without making network requests",
|
||
)
|
||
return parser
|
||
|
||
|
||
def main() -> None:
|
||
parser = build_parser()
|
||
args = parser.parse_args()
|
||
|
||
if args.dry_run:
|
||
print("Dry-run mode: parsing mock SSE response...")
|
||
results = asyncio.run(run_dry_run())
|
||
else:
|
||
print(f"Running load test: {args.concurrency} concurrent requests → {args.url}")
|
||
results = asyncio.run(
|
||
run_load_test(args.url, args.concurrency, args.query, args.auth_token)
|
||
)
|
||
|
||
print_stats(results)
|
||
|
||
if args.output:
|
||
write_json_output(results, args.output)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|