"""Tests for engine/tasks.py — Celery task definitions and sync fallback.""" import asyncio import sys import uuid from pathlib import Path from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session # Ensure backend is on sys.path _backend_dir = str(Path(__file__).resolve().parents[1]) if _backend_dir not in sys.path: sys.path.insert(0, _backend_dir) from engine.adapters.base import AdapterResponse, BaseAdapter from engine.cache import ResponseCacheLayer, compute_config_hash from engine.runner import EventBus, run_single from engine.tasks import ( SyncTaskResult, _do_execute_run, _do_execute_sweep, dispatch_run, dispatch_sweep, ) from models import ( Base, Experiment, ExperimentStatus, Project, Run, RunStatus, User, ) # --------------------------------------------------------------------------- # Fixtures / helpers # --------------------------------------------------------------------------- class MockAdapter(BaseAdapter): """Test adapter that returns fixed responses.""" def __init__(self, text: str = "mock output", tokens_in: int = 10, tokens_out: int = 5): self._text = text self._tokens_in = tokens_in self._tokens_out = tokens_out @property def name(self) -> str: return "mock" async def complete(self, prompt: str, model: str, params: dict) -> AdapterResponse: return AdapterResponse( text=self._text, tokens_in=self._tokens_in, tokens_out=self._tokens_out, latency_ms=42.0, ) async def list_models(self) -> list[str]: return ["mock-model"] async def test_connection(self) -> bool: return True @pytest.fixture def db_engine(): engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) return engine @pytest.fixture def db(db_engine): session = Session(db_engine) yield session session.close() @pytest.fixture def user(db): u = User(username="testuser", password_hash="fakehash") db.add(u) db.commit() db.refresh(u) return u @pytest.fixture def project(db, user): p = Project(name="Test Project", owner_id=user.id) db.add(p) db.commit() db.refresh(p) return p @pytest.fixture def experiment(db, project): exp = Experiment( project_id=project.id, name="Test Experiment", pipeline_stages=[ {"prompt_template": "Say hello", "model": "test-model", "params": {}} ], parameter_space={"type": "grid", "params": {}}, status=ExperimentStatus.draft, ) db.add(exp) db.commit() db.refresh(exp) return exp def _make_run(db: Session, experiment_id: uuid.UUID, config: dict | None = None) -> Run: config = config or { "prompt": "Say hello", "model": "test-model", "params": {}, "pipeline_stages": [ {"prompt_template": "Say hello", "model": "test-model", "params": {}} ], } config_hash = compute_config_hash( config.get("prompt", ""), config.get("model", ""), config.get("params", {}) ) run = Run( experiment_id=experiment_id, config_hash=config_hash, config=config, status=RunStatus.pending, ) db.add(run) db.commit() db.refresh(run) return run # --------------------------------------------------------------------------- # SyncTaskResult tests # --------------------------------------------------------------------------- class TestSyncTaskResult: def test_successful_result(self): r = SyncTaskResult(result={"status": "ok"}, task_id="test-123") assert r.id == "test-123" assert r.status == "SUCCESS" assert r.state == "SUCCESS" assert r.ready() is True assert r.successful() is True assert r.failed() is False assert r.get() == {"status": "ok"} def test_failed_result(self): err = ValueError("boom") r = SyncTaskResult(error=err) assert r.status == "FAILURE" assert r.ready() is True assert r.successful() is False assert r.failed() is True with pytest.raises(ValueError, match="boom"): r.get() def test_auto_generates_task_id(self): r = SyncTaskResult(result="data") assert r.id is not None assert len(r.id) > 0 # --------------------------------------------------------------------------- # _do_execute_run tests # --------------------------------------------------------------------------- class TestDoExecuteRun: def test_executes_run_successfully(self, db, experiment): run = _make_run(db, experiment.id) adapter = MockAdapter() cache = ResponseCacheLayer() event_bus = EventBus() # Patch the internals to use our fixtures with patch("engine.tasks._get_db_session", return_value=db), \ patch("engine.tasks._get_adapter", return_value=adapter), \ patch("engine.tasks._get_cache", return_value=cache), \ patch("engine.tasks._get_event_bus", return_value=event_bus): result = _do_execute_run(str(run.id)) assert result["run_id"] == str(run.id) assert result["status"] == "completed" assert result["tokens_in"] == 10 assert result["tokens_out"] == 5 assert result["duration_ms"] is not None def test_raises_for_missing_run(self, db): fake_id = str(uuid.uuid4()) with patch("engine.tasks._get_db_session", return_value=db), \ pytest.raises(ValueError, match="not found"): _do_execute_run(fake_id) def test_passes_endpoint_config(self, db, experiment): run = _make_run(db, experiment.id) adapter = MockAdapter() cache = ResponseCacheLayer() event_bus = EventBus() endpoint_cfg = {"url": "http://test:8080/v1", "api_key": "sk-test"} with patch("engine.tasks._get_db_session", return_value=db), \ patch("engine.tasks._get_adapter", return_value=adapter) as mock_get_adapter, \ patch("engine.tasks._get_cache", return_value=cache), \ patch("engine.tasks._get_event_bus", return_value=event_bus): _do_execute_run(str(run.id), endpoint_config=endpoint_cfg) mock_get_adapter.assert_called_once_with(endpoint_cfg) # --------------------------------------------------------------------------- # _do_execute_sweep tests # --------------------------------------------------------------------------- class TestDoExecuteSweep: def test_executes_sweep_successfully(self, db, experiment): adapter = MockAdapter() cache = ResponseCacheLayer() event_bus = EventBus() with patch("engine.tasks._get_db_session", return_value=db), \ patch("engine.tasks._get_adapter", return_value=adapter), \ patch("engine.tasks._get_cache", return_value=cache), \ patch("engine.tasks._get_event_bus", return_value=event_bus), \ patch("engine.tasks._get_redis_client", return_value=None): result = _do_execute_sweep(str(experiment.id)) assert result["experiment_id"] == str(experiment.id) assert result["total_runs"] >= 1 assert result["completed_runs"] >= 1 assert result["stopped_reason"] == "completed" def test_raises_for_missing_experiment(self, db): fake_id = str(uuid.uuid4()) with patch("engine.tasks._get_db_session", return_value=db), \ pytest.raises(ValueError, match="not found"): _do_execute_sweep(fake_id) def test_override_sweep_config(self, db, experiment): adapter = MockAdapter() cache = ResponseCacheLayer() event_bus = EventBus() custom_sweep = {"type": "random", "params": {"model": ["a", "b"]}, "n_trials": 2} with patch("engine.tasks._get_db_session", return_value=db), \ patch("engine.tasks._get_adapter", return_value=adapter), \ patch("engine.tasks._get_cache", return_value=cache), \ patch("engine.tasks._get_event_bus", return_value=event_bus), \ patch("engine.tasks._get_redis_client", return_value=None): result = _do_execute_sweep(str(experiment.id), sweep_config=custom_sweep) assert result["total_runs"] == 2 assert result["completed_runs"] == 2 # --------------------------------------------------------------------------- # dispatch_run / dispatch_sweep tests (sync fallback) # --------------------------------------------------------------------------- class TestDispatchRunSync: def test_sync_fallback_success(self, db, experiment): run = _make_run(db, experiment.id) adapter = MockAdapter() cache = ResponseCacheLayer() event_bus = EventBus() with patch("engine.tasks.settings") as mock_settings, \ patch("engine.tasks._get_db_session", return_value=db), \ patch("engine.tasks._get_adapter", return_value=adapter), \ patch("engine.tasks._get_cache", return_value=cache), \ patch("engine.tasks._get_event_bus", return_value=event_bus): mock_settings.use_in_process_queue = True task_result = dispatch_run(str(run.id)) assert task_result.successful() result = task_result.get() assert result["status"] == "completed" def test_sync_fallback_error(self, db): fake_id = str(uuid.uuid4()) with patch("engine.tasks.settings") as mock_settings, \ patch("engine.tasks._get_db_session", return_value=db): mock_settings.use_in_process_queue = True task_result = dispatch_run(fake_id) assert task_result.failed() with pytest.raises(ValueError): task_result.get() class TestDispatchSweepSync: def test_sync_fallback_success(self, db, experiment): adapter = MockAdapter() cache = ResponseCacheLayer() event_bus = EventBus() with patch("engine.tasks.settings") as mock_settings, \ patch("engine.tasks._get_db_session", return_value=db), \ patch("engine.tasks._get_adapter", return_value=adapter), \ patch("engine.tasks._get_cache", return_value=cache), \ patch("engine.tasks._get_event_bus", return_value=event_bus), \ patch("engine.tasks._get_redis_client", return_value=None): mock_settings.use_in_process_queue = True mock_settings.max_concurrent_runs = 2 mock_settings.max_tokens_per_sweep = 0 task_result = dispatch_sweep(str(experiment.id)) assert task_result.successful() result = task_result.get() assert result["completed_runs"] >= 1 def test_sync_fallback_error(self, db): fake_id = str(uuid.uuid4()) with patch("engine.tasks.settings") as mock_settings, \ patch("engine.tasks._get_db_session", return_value=db): mock_settings.use_in_process_queue = True task_result = dispatch_sweep(fake_id) assert task_result.failed() with pytest.raises(ValueError): task_result.get() # --------------------------------------------------------------------------- # Celery task registration tests # --------------------------------------------------------------------------- class TestCeleryTaskRegistration: def test_execute_run_task_exists(self): """execute_run should be registered as a Celery task or None if Celery unavailable.""" from engine.tasks import execute_run # In test environment, Celery may or may not be available # If available, it should be a Celery task with the correct name if execute_run is not None: assert execute_run.name == "engine.execute_run" def test_execute_sweep_task_exists(self): from engine.tasks import execute_sweep if execute_sweep is not None: assert execute_sweep.name == "engine.execute_sweep" # --------------------------------------------------------------------------- # dispatch with Celery available # --------------------------------------------------------------------------- class TestDispatchWithCelery: def test_dispatch_run_uses_celery_when_redis_available(self): """When use_in_process_queue is False and Celery tasks exist, should use .delay().""" mock_task = MagicMock() mock_task.delay.return_value = MagicMock(id="celery-task-id") with patch("engine.tasks.settings") as mock_settings, \ patch("engine.tasks.execute_run", mock_task): mock_settings.use_in_process_queue = False result = dispatch_run("some-run-id", endpoint_config={"url": "http://x"}) mock_task.delay.assert_called_once_with("some-run-id", {"url": "http://x"}) def test_dispatch_sweep_uses_celery_when_redis_available(self): mock_task = MagicMock() mock_task.delay.return_value = MagicMock(id="celery-task-id") with patch("engine.tasks.settings") as mock_settings, \ patch("engine.tasks.execute_sweep", mock_task): mock_settings.use_in_process_queue = False result = dispatch_sweep("some-exp-id", sweep_config={"type": "grid"}) mock_task.delay.assert_called_once_with("some-exp-id", {"type": "grid"}, None)