feat(01-02): add generation tuning params, seed control, multi-take, and run logging
Tuning session revealed strength=0.3 as optimal default for cover mode. Added --caption, --noise-strength, --guidance, --steps, --shift, --sampler, --vel-clamp, --vel-ema, --seed, --takes flags. Each output now has a JSON sidecar logging all parameters for reproducibility.
This commit is contained in:
parent
f40ca2b3fb
commit
45f6863131
1 changed files with 173 additions and 79 deletions
152
hum2inst.py
152
hum2inst.py
|
|
@ -9,6 +9,7 @@ Usage:
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
|
|
@ -80,6 +81,12 @@ def parse_args() -> argparse.Namespace:
|
|||
required=True,
|
||||
help="Target instrument name (e.g., piano, guitar, saxophone, violin, flute)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override the auto-generated caption (default: built from instrument name)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
|
|
@ -89,8 +96,14 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Audio cover strength (0.8-1.0, higher = more faithful to input melody, default: 0.9)",
|
||||
default=0.3,
|
||||
help="Audio cover strength (0.0-1.0, fraction of steps using cover conditioning, default: 0.3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise-strength",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Cover noise strength (0.0-1.0, 0=pure noise start, 1=closest to source audio, default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
|
|
@ -98,6 +111,55 @@ def parse_args() -> argparse.Namespace:
|
|||
default=None,
|
||||
help="Override output duration in seconds (default: auto-detect from input WAV)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Classifier-free guidance scale — higher = follows caption more strictly (default: 5.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of inference/denoising steps (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Timestep shift factor — warps denoising schedule (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
type=str,
|
||||
default="euler",
|
||||
choices=["euler", "heun"],
|
||||
help="Sampler mode: euler (fast) or heun (2nd-order, more accurate) (default: euler)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vel-clamp",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Velocity norm threshold — clamps outlier predictions to reduce artifacts (0=off, try 2.0) (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vel-ema",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Velocity EMA smoothing — smooths predictions across steps (0=off, try 0.1) (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for reproducible results (default: random)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--takes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of takes to generate with different seeds (default: 1)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
@ -140,7 +202,7 @@ def main() -> None:
|
|||
# -----------------------------------------------------------------------
|
||||
# Build caption
|
||||
# -----------------------------------------------------------------------
|
||||
caption = build_caption(args.instrument)
|
||||
caption = args.caption if args.caption else build_caption(args.instrument)
|
||||
print(f"Caption: {caption}")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
|
|
@ -178,8 +240,23 @@ def main() -> None:
|
|||
print("Model loaded successfully.")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Configure generation
|
||||
# Generate takes
|
||||
# -----------------------------------------------------------------------
|
||||
import random
|
||||
|
||||
instrument_clean = args.instrument.lower().strip().replace(" ", "-")
|
||||
num_takes = args.takes
|
||||
seeds = []
|
||||
|
||||
if args.seed is not None:
|
||||
# Explicit seed: use it for take 1, derive sequential seeds for additional takes
|
||||
seeds = [args.seed + i for i in range(num_takes)]
|
||||
else:
|
||||
seeds = [random.randint(0, 2**31 - 1) for _ in range(num_takes)]
|
||||
|
||||
for take_idx, seed in enumerate(seeds):
|
||||
take_label = f"[take {take_idx + 1}/{num_takes}]" if num_takes > 1 else ""
|
||||
|
||||
params = GenerationParams(
|
||||
task_type="cover",
|
||||
src_audio=input_path,
|
||||
|
|
@ -189,53 +266,51 @@ def main() -> None:
|
|||
duration=duration,
|
||||
bpm=120,
|
||||
audio_cover_strength=args.strength,
|
||||
inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
cover_noise_strength=args.noise_strength,
|
||||
inference_steps=args.steps,
|
||||
guidance_scale=args.guidance,
|
||||
shift=args.shift,
|
||||
sampler_mode=args.sampler,
|
||||
velocity_norm_threshold=args.vel_clamp,
|
||||
velocity_ema_factor=args.vel_ema,
|
||||
thinking=False,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
config = GenerationConfig(
|
||||
batch_size=1,
|
||||
use_random_seed=True,
|
||||
use_random_seed=False,
|
||||
audio_format="wav",
|
||||
)
|
||||
|
||||
# Use a temporary directory for ACE-Step's UUID-named output
|
||||
temp_save_dir = tempfile.mkdtemp(prefix="hum2inst_")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Generate
|
||||
# -----------------------------------------------------------------------
|
||||
print(f"Generating {args.instrument} cover...")
|
||||
print(f"Generating {args.instrument} cover (seed={seed}) {take_label}...")
|
||||
try:
|
||||
result = generate_music(
|
||||
dit_handler, llm_handler, params, config, save_dir=temp_save_dir
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Generation failed: {e}", file=sys.stderr)
|
||||
if num_takes == 1:
|
||||
sys.exit(1)
|
||||
continue
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Check result
|
||||
# -----------------------------------------------------------------------
|
||||
if not result.success or not result.audios:
|
||||
error_msg = getattr(result, "error", None) or "no audio produced"
|
||||
print(f"ERROR: Generation failed: {error_msg}", file=sys.stderr)
|
||||
if num_takes == 1:
|
||||
sys.exit(1)
|
||||
continue
|
||||
|
||||
# Get the generated file path
|
||||
generated_path = result.audios[0]["path"]
|
||||
|
||||
if not os.path.isfile(generated_path):
|
||||
print(
|
||||
f"ERROR: Expected output file not found: {generated_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"ERROR: Expected output file not found: {generated_path}", file=sys.stderr)
|
||||
if num_takes == 1:
|
||||
sys.exit(1)
|
||||
continue
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Silence detection
|
||||
# -----------------------------------------------------------------------
|
||||
if check_silence(generated_path):
|
||||
print(
|
||||
"WARNING: Output audio appears to be near-silent. "
|
||||
|
|
@ -243,23 +318,42 @@ def main() -> None:
|
|||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Rename output to user-friendly filename
|
||||
# -----------------------------------------------------------------------
|
||||
instrument_clean = args.instrument.lower().strip().replace(" ", "-")
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
final_filename = f"{instrument_clean}_{timestamp}.wav"
|
||||
final_filename = f"{instrument_clean}_{timestamp}_s{seed}.wav"
|
||||
final_path = os.path.join(output_dir, final_filename)
|
||||
|
||||
shutil.copy2(generated_path, final_path)
|
||||
|
||||
# Clean up temp directory
|
||||
log_path = final_path.replace(".wav", ".json")
|
||||
run_log = {
|
||||
"timestamp": timestamp,
|
||||
"input": input_path,
|
||||
"output": final_path,
|
||||
"instrument": args.instrument,
|
||||
"caption": caption,
|
||||
"duration": duration,
|
||||
"strength": args.strength,
|
||||
"noise_strength": args.noise_strength,
|
||||
"inference_steps": args.steps,
|
||||
"guidance_scale": args.guidance,
|
||||
"shift": args.shift,
|
||||
"sampler": args.sampler,
|
||||
"vel_clamp": args.vel_clamp,
|
||||
"vel_ema": args.vel_ema,
|
||||
"seed": seed,
|
||||
"take": take_idx + 1,
|
||||
"total_takes": num_takes,
|
||||
}
|
||||
with open(log_path, "w") as f:
|
||||
json.dump(run_log, f, indent=2)
|
||||
|
||||
try:
|
||||
shutil.rmtree(temp_save_dir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
print(f"Output saved: {final_path}")
|
||||
print(f"Run log: {log_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue