Update audiocraft/models/loaders.py
Browse files
audiocraft/models/loaders.py
CHANGED
|
@@ -121,8 +121,10 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
|
|
| 121 |
model = builders.get_lm_model(cfg)
|
| 122 |
if depth=='bfloat16':
|
| 123 |
model = model.to(torch.bfloat16)
|
|
|
|
| 124 |
if depth=='float16':
|
| 125 |
model = model.to(torch.float16)
|
|
|
|
| 126 |
model.load_state_dict(pkg['best_state'])
|
| 127 |
model.eval()
|
| 128 |
model.cfg = cfg
|
|
|
|
| 121 |
model = builders.get_lm_model(cfg)
|
| 122 |
if depth=='bfloat16':
|
| 123 |
model = model.to(torch.bfloat16)
|
| 124 |
+
cfg.dtype = 'bfloat16'
|
| 125 |
if depth=='float16':
|
| 126 |
model = model.to(torch.float16)
|
| 127 |
+
cfg.dtype = 'float16'
|
| 128 |
model.load_state_dict(pkg['best_state'])
|
| 129 |
model.eval()
|
| 130 |
model.cfg = cfg
|