Updated Code to run on Colab Notebook without Errors
Browse filesFixed the code to run the code on colab without errors. There was an issue with SNAC earlier which is fixed
README.md
CHANGED
|
@@ -153,6 +153,9 @@ def decode_snac_tokens(snac_tokens, snac_model):
|
|
| 153 |
if not snac_tokens or len(snac_tokens) % 7 != 0:
|
| 154 |
return None
|
| 155 |
|
|
|
|
|
|
|
|
|
|
| 156 |
# De-interleave tokens into 3 hierarchical levels
|
| 157 |
codes_lvl = [[] for _ in range(3)]
|
| 158 |
llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
|
|
@@ -172,7 +175,7 @@ def decode_snac_tokens(snac_tokens, snac_model):
|
|
| 172 |
# Convert to tensors for SNAC decoder
|
| 173 |
hierarchical_codes = []
|
| 174 |
for lvl_codes in codes_lvl:
|
| 175 |
-
tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=
|
| 176 |
if torch.any((tensor < 0) | (tensor > 4095)):
|
| 177 |
raise ValueError("Invalid SNAC token values")
|
| 178 |
hierarchical_codes.append(tensor)
|
|
|
|
| 153 |
if not snac_tokens or len(snac_tokens) % 7 != 0:
|
| 154 |
return None
|
| 155 |
|
| 156 |
+
# Get the device of the SNAC model. Fixed by Shresth to run on colab notebook :)
|
| 157 |
+
snac_device = next(snac_model.parameters()).device
|
| 158 |
+
|
| 159 |
# De-interleave tokens into 3 hierarchical levels
|
| 160 |
codes_lvl = [[] for _ in range(3)]
|
| 161 |
llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
|
|
|
|
| 175 |
# Convert to tensors for SNAC decoder
|
| 176 |
hierarchical_codes = []
|
| 177 |
for lvl_codes in codes_lvl:
|
| 178 |
+
tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
|
| 179 |
if torch.any((tensor < 0) | (tensor > 4095)):
|
| 180 |
raise ValueError("Invalid SNAC token values")
|
| 181 |
hierarchical_codes.append(tensor)
|