theshresthshukla commited on
Commit
34eb501
·
verified ·
1 Parent(s): 1ae1f0f

Updated Code to run on Colab Notebook without Errors

Browse files

Fixed the code to run the code on colab without errors. There was an issue with SNAC earlier which is fixed

Files changed (1) hide show
  1. README.md +4 -1
README.md CHANGED
@@ -153,6 +153,9 @@ def decode_snac_tokens(snac_tokens, snac_model):
153
  if not snac_tokens or len(snac_tokens) % 7 != 0:
154
  return None
155
 
 
 
 
156
  # De-interleave tokens into 3 hierarchical levels
157
  codes_lvl = [[] for _ in range(3)]
158
  llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
@@ -172,7 +175,7 @@ def decode_snac_tokens(snac_tokens, snac_model):
172
  # Convert to tensors for SNAC decoder
173
  hierarchical_codes = []
174
  for lvl_codes in codes_lvl:
175
- tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_model.device).unsqueeze(0)
176
  if torch.any((tensor < 0) | (tensor > 4095)):
177
  raise ValueError("Invalid SNAC token values")
178
  hierarchical_codes.append(tensor)
 
153
  if not snac_tokens or len(snac_tokens) % 7 != 0:
154
  return None
155
 
156
+ # Get the device of the SNAC model. Fixed by Shresth to run on colab notebook :)
157
+ snac_device = next(snac_model.parameters()).device
158
+
159
  # De-interleave tokens into 3 hierarchical levels
160
  codes_lvl = [[] for _ in range(3)]
161
  llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
 
175
  # Convert to tensors for SNAC decoder
176
  hierarchical_codes = []
177
  for lvl_codes in codes_lvl:
178
+ tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
179
  if torch.any((tensor < 0) | (tensor > 4095)):
180
  raise ValueError("Invalid SNAC token values")
181
  hierarchical_codes.append(tensor)