feat(01-02): add generation tuning params, seed control, multi-take, and run logging

Tuning session revealed strength=0.3 as optimal default for cover mode.
Added --caption, --noise-strength, --guidance, --steps, --shift, --sampler,
--vel-clamp, --vel-ema, --seed, --takes flags. Each output now has a JSON
sidecar logging all parameters for reproducibility.
This commit is contained in:
John Lightner 2026-04-11 03:24:27 -05:00
parent f40ca2b3fb
commit 45f6863131

View file

@ -9,6 +9,7 @@ Usage:
""" """
import argparse import argparse
import json
import math import math
import os import os
import shutil import shutil
@ -80,6 +81,12 @@ def parse_args() -> argparse.Namespace:
required=True, required=True,
help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)", help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)",
) )
parser.add_argument(
"--caption",
type=str,
default=None,
help="Override the auto-generated caption (default: built from instrument name)",
)
parser.add_argument( parser.add_argument(
"--output", "--output",
type=str, type=str,
@ -89,8 +96,14 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--strength", "--strength",
type=float, type=float,
default=0.9, default=0.3,
help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)", help="Audio cover strength (0.0-1.0, fraction of steps using cover conditioning, default: 0.3)",
)
parser.add_argument(
"--noise-strength",
type=float,
default=0.0,
help="Cover noise strength (0.0-1.0, 0=pure noise start, 1=closest to source audio, default: 0.0)",
) )
parser.add_argument( parser.add_argument(
"--duration", "--duration",
@ -98,6 +111,55 @@ def parse_args() -> argparse.Namespace:
default=None, default=None,
help="Override output duration in seconds (default: auto-detect from input WAV)", help="Override output duration in seconds (default: auto-detect from input WAV)",
) )
parser.add_argument(
"--guidance",
type=float,
default=5.0,
help="Classifier-free guidance scale — higher = follows caption more strictly (default: 5.0)",
)
parser.add_argument(
"--steps",
type=int,
default=50,
help="Number of inference/denoising steps (default: 50)",
)
parser.add_argument(
"--shift",
type=float,
default=1.0,
help="Timestep shift factor — warps denoising schedule (default: 1.0)",
)
parser.add_argument(
"--sampler",
type=str,
default="euler",
choices=["euler", "heun"],
help="Sampler mode: euler (fast) or heun (2nd-order, more accurate) (default: euler)",
)
parser.add_argument(
"--vel-clamp",
type=float,
default=0.0,
help="Velocity norm threshold — clamps outlier predictions to reduce artifacts (0=off, try 2.0) (default: 0.0)",
)
parser.add_argument(
"--vel-ema",
type=float,
default=0.0,
help="Velocity EMA smoothing — smooths predictions across steps (0=off, try 0.1) (default: 0.0)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducible results (default: random)",
)
parser.add_argument(
"--takes",
type=int,
default=1,
help="Number of takes to generate with different seeds (default: 1)",
)
return parser.parse_args() return parser.parse_args()
@ -140,7 +202,7 @@ def main() -> None:
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
# Build caption # Build caption
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
caption = build_caption(args.instrument) caption = args.caption if args.caption else build_caption(args.instrument)
print(f"Caption: {caption}") print(f"Caption: {caption}")
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
@ -178,88 +240,120 @@ def main() -> None:
print("Model loaded successfully.") print("Model loaded successfully.")
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
# Configure generation # Generate takes
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
params = GenerationParams( import random
task_type="cover",
src_audio=input_path,
caption=caption,
lyrics="",
instrumental=True,
duration=duration,
bpm=120,
audio_cover_strength=args.strength,
inference_steps=50,
guidance_scale=5.0,
thinking=False,
)
config = GenerationConfig(
batch_size=1,
use_random_seed=True,
audio_format="wav",
)
# Use a temporary directory for ACE-Step's UUID-named output
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
# -----------------------------------------------------------------------
# Generate
# -----------------------------------------------------------------------
print(f"Generating {args.instrument} cover...")
try:
result = generate_music(
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
)
except Exception as e:
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
sys.exit(1)
# -----------------------------------------------------------------------
# Check result
# -----------------------------------------------------------------------
if not result.success or not result.audios:
error_msg = getattr(result, "error", None) or "no audio produced"
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
sys.exit(1)
# Get the generated file path
generated_path = result.audios[0]["path"]
if not os.path.isfile(generated_path):
print(
f"ERROR: Expected output file not found: {generated_path}",
file=sys.stderr,
)
sys.exit(1)
# -----------------------------------------------------------------------
# Silence detection
# -----------------------------------------------------------------------
if check_silence(generated_path):
print(
"WARNING: Output audio appears to be near-silent. "
"The generation may have failed.",
file=sys.stderr,
)
# -----------------------------------------------------------------------
# Rename output to user-friendly filename
# -----------------------------------------------------------------------
instrument_clean = args.instrument.lower().strip().replace(" ", "-") instrument_clean = args.instrument.lower().strip().replace(" ", "-")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") num_takes = args.takes
final_filename = f"{instrument_clean}_{timestamp}.wav" seeds = []
final_path = os.path.join(output_dir, final_filename)
shutil.copy2(generated_path, final_path) if args.seed is not None:
# Explicit seed: use it for take 1, derive sequential seeds for additional takes
seeds = [args.seed + i for i in range(num_takes)]
else:
seeds = [random.randint(0, 2**31 - 1) for _ in range(num_takes)]
# Clean up temp directory for take_idx, seed in enumerate(seeds):
try: take_label = f"[take {take_idx + 1}/{num_takes}]" if num_takes > 1 else ""
shutil.rmtree(temp_save_dir)
except OSError:
pass
print(f"Output saved: {final_path}") params = GenerationParams(
task_type="cover",
src_audio=input_path,
caption=caption,
lyrics="",
instrumental=True,
duration=duration,
bpm=120,
audio_cover_strength=args.strength,
cover_noise_strength=args.noise_strength,
inference_steps=args.steps,
guidance_scale=args.guidance,
shift=args.shift,
sampler_mode=args.sampler,
velocity_norm_threshold=args.vel_clamp,
velocity_ema_factor=args.vel_ema,
thinking=False,
seed=seed,
)
config = GenerationConfig(
batch_size=1,
use_random_seed=False,
audio_format="wav",
)
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
print(f"Generating {args.instrument} cover (seed={seed}) {take_label}...")
try:
result = generate_music(
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
)
except Exception as e:
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
if num_takes == 1:
sys.exit(1)
continue
if not result.success or not result.audios:
error_msg = getattr(result, "error", None) or "no audio produced"
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
if num_takes == 1:
sys.exit(1)
continue
generated_path = result.audios[0]["path"]
if not os.path.isfile(generated_path):
print(f"ERROR: Expected output file not found: {generated_path}", file=sys.stderr)
if num_takes == 1:
sys.exit(1)
continue
if check_silence(generated_path):
print(
"WARNING: Output audio appears to be near-silent. "
"The generation may have failed.",
file=sys.stderr,
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_filename = f"{instrument_clean}_{timestamp}_s{seed}.wav"
final_path = os.path.join(output_dir, final_filename)
shutil.copy2(generated_path, final_path)
log_path = final_path.replace(".wav", ".json")
run_log = {
"timestamp": timestamp,
"input": input_path,
"output": final_path,
"instrument": args.instrument,
"caption": caption,
"duration": duration,
"strength": args.strength,
"noise_strength": args.noise_strength,
"inference_steps": args.steps,
"guidance_scale": args.guidance,
"shift": args.shift,
"sampler": args.sampler,
"vel_clamp": args.vel_clamp,
"vel_ema": args.vel_ema,
"seed": seed,
"take": take_idx + 1,
"total_takes": num_takes,
}
with open(log_path, "w") as f:
json.dump(run_log, f, indent=2)
try:
shutil.rmtree(temp_save_dir)
except OSError:
pass
print(f"Output saved: {final_path}")
print(f"Run log: {log_path}")
if __name__ == "__main__": if __name__ == "__main__":