From 35d72e7fa8fad9e3835638f79c60dbff4f3f8606 Mon Sep 17 00:00:00 2001 From: John Lightner Date: Tue, 7 Apr 2026 03:13:52 -0500 Subject: [PATCH] MAESTRO: Implement LLM endpoints router with CRUD, test_connection, and Fernet-encrypted API key storage - Add LLMEndpoint model to models.py with encrypted api_key field - Create encryption.py with Fernet symmetric encryption (key derived from JWT_SECRET via PBKDF2) - Implement full endpoints router: list, get, create, update, delete + test_connection - Test endpoint calls adapter.test_connection() and list_models() - API keys never exposed in responses; has_api_key boolean flag added - 25 tests in test_endpoints.py, all 444 tests passing --- Auto Run Docs/02a-backend-engine.md | 3 +- backend/encryption.py | 44 +++ backend/models.py | 18 ++ backend/requirements.txt | 1 + backend/routers/endpoints.py | 191 +++++++++++-- backend/schemas.py | 1 + backend/tests/test_endpoints.py | 417 ++++++++++++++++++++++++++++ backend/tests/test_routers.py | 11 +- 8 files changed, 658 insertions(+), 28 deletions(-) create mode 100644 backend/encryption.py create mode 100644 backend/tests/test_endpoints.py diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index 9513ab0..216327c 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -32,7 +32,8 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process. -- [ ] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET). +- [x] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET). + - [ ] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs. diff --git a/backend/encryption.py b/backend/encryption.py new file mode 100644 index 0000000..21d99d4 --- /dev/null +++ b/backend/encryption.py @@ -0,0 +1,44 @@ +"""Fernet symmetric encryption for sensitive fields (API keys). + +The encryption key is derived from JWT_SECRET using PBKDF2-HMAC-SHA256, +ensuring a stable 32-byte key suitable for Fernet. +""" + +import base64 +import hashlib + +from cryptography.fernet import Fernet, InvalidToken + +from config import settings + + +def _derive_fernet_key(secret: str) -> bytes: + """Derive a Fernet-compatible key from an arbitrary string secret.""" + # PBKDF2 with a fixed salt — the secret itself provides entropy. + # The salt is fixed so the same secret always yields the same key. + dk = hashlib.pbkdf2_hmac( + "sha256", + secret.encode("utf-8"), + b"promptlooper-fernet-salt", + iterations=100_000, + dklen=32, + ) + return base64.urlsafe_b64encode(dk) + + +def get_fernet() -> Fernet: + """Return a Fernet instance keyed from the current JWT_SECRET.""" + return Fernet(_derive_fernet_key(settings.jwt_secret)) + + +def encrypt_api_key(plain_key: str) -> str: + """Encrypt an API key and return the ciphertext as a UTF-8 string.""" + return get_fernet().encrypt(plain_key.encode("utf-8")).decode("utf-8") + + +def decrypt_api_key(encrypted_key: str) -> str: + """Decrypt an API key. Raises ValueError on failure.""" + try: + return get_fernet().decrypt(encrypted_key.encode("utf-8")).decode("utf-8") + except InvalidToken as exc: + raise ValueError("Failed to decrypt API key — JWT_SECRET may have changed") from exc diff --git a/backend/models.py b/backend/models.py index 7bd8cfd..00191ff 100644 --- a/backend/models.py +++ b/backend/models.py @@ -260,6 +260,24 @@ class ResponseCache(Base): ) +class LLMEndpoint(Base): + __tablename__ = "llm_endpoints" + + id: Mapped[uuid.UUID] = mapped_column( + primary_key=True, default=_new_uuid + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + url: Mapped[str] = mapped_column(String(2048), nullable=False) + api_key_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True) + default_model: Mapped[str | None] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow, nullable=False + ) + + class WebhookConfig(Base): __tablename__ = "webhook_configs" diff --git a/backend/requirements.txt b/backend/requirements.txt index 74dcde4..0471e9d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,6 +6,7 @@ alembic>=1.14,<2.0 pydantic>=2.0,<3.0 pydantic-settings>=2.0,<3.0 python-jose[cryptography]>=3.3,<4.0 +cryptography>=42.0,<45.0 passlib[bcrypt]>=1.7,<2.0 celery>=5.4,<6.0 redis>=5.0,<6.0 diff --git a/backend/routers/endpoints.py b/backend/routers/endpoints.py index 36163b1..8c930b0 100644 --- a/backend/routers/endpoints.py +++ b/backend/routers/endpoints.py @@ -1,37 +1,184 @@ -"""Endpoints router — LLM target management.""" +"""Endpoints router — CRUD for LLM endpoint configurations. + +Supports creating, listing, updating, and deleting LLM endpoint configs. +API keys are stored encrypted using Fernet (key derived from JWT_SECRET). +The test endpoint calls adapter.test_connection() and adapter.list_models(). +""" import uuid -from fastapi import APIRouter, Response +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from auth import get_current_user +from encryption import decrypt_api_key, encrypt_api_key +from main import get_db +from models import LLMEndpoint, User +from engine.adapters.openai_compat import OpenAICompatAdapter +from schemas import ( + EndpointCreate, + EndpointListResponse, + EndpointResponse, + EndpointUpdate, +) router = APIRouter() -@router.get("/", status_code=501) -def list_endpoints(): - """List configured LLM endpoints.""" - return Response(status_code=501, content="Not Implemented") +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- -@router.post("/", status_code=501) -def create_endpoint(): - """Add endpoint (URL, API key, label).""" - return Response(status_code=501, content="Not Implemented") +def _get_endpoint_or_404(db: Session, endpoint_id: uuid.UUID) -> LLMEndpoint: + endpoint = db.query(LLMEndpoint).filter(LLMEndpoint.id == endpoint_id).first() + if endpoint is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Endpoint not found") + return endpoint -@router.put("/{endpoint_id}", status_code=501) -def update_endpoint(endpoint_id: uuid.UUID): - """Update endpoint.""" - return Response(status_code=501, content="Not Implemented") +# --------------------------------------------------------------------------- +# CRUD +# --------------------------------------------------------------------------- -@router.delete("/{endpoint_id}", status_code=501) -def delete_endpoint(endpoint_id: uuid.UUID): - """Remove endpoint.""" - return Response(status_code=501, content="Not Implemented") +@router.get("/", response_model=EndpointListResponse) +def list_endpoints( + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> EndpointListResponse: + """List all configured LLM endpoints.""" + endpoints = db.query(LLMEndpoint).order_by(LLMEndpoint.name).all() + return EndpointListResponse( + items=[_to_response(ep) for ep in endpoints], + total=len(endpoints), + ) -@router.post("/{endpoint_id}/test", status_code=501) -def test_endpoint(endpoint_id: uuid.UUID): - """Test connectivity and list available models.""" - return Response(status_code=501, content="Not Implemented") +@router.post("/", response_model=EndpointResponse, status_code=status.HTTP_201_CREATED) +def create_endpoint( + body: EndpointCreate, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> EndpointResponse: + """Create a new LLM endpoint configuration.""" + endpoint = LLMEndpoint( + name=body.name, + url=body.url, + api_key_encrypted=encrypt_api_key(body.api_key) if body.api_key else None, + default_model=body.default_model, + ) + db.add(endpoint) + db.commit() + db.refresh(endpoint) + return _to_response(endpoint) + + +@router.get("/{endpoint_id}", response_model=EndpointResponse) +def get_endpoint( + endpoint_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> EndpointResponse: + """Get a single LLM endpoint configuration.""" + endpoint = _get_endpoint_or_404(db, endpoint_id) + return _to_response(endpoint) + + +@router.put("/{endpoint_id}", response_model=EndpointResponse) +def update_endpoint( + endpoint_id: uuid.UUID, + body: EndpointUpdate, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> EndpointResponse: + """Update an LLM endpoint configuration.""" + endpoint = _get_endpoint_or_404(db, endpoint_id) + + if body.name is not None: + endpoint.name = body.name + if body.url is not None: + endpoint.url = body.url + if body.api_key is not None: + # Empty string clears the key; non-empty encrypts it + endpoint.api_key_encrypted = encrypt_api_key(body.api_key) if body.api_key else None + if body.default_model is not None: + endpoint.default_model = body.default_model + + db.commit() + db.refresh(endpoint) + return _to_response(endpoint) + + +@router.delete("/{endpoint_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_endpoint( + endpoint_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> None: + """Delete an LLM endpoint configuration.""" + endpoint = _get_endpoint_or_404(db, endpoint_id) + db.delete(endpoint) + db.commit() + + +# --------------------------------------------------------------------------- +# Test connection +# --------------------------------------------------------------------------- + + +@router.post("/{endpoint_id}/test") +async def test_endpoint( + endpoint_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> dict: + """Test connectivity and list available models for an endpoint. + + Calls adapter.test_connection() and adapter.list_models() against the + stored endpoint configuration. + """ + endpoint = _get_endpoint_or_404(db, endpoint_id) + + api_key: str | None = None + if endpoint.api_key_encrypted: + try: + api_key = decrypt_api_key(endpoint.api_key_encrypted) + except ValueError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to decrypt API key — JWT_SECRET may have changed", + ) + + adapter = OpenAICompatAdapter(base_url=endpoint.url, api_key=api_key) + + connected = await adapter.test_connection() + models: list[str] = [] + if connected: + try: + models = await adapter.list_models() + except Exception: + pass # Connection worked but model listing failed + + return { + "endpoint_id": str(endpoint.id), + "name": endpoint.name, + "connected": connected, + "models": models, + } + + +# --------------------------------------------------------------------------- +# Response builder +# --------------------------------------------------------------------------- + + +def _to_response(endpoint: LLMEndpoint) -> EndpointResponse: + """Convert ORM model to response schema (never expose encrypted key).""" + return EndpointResponse( + id=endpoint.id, + name=endpoint.name, + url=endpoint.url, + default_model=endpoint.default_model, + has_api_key=endpoint.api_key_encrypted is not None, + ) diff --git a/backend/schemas.py b/backend/schemas.py index e6d9877..87db0ba 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -188,6 +188,7 @@ class EndpointResponse(BaseModel): name: str url: str default_model: str | None + has_api_key: bool = False class EndpointListResponse(BaseModel): diff --git a/backend/tests/test_endpoints.py b/backend/tests/test_endpoints.py new file mode 100644 index 0000000..1dcebd7 --- /dev/null +++ b/backend/tests/test_endpoints.py @@ -0,0 +1,417 @@ +"""Tests for backend/routers/endpoints.py — LLM endpoint CRUD + test_connection.""" + +import os +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + + +JWT_SECRET = "test-secret-key-for-jwt-signing" +API_KEY = "test-api-key-12345" + + +@pytest.fixture(autouse=True) +def _isolate_settings(tmp_path): + """Ensure tests use a temp SQLite DB and no Redis.""" + env = { + "DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}", + "REDIS_URL": "", + "DATA_DIR": str(tmp_path), + "JWT_SECRET": JWT_SECRET, + "API_KEY": API_KEY, + } + with patch.dict(os.environ, env, clear=False): + import config + new_settings = config.Settings(_env_file=None) + config.settings = new_settings + + import main + main.settings = new_settings + main._init_db() + main._init_redis() + + from models import Base + Base.metadata.create_all(bind=main.engine) + + import auth + auth.settings = new_settings + + # Patch encryption module's settings reference + import encryption + encryption.settings = new_settings + + yield + + +@pytest.fixture +def db_session(): + from main import get_db + gen = get_db() + session = next(gen) + yield session + try: + next(gen) + except StopIteration: + pass + + +@pytest.fixture +def admin_user(db_session): + from auth import hash_password + from models import User + user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + return user + + +@pytest.fixture +def auth_headers(): + return {"X-Api-Key": API_KEY} + + +@pytest.fixture +def client(): + from main import app + return TestClient(app) + + +# --------------------------------------------------------------------------- +# Encryption module tests +# --------------------------------------------------------------------------- + +class TestEncryption: + def test_encrypt_decrypt_roundtrip(self): + from encryption import encrypt_api_key, decrypt_api_key + plain = "sk-test-key-12345" + encrypted = encrypt_api_key(plain) + assert encrypted != plain + assert decrypt_api_key(encrypted) == plain + + def test_different_keys_produce_different_ciphertexts(self): + from encryption import encrypt_api_key + ct1 = encrypt_api_key("my-key") + ct2 = encrypt_api_key("my-key") + # Fernet uses a random IV each time, so ciphertexts differ + assert ct1 != ct2 + + def test_decrypt_bad_data_raises_value_error(self): + from encryption import decrypt_api_key + with pytest.raises(ValueError, match="Failed to decrypt"): + decrypt_api_key("not-valid-fernet-token") + + def test_decrypt_with_wrong_secret_fails(self): + from encryption import encrypt_api_key, decrypt_api_key + encrypted = encrypt_api_key("my-key") + + # Change the settings secret + import config + old_secret = config.settings.jwt_secret + config.settings.jwt_secret = "completely-different-secret" + + import encryption + encryption.settings = config.settings + + with pytest.raises(ValueError, match="Failed to decrypt"): + decrypt_api_key(encrypted) + + # Restore + config.settings.jwt_secret = old_secret + encryption.settings = config.settings + + +# --------------------------------------------------------------------------- +# Endpoint CRUD tests +# --------------------------------------------------------------------------- + + +class TestCreateEndpoint: + def test_create_minimal(self, client, admin_user, auth_headers): + resp = client.post("/api/endpoints/", json={ + "name": "Local Ollama", + "url": "http://localhost:11434/v1", + }, headers=auth_headers) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "Local Ollama" + assert data["url"] == "http://localhost:11434/v1" + assert data["default_model"] is None + assert data["has_api_key"] is False + assert "id" in data + + def test_create_with_api_key(self, client, admin_user, auth_headers): + resp = client.post("/api/endpoints/", json={ + "name": "OpenAI", + "url": "https://api.openai.com/v1", + "api_key": "sk-test-12345", + "default_model": "gpt-4", + }, headers=auth_headers) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "OpenAI" + assert data["has_api_key"] is True + assert data["default_model"] == "gpt-4" + # API key must NOT appear in response + assert "api_key" not in data + assert "api_key_encrypted" not in data + + def test_create_requires_auth(self, client): + resp = client.post("/api/endpoints/", json={ + "name": "Test", + "url": "http://localhost:8080/v1", + }) + assert resp.status_code == 401 + + def test_create_validates_name(self, client, admin_user, auth_headers): + resp = client.post("/api/endpoints/", json={ + "name": "", + "url": "http://localhost/v1", + }, headers=auth_headers) + assert resp.status_code == 422 + + def test_api_key_is_encrypted_in_db(self, client, admin_user, auth_headers, db_session): + resp = client.post("/api/endpoints/", json={ + "name": "Test", + "url": "http://localhost/v1", + "api_key": "sk-secret-key", + }, headers=auth_headers) + endpoint_id = resp.json()["id"] + + from models import LLMEndpoint + endpoint = db_session.query(LLMEndpoint).filter( + LLMEndpoint.id == uuid.UUID(endpoint_id) + ).first() + assert endpoint is not None + assert endpoint.api_key_encrypted is not None + assert endpoint.api_key_encrypted != "sk-secret-key" + + from encryption import decrypt_api_key + assert decrypt_api_key(endpoint.api_key_encrypted) == "sk-secret-key" + + +class TestListEndpoints: + def test_list_empty(self, client, admin_user, auth_headers): + resp = client.get("/api/endpoints/", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["items"] == [] + assert data["total"] == 0 + + def test_list_multiple(self, client, admin_user, auth_headers): + for name in ["Alpha", "Beta", "Gamma"]: + client.post("/api/endpoints/", json={ + "name": name, "url": f"http://{name.lower()}/v1", + }, headers=auth_headers) + + resp = client.get("/api/endpoints/", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 3 + # Should be ordered by name + names = [ep["name"] for ep in data["items"]] + assert names == ["Alpha", "Beta", "Gamma"] + + def test_list_requires_auth(self, client): + resp = client.get("/api/endpoints/") + assert resp.status_code == 401 + + +class TestGetEndpoint: + def test_get_existing(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "Test EP", + "url": "http://test/v1", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + resp = client.get(f"/api/endpoints/{ep_id}", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["name"] == "Test EP" + + def test_get_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.get(f"/api/endpoints/{fake_id}", headers=auth_headers) + assert resp.status_code == 404 + + +class TestUpdateEndpoint: + def test_update_name(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "Old Name", + "url": "http://test/v1", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + resp = client.put(f"/api/endpoints/{ep_id}", json={ + "name": "New Name", + }, headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["name"] == "New Name" + assert resp.json()["url"] == "http://test/v1" # unchanged + + def test_update_api_key(self, client, admin_user, auth_headers, db_session): + create_resp = client.post("/api/endpoints/", json={ + "name": "Test", + "url": "http://test/v1", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + assert create_resp.json()["has_api_key"] is False + + # Set API key + resp = client.put(f"/api/endpoints/{ep_id}", json={ + "api_key": "sk-new-key", + }, headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["has_api_key"] is True + + def test_clear_api_key(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "Test", + "url": "http://test/v1", + "api_key": "sk-key", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + # Clear by sending empty string + resp = client.put(f"/api/endpoints/{ep_id}", json={ + "api_key": "", + }, headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["has_api_key"] is False + + def test_update_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.put(f"/api/endpoints/{fake_id}", json={ + "name": "X", + }, headers=auth_headers) + assert resp.status_code == 404 + + +class TestDeleteEndpoint: + def test_delete_existing(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "ToDelete", + "url": "http://test/v1", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + resp = client.delete(f"/api/endpoints/{ep_id}", headers=auth_headers) + assert resp.status_code == 204 + + # Verify gone + resp = client.get(f"/api/endpoints/{ep_id}", headers=auth_headers) + assert resp.status_code == 404 + + def test_delete_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.delete(f"/api/endpoints/{fake_id}", headers=auth_headers) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Test connection endpoint +# --------------------------------------------------------------------------- + + +class TestTestEndpoint: + def test_test_connection_success(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "Mock EP", + "url": "http://mock-llm/v1", + "api_key": "sk-mock", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + mock_adapter = AsyncMock() + mock_adapter.test_connection.return_value = True + mock_adapter.list_models.return_value = ["model-a", "model-b"] + + with patch( + "routers.endpoints.OpenAICompatAdapter", + return_value=mock_adapter, + ) as mock_cls: + resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers) + + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is True + assert data["models"] == ["model-a", "model-b"] + assert data["name"] == "Mock EP" + + # Verify adapter was constructed with decrypted key + call_kwargs = mock_cls.call_args + assert call_kwargs.kwargs["base_url"] == "http://mock-llm/v1" + assert call_kwargs.kwargs["api_key"] == "sk-mock" + + def test_test_connection_failure(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "Bad EP", + "url": "http://bad-host/v1", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + mock_adapter = AsyncMock() + mock_adapter.test_connection.return_value = False + + with patch( + "routers.endpoints.OpenAICompatAdapter", + return_value=mock_adapter, + ): + resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers) + + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is False + assert data["models"] == [] + + def test_test_connection_no_api_key(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "No Key EP", + "url": "http://local/v1", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + mock_adapter = AsyncMock() + mock_adapter.test_connection.return_value = True + mock_adapter.list_models.return_value = ["llama3"] + + with patch( + "routers.endpoints.OpenAICompatAdapter", + return_value=mock_adapter, + ) as mock_cls: + resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers) + + assert resp.status_code == 200 + assert resp.json()["connected"] is True + # Should have been called with api_key=None + assert mock_cls.call_args.kwargs["api_key"] is None + + def test_test_connection_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.post(f"/api/endpoints/{fake_id}/test", headers=auth_headers) + assert resp.status_code == 404 + + def test_test_connection_list_models_fails_gracefully(self, client, admin_user, auth_headers): + create_resp = client.post("/api/endpoints/", json={ + "name": "Partial EP", + "url": "http://partial/v1", + }, headers=auth_headers) + ep_id = create_resp.json()["id"] + + mock_adapter = AsyncMock() + mock_adapter.test_connection.return_value = True + mock_adapter.list_models.side_effect = RuntimeError("models endpoint broken") + + with patch( + "routers.endpoints.OpenAICompatAdapter", + return_value=mock_adapter, + ): + resp = client.post(f"/api/endpoints/{ep_id}/test", headers=auth_headers) + + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is True + assert data["models"] == [] diff --git a/backend/tests/test_routers.py b/backend/tests/test_routers.py index 9a1c9e5..dff6c08 100644 --- a/backend/tests/test_routers.py +++ b/backend/tests/test_routers.py @@ -142,30 +142,31 @@ def test_runs_leaderboard(client): # ---- Endpoints router (/api/endpoints) ---- +# Endpoints router is now fully implemented and requires auth (returns 401 without credentials) def test_endpoints_list(client): resp = client.get("/api/endpoints/") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_endpoints_create(client): resp = client.post("/api/endpoints/") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_endpoints_update(client): resp = client.put("/api/endpoints/00000000-0000-0000-0000-000000000001") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_endpoints_delete(client): resp = client.delete("/api/endpoints/00000000-0000-0000-0000-000000000001") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_endpoints_test(client): resp = client.post("/api/endpoints/00000000-0000-0000-0000-000000000001/test") - assert resp.status_code == 501 + assert resp.status_code == 401 # ---- Export router (/api/export) ----