diff --git a/Auto Run Docs/02b-frontend-dashboard.md b/Auto Run Docs/02b-frontend-dashboard.md index 9080650..5a3a34f 100644 --- a/Auto Run Docs/02b-frontend-dashboard.md +++ b/Auto Run Docs/02b-frontend-dashboard.md @@ -20,7 +20,8 @@ Build the React frontend: setup wizard, experiment builder, real-time observabil - [x] Build the prompt template editor component (frontend/src/components/PromptEditor.tsx). Use a code editor library (CodeMirror or Monaco, loaded from CDN). Support Jinja2 template syntax highlighting. Show available template variables in a sidebar (input_data, previous_stage_output, etc.). Include a "Preview" button that renders the template with sample data. -- [ ] Build the model selector component (frontend/src/components/ModelSelector.tsx). Dropdown grouped by endpoint. Each option shows model name + endpoint label. Include a "refresh models" button that calls the endpoint test API to refresh available models. Show a connectivity indicator (green dot = reachable, red = error). +- [x] Build the model selector component (frontend/src/components/ModelSelector.tsx). Dropdown grouped by endpoint. Each option shows model name + endpoint label. Include a "refresh models" button that calls the endpoint test API to refresh available models. Show a connectivity indicator (green dot = reachable, red = error). + - [ ] Implement the Live Observability page (frontend/src/pages/Live.tsx). This is the star of the show — the real-time dashboard during active sweeps. Layout: left column (60%) shows the activity timeline and current run details, right column (40%) shows the leaderboard and steering controls. Connect via WebSocket to /ws/experiments/{id}. Everything updates in real-time without page refresh. diff --git a/frontend/src/components/ModelSelector.test.tsx b/frontend/src/components/ModelSelector.test.tsx new file mode 100644 index 0000000..3d29d5e --- /dev/null +++ b/frontend/src/components/ModelSelector.test.tsx @@ -0,0 +1,235 @@ +import { render, screen, waitFor, within } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import ModelSelector, { ModelSelectorProps } from "./ModelSelector"; +import * as client from "../api/client"; + +// --------------------------------------------------------------------------- +// Mocks +// --------------------------------------------------------------------------- + +const MOCK_ENDPOINTS: client.EndpointResponse[] = [ + { + id: "ep-1", + name: "Local vLLM", + url: "http://localhost:8080", + default_model: "llama-3-70b", + }, + { + id: "ep-2", + name: "OpenAI", + url: "https://api.openai.com", + default_model: "gpt-4", + }, + { + id: "ep-3", + name: "Ollama", + url: "http://localhost:11434", + default_model: null, + }, +]; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function renderSelector(overrides: Partial = {}) { + const defaultProps: ModelSelectorProps = { + value: "::", + onChange: vi.fn(), + endpoints: MOCK_ENDPOINTS, + "data-testid": "model-selector", + ...overrides, + }; + return { + ...render(), + props: defaultProps, + }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("ModelSelector", () => { + beforeEach(() => { + vi.restoreAllMocks(); + }); + + it("renders a select element with placeholder option", () => { + renderSelector(); + const select = screen.getByTestId("model-select"); + expect(select).toBeInTheDocument(); + expect(screen.getByText("Select a model...")).toBeInTheDocument(); + }); + + it("groups options by endpoint using optgroups", () => { + renderSelector(); + const select = screen.getByTestId("model-select"); + const groups = select.querySelectorAll("optgroup"); + expect(groups).toHaveLength(3); + expect(groups[0]).toHaveAttribute("label", "Local vLLM"); + expect(groups[1]).toHaveAttribute("label", "OpenAI"); + expect(groups[2]).toHaveAttribute("label", "Ollama"); + }); + + it("shows model name + endpoint label in each option", () => { + renderSelector(); + expect(screen.getByText(/llama-3-70b.*\(Local vLLM\)/)).toBeInTheDocument(); + expect(screen.getByText(/gpt-4.*\(OpenAI\)/)).toBeInTheDocument(); + }); + + it("falls back to endpoint name when default_model is null", () => { + renderSelector(); + expect(screen.getByText(/Ollama.*\(Ollama\)/)).toBeInTheDocument(); + }); + + it("calls onChange when a model is selected", async () => { + const onChange = vi.fn(); + renderSelector({ onChange }); + const user = userEvent.setup(); + const select = screen.getByTestId("model-select"); + + await user.selectOptions(select, "ep-2::gpt-4"); + + expect(onChange).toHaveBeenCalledWith("ep-2::gpt-4"); + }); + + it("renders the refresh button", () => { + renderSelector(); + expect(screen.getByTestId("refresh-models-btn")).toBeInTheDocument(); + expect(screen.getByText("Refresh")).toBeInTheDocument(); + }); + + it("shows 'Testing...' text while refreshing", async () => { + // Make endpoint.test hang + vi.spyOn(client.endpoints, "test").mockImplementation( + () => new Promise(() => {}), + ); + renderSelector(); + const user = userEvent.setup(); + + await user.click(screen.getByTestId("refresh-models-btn")); + + expect(screen.getByText("Testing...")).toBeInTheDocument(); + }); + + it("shows green dots for reachable endpoints after refresh", async () => { + vi.spyOn(client.endpoints, "test").mockResolvedValue({ ok: true }); + vi.spyOn(client.endpoints, "list").mockResolvedValue({ + items: MOCK_ENDPOINTS, + total: MOCK_ENDPOINTS.length, + }); + + renderSelector(); + const user = userEvent.setup(); + await user.click(screen.getByTestId("refresh-models-btn")); + + await waitFor(() => { + expect(screen.getByTestId("status-ok-ep-1")).toBeInTheDocument(); + expect(screen.getByTestId("status-ok-ep-2")).toBeInTheDocument(); + expect(screen.getByTestId("status-ok-ep-3")).toBeInTheDocument(); + }); + }); + + it("shows red dots for unreachable endpoints after refresh", async () => { + vi.spyOn(client.endpoints, "test").mockRejectedValue( + new Error("Connection refused"), + ); + vi.spyOn(client.endpoints, "list").mockResolvedValue({ + items: MOCK_ENDPOINTS, + total: MOCK_ENDPOINTS.length, + }); + + renderSelector(); + const user = userEvent.setup(); + await user.click(screen.getByTestId("refresh-models-btn")); + + await waitFor(() => { + expect(screen.getByTestId("status-error-ep-1")).toBeInTheDocument(); + expect(screen.getByTestId("status-error-ep-2")).toBeInTheDocument(); + expect(screen.getByTestId("status-error-ep-3")).toBeInTheDocument(); + }); + }); + + it("shows mixed status when some endpoints pass and some fail", async () => { + vi.spyOn(client.endpoints, "test").mockImplementation((id: string) => { + if (id === "ep-1") return Promise.resolve({ ok: true }); + return Promise.reject(new Error("fail")); + }); + vi.spyOn(client.endpoints, "list").mockResolvedValue({ + items: MOCK_ENDPOINTS, + total: MOCK_ENDPOINTS.length, + }); + + renderSelector(); + const user = userEvent.setup(); + await user.click(screen.getByTestId("refresh-models-btn")); + + await waitFor(() => { + expect(screen.getByTestId("status-ok-ep-1")).toBeInTheDocument(); + expect(screen.getByTestId("status-error-ep-2")).toBeInTheDocument(); + expect(screen.getByTestId("status-error-ep-3")).toBeInTheDocument(); + }); + }); + + it("calls onEndpointsRefreshed with updated list after refresh", async () => { + const onEndpointsRefreshed = vi.fn(); + const updatedEndpoints = [MOCK_ENDPOINTS[0]]; + vi.spyOn(client.endpoints, "test").mockResolvedValue({ ok: true }); + vi.spyOn(client.endpoints, "list").mockResolvedValue({ + items: updatedEndpoints, + total: 1, + }); + + renderSelector({ onEndpointsRefreshed }); + const user = userEvent.setup(); + await user.click(screen.getByTestId("refresh-models-btn")); + + await waitFor(() => { + expect(onEndpointsRefreshed).toHaveBeenCalledWith(updatedEndpoints); + }); + }); + + it("does not show connectivity indicators before refresh", () => { + renderSelector(); + expect( + screen.queryByTestId("connectivity-indicators"), + ).not.toBeInTheDocument(); + }); + + it("disables refresh button while refreshing", async () => { + vi.spyOn(client.endpoints, "test").mockImplementation( + () => new Promise(() => {}), + ); + renderSelector(); + const user = userEvent.setup(); + const btn = screen.getByTestId("refresh-models-btn"); + + await user.click(btn); + + expect(btn).toBeDisabled(); + }); + + it("renders with a pre-selected value", () => { + renderSelector({ value: "ep-1::llama-3-70b" }); + const select = screen.getByTestId("model-select") as HTMLSelectElement; + expect(select.value).toBe("ep-1::llama-3-70b"); + }); + + it("renders correctly with no endpoints", () => { + renderSelector({ endpoints: [] }); + const select = screen.getByTestId("model-select"); + const groups = select.querySelectorAll("optgroup"); + expect(groups).toHaveLength(0); + // Only placeholder option + const options = select.querySelectorAll("option"); + expect(options).toHaveLength(1); + expect(options[0].textContent).toBe("Select a model..."); + }); + + it("applies custom data-testid", () => { + renderSelector({ "data-testid": "custom-selector" }); + expect(screen.getByTestId("custom-selector")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/ModelSelector.tsx b/frontend/src/components/ModelSelector.tsx new file mode 100644 index 0000000..cd98a7f --- /dev/null +++ b/frontend/src/components/ModelSelector.tsx @@ -0,0 +1,212 @@ +import { useState, useCallback } from "react"; +import { EndpointResponse, endpoints as endpointsApi } from "../api/client"; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface ModelSelectorProps { + /** Currently selected value as "endpointId::modelName" */ + value: string; + /** Called when user selects a new model */ + onChange: (value: string) => void; + /** Available endpoints to populate the dropdown */ + endpoints: EndpointResponse[]; + /** Called after endpoints are refreshed so parent can update its list */ + onEndpointsRefreshed?: (updated: EndpointResponse[]) => void; + /** Test id */ + "data-testid"?: string; +} + +/** Per-endpoint connectivity state tracked internally. */ +interface EndpointStatus { + status: "unknown" | "reachable" | "error"; + loading: boolean; +} + +// --------------------------------------------------------------------------- +// Component +// --------------------------------------------------------------------------- + +export default function ModelSelector({ + value, + onChange, + endpoints, + onEndpointsRefreshed, + "data-testid": testId, +}: ModelSelectorProps) { + const [statusMap, setStatusMap] = useState>( + {}, + ); + const [refreshing, setRefreshing] = useState(false); + + // -- helpers --------------------------------------------------------------- + + function getStatus(endpointId: string): EndpointStatus { + return statusMap[endpointId] ?? { status: "unknown", loading: false }; + } + + function statusDot(endpointId: string) { + const s = getStatus(endpointId); + if (s.loading) { + return ( + + ); + } + if (s.status === "reachable") { + return ( + + ); + } + if (s.status === "error") { + return ( + + ); + } + return ( + + ); + } + + // -- refresh models -------------------------------------------------------- + + const refreshModels = useCallback(async () => { + if (refreshing) return; + setRefreshing(true); + + const newStatusMap: Record = {}; + + // Mark all as loading + for (const ep of endpoints) { + newStatusMap[ep.id] = { status: "unknown", loading: true }; + } + setStatusMap({ ...newStatusMap }); + + // Test each endpoint in parallel + const results = await Promise.allSettled( + endpoints.map(async (ep) => { + try { + await endpointsApi.test(ep.id); + return { id: ep.id, status: "reachable" as const }; + } catch { + return { id: ep.id, status: "error" as const }; + } + }), + ); + + const finalStatusMap: Record = {}; + for (const result of results) { + if (result.status === "fulfilled") { + finalStatusMap[result.value.id] = { + status: result.value.status, + loading: false, + }; + } + } + setStatusMap(finalStatusMap); + + // Refresh endpoint list from server + if (onEndpointsRefreshed) { + try { + const list = await endpointsApi.list(); + onEndpointsRefreshed(list.items); + } catch { + // Silently fail — status indicators already show the issue + } + } + + setRefreshing(false); + }, [endpoints, onEndpointsRefreshed, refreshing]); + + // -- group by endpoint ----------------------------------------------------- + + const groups = endpoints.reduce< + Record + >((acc, ep) => { + acc[ep.id] = { endpoint: ep }; + return acc; + }, {}); + + // -- render ---------------------------------------------------------------- + + return ( +
+
+ + + +
+ + {/* Connectivity indicators */} + {endpoints.length > 0 && Object.keys(statusMap).length > 0 && ( +
+ {endpoints.map((ep) => ( + + {statusDot(ep.id)} + {ep.name} + + ))} +
+ )} +
+ ); +} diff --git a/frontend/src/pages/ExperimentPage.tsx b/frontend/src/pages/ExperimentPage.tsx index 54853d4..e4dd32d 100644 --- a/frontend/src/pages/ExperimentPage.tsx +++ b/frontend/src/pages/ExperimentPage.tsx @@ -11,6 +11,7 @@ import type { EndpointResponse, } from "../api/client"; import PromptEditor from "../components/PromptEditor"; +import ModelSelector from "../components/ModelSelector"; // --------------------------------------------------------------------------- // Types @@ -348,25 +349,19 @@ function PipelineStageCard({ - + endpoints={endpointList} + data-testid={`model-selector-${index}`} + /> {/* Preview is now integrated in PromptEditor above */}