Spaces:
Sleeping
Sleeping
IlayMalinyak
commited on
Commit
·
99dc7bf
1
Parent(s):
81f7a68
3layer kan
Browse files
tasks/models/frugal_2025-01-26/frugal_kan_2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f5652f3be0033b5a249ded449c1f7c40dd2f32d649b07ca7c2f6158a4b57cb5
|
| 3 |
+
size 1710980
|
tasks/models/frugal_2025-01-27/CNNEncoder_frugal_2.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tasks/models/frugal_2025-01-27/frugal_kan_2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f520cff8b9531981e16a8b009b6a55fb8ca98573fc4d3dc6806df60b07a49c2
|
| 3 |
+
size 1710980
|
tasks/run.py
CHANGED
|
@@ -9,6 +9,7 @@ import yaml
|
|
| 9 |
import datetime
|
| 10 |
import json
|
| 11 |
import numpy as np
|
|
|
|
| 12 |
|
| 13 |
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
current_date = datetime.date.today().strftime("%Y-%m-%d")
|
|
@@ -56,6 +57,13 @@ model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
|
| 56 |
# model.kan.speed()
|
| 57 |
# model = KanEncoder(kan_args.get_dict())
|
| 58 |
model = model.to(local_rank)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
| 60 |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 61 |
print(f"Number of parameters: {num_params}")
|
|
|
|
| 9 |
import datetime
|
| 10 |
import json
|
| 11 |
import numpy as np
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
|
| 14 |
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
current_date = datetime.date.today().strftime("%Y-%m-%d")
|
|
|
|
| 57 |
# model.kan.speed()
|
| 58 |
# model = KanEncoder(kan_args.get_dict())
|
| 59 |
model = model.to(local_rank)
|
| 60 |
+
state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
|
| 61 |
+
new_state_dict = OrderedDict()
|
| 62 |
+
for key, value in state_dict.items():
|
| 63 |
+
if key.startswith('module.'):
|
| 64 |
+
key = key[7:]
|
| 65 |
+
new_state_dict[key] = value
|
| 66 |
+
missing, unexpected = model.load_state_dict(new_state_dict)
|
| 67 |
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
| 68 |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 69 |
print(f"Number of parameters: {num_params}")
|
tasks/utils/config.yaml
CHANGED
|
@@ -12,7 +12,7 @@ Data:
|
|
| 12 |
max_days_lc: 270
|
| 13 |
lc_freq: 0.0208
|
| 14 |
create_umap: True
|
| 15 |
-
checkpoint_path: 'tasks/models/frugal_2025-01-
|
| 16 |
|
| 17 |
CNNEncoder:
|
| 18 |
# Model
|
|
@@ -31,7 +31,7 @@ CNNEncoder:
|
|
| 31 |
avg_output: False
|
| 32 |
|
| 33 |
KAN:
|
| 34 |
-
layers_hidden: [1125,32,8,
|
| 35 |
grid_min: -1.2
|
| 36 |
grid_max: 1.2
|
| 37 |
num_grids: 8
|
|
|
|
| 12 |
max_days_lc: 270
|
| 13 |
lc_freq: 0.0208
|
| 14 |
create_umap: True
|
| 15 |
+
checkpoint_path: 'tasks/models/frugal_2025-01-27/frugal_kan_2.pth'
|
| 16 |
|
| 17 |
CNNEncoder:
|
| 18 |
# Model
|
|
|
|
| 31 |
avg_output: False
|
| 32 |
|
| 33 |
KAN:
|
| 34 |
+
layers_hidden: [1125,32,8,1]
|
| 35 |
grid_min: -1.2
|
| 36 |
grid_max: 1.2
|
| 37 |
num_grids: 8
|