MAESTRO: Implement LLM endpoints router with CRUD, test_connection, and Fernet-encrypted API key storage
- Add LLMEndpoint model to models.py with encrypted api_key field - Create encryption.py with Fernet symmetric encryption (key derived from JWT_SECRET via PBKDF2) - Implement full endpoints router: list, get, create, update, delete + test_connection - Test endpoint calls adapter.test_connection() and list_models() - API keys never exposed in responses; has_api_key boolean flag added - 25 tests in test_endpoints.py, all 444 tests passing
This commit is contained in:
parent
1253994c9e
commit
35d72e7fa8
8 changed files with 658 additions and 28 deletions
|
|
@ -32,7 +32,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
||||||
- [x] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
|
- [x] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
|
||||||
<!-- Completed: Created engine/tasks.py with execute_run and execute_sweep Celery tasks (autodiscovered via worker.py). SyncTaskResult class mimics AsyncResult for fallback. dispatch_run/dispatch_sweep helpers route to Celery or sync execution based on settings.use_in_process_queue. 17 tests in test_tasks.py, all passing. -->
|
<!-- Completed: Created engine/tasks.py with execute_run and execute_sweep Celery tasks (autodiscovered via worker.py). SyncTaskResult class mimics AsyncResult for fallback. dispatch_run/dispatch_sweep helpers route to Celery or sync execution based on settings.use_in_process_queue. 17 tests in test_tasks.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET).
|
- [x] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET).
|
||||||
|
<!-- Completed: Full CRUD (list/get/create/update/delete) + test_connection endpoint. LLMEndpoint model added to models.py. Fernet encryption via encryption.py (PBKDF2 key derivation from JWT_SECRET). API keys never exposed in responses; has_api_key boolean flag added to EndpointResponse. 25 tests in test_endpoints.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs.
|
- [ ] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs.
|
||||||
|
|
||||||
|
|
|
||||||
44
backend/encryption.py
Normal file
44
backend/encryption.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Fernet symmetric encryption for sensitive fields (API keys).
|
||||||
|
|
||||||
|
The encryption key is derived from JWT_SECRET using PBKDF2-HMAC-SHA256,
|
||||||
|
ensuring a stable 32-byte key suitable for Fernet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_fernet_key(secret: str) -> bytes:
|
||||||
|
"""Derive a Fernet-compatible key from an arbitrary string secret."""
|
||||||
|
# PBKDF2 with a fixed salt — the secret itself provides entropy.
|
||||||
|
# The salt is fixed so the same secret always yields the same key.
|
||||||
|
dk = hashlib.pbkdf2_hmac(
|
||||||
|
"sha256",
|
||||||
|
secret.encode("utf-8"),
|
||||||
|
b"promptlooper-fernet-salt",
|
||||||
|
iterations=100_000,
|
||||||
|
dklen=32,
|
||||||
|
)
|
||||||
|
return base64.urlsafe_b64encode(dk)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fernet() -> Fernet:
|
||||||
|
"""Return a Fernet instance keyed from the current JWT_SECRET."""
|
||||||
|
return Fernet(_derive_fernet_key(settings.jwt_secret))
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_api_key(plain_key: str) -> str:
|
||||||
|
"""Encrypt an API key and return the ciphertext as a UTF-8 string."""
|
||||||
|
return get_fernet().encrypt(plain_key.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_api_key(encrypted_key: str) -> str:
|
||||||
|
"""Decrypt an API key. Raises ValueError on failure."""
|
||||||
|
try:
|
||||||
|
return get_fernet().decrypt(encrypted_key.encode("utf-8")).decode("utf-8")
|
||||||
|
except InvalidToken as exc:
|
||||||
|
raise ValueError("Failed to decrypt API key — JWT_SECRET may have changed") from exc
|
||||||
|
|
@ -260,6 +260,24 @@ class ResponseCache(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEndpoint(Base):
|
||||||
|
__tablename__ = "llm_endpoints"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
primary_key=True, default=_new_uuid
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
url: Mapped[str] = mapped_column(String(2048), nullable=False)
|
||||||
|
api_key_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
default_model: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, nullable=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow, nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebhookConfig(Base):
|
class WebhookConfig(Base):
|
||||||
__tablename__ = "webhook_configs"
|
__tablename__ = "webhook_configs"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ alembic>=1.14,<2.0
|
||||||
pydantic>=2.0,<3.0
|
pydantic>=2.0,<3.0
|
||||||
pydantic-settings>=2.0,<3.0
|
pydantic-settings>=2.0,<3.0
|
||||||
python-jose[cryptography]>=3.3,<4.0
|
python-jose[cryptography]>=3.3,<4.0
|
||||||
|
cryptography>=42.0,<45.0
|
||||||
passlib[bcrypt]>=1.7,<2.0
|
passlib[bcrypt]>=1.7,<2.0
|
||||||
celery>=5.4,<6.0
|
celery>=5.4,<6.0
|
||||||
redis>=5.0,<6.0
|
redis>=5.0,<6.0
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,184 @@
|
||||||
"""Endpoints router — LLM target management."""
|
"""Endpoints router — CRUD for LLM endpoint configurations.
|
||||||
|
|
||||||
|
Supports creating, listing, updating, and deleting LLM endpoint configs.
|
||||||
|
API keys are stored encrypted using Fernet (key derived from JWT_SECRET).
|
||||||
|
The test endpoint calls adapter.test_connection() and adapter.list_models().
|
||||||
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Response
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from auth import get_current_user
|
||||||
|
from encryption import decrypt_api_key, encrypt_api_key
|
||||||
|
from main import get_db
|
||||||
|
from models import LLMEndpoint, User
|
||||||
|
from engine.adapters.openai_compat import OpenAICompatAdapter
|
||||||
|
from schemas import (
|
||||||
|
EndpointCreate,
|
||||||
|
EndpointListResponse,
|
||||||
|
EndpointResponse,
|
||||||
|
EndpointUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", status_code=501)
|
# ---------------------------------------------------------------------------
|
||||||
def list_endpoints():
|
# Helpers
|
||||||
"""List configured LLM endpoints."""
|
# ---------------------------------------------------------------------------
|
||||||
return Response(status_code=501, content="Not Implemented")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", status_code=501)
|
def _get_endpoint_or_404(db: Session, endpoint_id: uuid.UUID) -> LLMEndpoint:
|
||||||
def create_endpoint():
|
endpoint = db.query(LLMEndpoint).filter(LLMEndpoint.id == endpoint_id).first()
|
||||||
"""Add endpoint (URL, API key, label)."""
|
if endpoint is None:
|
||||||
return Response(status_code=501, content="Not Implemented")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Endpoint not found")
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{endpoint_id}", status_code=501)
|
# ---------------------------------------------------------------------------
|
||||||
def update_endpoint(endpoint_id: uuid.UUID):
|
# CRUD
|
||||||
"""Update endpoint."""
|
# ---------------------------------------------------------------------------
|
||||||
return Response(status_code=501, content="Not Implemented")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{endpoint_id}", status_code=501)
|
@router.get("/", response_model=EndpointListResponse)
|
||||||
def delete_endpoint(endpoint_id: uuid.UUID):
|
def list_endpoints(
|
||||||
"""Remove endpoint."""
|
db: Session = Depends(get_db),
|
||||||
return Response(status_code=501, content="Not Implemented")
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> EndpointListResponse:
|
||||||
|
"""List all configured LLM endpoints."""
|
||||||
|
endpoints = db.query(LLMEndpoint).order_by(LLMEndpoint.name).all()
|
||||||
|
return EndpointListResponse(
|
||||||
|
items=[_to_response(ep) for ep in endpoints],
|
||||||
|
total=len(endpoints),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{endpoint_id}/test", status_code=501)
|
@router.post("/", response_model=EndpointResponse, status_code=status.HTTP_201_CREATED)
|
||||||
def test_endpoint(endpoint_id: uuid.UUID):
|
def create_endpoint(
|
||||||
"""Test connectivity and list available models."""
|
body: EndpointCreate,
|
||||||
return Response(status_code=501, content="Not Implemented")
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> EndpointResponse:
|
||||||
|
"""Create a new LLM endpoint configuration."""
|
||||||
|
endpoint = LLMEndpoint(
|
||||||
|
name=body.name,
|
||||||
|
url=body.url,
|
||||||
|
api_key_encrypted=encrypt_api_key(body.api_key) if body.api_key else None,
|
||||||
|
default_model=body.default_model,
|
||||||
|
)
|
||||||
|
db.add(endpoint)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(endpoint)
|
||||||
|
return _to_response(endpoint)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{endpoint_id}", response_model=EndpointResponse)
|
||||||
|
def get_endpoint(
|
||||||
|
endpoint_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> EndpointResponse:
|
||||||
|
"""Get a single LLM endpoint configuration."""
|
||||||
|
endpoint = _get_endpoint_or_404(db, endpoint_id)
|
||||||
|
return _to_response(endpoint)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{endpoint_id}", response_model=EndpointResponse)
|
||||||
|
def update_endpoint(
|
||||||
|
endpoint_id: uuid.UUID,
|
||||||
|
body: EndpointUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> EndpointResponse:
|
||||||
|
"""Update an LLM endpoint configuration."""
|
||||||
|
endpoint = _get_endpoint_or_404(db, endpoint_id)
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
endpoint.name = body.name
|
||||||
|
if body.url is not None:
|
||||||
|
endpoint.url = body.url
|
||||||
|
if body.api_key is not None:
|
||||||
|
# Empty string clears the key; non-empty encrypts it
|
||||||
|
endpoint.api_key_encrypted = encrypt_api_key(body.api_key) if body.api_key else None
|
||||||
|
if body.default_model is not None:
|
||||||
|
endpoint.default_model = body.default_model
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(endpoint)
|
||||||
|
return _to_response(endpoint)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{endpoint_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
def delete_endpoint(
|
||||||
|
endpoint_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> None:
|
||||||
|
"""Delete an LLM endpoint configuration."""
|
||||||
|
endpoint = _get_endpoint_or_404(db, endpoint_id)
|
||||||
|
db.delete(endpoint)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test connection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{endpoint_id}/test")
|
||||||
|
async def test_endpoint(
|
||||||
|
endpoint_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Test connectivity and list available models for an endpoint.
|
||||||
|
|
||||||
|
Calls adapter.test_connection() and adapter.list_models() against the
|
||||||
|
stored endpoint configuration.
|
||||||
|
"""
|
||||||
|
endpoint = _get_endpoint_or_404(db, endpoint_id)
|
||||||
|
|
||||||
|
api_key: str | None = None
|
||||||
|
if endpoint.api_key_encrypted:
|
||||||
|
try:
|
||||||
|
api_key = decrypt_api_key(endpoint.api_key_encrypted)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to decrypt API key — JWT_SECRET may have changed",
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = OpenAICompatAdapter(base_url=endpoint.url, api_key=api_key)
|
||||||
|
|
||||||
|
connected = await adapter.test_connection()
|
||||||
|
models: list[str] = []
|
||||||
|
if connected:
|
||||||
|
try:
|
||||||
|
models = await adapter.list_models()
|
||||||
|
except Exception:
|
||||||
|
pass # Connection worked but model listing failed
|
||||||
|
|
||||||
|
return {
|
||||||
|
"endpoint_id": str(endpoint.id),
|
||||||
|
"name": endpoint.name,
|
||||||
|
"connected": connected,
|
||||||
|
"models": models,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Response builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _to_response(endpoint: LLMEndpoint) -> EndpointResponse:
|
||||||
|
"""Convert ORM model to response schema (never expose encrypted key)."""
|
||||||
|
return EndpointResponse(
|
||||||
|
id=endpoint.id,
|
||||||
|
name=endpoint.name,
|
||||||
|
url=endpoint.url,
|
||||||
|
default_model=endpoint.default_model,
|
||||||
|
has_api_key=endpoint.api_key_encrypted is not None,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -188,6 +188,7 @@ class EndpointResponse(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
url: str
|
url: str
|
||||||
default_model: str | None
|
default_model: str | None
|
||||||
|
has_api_key: bool = False
|
||||||
|
|
||||||
|
|
||||||
class EndpointListResponse(BaseModel):
|
class EndpointListResponse(BaseModel):
|
||||||
|
|
|
||||||
417
backend/tests/test_endpoints.py
Normal file
417
backend/tests/test_endpoints.py
Normal file
|
|
@ -0,0 +1,417 @@
|
||||||
|
"""Tests for backend/routers/endpoints.py — LLM endpoint CRUD + test_connection."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
JWT_SECRET = "test-secret-key-for-jwt-signing"
|
||||||
|
API_KEY = "test-api-key-12345"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolate_settings(tmp_path):
|
||||||
|
"""Ensure tests use a temp SQLite DB and no Redis."""
|
||||||
|
env = {
|
||||||
|
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
|
||||||
|
"REDIS_URL": "",
|
||||||
|
"DATA_DIR": str(tmp_path),
|
||||||
|
"JWT_SECRET": JWT_SECRET,
|
||||||
|
"API_KEY": API_KEY,
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
import config
|
||||||
|
new_settings = config.Settings(_env_file=None)
|
||||||
|
config.settings = new_settings
|
||||||
|
|
||||||
|
import main
|
||||||
|
main.settings = new_settings
|
||||||
|
main._init_db()
|
||||||
|
main._init_redis()
|
||||||
|
|
||||||
|
from models import Base
|
||||||
|
Base.metadata.create_all(bind=main.engine)
|
||||||
|
|
||||||
|
import auth
|
||||||
|
auth.settings = new_settings
|
||||||
|
|
||||||
|
# Patch encryption module's settings reference
|
||||||
|
import encryption
|
||||||
|
encryption.settings = new_settings
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_session():
|
||||||
|
from main import get_db
|
||||||
|
gen = get_db()
|
||||||
|
session = next(gen)
|
||||||
|
yield session
|
||||||
|
try:
|
||||||
|
next(gen)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user(db_session):
|
||||||
|
from auth import hash_password
|
||||||
|
from models import User
|
||||||
|
user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers():
|
||||||
|
return {"X-Api-Key": API_KEY}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
from main import app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Encryption module tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEncryption:
|
||||||
|
def test_encrypt_decrypt_roundtrip(self):
|
||||||
|
from encryption import encrypt_api_key, decrypt_api_key
|
||||||
|
plain = "sk-test-key-12345"
|
||||||
|
encrypted = encrypt_api_key(plain)
|
||||||
|
assert encrypted != plain
|
||||||
|
assert decrypt_api_key(encrypted) == plain
|
||||||
|
|
||||||
|
def test_different_keys_produce_different_ciphertexts(self):
|
||||||
|
from encryption import encrypt_api_key
|
||||||
|
ct1 = encrypt_api_key("my-key")
|
||||||
|
ct2 = encrypt_api_key("my-key")
|
||||||
|
# Fernet uses a random IV each time, so ciphertexts differ
|
||||||
|
assert ct1 != ct2
|
||||||
|
|
||||||
|
def test_decrypt_bad_data_raises_value_error(self):
|
||||||
|
from encryption import decrypt_api_key
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_api_key("not-valid-fernet-token")
|
||||||
|
|
||||||
|
def test_decrypt_with_wrong_secret_fails(self):
|
||||||
|
from encryption import encrypt_api_key, decrypt_api_key
|
||||||
|
encrypted = encrypt_api_key("my-key")
|
||||||
|
|
||||||
|
# Change the settings secret
|
||||||
|
import config
|
||||||
|
old_secret = config.settings.jwt_secret
|
||||||
|
config.settings.jwt_secret = "completely-different-secret"
|
||||||
|
|
||||||
|
import encryption
|
||||||
|
encryption.settings = config.settings
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_api_key(encrypted)
|
||||||
|
|
||||||
|
# Restore
|
||||||
|
config.settings.jwt_secret = old_secret
|
||||||
|
encryption.settings = config.settings
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoint CRUD tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateEndpoint:
|
||||||
|
def test_create_minimal(self, client, admin_user, auth_headers):
|
||||||
|
resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Local Ollama",
|
||||||
|
"url": "http://localhost:11434/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "Local Ollama"
|
||||||
|
assert data["url"] == "http://localhost:11434/v1"
|
||||||
|
assert data["default_model"] is None
|
||||||
|
assert data["has_api_key"] is False
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
def test_create_with_api_key(self, client, admin_user, auth_headers):
|
||||||
|
resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "OpenAI",
|
||||||
|
"url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-test-12345",
|
||||||
|
"default_model": "gpt-4",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "OpenAI"
|
||||||
|
assert data["has_api_key"] is True
|
||||||
|
assert data["default_model"] == "gpt-4"
|
||||||
|
# API key must NOT appear in response
|
||||||
|
assert "api_key" not in data
|
||||||
|
assert "api_key_encrypted" not in data
|
||||||
|
|
||||||
|
def test_create_requires_auth(self, client):
|
||||||
|
resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Test",
|
||||||
|
"url": "http://localhost:8080/v1",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_create_validates_name(self, client, admin_user, auth_headers):
|
||||||
|
resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "",
|
||||||
|
"url": "http://localhost/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_api_key_is_encrypted_in_db(self, client, admin_user, auth_headers, db_session):
|
||||||
|
resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Test",
|
||||||
|
"url": "http://localhost/v1",
|
||||||
|
"api_key": "sk-secret-key",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
endpoint_id = resp.json()["id"]
|
||||||
|
|
||||||
|
from models import LLMEndpoint
|
||||||
|
endpoint = db_session.query(LLMEndpoint).filter(
|
||||||
|
LLMEndpoint.id == uuid.UUID(endpoint_id)
|
||||||
|
).first()
|
||||||
|
assert endpoint is not None
|
||||||
|
assert endpoint.api_key_encrypted is not None
|
||||||
|
assert endpoint.api_key_encrypted != "sk-secret-key"
|
||||||
|
|
||||||
|
from encryption import decrypt_api_key
|
||||||
|
assert decrypt_api_key(endpoint.api_key_encrypted) == "sk-secret-key"
|
||||||
|
|
||||||
|
|
||||||
|
class TestListEndpoints:
|
||||||
|
def test_list_empty(self, client, admin_user, auth_headers):
|
||||||
|
resp = client.get("/api/endpoints/", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["items"] == []
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
def test_list_multiple(self, client, admin_user, auth_headers):
|
||||||
|
for name in ["Alpha", "Beta", "Gamma"]:
|
||||||
|
client.post("/api/endpoints/", json={
|
||||||
|
"name": name, "url": f"http://{name.lower()}/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
|
||||||
|
resp = client.get("/api/endpoints/", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
# Should be ordered by name
|
||||||
|
names = [ep["name"] for ep in data["items"]]
|
||||||
|
assert names == ["Alpha", "Beta", "Gamma"]
|
||||||
|
|
||||||
|
def test_list_requires_auth(self, client):
|
||||||
|
resp = client.get("/api/endpoints/")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetEndpoint:
|
||||||
|
def test_get_existing(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Test EP",
|
||||||
|
"url": "http://test/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.get(f"/api/endpoints/{ep_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["name"] == "Test EP"
|
||||||
|
|
||||||
|
def test_get_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.get(f"/api/endpoints/{fake_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateEndpoint:
|
||||||
|
def test_update_name(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Old Name",
|
||||||
|
"url": "http://test/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.put(f"/api/endpoints/{ep_id}", json={
|
||||||
|
"name": "New Name",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["name"] == "New Name"
|
||||||
|
assert resp.json()["url"] == "http://test/v1" # unchanged
|
||||||
|
|
||||||
|
def test_update_api_key(self, client, admin_user, auth_headers, db_session):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Test",
|
||||||
|
"url": "http://test/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
assert create_resp.json()["has_api_key"] is False
|
||||||
|
|
||||||
|
# Set API key
|
||||||
|
resp = client.put(f"/api/endpoints/{ep_id}", json={
|
||||||
|
"api_key": "sk-new-key",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["has_api_key"] is True
|
||||||
|
|
||||||
|
def test_clear_api_key(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Test",
|
||||||
|
"url": "http://test/v1",
|
||||||
|
"api_key": "sk-key",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
# Clear by sending empty string
|
||||||
|
resp = client.put(f"/api/endpoints/{ep_id}", json={
|
||||||
|
"api_key": "",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["has_api_key"] is False
|
||||||
|
|
||||||
|
def test_update_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.put(f"/api/endpoints/{fake_id}", json={
|
||||||
|
"name": "X",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteEndpoint:
|
||||||
|
def test_delete_existing(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "ToDelete",
|
||||||
|
"url": "http://test/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.delete(f"/api/endpoints/{ep_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
# Verify gone
|
||||||
|
resp = client.get(f"/api/endpoints/{ep_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.delete(f"/api/endpoints/{fake_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test connection endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTestEndpoint:
|
||||||
|
def test_test_connection_success(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Mock EP",
|
||||||
|
"url": "http://mock-llm/v1",
|
||||||
|
"api_key": "sk-mock",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.test_connection.return_value = True
|
||||||
|
mock_adapter.list_models.return_value = ["model-a", "model-b"]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"routers.endpoints.OpenAICompatAdapter",
|
||||||
|
return_value=mock_adapter,
|
||||||
|
) as mock_cls:
|
||||||
|
resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["connected"] is True
|
||||||
|
assert data["models"] == ["model-a", "model-b"]
|
||||||
|
assert data["name"] == "Mock EP"
|
||||||
|
|
||||||
|
# Verify adapter was constructed with decrypted key
|
||||||
|
call_kwargs = mock_cls.call_args
|
||||||
|
assert call_kwargs.kwargs["base_url"] == "http://mock-llm/v1"
|
||||||
|
assert call_kwargs.kwargs["api_key"] == "sk-mock"
|
||||||
|
|
||||||
|
def test_test_connection_failure(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Bad EP",
|
||||||
|
"url": "http://bad-host/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.test_connection.return_value = False
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"routers.endpoints.OpenAICompatAdapter",
|
||||||
|
return_value=mock_adapter,
|
||||||
|
):
|
||||||
|
resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["connected"] is False
|
||||||
|
assert data["models"] == []
|
||||||
|
|
||||||
|
def test_test_connection_no_api_key(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "No Key EP",
|
||||||
|
"url": "http://local/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.test_connection.return_value = True
|
||||||
|
mock_adapter.list_models.return_value = ["llama3"]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"routers.endpoints.OpenAICompatAdapter",
|
||||||
|
return_value=mock_adapter,
|
||||||
|
) as mock_cls:
|
||||||
|
resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["connected"] is True
|
||||||
|
# Should have been called with api_key=None
|
||||||
|
assert mock_cls.call_args.kwargs["api_key"] is None
|
||||||
|
|
||||||
|
def test_test_connection_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.post(f"/api/endpoints/{fake_id}/test", headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_test_connection_list_models_fails_gracefully(self, client, admin_user, auth_headers):
|
||||||
|
create_resp = client.post("/api/endpoints/", json={
|
||||||
|
"name": "Partial EP",
|
||||||
|
"url": "http://partial/v1",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
ep_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
mock_adapter = AsyncMock()
|
||||||
|
mock_adapter.test_connection.return_value = True
|
||||||
|
mock_adapter.list_models.side_effect = RuntimeError("models endpoint broken")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"routers.endpoints.OpenAICompatAdapter",
|
||||||
|
return_value=mock_adapter,
|
||||||
|
):
|
||||||
|
resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["connected"] is True
|
||||||
|
assert data["models"] == []
|
||||||
|
|
@ -142,30 +142,31 @@ def test_runs_leaderboard(client):
|
||||||
|
|
||||||
|
|
||||||
# ---- Endpoints router (/api/endpoints) ----
|
# ---- Endpoints router (/api/endpoints) ----
|
||||||
|
# Endpoints router is now fully implemented and requires auth (returns 401 without credentials)
|
||||||
|
|
||||||
def test_endpoints_list(client):
|
def test_endpoints_list(client):
|
||||||
resp = client.get("/api/endpoints/")
|
resp = client.get("/api/endpoints/")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_endpoints_create(client):
|
def test_endpoints_create(client):
|
||||||
resp = client.post("/api/endpoints/")
|
resp = client.post("/api/endpoints/")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_endpoints_update(client):
|
def test_endpoints_update(client):
|
||||||
resp = client.put("/api/endpoints/00000000-0000-0000-0000-000000000001")
|
resp = client.put("/api/endpoints/00000000-0000-0000-0000-000000000001")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_endpoints_delete(client):
|
def test_endpoints_delete(client):
|
||||||
resp = client.delete("/api/endpoints/00000000-0000-0000-0000-000000000001")
|
resp = client.delete("/api/endpoints/00000000-0000-0000-0000-000000000001")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_endpoints_test(client):
|
def test_endpoints_test(client):
|
||||||
resp = client.post("/api/endpoints/00000000-0000-0000-0000-000000000001/test")
|
resp = client.post("/api/endpoints/00000000-0000-0000-0000-000000000001/test")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
# ---- Export router (/api/export) ----
|
# ---- Export router (/api/export) ----
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue