sonosketch/hum2inst.py
John Lightner 5a233898c8 feat(01-01): create hum2inst.py CLI pipeline script
- Single-file CLI wrapping ACE-Step XL-SFT cover mode
- argparse with --instrument, --output, --strength, --duration flags
- Auto-detect input WAV duration via torchaudio
- Caption templates for piano, guitar, saxophone, violin, flute
- CUDA GPU check with clear error message
- Silence detection on output audio
- User-friendly output naming with instrument + timestamp
2026-04-11 02:11:53 -05:00

273 lines
9.2 KiB
Python

#!/usr/bin/env python3
"""
hum2inst.py - Convert humming to instrument audio using ACE-Step XL-SFT cover mode.
Usage:
python hum2inst.py input.wav --instrument piano
python hum2inst.py input.wav --instrument guitar --strength 0.85 --output ./my_output/
python hum2inst.py input.wav --instrument saxophone --duration 15
"""
import argparse
import math
import os
import shutil
import sys
import tempfile
from datetime import datetime
from pathlib import Path
# ---------------------------------------------------------------------------
# Caption templates for common instruments
# ---------------------------------------------------------------------------
CAPTION_TEMPLATES = {
"piano": "solo acoustic piano, gentle melody, warm tone, clear and expressive",
"guitar": "solo acoustic guitar, fingerpicked melody, warm and intimate",
"saxophone": "solo saxophone, smooth jazz melody, soulful and expressive",
"violin": "solo violin, classical melody, rich and emotional",
"flute": "solo flute, gentle melody, airy and delicate",
}
def build_caption(instrument: str) -> str:
"""Build an ACE-Step caption from the instrument name."""
instrument_lower = instrument.lower().strip()
if instrument_lower in CAPTION_TEMPLATES:
return CAPTION_TEMPLATES[instrument_lower]
return f"solo {instrument_lower}, clear and expressive melody, warm tone"
def get_wav_duration(wav_path: str) -> float:
"""Return the duration of a WAV file in seconds using torchaudio."""
import torchaudio
info = torchaudio.info(wav_path)
return info.num_frames / info.sample_rate
def check_silence(audio_path: str, threshold_db: float = -60.0) -> bool:
"""Check if the output audio is near-silent.
Returns True if the audio is below the threshold (likely silent/failed).
"""
import torch
import torchaudio
waveform, _sr = torchaudio.load(audio_path)
rms = torch.sqrt(torch.mean(waveform ** 2))
if rms > 0:
rms_db = 20 * torch.log10(rms).item()
else:
rms_db = -float("inf")
return rms_db < threshold_db
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Convert a hummed melody to an instrument rendition using ACE-Step.",
epilog="Example: python hum2inst.py humming.wav --instrument piano",
)
parser.add_argument(
"input",
type=str,
help="Path to the input WAV file containing humming",
)
parser.add_argument(
"--instrument",
type=str,
required=True,
help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)",
)
parser.add_argument(
"--output",
type=str,
default="./output/",
help="Output directory for the generated audio (default: ./output/)",
)
parser.add_argument(
"--strength",
type=float,
default=0.9,
help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)",
)
parser.add_argument(
"--duration",
type=float,
default=None,
help="Override output duration in seconds (default: auto-detect from input WAV)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
# -----------------------------------------------------------------------
# Validate input file
# -----------------------------------------------------------------------
input_path = os.path.abspath(args.input)
if not os.path.isfile(input_path):
print(f"ERROR: Input file not found: {input_path}", file=sys.stderr)
sys.exit(1)
# -----------------------------------------------------------------------
# CUDA check
# -----------------------------------------------------------------------
print("Checking GPU availability...")
import torch
if not torch.cuda.is_available():
print(
"ERROR: CUDA GPU required. No CUDA-capable GPU detected.",
file=sys.stderr,
)
sys.exit(1)
print(f"GPU detected: {torch.cuda.get_device_name(0)}")
# -----------------------------------------------------------------------
# Detect input duration
# -----------------------------------------------------------------------
if args.duration is not None:
duration = round(args.duration)
print(f"Using override duration: {duration}s")
else:
raw_duration = get_wav_duration(input_path)
duration = round(raw_duration)
print(f"Input duration: {raw_duration:.1f}s (rounded to {duration}s)")
# -----------------------------------------------------------------------
# Build caption
# -----------------------------------------------------------------------
caption = build_caption(args.instrument)
print(f"Caption: {caption}")
# -----------------------------------------------------------------------
# Create output directory
# -----------------------------------------------------------------------
output_dir = os.path.abspath(args.output)
os.makedirs(output_dir, exist_ok=True)
# -----------------------------------------------------------------------
# Initialize ACE-Step
# -----------------------------------------------------------------------
print("Loading ACE-Step model...")
script_dir = os.path.dirname(os.path.abspath(__file__))
ace_step_dir = os.path.join(script_dir, "ace-step")
sys.path.insert(0, ace_step_dir)
from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.inference import GenerationParams, GenerationConfig, generate_music
from acestep.gpu_config import get_gpu_config, set_global_gpu_config
gpu_config = get_gpu_config()
set_global_gpu_config(gpu_config)
dit_handler = AceStepHandler()
llm_handler = LLMHandler()
dit_handler.initialize_service(
project_root=ace_step_dir,
config_path="acestep-v15-xl-sft",
device="cuda",
use_flash_attention=dit_handler.is_flash_attention_available("cuda"),
)
print("Model loaded successfully.")
# -----------------------------------------------------------------------
# Configure generation
# -----------------------------------------------------------------------
params = GenerationParams(
task_type="cover",
src_audio=input_path,
caption=caption,
lyrics="",
instrumental=True,
duration=duration,
bpm=120,
audio_cover_strength=args.strength,
inference_steps=50,
guidance_scale=5.0,
thinking=False,
)
config = GenerationConfig(
batch_size=1,
use_random_seed=True,
audio_format="wav",
)
# Use a temporary directory for ACE-Step's UUID-named output
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
# -----------------------------------------------------------------------
# Generate
# -----------------------------------------------------------------------
print(f"Generating {args.instrument} cover...")
try:
result = generate_music(
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
)
except Exception as e:
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
sys.exit(1)
# -----------------------------------------------------------------------
# Check result
# -----------------------------------------------------------------------
if not result.success or not result.audios:
error_msg = getattr(result, "error", None) or "no audio produced"
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
sys.exit(1)
# Get the generated file path
generated_path = result.audios[0]["path"]
if not os.path.isfile(generated_path):
print(
f"ERROR: Expected output file not found: {generated_path}",
file=sys.stderr,
)
sys.exit(1)
# -----------------------------------------------------------------------
# Silence detection
# -----------------------------------------------------------------------
if check_silence(generated_path):
print(
"WARNING: Output audio appears to be near-silent. "
"The generation may have failed.",
file=sys.stderr,
)
# -----------------------------------------------------------------------
# Rename output to user-friendly filename
# -----------------------------------------------------------------------
instrument_clean = args.instrument.lower().strip().replace(" ", "-")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_filename = f"{instrument_clean}_{timestamp}.wav"
final_path = os.path.join(output_dir, final_filename)
shutil.copy2(generated_path, final_path)
# Clean up temp directory
try:
shutil.rmtree(temp_save_dir)
except OSError:
pass
print(f"Output saved: {final_path}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nInterrupted.", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
sys.exit(1)