Update audiocraft/models/lm.py
Browse files- audiocraft/models/lm.py +94 -0
audiocraft/models/lm.py
CHANGED
|
@@ -531,3 +531,97 @@ class LMModel(StreamingModule):
|
|
| 531 |
# ensure the returned codes are all valid
|
| 532 |
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
| 533 |
return out_codes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
# ensure the returned codes are all valid
|
| 532 |
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
| 533 |
return out_codes
|
| 534 |
+
@torch.no_grad()
|
| 535 |
+
def generate_segment(self,
|
| 536 |
+
segment: int,
|
| 537 |
+
prompt_text: str,
|
| 538 |
+
max_segment_len: int,
|
| 539 |
+
seed: tp.Optional[int] = None,
|
| 540 |
+
# Pass other generation params like temp, top_k, etc.
|
| 541 |
+
**kwargs
|
| 542 |
+
) -> tp.Tuple[torch.Tensor, int]:
|
| 543 |
+
"""
|
| 544 |
+
Generates audio segment by segment, saving state to the filesystem.
|
| 545 |
+
This mirrors the logic from the RealViz script for robust, persistent state.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
segment (int): The segment number to generate (starts at 1).
|
| 549 |
+
prompt_text (str): The text description for the music.
|
| 550 |
+
max_segment_len (int): The number of tokens to generate in this segment.
|
| 551 |
+
seed (int, optional): The seed for generation. If None and segment is 1,
|
| 552 |
+
a random seed is created.
|
| 553 |
+
**kwargs: Additional generation parameters (temp, top_k, cfg_coef).
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
A tuple containing:
|
| 557 |
+
- full_codes (torch.Tensor): The generated tokens for the ENTIRE song so far.
|
| 558 |
+
- seed (int): The seed used for the generation process.
|
| 559 |
+
"""
|
| 560 |
+
# Ensure a consistent seed across all segments of a song
|
| 561 |
+
if segment == 1:
|
| 562 |
+
if seed is None:
|
| 563 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 564 |
+
print(f"Starting new generation with Seed: {seed}")
|
| 565 |
+
|
| 566 |
+
# --- This block runs only for the very first segment ---
|
| 567 |
+
conditions = [ConditioningAttributes(text={'description': prompt_text})]
|
| 568 |
+
# Start with an empty prompt tensor
|
| 569 |
+
prompt_codes = torch.zeros((1, self.num_codebooks, 0), dtype=torch.long, device=self.device)
|
| 570 |
+
self.clear_streaming_state() # Ensure model state is fresh
|
| 571 |
+
else:
|
| 572 |
+
# --- This block runs for all subsequent segments ---
|
| 573 |
+
state_file = f"musicgen_state_{segment-1}_{seed}.pt"
|
| 574 |
+
if not os.path.exists(state_file):
|
| 575 |
+
raise FileNotFoundError(f"State file not found! Cannot resume from segment {segment}. Please run segment {segment-1} first.")
|
| 576 |
+
|
| 577 |
+
print(f"Resuming from state file: {state_file}")
|
| 578 |
+
state = torch.load(state_file, map_location=self.device)
|
| 579 |
+
|
| 580 |
+
# Restore all necessary components from the saved state
|
| 581 |
+
seed = state['seed']
|
| 582 |
+
conditions = state['conditions']
|
| 583 |
+
# The prompt for the next segment is the full output from the previous one
|
| 584 |
+
prompt_codes = state['generated_tokens']
|
| 585 |
+
# CRITICAL: Restore the model's internal KV cache
|
| 586 |
+
self.set_streaming_state(state['model_state'])
|
| 587 |
+
|
| 588 |
+
# --- This part runs for EVERY segment ---
|
| 589 |
+
# The 'generate' function here refers to the original, non-chunking one.
|
| 590 |
+
# We are using it to generate just one segment's worth of audio.
|
| 591 |
+
# `remove_prompts=True` is vital to avoid re-generating the input prompt.
|
| 592 |
+
newly_generated_codes = self.generate(
|
| 593 |
+
prompt=prompt_codes,
|
| 594 |
+
conditions=conditions,
|
| 595 |
+
max_gen_len=prompt_codes.shape[-1] + max_segment_len, # Generate N more tokens
|
| 596 |
+
remove_prompts=True,
|
| 597 |
+
**kwargs
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# Combine the previous audio with the new segment
|
| 601 |
+
full_codes = torch.cat([prompt_codes, newly_generated_codes], dim=-1)
|
| 602 |
+
|
| 603 |
+
# --- Save the new state for the NEXT segment to use ---
|
| 604 |
+
print(f"Segment {segment} finished. Saving state...")
|
| 605 |
+
new_model_state = self.get_streaming_state()
|
| 606 |
+
|
| 607 |
+
# Move tensors to CPU before saving for portability
|
| 608 |
+
new_model_state.to('cpu')
|
| 609 |
+
|
| 610 |
+
new_state_to_save = {
|
| 611 |
+
'seed': seed,
|
| 612 |
+
'conditions': conditions,
|
| 613 |
+
'generated_tokens': full_codes.to('cpu'),
|
| 614 |
+
'model_state': new_model_state,
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
# Save the state dictionary to a file
|
| 618 |
+
new_state_file = f"musicgen_state_{segment}_{seed}.pt"
|
| 619 |
+
torch.save(new_state_to_save, new_state_file)
|
| 620 |
+
print(f"State for resuming at segment {segment + 1} saved to {new_state_file}")
|
| 621 |
+
|
| 622 |
+
return full_codes, seed
|
| 623 |
+
|
| 624 |
+
# You should also add the device property to your LMModel class if it's not there
|
| 625 |
+
@property
|
| 626 |
+
def device(self):
|
| 627 |
+
return next(self.parameters()).device
|