- Single-file CLI wrapping ACE-Step XL-SFT cover mode - argparse with --instrument, --output, --strength, --duration flags - Auto-detect input WAV duration via torchaudio - Caption templates for piano, guitar, saxophone, violin, flute - CUDA GPU check with clear error message - Silence detection on output audio - User-friendly output naming with instrument + timestamp
273 lines
9.2 KiB
Python
273 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
hum2inst.py - Convert humming to instrument audio using ACE-Step XL-SFT cover mode.
|
|
|
|
Usage:
|
|
python hum2inst.py input.wav --instrument piano
|
|
python hum2inst.py input.wav --instrument guitar --strength 0.85 --output ./my_output/
|
|
python hum2inst.py input.wav --instrument saxophone --duration 15
|
|
"""
|
|
|
|
import argparse
|
|
import math
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Caption templates for common instruments
|
|
# ---------------------------------------------------------------------------
|
|
CAPTION_TEMPLATES = {
|
|
"piano": "solo acoustic piano, gentle melody, warm tone, clear and expressive",
|
|
"guitar": "solo acoustic guitar, fingerpicked melody, warm and intimate",
|
|
"saxophone": "solo saxophone, smooth jazz melody, soulful and expressive",
|
|
"violin": "solo violin, classical melody, rich and emotional",
|
|
"flute": "solo flute, gentle melody, airy and delicate",
|
|
}
|
|
|
|
|
|
def build_caption(instrument: str) -> str:
|
|
"""Build an ACE-Step caption from the instrument name."""
|
|
instrument_lower = instrument.lower().strip()
|
|
if instrument_lower in CAPTION_TEMPLATES:
|
|
return CAPTION_TEMPLATES[instrument_lower]
|
|
return f"solo {instrument_lower}, clear and expressive melody, warm tone"
|
|
|
|
|
|
def get_wav_duration(wav_path: str) -> float:
|
|
"""Return the duration of a WAV file in seconds using torchaudio."""
|
|
import torchaudio
|
|
|
|
info = torchaudio.info(wav_path)
|
|
return info.num_frames / info.sample_rate
|
|
|
|
|
|
def check_silence(audio_path: str, threshold_db: float = -60.0) -> bool:
|
|
"""Check if the output audio is near-silent.
|
|
|
|
Returns True if the audio is below the threshold (likely silent/failed).
|
|
"""
|
|
import torch
|
|
import torchaudio
|
|
|
|
waveform, _sr = torchaudio.load(audio_path)
|
|
rms = torch.sqrt(torch.mean(waveform ** 2))
|
|
if rms > 0:
|
|
rms_db = 20 * torch.log10(rms).item()
|
|
else:
|
|
rms_db = -float("inf")
|
|
return rms_db < threshold_db
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
"""Parse command-line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert a hummed melody to an instrument rendition using ACE-Step.",
|
|
epilog="Example: python hum2inst.py humming.wav --instrument piano",
|
|
)
|
|
parser.add_argument(
|
|
"input",
|
|
type=str,
|
|
help="Path to the input WAV file containing humming",
|
|
)
|
|
parser.add_argument(
|
|
"--instrument",
|
|
type=str,
|
|
required=True,
|
|
help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="./output/",
|
|
help="Output directory for the generated audio (default: ./output/)",
|
|
)
|
|
parser.add_argument(
|
|
"--strength",
|
|
type=float,
|
|
default=0.9,
|
|
help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)",
|
|
)
|
|
parser.add_argument(
|
|
"--duration",
|
|
type=float,
|
|
default=None,
|
|
help="Override output duration in seconds (default: auto-detect from input WAV)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Validate input file
|
|
# -----------------------------------------------------------------------
|
|
input_path = os.path.abspath(args.input)
|
|
if not os.path.isfile(input_path):
|
|
print(f"ERROR: Input file not found: {input_path}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# CUDA check
|
|
# -----------------------------------------------------------------------
|
|
print("Checking GPU availability...")
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
print(
|
|
"ERROR: CUDA GPU required. No CUDA-capable GPU detected.",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
print(f"GPU detected: {torch.cuda.get_device_name(0)}")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Detect input duration
|
|
# -----------------------------------------------------------------------
|
|
if args.duration is not None:
|
|
duration = round(args.duration)
|
|
print(f"Using override duration: {duration}s")
|
|
else:
|
|
raw_duration = get_wav_duration(input_path)
|
|
duration = round(raw_duration)
|
|
print(f"Input duration: {raw_duration:.1f}s (rounded to {duration}s)")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Build caption
|
|
# -----------------------------------------------------------------------
|
|
caption = build_caption(args.instrument)
|
|
print(f"Caption: {caption}")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Create output directory
|
|
# -----------------------------------------------------------------------
|
|
output_dir = os.path.abspath(args.output)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Initialize ACE-Step
|
|
# -----------------------------------------------------------------------
|
|
print("Loading ACE-Step model...")
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
ace_step_dir = os.path.join(script_dir, "ace-step")
|
|
sys.path.insert(0, ace_step_dir)
|
|
|
|
from acestep.handler import AceStepHandler
|
|
from acestep.llm_inference import LLMHandler
|
|
from acestep.inference import GenerationParams, GenerationConfig, generate_music
|
|
from acestep.gpu_config import get_gpu_config, set_global_gpu_config
|
|
|
|
gpu_config = get_gpu_config()
|
|
set_global_gpu_config(gpu_config)
|
|
|
|
dit_handler = AceStepHandler()
|
|
llm_handler = LLMHandler()
|
|
|
|
dit_handler.initialize_service(
|
|
project_root=ace_step_dir,
|
|
config_path="acestep-v15-xl-sft",
|
|
device="cuda",
|
|
use_flash_attention=dit_handler.is_flash_attention_available("cuda"),
|
|
)
|
|
print("Model loaded successfully.")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Configure generation
|
|
# -----------------------------------------------------------------------
|
|
params = GenerationParams(
|
|
task_type="cover",
|
|
src_audio=input_path,
|
|
caption=caption,
|
|
lyrics="",
|
|
instrumental=True,
|
|
duration=duration,
|
|
bpm=120,
|
|
audio_cover_strength=args.strength,
|
|
inference_steps=50,
|
|
guidance_scale=5.0,
|
|
thinking=False,
|
|
)
|
|
|
|
config = GenerationConfig(
|
|
batch_size=1,
|
|
use_random_seed=True,
|
|
audio_format="wav",
|
|
)
|
|
|
|
# Use a temporary directory for ACE-Step's UUID-named output
|
|
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Generate
|
|
# -----------------------------------------------------------------------
|
|
print(f"Generating {args.instrument} cover...")
|
|
try:
|
|
result = generate_music(
|
|
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
|
|
)
|
|
except Exception as e:
|
|
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Check result
|
|
# -----------------------------------------------------------------------
|
|
if not result.success or not result.audios:
|
|
error_msg = getattr(result, "error", None) or "no audio produced"
|
|
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Get the generated file path
|
|
generated_path = result.audios[0]["path"]
|
|
|
|
if not os.path.isfile(generated_path):
|
|
print(
|
|
f"ERROR: Expected output file not found: {generated_path}",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Silence detection
|
|
# -----------------------------------------------------------------------
|
|
if check_silence(generated_path):
|
|
print(
|
|
"WARNING: Output audio appears to be near-silent. "
|
|
"The generation may have failed.",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Rename output to user-friendly filename
|
|
# -----------------------------------------------------------------------
|
|
instrument_clean = args.instrument.lower().strip().replace(" ", "-")
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
final_filename = f"{instrument_clean}_{timestamp}.wav"
|
|
final_path = os.path.join(output_dir, final_filename)
|
|
|
|
shutil.copy2(generated_path, final_path)
|
|
|
|
# Clean up temp directory
|
|
try:
|
|
shutil.rmtree(temp_save_dir)
|
|
except OSError:
|
|
pass
|
|
|
|
print(f"Output saved: {final_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
print("\nInterrupted.", file=sys.stderr)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|