feat(01-02): add generation tuning params, seed control, multi-take, and run logging
Tuning session revealed strength=0.3 as optimal default for cover mode. Added --caption, --noise-strength, --guidance, --steps, --shift, --sampler, --vel-clamp, --vel-ema, --seed, --takes flags. Each output now has a JSON sidecar logging all parameters for reproducibility.
This commit is contained in:
parent
f40ca2b3fb
commit
45f6863131
1 changed files with 173 additions and 79 deletions
252
hum2inst.py
252
hum2inst.py
|
|
@ -9,6 +9,7 @@ Usage:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
@ -80,6 +81,12 @@ def parse_args() -> argparse.Namespace:
|
||||||
required=True,
|
required=True,
|
||||||
help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)",
|
help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--caption",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override the auto-generated caption (default: built from instrument name)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -89,8 +96,14 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--strength",
|
"--strength",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.9,
|
default=0.3,
|
||||||
help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)",
|
help="Audio cover strength (0.0-1.0, fraction of steps using cover conditioning, default: 0.3)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise-strength",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Cover noise strength (0.0-1.0, 0=pure noise start, 1=closest to source audio, default: 0.0)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--duration",
|
"--duration",
|
||||||
|
|
@ -98,6 +111,55 @@ def parse_args() -> argparse.Namespace:
|
||||||
default=None,
|
default=None,
|
||||||
help="Override output duration in seconds (default: auto-detect from input WAV)",
|
help="Override output duration in seconds (default: auto-detect from input WAV)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="Classifier-free guidance scale — higher = follows caption more strictly (default: 5.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of inference/denoising steps (default: 50)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shift",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Timestep shift factor — warps denoising schedule (default: 1.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampler",
|
||||||
|
type=str,
|
||||||
|
default="euler",
|
||||||
|
choices=["euler", "heun"],
|
||||||
|
help="Sampler mode: euler (fast) or heun (2nd-order, more accurate) (default: euler)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vel-clamp",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Velocity norm threshold — clamps outlier predictions to reduce artifacts (0=off, try 2.0) (default: 0.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vel-ema",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Velocity EMA smoothing — smooths predictions across steps (0=off, try 0.1) (default: 0.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Random seed for reproducible results (default: random)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--takes",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of takes to generate with different seeds (default: 1)",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,7 +202,7 @@ def main() -> None:
|
||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
# Build caption
|
# Build caption
|
||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
caption = build_caption(args.instrument)
|
caption = args.caption if args.caption else build_caption(args.instrument)
|
||||||
print(f"Caption: {caption}")
|
print(f"Caption: {caption}")
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
|
|
@ -178,88 +240,120 @@ def main() -> None:
|
||||||
print("Model loaded successfully.")
|
print("Model loaded successfully.")
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
# Configure generation
|
# Generate takes
|
||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
params = GenerationParams(
|
import random
|
||||||
task_type="cover",
|
|
||||||
src_audio=input_path,
|
|
||||||
caption=caption,
|
|
||||||
lyrics="",
|
|
||||||
instrumental=True,
|
|
||||||
duration=duration,
|
|
||||||
bpm=120,
|
|
||||||
audio_cover_strength=args.strength,
|
|
||||||
inference_steps=50,
|
|
||||||
guidance_scale=5.0,
|
|
||||||
thinking=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = GenerationConfig(
|
|
||||||
batch_size=1,
|
|
||||||
use_random_seed=True,
|
|
||||||
audio_format="wav",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use a temporary directory for ACE-Step's UUID-named output
|
|
||||||
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Generate
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
print(f"Generating {args.instrument} cover...")
|
|
||||||
try:
|
|
||||||
result = generate_music(
|
|
||||||
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Check result
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
if not result.success or not result.audios:
|
|
||||||
error_msg = getattr(result, "error", None) or "no audio produced"
|
|
||||||
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Get the generated file path
|
|
||||||
generated_path = result.audios[0]["path"]
|
|
||||||
|
|
||||||
if not os.path.isfile(generated_path):
|
|
||||||
print(
|
|
||||||
f"ERROR: Expected output file not found: {generated_path}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Silence detection
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
if check_silence(generated_path):
|
|
||||||
print(
|
|
||||||
"WARNING: Output audio appears to be near-silent. "
|
|
||||||
"The generation may have failed.",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Rename output to user-friendly filename
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
instrument_clean = args.instrument.lower().strip().replace(" ", "-")
|
instrument_clean = args.instrument.lower().strip().replace(" ", "-")
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
num_takes = args.takes
|
||||||
final_filename = f"{instrument_clean}_{timestamp}.wav"
|
seeds = []
|
||||||
final_path = os.path.join(output_dir, final_filename)
|
|
||||||
|
|
||||||
shutil.copy2(generated_path, final_path)
|
if args.seed is not None:
|
||||||
|
# Explicit seed: use it for take 1, derive sequential seeds for additional takes
|
||||||
|
seeds = [args.seed + i for i in range(num_takes)]
|
||||||
|
else:
|
||||||
|
seeds = [random.randint(0, 2**31 - 1) for _ in range(num_takes)]
|
||||||
|
|
||||||
# Clean up temp directory
|
for take_idx, seed in enumerate(seeds):
|
||||||
try:
|
take_label = f"[take {take_idx + 1}/{num_takes}]" if num_takes > 1 else ""
|
||||||
shutil.rmtree(temp_save_dir)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print(f"Output saved: {final_path}")
|
params = GenerationParams(
|
||||||
|
task_type="cover",
|
||||||
|
src_audio=input_path,
|
||||||
|
caption=caption,
|
||||||
|
lyrics="",
|
||||||
|
instrumental=True,
|
||||||
|
duration=duration,
|
||||||
|
bpm=120,
|
||||||
|
audio_cover_strength=args.strength,
|
||||||
|
cover_noise_strength=args.noise_strength,
|
||||||
|
inference_steps=args.steps,
|
||||||
|
guidance_scale=args.guidance,
|
||||||
|
shift=args.shift,
|
||||||
|
sampler_mode=args.sampler,
|
||||||
|
velocity_norm_threshold=args.vel_clamp,
|
||||||
|
velocity_ema_factor=args.vel_ema,
|
||||||
|
thinking=False,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
batch_size=1,
|
||||||
|
use_random_seed=False,
|
||||||
|
audio_format="wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
|
||||||
|
|
||||||
|
print(f"Generating {args.instrument} cover (seed={seed}) {take_label}...")
|
||||||
|
try:
|
||||||
|
result = generate_music(
|
||||||
|
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
|
||||||
|
if num_takes == 1:
|
||||||
|
sys.exit(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not result.success or not result.audios:
|
||||||
|
error_msg = getattr(result, "error", None) or "no audio produced"
|
||||||
|
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
|
||||||
|
if num_takes == 1:
|
||||||
|
sys.exit(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
generated_path = result.audios[0]["path"]
|
||||||
|
|
||||||
|
if not os.path.isfile(generated_path):
|
||||||
|
print(f"ERROR: Expected output file not found: {generated_path}", file=sys.stderr)
|
||||||
|
if num_takes == 1:
|
||||||
|
sys.exit(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if check_silence(generated_path):
|
||||||
|
print(
|
||||||
|
"WARNING: Output audio appears to be near-silent. "
|
||||||
|
"The generation may have failed.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
final_filename = f"{instrument_clean}_{timestamp}_s{seed}.wav"
|
||||||
|
final_path = os.path.join(output_dir, final_filename)
|
||||||
|
|
||||||
|
shutil.copy2(generated_path, final_path)
|
||||||
|
|
||||||
|
log_path = final_path.replace(".wav", ".json")
|
||||||
|
run_log = {
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"input": input_path,
|
||||||
|
"output": final_path,
|
||||||
|
"instrument": args.instrument,
|
||||||
|
"caption": caption,
|
||||||
|
"duration": duration,
|
||||||
|
"strength": args.strength,
|
||||||
|
"noise_strength": args.noise_strength,
|
||||||
|
"inference_steps": args.steps,
|
||||||
|
"guidance_scale": args.guidance,
|
||||||
|
"shift": args.shift,
|
||||||
|
"sampler": args.sampler,
|
||||||
|
"vel_clamp": args.vel_clamp,
|
||||||
|
"vel_ema": args.vel_ema,
|
||||||
|
"seed": seed,
|
||||||
|
"take": take_idx + 1,
|
||||||
|
"total_takes": num_takes,
|
||||||
|
}
|
||||||
|
with open(log_path, "w") as f:
|
||||||
|
json.dump(run_log, f, indent=2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(temp_save_dir)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print(f"Output saved: {final_path}")
|
||||||
|
print(f"Run log: {log_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue