test: Created standalone async load test script that fires concurrent c…
- "scripts/load_test_chat.py" GSD-Task: S08/T02
This commit is contained in:
parent
7b048ccbaf
commit
183d852f31
1 changed files with 366 additions and 0 deletions
366
scripts/load_test_chat.py
Normal file
366
scripts/load_test_chat.py
Normal file
|
|
@ -0,0 +1,366 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Load test for Chrysopedia chat SSE endpoint.
|
||||
|
||||
Fires N concurrent chat requests, parses the SSE stream to measure
|
||||
time-to-first-token (TTFT) and total response time, and reports
|
||||
min / p50 / p95 / max latency statistics.
|
||||
|
||||
Requirements:
|
||||
pip install httpx (already a project dependency)
|
||||
|
||||
Rate-limit note:
|
||||
The default anonymous rate limit is 10 requests/hour per IP.
|
||||
Running 10 concurrent requests from one IP will saturate that quota.
|
||||
Use --auth-token to authenticate (per-user limit is higher), or
|
||||
temporarily raise the rate limit in the API config.
|
||||
|
||||
Examples:
|
||||
# Quick smoke test (1 request)
|
||||
python scripts/load_test_chat.py --concurrency 1
|
||||
|
||||
# Full load test with auth token and JSON output
|
||||
python scripts/load_test_chat.py --concurrency 10 \\
|
||||
--auth-token eyJ... --output results.json
|
||||
|
||||
# Dry-run to verify SSE parsing without a live server
|
||||
python scripts/load_test_chat.py --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatResult:
|
||||
"""Metrics from a single chat request."""
|
||||
|
||||
request_id: int = 0
|
||||
ttft_ms: float | None = None
|
||||
total_ms: float = 0.0
|
||||
token_count: int = 0
|
||||
error: str | None = None
|
||||
status_code: int | None = None
|
||||
fallback_used: bool | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_sse_lines(raw_lines: list[str]):
|
||||
"""Yield (event_type, data_str) tuples from raw SSE lines."""
|
||||
current_event = ""
|
||||
data_buf: list[str] = []
|
||||
for line in raw_lines:
|
||||
if line.startswith("event: "):
|
||||
current_event = line[7:].strip()
|
||||
elif line.startswith("data: "):
|
||||
data_buf.append(line[6:])
|
||||
elif line.strip() == "" and (current_event or data_buf):
|
||||
yield current_event, "\n".join(data_buf)
|
||||
current_event = ""
|
||||
data_buf = []
|
||||
# Flush any remaining partial event
|
||||
if current_event or data_buf:
|
||||
yield current_event, "\n".join(data_buf)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single request runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_single_chat(
|
||||
client: Any, # httpx.AsyncClient
|
||||
url: str,
|
||||
query: str,
|
||||
request_id: int,
|
||||
) -> ChatResult:
|
||||
"""POST to the chat endpoint and parse the SSE stream."""
|
||||
result = ChatResult(request_id=request_id)
|
||||
t0 = time.monotonic()
|
||||
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{url}/api/v1/chat",
|
||||
json={"query": query},
|
||||
timeout=60.0,
|
||||
) as resp:
|
||||
result.status_code = resp.status_code
|
||||
if resp.status_code != 200:
|
||||
body = await resp.aread()
|
||||
result.error = f"HTTP {resp.status_code}: {body.decode(errors='replace')[:200]}"
|
||||
result.total_ms = (time.monotonic() - t0) * 1000
|
||||
return result
|
||||
|
||||
raw_lines: list[str] = []
|
||||
async for line in resp.aiter_lines():
|
||||
raw_lines.append(line)
|
||||
|
||||
# Detect first token for TTFT
|
||||
if result.ttft_ms is None and line.startswith("event: token"):
|
||||
result.ttft_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
# Parse collected SSE events
|
||||
for event_type, data_str in parse_sse_lines(raw_lines):
|
||||
if event_type == "token":
|
||||
result.token_count += 1
|
||||
elif event_type == "done":
|
||||
try:
|
||||
done = json.loads(data_str)
|
||||
result.fallback_used = done.get("fallback_used")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif event_type == "error":
|
||||
result.error = data_str
|
||||
|
||||
except Exception as exc:
|
||||
result.error = f"{type(exc).__name__}: {exc}"
|
||||
|
||||
result.total_ms = (time.monotonic() - t0) * 1000
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dry-run: mock SSE stream for offline testing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MOCK_SSE = """\
|
||||
event: sources
|
||||
data: [{"title":"Test","url":"/t/test"}]
|
||||
|
||||
event: token
|
||||
data: Hello
|
||||
|
||||
event: token
|
||||
data: world
|
||||
|
||||
event: token
|
||||
data: !
|
||||
|
||||
event: done
|
||||
data: {"cascade_tier":"global","conversation_id":"test-123","fallback_used":false}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
async def run_dry_run() -> list[ChatResult]:
|
||||
"""Parse a canned SSE response to verify the parsing logic works."""
|
||||
result = ChatResult(request_id=0, status_code=200)
|
||||
t0 = time.monotonic()
|
||||
|
||||
raw_lines = _MOCK_SSE.strip().splitlines()
|
||||
|
||||
for line in raw_lines:
|
||||
if result.ttft_ms is None and line.startswith("event: token"):
|
||||
result.ttft_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
for event_type, data_str in parse_sse_lines(raw_lines):
|
||||
if event_type == "token":
|
||||
result.token_count += 1
|
||||
elif event_type == "done":
|
||||
try:
|
||||
done = json.loads(data_str)
|
||||
result.fallback_used = done.get("fallback_used")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif event_type == "error":
|
||||
result.error = data_str
|
||||
|
||||
result.total_ms = (time.monotonic() - t0) * 1000
|
||||
return [result]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load test orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_load_test(
|
||||
url: str,
|
||||
concurrency: int,
|
||||
query: str,
|
||||
auth_token: str | None = None,
|
||||
) -> list[ChatResult]:
|
||||
"""Fire concurrent chat requests and collect results."""
|
||||
import httpx
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if auth_token:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
tasks = [
|
||||
run_single_chat(client, url, query, i)
|
||||
for i in range(concurrency)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return list(results)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Statistics & reporting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def percentile(values: list[float], p: float) -> float:
|
||||
"""Return the p-th percentile of a sorted list (0–100 scale)."""
|
||||
if not values:
|
||||
return 0.0
|
||||
k = (len(values) - 1) * (p / 100)
|
||||
f = int(k)
|
||||
c = f + 1 if f + 1 < len(values) else f
|
||||
d = k - f
|
||||
return values[f] + d * (values[c] - values[f])
|
||||
|
||||
|
||||
def print_stats(results: list[ChatResult]) -> None:
|
||||
"""Print summary statistics and per-request table."""
|
||||
successes = [r for r in results if r.error is None]
|
||||
errors = [r for r in results if r.error is not None]
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Chat Load Test Results ({len(results)} requests)")
|
||||
print(f"{'='*60}")
|
||||
print(f" Successes: {len(successes)} | Errors: {len(errors)}")
|
||||
|
||||
if successes:
|
||||
totals = sorted(r.total_ms for r in successes)
|
||||
ttfts = sorted(r.ttft_ms for r in successes if r.ttft_ms is not None)
|
||||
tokens = [r.token_count for r in successes]
|
||||
|
||||
print(f"\n Total Response Time (ms):")
|
||||
print(f" min={totals[0]:.0f} p50={percentile(totals, 50):.0f}"
|
||||
f" p95={percentile(totals, 95):.0f} max={totals[-1]:.0f}")
|
||||
|
||||
if ttfts:
|
||||
print(f" Time to First Token (ms):")
|
||||
print(f" min={ttfts[0]:.0f} p50={percentile(ttfts, 50):.0f}"
|
||||
f" p95={percentile(ttfts, 95):.0f} max={ttfts[-1]:.0f}")
|
||||
|
||||
print(f" Tokens per response:")
|
||||
print(f" min={min(tokens)} avg={statistics.mean(tokens):.1f}"
|
||||
f" max={max(tokens)}")
|
||||
|
||||
fallback_count = sum(1 for r in successes if r.fallback_used)
|
||||
if fallback_count:
|
||||
print(f" Fallback used: {fallback_count}/{len(successes)}")
|
||||
|
||||
# Per-request table
|
||||
print(f"\n {'#':>3} {'Status':>6} {'TTFT':>8} {'Total':>8} {'Tokens':>6} Error")
|
||||
print(f" {'-'*3} {'-'*6} {'-'*8} {'-'*8} {'-'*6} {'-'*20}")
|
||||
for r in results:
|
||||
status = str(r.status_code or "---")
|
||||
ttft = f"{r.ttft_ms:.0f}ms" if r.ttft_ms is not None else "---"
|
||||
total = f"{r.total_ms:.0f}ms"
|
||||
err = (r.error or "")[:40]
|
||||
print(f" {r.request_id:>3} {status:>6} {ttft:>8} {total:>8} {r.token_count:>6} {err}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def write_json_output(results: list[ChatResult], path: str) -> None:
|
||||
"""Write results to a JSON file."""
|
||||
data = {
|
||||
"results": [asdict(r) for r in results],
|
||||
"summary": {},
|
||||
}
|
||||
successes = [r for r in results if r.error is None]
|
||||
if successes:
|
||||
totals = sorted(r.total_ms for r in successes)
|
||||
ttfts = sorted(r.ttft_ms for r in successes if r.ttft_ms is not None)
|
||||
data["summary"] = {
|
||||
"total_requests": len(results),
|
||||
"successes": len(successes),
|
||||
"errors": len(results) - len(successes),
|
||||
"total_ms": {
|
||||
"min": totals[0],
|
||||
"p50": percentile(totals, 50),
|
||||
"p95": percentile(totals, 95),
|
||||
"max": totals[-1],
|
||||
},
|
||||
}
|
||||
if ttfts:
|
||||
data["summary"]["ttft_ms"] = {
|
||||
"min": ttfts[0],
|
||||
"p50": percentile(ttfts, 50),
|
||||
"p95": percentile(ttfts, 95),
|
||||
"max": ttfts[-1],
|
||||
}
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
print(f"Results written to {path}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Load test the Chrysopedia chat SSE endpoint.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default="http://localhost:8096",
|
||||
help="Base URL of the Chrysopedia API (default: http://localhost:8096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency", "-c",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of concurrent chat requests (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query", "-q",
|
||||
default="What are common compression techniques?",
|
||||
help="Chat query to send",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auth-token",
|
||||
default=None,
|
||||
help="Bearer token for authenticated requests (avoids IP rate limit)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
default=None,
|
||||
help="Write results as JSON to this file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Parse a mock SSE response without making network requests",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
print("Dry-run mode: parsing mock SSE response...")
|
||||
results = asyncio.run(run_dry_run())
|
||||
else:
|
||||
print(f"Running load test: {args.concurrency} concurrent requests → {args.url}")
|
||||
results = asyncio.run(
|
||||
run_load_test(args.url, args.concurrency, args.query, args.auth_token)
|
||||
)
|
||||
|
||||
print_stats(results)
|
||||
|
||||
if args.output:
|
||||
write_json_output(results, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Reference in a new issue