#!/usr/bin/env python3 """ hum2inst.py - Convert humming to instrument audio using ACE-Step XL-SFT cover mode. Usage: python hum2inst.py input.wav --instrument piano python hum2inst.py input.wav --instrument guitar --strength 0.85 --output ./my_output/ python hum2inst.py input.wav --instrument saxophone --duration 15 """ import argparse import math import os import shutil import sys import tempfile from datetime import datetime from pathlib import Path # --------------------------------------------------------------------------- # Caption templates for common instruments # --------------------------------------------------------------------------- CAPTION_TEMPLATES = { "piano": "solo acoustic piano, gentle melody, warm tone, clear and expressive", "guitar": "solo acoustic guitar, fingerpicked melody, warm and intimate", "saxophone": "solo saxophone, smooth jazz melody, soulful and expressive", "violin": "solo violin, classical melody, rich and emotional", "flute": "solo flute, gentle melody, airy and delicate", } def build_caption(instrument: str) -> str: """Build an ACE-Step caption from the instrument name.""" instrument_lower = instrument.lower().strip() if instrument_lower in CAPTION_TEMPLATES: return CAPTION_TEMPLATES[instrument_lower] return f"solo {instrument_lower}, clear and expressive melody, warm tone" def get_wav_duration(wav_path: str) -> float: """Return the duration of a WAV file in seconds using torchaudio.""" import torchaudio info = torchaudio.info(wav_path) return info.num_frames / info.sample_rate def check_silence(audio_path: str, threshold_db: float = -60.0) -> bool: """Check if the output audio is near-silent. Returns True if the audio is below the threshold (likely silent/failed). """ import torch import torchaudio waveform, _sr = torchaudio.load(audio_path) rms = torch.sqrt(torch.mean(waveform ** 2)) if rms > 0: rms_db = 20 * torch.log10(rms).item() else: rms_db = -float("inf") return rms_db < threshold_db def parse_args() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Convert a hummed melody to an instrument rendition using ACE-Step.", epilog="Example: python hum2inst.py humming.wav --instrument piano", ) parser.add_argument( "input", type=str, help="Path to the input WAV file containing humming", ) parser.add_argument( "--instrument", type=str, required=True, help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)", ) parser.add_argument( "--output", type=str, default="./output/", help="Output directory for the generated audio (default: ./output/)", ) parser.add_argument( "--strength", type=float, default=0.9, help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)", ) parser.add_argument( "--duration", type=float, default=None, help="Override output duration in seconds (default: auto-detect from input WAV)", ) return parser.parse_args() def main() -> None: args = parse_args() # ----------------------------------------------------------------------- # Validate input file # ----------------------------------------------------------------------- input_path = os.path.abspath(args.input) if not os.path.isfile(input_path): print(f"ERROR: Input file not found: {input_path}", file=sys.stderr) sys.exit(1) # ----------------------------------------------------------------------- # CUDA check # ----------------------------------------------------------------------- print("Checking GPU availability...") import torch if not torch.cuda.is_available(): print( "ERROR: CUDA GPU required. No CUDA-capable GPU detected.", file=sys.stderr, ) sys.exit(1) print(f"GPU detected: {torch.cuda.get_device_name(0)}") # ----------------------------------------------------------------------- # Detect input duration # ----------------------------------------------------------------------- if args.duration is not None: duration = round(args.duration) print(f"Using override duration: {duration}s") else: raw_duration = get_wav_duration(input_path) duration = round(raw_duration) print(f"Input duration: {raw_duration:.1f}s (rounded to {duration}s)") # ----------------------------------------------------------------------- # Build caption # ----------------------------------------------------------------------- caption = build_caption(args.instrument) print(f"Caption: {caption}") # ----------------------------------------------------------------------- # Create output directory # ----------------------------------------------------------------------- output_dir = os.path.abspath(args.output) os.makedirs(output_dir, exist_ok=True) # ----------------------------------------------------------------------- # Initialize ACE-Step # ----------------------------------------------------------------------- print("Loading ACE-Step model...") script_dir = os.path.dirname(os.path.abspath(__file__)) ace_step_dir = os.path.join(script_dir, "ace-step") sys.path.insert(0, ace_step_dir) from acestep.handler import AceStepHandler from acestep.llm_inference import LLMHandler from acestep.inference import GenerationParams, GenerationConfig, generate_music from acestep.gpu_config import get_gpu_config, set_global_gpu_config gpu_config = get_gpu_config() set_global_gpu_config(gpu_config) dit_handler = AceStepHandler() llm_handler = LLMHandler() dit_handler.initialize_service( project_root=ace_step_dir, config_path="acestep-v15-xl-sft", device="cuda", use_flash_attention=dit_handler.is_flash_attention_available("cuda"), ) print("Model loaded successfully.") # ----------------------------------------------------------------------- # Configure generation # ----------------------------------------------------------------------- params = GenerationParams( task_type="cover", src_audio=input_path, caption=caption, lyrics="", instrumental=True, duration=duration, bpm=120, audio_cover_strength=args.strength, inference_steps=50, guidance_scale=5.0, thinking=False, ) config = GenerationConfig( batch_size=1, use_random_seed=True, audio_format="wav", ) # Use a temporary directory for ACE-Step's UUID-named output temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_") # ----------------------------------------------------------------------- # Generate # ----------------------------------------------------------------------- print(f"Generating {args.instrument} cover...") try: result = generate_music( dit_handler, llm_handler, params, config, save_dir=temp_save_dir ) except Exception as e: print(f"ERROR: Generation failed: {e}", file=sys.stderr) sys.exit(1) # ----------------------------------------------------------------------- # Check result # ----------------------------------------------------------------------- if not result.success or not result.audios: error_msg = getattr(result, "error", None) or "no audio produced" print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr) sys.exit(1) # Get the generated file path generated_path = result.audios[0]["path"] if not os.path.isfile(generated_path): print( f"ERROR: Expected output file not found: {generated_path}", file=sys.stderr, ) sys.exit(1) # ----------------------------------------------------------------------- # Silence detection # ----------------------------------------------------------------------- if check_silence(generated_path): print( "WARNING: Output audio appears to be near-silent. " "The generation may have failed.", file=sys.stderr, ) # ----------------------------------------------------------------------- # Rename output to user-friendly filename # ----------------------------------------------------------------------- instrument_clean = args.instrument.lower().strip().replace(" ", "-") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") final_filename = f"{instrument_clean}_{timestamp}.wav" final_path = os.path.join(output_dir, final_filename) shutil.copy2(generated_path, final_path) # Clean up temp directory try: shutil.rmtree(temp_save_dir) except OSError: pass print(f"Output saved: {final_path}") if __name__ == "__main__": try: main() except KeyboardInterrupt: print("\nInterrupted.", file=sys.stderr) sys.exit(1) except Exception as e: print(f"ERROR: Unexpected error: {e}", file=sys.stderr) sys.exit(1)