Tuning session revealed strength=0.3 as optimal default for cover mode. Added --caption, --noise-strength, --guidance, --steps, --shift, --sampler, --vel-clamp, --vel-ema, --seed, --takes flags. Each output now has a JSON sidecar logging all parameters for reproducibility.
367 lines
12 KiB
Python
367 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
hum2inst.py - Convert humming to instrument audio using ACE-Step XL-SFT cover mode.
|
|
|
|
Usage:
|
|
python hum2inst.py input.wav --instrument piano
|
|
python hum2inst.py input.wav --instrument guitar --strength 0.85 --output ./my_output/
|
|
python hum2inst.py input.wav --instrument saxophone --duration 15
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Caption templates for common instruments
|
|
# ---------------------------------------------------------------------------
|
|
CAPTION_TEMPLATES = {
|
|
"piano": "solo acoustic piano, gentle melody, warm tone, clear and expressive",
|
|
"guitar": "solo acoustic guitar, fingerpicked melody, warm and intimate",
|
|
"saxophone": "solo saxophone, smooth jazz melody, soulful and expressive",
|
|
"violin": "solo violin, classical melody, rich and emotional",
|
|
"flute": "solo flute, gentle melody, airy and delicate",
|
|
}
|
|
|
|
|
|
def build_caption(instrument: str) -> str:
|
|
"""Build an ACE-Step caption from the instrument name."""
|
|
instrument_lower = instrument.lower().strip()
|
|
if instrument_lower in CAPTION_TEMPLATES:
|
|
return CAPTION_TEMPLATES[instrument_lower]
|
|
return f"solo {instrument_lower}, clear and expressive melody, warm tone"
|
|
|
|
|
|
def get_wav_duration(wav_path: str) -> float:
|
|
"""Return the duration of a WAV file in seconds using torchaudio."""
|
|
import torchaudio
|
|
|
|
info = torchaudio.info(wav_path)
|
|
return info.num_frames / info.sample_rate
|
|
|
|
|
|
def check_silence(audio_path: str, threshold_db: float = -60.0) -> bool:
|
|
"""Check if the output audio is near-silent.
|
|
|
|
Returns True if the audio is below the threshold (likely silent/failed).
|
|
"""
|
|
import torch
|
|
import torchaudio
|
|
|
|
waveform, _sr = torchaudio.load(audio_path)
|
|
rms = torch.sqrt(torch.mean(waveform ** 2))
|
|
if rms > 0:
|
|
rms_db = 20 * torch.log10(rms).item()
|
|
else:
|
|
rms_db = -float("inf")
|
|
return rms_db < threshold_db
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
"""Parse command-line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert a hummed melody to an instrument rendition using ACE-Step.",
|
|
epilog="Example: python hum2inst.py humming.wav --instrument piano",
|
|
)
|
|
parser.add_argument(
|
|
"input",
|
|
type=str,
|
|
help="Path to the input WAV file containing humming",
|
|
)
|
|
parser.add_argument(
|
|
"--instrument",
|
|
type=str,
|
|
required=True,
|
|
help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)",
|
|
)
|
|
parser.add_argument(
|
|
"--caption",
|
|
type=str,
|
|
default=None,
|
|
help="Override the auto-generated caption (default: built from instrument name)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="./output/",
|
|
help="Output directory for the generated audio (default: ./output/)",
|
|
)
|
|
parser.add_argument(
|
|
"--strength",
|
|
type=float,
|
|
default=0.3,
|
|
help="Audio cover strength (0.0-1.0, fraction of steps using cover conditioning, default: 0.3)",
|
|
)
|
|
parser.add_argument(
|
|
"--noise-strength",
|
|
type=float,
|
|
default=0.0,
|
|
help="Cover noise strength (0.0-1.0, 0=pure noise start, 1=closest to source audio, default: 0.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--duration",
|
|
type=float,
|
|
default=None,
|
|
help="Override output duration in seconds (default: auto-detect from input WAV)",
|
|
)
|
|
parser.add_argument(
|
|
"--guidance",
|
|
type=float,
|
|
default=5.0,
|
|
help="Classifier-free guidance scale — higher = follows caption more strictly (default: 5.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--steps",
|
|
type=int,
|
|
default=50,
|
|
help="Number of inference/denoising steps (default: 50)",
|
|
)
|
|
parser.add_argument(
|
|
"--shift",
|
|
type=float,
|
|
default=1.0,
|
|
help="Timestep shift factor — warps denoising schedule (default: 1.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--sampler",
|
|
type=str,
|
|
default="euler",
|
|
choices=["euler", "heun"],
|
|
help="Sampler mode: euler (fast) or heun (2nd-order, more accurate) (default: euler)",
|
|
)
|
|
parser.add_argument(
|
|
"--vel-clamp",
|
|
type=float,
|
|
default=0.0,
|
|
help="Velocity norm threshold — clamps outlier predictions to reduce artifacts (0=off, try 2.0) (default: 0.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--vel-ema",
|
|
type=float,
|
|
default=0.0,
|
|
help="Velocity EMA smoothing — smooths predictions across steps (0=off, try 0.1) (default: 0.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=None,
|
|
help="Random seed for reproducible results (default: random)",
|
|
)
|
|
parser.add_argument(
|
|
"--takes",
|
|
type=int,
|
|
default=1,
|
|
help="Number of takes to generate with different seeds (default: 1)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Validate input file
|
|
# -----------------------------------------------------------------------
|
|
input_path = os.path.abspath(args.input)
|
|
if not os.path.isfile(input_path):
|
|
print(f"ERROR: Input file not found: {input_path}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# CUDA check
|
|
# -----------------------------------------------------------------------
|
|
print("Checking GPU availability...")
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
print(
|
|
"ERROR: CUDA GPU required. No CUDA-capable GPU detected.",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
print(f"GPU detected: {torch.cuda.get_device_name(0)}")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Detect input duration
|
|
# -----------------------------------------------------------------------
|
|
if args.duration is not None:
|
|
duration = round(args.duration)
|
|
print(f"Using override duration: {duration}s")
|
|
else:
|
|
raw_duration = get_wav_duration(input_path)
|
|
duration = round(raw_duration)
|
|
print(f"Input duration: {raw_duration:.1f}s (rounded to {duration}s)")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Build caption
|
|
# -----------------------------------------------------------------------
|
|
caption = args.caption if args.caption else build_caption(args.instrument)
|
|
print(f"Caption: {caption}")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Create output directory
|
|
# -----------------------------------------------------------------------
|
|
output_dir = os.path.abspath(args.output)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Initialize ACE-Step
|
|
# -----------------------------------------------------------------------
|
|
print("Loading ACE-Step model...")
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
ace_step_dir = os.path.join(script_dir, "ace-step")
|
|
sys.path.insert(0, ace_step_dir)
|
|
|
|
from acestep.handler import AceStepHandler
|
|
from acestep.llm_inference import LLMHandler
|
|
from acestep.inference import GenerationParams, GenerationConfig, generate_music
|
|
from acestep.gpu_config import get_gpu_config, set_global_gpu_config
|
|
|
|
gpu_config = get_gpu_config()
|
|
set_global_gpu_config(gpu_config)
|
|
|
|
dit_handler = AceStepHandler()
|
|
llm_handler = LLMHandler()
|
|
|
|
dit_handler.initialize_service(
|
|
project_root=ace_step_dir,
|
|
config_path="acestep-v15-xl-sft",
|
|
device="cuda",
|
|
use_flash_attention=dit_handler.is_flash_attention_available("cuda"),
|
|
)
|
|
print("Model loaded successfully.")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Generate takes
|
|
# -----------------------------------------------------------------------
|
|
import random
|
|
|
|
instrument_clean = args.instrument.lower().strip().replace(" ", "-")
|
|
num_takes = args.takes
|
|
seeds = []
|
|
|
|
if args.seed is not None:
|
|
# Explicit seed: use it for take 1, derive sequential seeds for additional takes
|
|
seeds = [args.seed + i for i in range(num_takes)]
|
|
else:
|
|
seeds = [random.randint(0, 2**31 - 1) for _ in range(num_takes)]
|
|
|
|
for take_idx, seed in enumerate(seeds):
|
|
take_label = f"[take {take_idx + 1}/{num_takes}]" if num_takes > 1 else ""
|
|
|
|
params = GenerationParams(
|
|
task_type="cover",
|
|
src_audio=input_path,
|
|
caption=caption,
|
|
lyrics="",
|
|
instrumental=True,
|
|
duration=duration,
|
|
bpm=120,
|
|
audio_cover_strength=args.strength,
|
|
cover_noise_strength=args.noise_strength,
|
|
inference_steps=args.steps,
|
|
guidance_scale=args.guidance,
|
|
shift=args.shift,
|
|
sampler_mode=args.sampler,
|
|
velocity_norm_threshold=args.vel_clamp,
|
|
velocity_ema_factor=args.vel_ema,
|
|
thinking=False,
|
|
seed=seed,
|
|
)
|
|
|
|
config = GenerationConfig(
|
|
batch_size=1,
|
|
use_random_seed=False,
|
|
audio_format="wav",
|
|
)
|
|
|
|
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
|
|
|
|
print(f"Generating {args.instrument} cover (seed={seed}) {take_label}...")
|
|
try:
|
|
result = generate_music(
|
|
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
|
|
)
|
|
except Exception as e:
|
|
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
|
|
if num_takes == 1:
|
|
sys.exit(1)
|
|
continue
|
|
|
|
if not result.success or not result.audios:
|
|
error_msg = getattr(result, "error", None) or "no audio produced"
|
|
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
|
|
if num_takes == 1:
|
|
sys.exit(1)
|
|
continue
|
|
|
|
generated_path = result.audios[0]["path"]
|
|
|
|
if not os.path.isfile(generated_path):
|
|
print(f"ERROR: Expected output file not found: {generated_path}", file=sys.stderr)
|
|
if num_takes == 1:
|
|
sys.exit(1)
|
|
continue
|
|
|
|
if check_silence(generated_path):
|
|
print(
|
|
"WARNING: Output audio appears to be near-silent. "
|
|
"The generation may have failed.",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
final_filename = f"{instrument_clean}_{timestamp}_s{seed}.wav"
|
|
final_path = os.path.join(output_dir, final_filename)
|
|
|
|
shutil.copy2(generated_path, final_path)
|
|
|
|
log_path = final_path.replace(".wav", ".json")
|
|
run_log = {
|
|
"timestamp": timestamp,
|
|
"input": input_path,
|
|
"output": final_path,
|
|
"instrument": args.instrument,
|
|
"caption": caption,
|
|
"duration": duration,
|
|
"strength": args.strength,
|
|
"noise_strength": args.noise_strength,
|
|
"inference_steps": args.steps,
|
|
"guidance_scale": args.guidance,
|
|
"shift": args.shift,
|
|
"sampler": args.sampler,
|
|
"vel_clamp": args.vel_clamp,
|
|
"vel_ema": args.vel_ema,
|
|
"seed": seed,
|
|
"take": take_idx + 1,
|
|
"total_takes": num_takes,
|
|
}
|
|
with open(log_path, "w") as f:
|
|
json.dump(run_log, f, indent=2)
|
|
|
|
try:
|
|
shutil.rmtree(temp_save_dir)
|
|
except OSError:
|
|
pass
|
|
|
|
print(f"Output saved: {final_path}")
|
|
print(f"Run log: {log_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
print("\nInterrupted.", file=sys.stderr)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|