164 lines
6.4 KiB
Python
164 lines
6.4 KiB
Python
"""
|
|
MusicGen Melody Large - Hum to Instrument
|
|
==========================================
|
|
Feed it a WAV/MP3 of you humming + a text prompt describing the instrument,
|
|
and it outputs that melody played on the described instrument.
|
|
|
|
Usage:
|
|
python musicgen_melody.py --input hum.wav --prompt "solo acoustic piano, gentle and warm"
|
|
python musicgen_melody.py --input hum.wav --prompt "solo electric guitar, jazz improvisation" --duration 20
|
|
python musicgen_melody.py --input hum.wav --prompt "solo saxophone, smooth jazz" --output sax_output.wav
|
|
|
|
Without --input, generates from text prompt only (no melody conditioning).
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torchaudio
|
|
from transformers import AutoProcessor, MusicgenMelodyForConditionalGeneration
|
|
|
|
|
|
def load_audio(path: str, target_sr: int = 32000):
|
|
"""Load audio file and resample to target sample rate."""
|
|
waveform, sr = torchaudio.load(path)
|
|
# Convert to mono if stereo
|
|
if waveform.shape[0] > 1:
|
|
waveform = waveform.mean(dim=0, keepdim=True)
|
|
# Resample if needed
|
|
if sr != target_sr:
|
|
resampler = torchaudio.transforms.Resample(sr, target_sr)
|
|
waveform = resampler(waveform)
|
|
return waveform, target_sr
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="MusicGen Melody - Hum to Instrument")
|
|
parser.add_argument("--input", "-i", type=str, default=None,
|
|
help="Path to input audio (WAV/MP3) of humming/melody (8-30 seconds)")
|
|
parser.add_argument("--prompt", "-p", type=str, required=True,
|
|
help="Text prompt describing desired instrument and style")
|
|
parser.add_argument("--duration", "-d", type=int, default=None,
|
|
help="Output duration in seconds (default: match input length, max 30)")
|
|
parser.add_argument("--output", "-o", type=str, default=None,
|
|
help="Output WAV path (default: auto-generated in output/musicgen/)")
|
|
parser.add_argument("--guidance", "-g", type=float, default=3.0,
|
|
help="Classifier-free guidance scale (default: 3.0, range 1-5)")
|
|
parser.add_argument("--top-k", type=int, default=250,
|
|
help="Top-k sampling (default: 250)")
|
|
parser.add_argument("--top-p", type=float, default=0.0,
|
|
help="Top-p nucleus sampling (default: 0.0 = disabled)")
|
|
parser.add_argument("--temperature", "-t", type=float, default=1.0,
|
|
help="Sampling temperature (default: 1.0)")
|
|
parser.add_argument("--seed", "-s", type=int, default=-1,
|
|
help="Random seed (-1 = random)")
|
|
args = parser.parse_args()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Device: {device}")
|
|
if device == "cuda":
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
|
|
# Load model
|
|
print("Loading MusicGen Melody Large...")
|
|
t0 = time.time()
|
|
processor = AutoProcessor.from_pretrained("facebook/musicgen-melody-large")
|
|
model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody-large")
|
|
model = model.to(device)
|
|
print(f"Model loaded in {time.time() - t0:.1f}s")
|
|
print(f"VRAM used: {torch.cuda.memory_allocated() / 1e9:.1f} GB")
|
|
|
|
# Determine duration
|
|
sample_rate = model.config.audio_encoder.sampling_rate # 32000
|
|
max_duration = 30
|
|
duration = min(args.duration or max_duration, max_duration)
|
|
|
|
# Prepare inputs
|
|
if args.input:
|
|
print(f"Loading input audio: {args.input}")
|
|
waveform, sr = load_audio(args.input, target_sr=sample_rate)
|
|
input_duration = waveform.shape[1] / sr
|
|
print(f"Input duration: {input_duration:.1f}s")
|
|
|
|
if input_duration > max_duration:
|
|
print(f"Warning: Input longer than {max_duration}s, truncating.")
|
|
waveform = waveform[:, :max_duration * sr]
|
|
input_duration = max_duration
|
|
|
|
if args.duration is None:
|
|
duration = min(int(input_duration) + 2, max_duration) # slight padding
|
|
print(f"Auto duration: {duration}s (input + 2s padding)")
|
|
|
|
audio_np = waveform.squeeze(0).numpy().astype(np.float32)
|
|
inputs = processor(
|
|
audio=audio_np,
|
|
sampling_rate=sample_rate,
|
|
text=[args.prompt],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
).to(device)
|
|
print(f"Melody conditioning active: input_features shape = {inputs['input_features'].shape}")
|
|
else:
|
|
print("No input audio - generating from text prompt only.")
|
|
inputs = processor(
|
|
text=[args.prompt],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
).to(device)
|
|
|
|
# Calculate max new tokens from duration
|
|
# MusicGen generates at ~50 tokens per second of audio
|
|
max_new_tokens = int(duration * 50)
|
|
|
|
# Seed right before generation (after model loading) so conditioning is not overridden
|
|
if args.seed >= 0:
|
|
torch.manual_seed(args.seed)
|
|
if device == "cuda":
|
|
torch.cuda.manual_seed(args.seed)
|
|
print(f"Seed: {args.seed}")
|
|
|
|
print(f"\nGenerating {duration}s of audio...")
|
|
print(f" Prompt: {args.prompt}")
|
|
print(f" Guidance: {args.guidance}")
|
|
print(f" Temperature: {args.temperature}")
|
|
print(f" Top-k: {args.top_k}")
|
|
|
|
t0 = time.time()
|
|
with torch.inference_mode():
|
|
audio_values = model.generate(
|
|
**inputs,
|
|
max_new_tokens=max_new_tokens,
|
|
guidance_scale=args.guidance,
|
|
do_sample=True,
|
|
temperature=args.temperature,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p if args.top_p > 0 else None,
|
|
)
|
|
gen_time = time.time() - t0
|
|
print(f"Generated in {gen_time:.1f}s (RTF: {gen_time/duration:.2f}x)")
|
|
|
|
# Save output
|
|
audio = audio_values[0, 0].cpu() # [samples]
|
|
|
|
# Create output directory
|
|
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output", "musicgen")
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
if args.output:
|
|
output_path = args.output
|
|
else:
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
safe_prompt = args.prompt[:40].replace(" ", "_").replace(",", "")
|
|
output_path = os.path.join(output_dir, f"{timestamp}_{safe_prompt}.wav")
|
|
|
|
torchaudio.save(output_path, audio.unsqueeze(0), sample_rate)
|
|
print(f"\nSaved: {output_path}")
|
|
print(f"Duration: {audio.shape[0] / sample_rate:.1f}s @ {sample_rate}Hz")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|