diff --git a/hum2inst.py b/hum2inst.py index 710cb44..530c2c9 100644 --- a/hum2inst.py +++ b/hum2inst.py @@ -9,6 +9,7 @@ Usage: """ import argparse +import json import math import os import shutil @@ -80,6 +81,12 @@ def parse_args() -> argparse.Namespace: required=True, help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)", ) + parser.add_argument( + "--caption", + type=str, + default=None, + help="Override the auto-generated caption (default: built from instrument name)", + ) parser.add_argument( "--output", type=str, @@ -89,8 +96,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--strength", type=float, - default=0.9, - help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)", + default=0.3, + help="Audio cover strength (0.0-1.0, fraction of steps using cover conditioning, default: 0.3)", + ) + parser.add_argument( + "--noise-strength", + type=float, + default=0.0, + help="Cover noise strength (0.0-1.0, 0=pure noise start, 1=closest to source audio, default: 0.0)", ) parser.add_argument( "--duration", @@ -98,6 +111,55 @@ def parse_args() -> argparse.Namespace: default=None, help="Override output duration in seconds (default: auto-detect from input WAV)", ) + parser.add_argument( + "--guidance", + type=float, + default=5.0, + help="Classifier-free guidance scale — higher = follows caption more strictly (default: 5.0)", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference/denoising steps (default: 50)", + ) + parser.add_argument( + "--shift", + type=float, + default=1.0, + help="Timestep shift factor — warps denoising schedule (default: 1.0)", + ) + parser.add_argument( + "--sampler", + type=str, + default="euler", + choices=["euler", "heun"], + help="Sampler mode: euler (fast) or heun (2nd-order, more accurate) (default: euler)", + ) + parser.add_argument( + "--vel-clamp", + type=float, + default=0.0, + help="Velocity norm threshold — clamps outlier predictions to reduce artifacts (0=off, try 2.0) (default: 0.0)", + ) + parser.add_argument( + "--vel-ema", + type=float, + default=0.0, + help="Velocity EMA smoothing — smooths predictions across steps (0=off, try 0.1) (default: 0.0)", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducible results (default: random)", + ) + parser.add_argument( + "--takes", + type=int, + default=1, + help="Number of takes to generate with different seeds (default: 1)", + ) return parser.parse_args() @@ -140,7 +202,7 @@ def main() -> None: # ----------------------------------------------------------------------- # Build caption # ----------------------------------------------------------------------- - caption = build_caption(args.instrument) + caption = args.caption if args.caption else build_caption(args.instrument) print(f"Caption: {caption}") # ----------------------------------------------------------------------- @@ -178,88 +240,120 @@ def main() -> None: print("Model loaded successfully.") # ----------------------------------------------------------------------- - # Configure generation + # Generate takes # ----------------------------------------------------------------------- - params = GenerationParams( - task_type="cover", - src_audio=input_path, - caption=caption, - lyrics="", - instrumental=True, - duration=duration, - bpm=120, - audio_cover_strength=args.strength, - inference_steps=50, - guidance_scale=5.0, - thinking=False, - ) + import random - config = GenerationConfig( - batch_size=1, - use_random_seed=True, - audio_format="wav", - ) - - # Use a temporary directory for ACE-Step's UUID-named output - temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_") - - # ----------------------------------------------------------------------- - # Generate - # ----------------------------------------------------------------------- - print(f"Generating {args.instrument} cover...") - try: - result = generate_music( - dit_handler, llm_handler, params, config, save_dir=temp_save_dir - ) - except Exception as e: - print(f"ERROR: Generation failed: {e}", file=sys.stderr) - sys.exit(1) - - # ----------------------------------------------------------------------- - # Check result - # ----------------------------------------------------------------------- - if not result.success or not result.audios: - error_msg = getattr(result, "error", None) or "no audio produced" - print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr) - sys.exit(1) - - # Get the generated file path - generated_path = result.audios[0]["path"] - - if not os.path.isfile(generated_path): - print( - f"ERROR: Expected output file not found: {generated_path}", - file=sys.stderr, - ) - sys.exit(1) - - # ----------------------------------------------------------------------- - # Silence detection - # ----------------------------------------------------------------------- - if check_silence(generated_path): - print( - "WARNING: Output audio appears to be near-silent. " - "The generation may have failed.", - file=sys.stderr, - ) - - # ----------------------------------------------------------------------- - # Rename output to user-friendly filename - # ----------------------------------------------------------------------- instrument_clean = args.instrument.lower().strip().replace(" ", "-") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - final_filename = f"{instrument_clean}_{timestamp}.wav" - final_path = os.path.join(output_dir, final_filename) + num_takes = args.takes + seeds = [] - shutil.copy2(generated_path, final_path) + if args.seed is not None: + # Explicit seed: use it for take 1, derive sequential seeds for additional takes + seeds = [args.seed + i for i in range(num_takes)] + else: + seeds = [random.randint(0, 2**31 - 1) for _ in range(num_takes)] - # Clean up temp directory - try: - shutil.rmtree(temp_save_dir) - except OSError: - pass + for take_idx, seed in enumerate(seeds): + take_label = f"[take {take_idx + 1}/{num_takes}]" if num_takes > 1 else "" - print(f"Output saved: {final_path}") + params = GenerationParams( + task_type="cover", + src_audio=input_path, + caption=caption, + lyrics="", + instrumental=True, + duration=duration, + bpm=120, + audio_cover_strength=args.strength, + cover_noise_strength=args.noise_strength, + inference_steps=args.steps, + guidance_scale=args.guidance, + shift=args.shift, + sampler_mode=args.sampler, + velocity_norm_threshold=args.vel_clamp, + velocity_ema_factor=args.vel_ema, + thinking=False, + seed=seed, + ) + + config = GenerationConfig( + batch_size=1, + use_random_seed=False, + audio_format="wav", + ) + + temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_") + + print(f"Generating {args.instrument} cover (seed={seed}) {take_label}...") + try: + result = generate_music( + dit_handler, llm_handler, params, config, save_dir=temp_save_dir + ) + except Exception as e: + print(f"ERROR: Generation failed: {e}", file=sys.stderr) + if num_takes == 1: + sys.exit(1) + continue + + if not result.success or not result.audios: + error_msg = getattr(result, "error", None) or "no audio produced" + print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr) + if num_takes == 1: + sys.exit(1) + continue + + generated_path = result.audios[0]["path"] + + if not os.path.isfile(generated_path): + print(f"ERROR: Expected output file not found: {generated_path}", file=sys.stderr) + if num_takes == 1: + sys.exit(1) + continue + + if check_silence(generated_path): + print( + "WARNING: Output audio appears to be near-silent. " + "The generation may have failed.", + file=sys.stderr, + ) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + final_filename = f"{instrument_clean}_{timestamp}_s{seed}.wav" + final_path = os.path.join(output_dir, final_filename) + + shutil.copy2(generated_path, final_path) + + log_path = final_path.replace(".wav", ".json") + run_log = { + "timestamp": timestamp, + "input": input_path, + "output": final_path, + "instrument": args.instrument, + "caption": caption, + "duration": duration, + "strength": args.strength, + "noise_strength": args.noise_strength, + "inference_steps": args.steps, + "guidance_scale": args.guidance, + "shift": args.shift, + "sampler": args.sampler, + "vel_clamp": args.vel_clamp, + "vel_ema": args.vel_ema, + "seed": seed, + "take": take_idx + 1, + "total_takes": num_takes, + } + with open(log_path, "w") as f: + json.dump(run_log, f, indent=2) + + try: + shutil.rmtree(temp_save_dir) + except OSError: + pass + + print(f"Output saved: {final_path}") + print(f"Run log: {log_path}") if __name__ == "__main__":