mirror of
https://github.com/xpltdco/media-rip.git
synced 2026-04-03 02:53:58 -06:00
159 lines
4.9 KiB
Python
159 lines
4.9 KiB
Python
"""Tests for the aiosqlite database layer."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
from app.core.database import (
|
|
close_db,
|
|
create_job,
|
|
delete_job,
|
|
get_job,
|
|
get_jobs_by_session,
|
|
init_db,
|
|
update_job_progress,
|
|
update_job_status,
|
|
)
|
|
from app.models.job import Job, JobStatus
|
|
|
|
|
|
def _make_job(session_id: str = "sess-1", **overrides) -> Job:
|
|
"""Factory for test Job instances."""
|
|
defaults = dict(
|
|
id=str(uuid.uuid4()),
|
|
session_id=session_id,
|
|
url="https://example.com/video",
|
|
status=JobStatus.queued,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
defaults.update(overrides)
|
|
return Job(**defaults)
|
|
|
|
|
|
class TestInitDb:
|
|
"""Database initialisation and PRAGMA verification."""
|
|
|
|
async def test_creates_all_tables(self, db):
|
|
cursor = await db.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
|
)
|
|
tables = {row[0] for row in await cursor.fetchall()}
|
|
assert "sessions" in tables
|
|
assert "jobs" in tables
|
|
assert "config" in tables
|
|
assert "unsupported_urls" in tables
|
|
|
|
async def test_wal_mode_enabled(self, db):
|
|
cursor = await db.execute("PRAGMA journal_mode")
|
|
row = await cursor.fetchone()
|
|
assert row[0] == "wal"
|
|
|
|
async def test_busy_timeout_set(self, db):
|
|
cursor = await db.execute("PRAGMA busy_timeout")
|
|
row = await cursor.fetchone()
|
|
assert row[0] == 5000
|
|
|
|
async def test_indexes_created(self, db):
|
|
cursor = await db.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='index' AND name LIKE 'idx_%'"
|
|
)
|
|
indexes = {row[0] for row in await cursor.fetchall()}
|
|
assert "idx_jobs_session_status" in indexes
|
|
assert "idx_jobs_completed" in indexes
|
|
assert "idx_sessions_last_seen" in indexes
|
|
|
|
|
|
class TestJobCrud:
|
|
"""CRUD operations on the jobs table."""
|
|
|
|
async def test_create_and_get_roundtrip(self, db):
|
|
job = _make_job()
|
|
created = await create_job(db, job)
|
|
assert created.id == job.id
|
|
|
|
fetched = await get_job(db, job.id)
|
|
assert fetched is not None
|
|
assert fetched.id == job.id
|
|
assert fetched.url == job.url
|
|
assert fetched.status == JobStatus.queued
|
|
|
|
async def test_get_nonexistent_returns_none(self, db):
|
|
result = await get_job(db, "no-such-id")
|
|
assert result is None
|
|
|
|
async def test_get_jobs_by_session(self, db):
|
|
j1 = _make_job(session_id="sess-A")
|
|
j2 = _make_job(session_id="sess-A")
|
|
j3 = _make_job(session_id="sess-B")
|
|
await create_job(db, j1)
|
|
await create_job(db, j2)
|
|
await create_job(db, j3)
|
|
|
|
sess_a_jobs = await get_jobs_by_session(db, "sess-A")
|
|
assert len(sess_a_jobs) == 2
|
|
assert all(j.session_id == "sess-A" for j in sess_a_jobs)
|
|
|
|
sess_b_jobs = await get_jobs_by_session(db, "sess-B")
|
|
assert len(sess_b_jobs) == 1
|
|
|
|
async def test_update_job_status(self, db):
|
|
job = _make_job()
|
|
await create_job(db, job)
|
|
|
|
await update_job_status(db, job.id, "failed", error_message="404 not found")
|
|
updated = await get_job(db, job.id)
|
|
assert updated is not None
|
|
assert updated.status == JobStatus.failed
|
|
assert updated.error_message == "404 not found"
|
|
|
|
async def test_update_job_progress(self, db):
|
|
job = _make_job()
|
|
await create_job(db, job)
|
|
|
|
await update_job_progress(
|
|
db, job.id,
|
|
progress_percent=42.5,
|
|
speed="1.2 MiB/s",
|
|
eta="2m30s",
|
|
filename="video.mp4",
|
|
)
|
|
updated = await get_job(db, job.id)
|
|
assert updated is not None
|
|
assert updated.progress_percent == 42.5
|
|
assert updated.speed == "1.2 MiB/s"
|
|
assert updated.eta == "2m30s"
|
|
assert updated.filename == "video.mp4"
|
|
|
|
async def test_delete_job(self, db):
|
|
job = _make_job()
|
|
await create_job(db, job)
|
|
|
|
await delete_job(db, job.id)
|
|
assert await get_job(db, job.id) is None
|
|
|
|
|
|
class TestConcurrentWrites:
|
|
"""Verify WAL mode handles concurrent writers without SQLITE_BUSY."""
|
|
|
|
async def test_three_concurrent_inserts(self, tmp_db_path):
|
|
"""Launch 3 simultaneous create_job calls via asyncio.gather."""
|
|
db = await init_db(tmp_db_path)
|
|
|
|
jobs = [_make_job(session_id="concurrent") for _ in range(3)]
|
|
results = await asyncio.gather(
|
|
*[create_job(db, j) for j in jobs],
|
|
return_exceptions=True,
|
|
)
|
|
|
|
# No exceptions — all three succeeded
|
|
for r in results:
|
|
assert isinstance(r, Job), f"Expected Job, got {type(r).__name__}: {r}"
|
|
|
|
# Verify all three exist
|
|
all_jobs = await get_jobs_by_session(db, "concurrent")
|
|
assert len(all_jobs) == 3
|
|
|
|
await close_db(db)
|