Bharath Kumar Kakumani theshresthshukla commited on
Commit
8b770f9
·
verified ·
1 Parent(s): dc45b1c

Updated Code to run on Colab Notebook without Errors (#13)

Browse files

- Updated Code to run on Colab Notebook without Errors (34eb501a93ab974b0b0ff45d6bba22669d51cea4)


Co-authored-by: Shresth Shukla <[email protected]>

Files changed (1) hide show
  1. README.md +4 -1
README.md CHANGED
@@ -152,6 +152,9 @@ def decode_snac_tokens(snac_tokens, snac_model):
152
  if not snac_tokens or len(snac_tokens) % 7 != 0:
153
  return None
154
 
 
 
 
155
  # De-interleave tokens into 3 hierarchical levels
156
  codes_lvl = [[] for _ in range(3)]
157
  llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
@@ -171,7 +174,7 @@ def decode_snac_tokens(snac_tokens, snac_model):
171
  # Convert to tensors for SNAC decoder
172
  hierarchical_codes = []
173
  for lvl_codes in codes_lvl:
174
- tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_model.device).unsqueeze(0)
175
  if torch.any((tensor < 0) | (tensor > 4095)):
176
  raise ValueError("Invalid SNAC token values")
177
  hierarchical_codes.append(tensor)
 
152
  if not snac_tokens or len(snac_tokens) % 7 != 0:
153
  return None
154
 
155
+ # Get the device of the SNAC model. Fixed by Shresth to run on colab notebook :)
156
+ snac_device = next(snac_model.parameters()).device
157
+
158
  # De-interleave tokens into 3 hierarchical levels
159
  codes_lvl = [[] for _ in range(3)]
160
  llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
 
174
  # Convert to tensors for SNAC decoder
175
  hierarchical_codes = []
176
  for lvl_codes in codes_lvl:
177
+ tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
178
  if torch.any((tensor < 0) | (tensor > 4095)):
179
  raise ValueError("Invalid SNAC token values")
180
  hierarchical_codes.append(tensor)