diff --git a/hum2inst.py b/hum2inst.py new file mode 100644 index 0000000..710cb44 --- /dev/null +++ b/hum2inst.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +""" +hum2inst.py - Convert humming to instrument audio using ACE-Step XL-SFT cover mode. + +Usage: + python hum2inst.py input.wav --instrument piano + python hum2inst.py input.wav --instrument guitar --strength 0.85 --output ./my_output/ + python hum2inst.py input.wav --instrument saxophone --duration 15 +""" + +import argparse +import math +import os +import shutil +import sys +import tempfile +from datetime import datetime +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Caption templates for common instruments +# --------------------------------------------------------------------------- +CAPTION_TEMPLATES = { + "piano": "solo acoustic piano, gentle melody, warm tone, clear and expressive", + "guitar": "solo acoustic guitar, fingerpicked melody, warm and intimate", + "saxophone": "solo saxophone, smooth jazz melody, soulful and expressive", + "violin": "solo violin, classical melody, rich and emotional", + "flute": "solo flute, gentle melody, airy and delicate", +} + + +def build_caption(instrument: str) -> str: + """Build an ACE-Step caption from the instrument name.""" + instrument_lower = instrument.lower().strip() + if instrument_lower in CAPTION_TEMPLATES: + return CAPTION_TEMPLATES[instrument_lower] + return f"solo {instrument_lower}, clear and expressive melody, warm tone" + + +def get_wav_duration(wav_path: str) -> float: + """Return the duration of a WAV file in seconds using torchaudio.""" + import torchaudio + + info = torchaudio.info(wav_path) + return info.num_frames / info.sample_rate + + +def check_silence(audio_path: str, threshold_db: float = -60.0) -> bool: + """Check if the output audio is near-silent. + + Returns True if the audio is below the threshold (likely silent/failed). + """ + import torch + import torchaudio + + waveform, _sr = torchaudio.load(audio_path) + rms = torch.sqrt(torch.mean(waveform ** 2)) + if rms > 0: + rms_db = 20 * torch.log10(rms).item() + else: + rms_db = -float("inf") + return rms_db < threshold_db + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Convert a hummed melody to an instrument rendition using ACE-Step.", + epilog="Example: python hum2inst.py humming.wav --instrument piano", + ) + parser.add_argument( + "input", + type=str, + help="Path to the input WAV file containing humming", + ) + parser.add_argument( + "--instrument", + type=str, + required=True, + help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)", + ) + parser.add_argument( + "--output", + type=str, + default="./output/", + help="Output directory for the generated audio (default: ./output/)", + ) + parser.add_argument( + "--strength", + type=float, + default=0.9, + help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)", + ) + parser.add_argument( + "--duration", + type=float, + default=None, + help="Override output duration in seconds (default: auto-detect from input WAV)", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + # ----------------------------------------------------------------------- + # Validate input file + # ----------------------------------------------------------------------- + input_path = os.path.abspath(args.input) + if not os.path.isfile(input_path): + print(f"ERROR: Input file not found: {input_path}", file=sys.stderr) + sys.exit(1) + + # ----------------------------------------------------------------------- + # CUDA check + # ----------------------------------------------------------------------- + print("Checking GPU availability...") + import torch + + if not torch.cuda.is_available(): + print( + "ERROR: CUDA GPU required. No CUDA-capable GPU detected.", + file=sys.stderr, + ) + sys.exit(1) + print(f"GPU detected: {torch.cuda.get_device_name(0)}") + + # ----------------------------------------------------------------------- + # Detect input duration + # ----------------------------------------------------------------------- + if args.duration is not None: + duration = round(args.duration) + print(f"Using override duration: {duration}s") + else: + raw_duration = get_wav_duration(input_path) + duration = round(raw_duration) + print(f"Input duration: {raw_duration:.1f}s (rounded to {duration}s)") + + # ----------------------------------------------------------------------- + # Build caption + # ----------------------------------------------------------------------- + caption = build_caption(args.instrument) + print(f"Caption: {caption}") + + # ----------------------------------------------------------------------- + # Create output directory + # ----------------------------------------------------------------------- + output_dir = os.path.abspath(args.output) + os.makedirs(output_dir, exist_ok=True) + + # ----------------------------------------------------------------------- + # Initialize ACE-Step + # ----------------------------------------------------------------------- + print("Loading ACE-Step model...") + + script_dir = os.path.dirname(os.path.abspath(__file__)) + ace_step_dir = os.path.join(script_dir, "ace-step") + sys.path.insert(0, ace_step_dir) + + from acestep.handler import AceStepHandler + from acestep.llm_inference import LLMHandler + from acestep.inference import GenerationParams, GenerationConfig, generate_music + from acestep.gpu_config import get_gpu_config, set_global_gpu_config + + gpu_config = get_gpu_config() + set_global_gpu_config(gpu_config) + + dit_handler = AceStepHandler() + llm_handler = LLMHandler() + + dit_handler.initialize_service( + project_root=ace_step_dir, + config_path="acestep-v15-xl-sft", + device="cuda", + use_flash_attention=dit_handler.is_flash_attention_available("cuda"), + ) + print("Model loaded successfully.") + + # ----------------------------------------------------------------------- + # Configure generation + # ----------------------------------------------------------------------- + params = GenerationParams( + task_type="cover", + src_audio=input_path, + caption=caption, + lyrics="", + instrumental=True, + duration=duration, + bpm=120, + audio_cover_strength=args.strength, + inference_steps=50, + guidance_scale=5.0, + thinking=False, + ) + + config = GenerationConfig( + batch_size=1, + use_random_seed=True, + audio_format="wav", + ) + + # Use a temporary directory for ACE-Step's UUID-named output + temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_") + + # ----------------------------------------------------------------------- + # Generate + # ----------------------------------------------------------------------- + print(f"Generating {args.instrument} cover...") + try: + result = generate_music( + dit_handler, llm_handler, params, config, save_dir=temp_save_dir + ) + except Exception as e: + print(f"ERROR: Generation failed: {e}", file=sys.stderr) + sys.exit(1) + + # ----------------------------------------------------------------------- + # Check result + # ----------------------------------------------------------------------- + if not result.success or not result.audios: + error_msg = getattr(result, "error", None) or "no audio produced" + print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr) + sys.exit(1) + + # Get the generated file path + generated_path = result.audios[0]["path"] + + if not os.path.isfile(generated_path): + print( + f"ERROR: Expected output file not found: {generated_path}", + file=sys.stderr, + ) + sys.exit(1) + + # ----------------------------------------------------------------------- + # Silence detection + # ----------------------------------------------------------------------- + if check_silence(generated_path): + print( + "WARNING: Output audio appears to be near-silent. " + "The generation may have failed.", + file=sys.stderr, + ) + + # ----------------------------------------------------------------------- + # Rename output to user-friendly filename + # ----------------------------------------------------------------------- + instrument_clean = args.instrument.lower().strip().replace(" ", "-") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + final_filename = f"{instrument_clean}_{timestamp}.wav" + final_path = os.path.join(output_dir, final_filename) + + shutil.copy2(generated_path, final_path) + + # Clean up temp directory + try: + shutil.rmtree(temp_save_dir) + except OSError: + pass + + print(f"Output saved: {final_path}") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nInterrupted.", file=sys.stderr) + sys.exit(1) + except Exception as e: + print(f"ERROR: Unexpected error: {e}", file=sys.stderr) + sys.exit(1)