Updated Code to run on Colab Notebook without Errors (#13)
Browse files- Updated Code to run on Colab Notebook without Errors (34eb501a93ab974b0b0ff45d6bba22669d51cea4)
Co-authored-by: Shresth Shukla <[email protected]>
README.md
CHANGED
|
@@ -152,6 +152,9 @@ def decode_snac_tokens(snac_tokens, snac_model):
|
|
| 152 |
if not snac_tokens or len(snac_tokens) % 7 != 0:
|
| 153 |
return None
|
| 154 |
|
|
|
|
|
|
|
|
|
|
| 155 |
# De-interleave tokens into 3 hierarchical levels
|
| 156 |
codes_lvl = [[] for _ in range(3)]
|
| 157 |
llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
|
|
@@ -171,7 +174,7 @@ def decode_snac_tokens(snac_tokens, snac_model):
|
|
| 171 |
# Convert to tensors for SNAC decoder
|
| 172 |
hierarchical_codes = []
|
| 173 |
for lvl_codes in codes_lvl:
|
| 174 |
-
tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=
|
| 175 |
if torch.any((tensor < 0) | (tensor > 4095)):
|
| 176 |
raise ValueError("Invalid SNAC token values")
|
| 177 |
hierarchical_codes.append(tensor)
|
|
|
|
| 152 |
if not snac_tokens or len(snac_tokens) % 7 != 0:
|
| 153 |
return None
|
| 154 |
|
| 155 |
+
# Get the device of the SNAC model. Fixed by Shresth to run on colab notebook :)
|
| 156 |
+
snac_device = next(snac_model.parameters()).device
|
| 157 |
+
|
| 158 |
# De-interleave tokens into 3 hierarchical levels
|
| 159 |
codes_lvl = [[] for _ in range(3)]
|
| 160 |
llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
|
|
|
|
| 174 |
# Convert to tensors for SNAC decoder
|
| 175 |
hierarchical_codes = []
|
| 176 |
for lvl_codes in codes_lvl:
|
| 177 |
+
tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
|
| 178 |
if torch.any((tensor < 0) | (tensor > 4095)):
|
| 179 |
raise ValueError("Invalid SNAC token values")
|
| 180 |
hierarchical_codes.append(tensor)
|