Update audiocraft/models/loaders.py
Browse files- audiocraft/models/loaders.py +11 -9
audiocraft/models/loaders.py
CHANGED
|
@@ -104,15 +104,17 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
|
|
| 104 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
| 105 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 106 |
cfg.device = str(device)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
| 117 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
| 118 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
|
|
|
| 104 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
| 105 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 106 |
cfg.device = str(device)
|
| 107 |
+
|
| 108 |
+
cfg.dtype = 'float32'
|
| 109 |
+
#if cfg.device == 'cpu':
|
| 110 |
+
# cfg.dtype = 'float32'
|
| 111 |
+
#else:
|
| 112 |
+
# if depth=='float32':
|
| 113 |
+
# cfg.dtype = 'float32'
|
| 114 |
+
# if depth=='bfloat16':
|
| 115 |
+
# cfg.dtype = 'bfloat16'
|
| 116 |
+
# if depth=='float16':
|
| 117 |
+
# cfg.dtype = 'float16'
|
| 118 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
| 119 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
| 120 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|