pszemraj commited on
Commit
839b37f
·
verified ·
1 Parent(s): 2adf6da

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +166 -3
README.md CHANGED
@@ -1,3 +1,166 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+
6
+ # nGPT-enwiki8
7
+
8
+ small nGPT model trained on enwiki8 for testing purposes with [nGPT-pytorch](https://github.com/lucidrains/nGPT-pytorch)
9
+
10
+ ## inference
11
+
12
+ 1. download a weights file from this repo
13
+ 2. install dependencies
14
+
15
+ ```sh
16
+ # assuming you already have torch
17
+ pip install fire nGPT-pytorch
18
+ ```
19
+
20
+ run inference with below with `python inference.py /path/to/model.pt "Once upon a time"`
21
+
22
+
23
+ ```py
24
+ # inference.py
25
+ import json
26
+ import sys
27
+ from pathlib import Path
28
+
29
+ import fire
30
+ import torch
31
+ from nGPT_pytorch import nGPT
32
+
33
+
34
+ def exists(v):
35
+ return v is not None
36
+
37
+
38
+ def decode_token(token):
39
+ return str(chr(max(32, token)))
40
+
41
+
42
+ def decode_tokens(tokens):
43
+ return "".join(list(map(decode_token, tokens)))
44
+
45
+
46
+ def log(t, eps=1e-20):
47
+ return torch.log(t.clamp(min=eps))
48
+
49
+
50
+ def gumbel_noise(t):
51
+ noise = torch.zeros_like(t).uniform_(0, 1)
52
+ return -log(-log(noise))
53
+
54
+
55
+ def gumbel_sample(t, temperature=1.0, dim=-1, keepdim=True):
56
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(
57
+ dim=dim, keepdim=keepdim
58
+ )
59
+
60
+
61
+ def min_p_filter(logits, min_p=0.1):
62
+ probs = logits.softmax(dim=-1)
63
+ max_probs = probs.amax(dim=-1, keepdim=True)
64
+ limit = min_p * max_probs
65
+ return torch.where(probs < limit, float("-inf"), logits)
66
+
67
+
68
+ def base_decoding(
69
+ net,
70
+ prompt: torch.Tensor,
71
+ seq_len: int,
72
+ temperature=1.5,
73
+ min_p=1e-1,
74
+ filter_thres=0.9,
75
+ ):
76
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
77
+ sample_num_times = max(0, seq_len - prompt_seq_len)
78
+
79
+ for _ in range(sample_num_times):
80
+ logits = net(out)
81
+ logits = logits[:, -1]
82
+
83
+ logits = min_p_filter(logits, min_p=min_p)
84
+ sample = gumbel_sample(logits, temperature=temperature, dim=-1)
85
+
86
+ out = torch.cat((out, sample), dim=-1)
87
+
88
+ return out[..., prompt_seq_len:]
89
+
90
+
91
+ def main(
92
+ checkpoint_path: str,
93
+ prompt: str,
94
+ max_new_tokens: int = 100,
95
+ temperature: float = 1.0,
96
+ min_p: float = 0.1,
97
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
98
+ ):
99
+ """Generate text using a trained nGPT model."""
100
+
101
+ # Load checkpoint
102
+ checkpoint_path = Path(checkpoint_path)
103
+ if not checkpoint_path.exists():
104
+ print(f"Error: Checkpoint not found at {checkpoint_path}")
105
+ sys.exit(1)
106
+
107
+ print(f"Loading checkpoint from {checkpoint_path}...")
108
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
109
+
110
+ # Get config from checkpoint or file
111
+ config = checkpoint.get("config", {})
112
+ if not config and checkpoint_path.parent.joinpath("config.json").exists():
113
+ with open(checkpoint_path.parent.joinpath("config.json")) as f:
114
+ config = json.load(f)
115
+
116
+ use_parametrize = config.get("use_parametrize", True)
117
+
118
+ # Initialize model
119
+ model = nGPT(
120
+ num_tokens=256,
121
+ dim=512,
122
+ depth=8,
123
+ tied_embedding=True,
124
+ add_value_residual=True,
125
+ attn_norm_qk=False,
126
+ manual_norm_weights=not use_parametrize,
127
+ ).to(device)
128
+
129
+ # Load weights
130
+ model.load_state_dict(checkpoint["model_state_dict"])
131
+ model.eval()
132
+
133
+ print("\nModel loaded successfully. Generating with:")
134
+ print(f" Temperature: {temperature}")
135
+ print(f" Min-p: {min_p}")
136
+ print(f" Max new tokens: {max_new_tokens}")
137
+
138
+ # Convert prompt to tensor
139
+ prompt_tensor = torch.tensor(
140
+ [ord(c) for c in prompt], dtype=torch.long, device=device
141
+ )
142
+ prompt_tensor = prompt_tensor.unsqueeze(0)
143
+
144
+ # Generate
145
+ with torch.no_grad():
146
+ sampled = base_decoding(
147
+ model,
148
+ prompt_tensor,
149
+ seq_len=max_new_tokens,
150
+ temperature=temperature,
151
+ min_p=min_p,
152
+ )
153
+
154
+ generated = decode_tokens(sampled[0])
155
+
156
+ print("\nGenerated text:")
157
+ print("-" * 80)
158
+ print(prompt + generated)
159
+ print("-" * 80)
160
+
161
+ return generated
162
+
163
+
164
+ if __name__ == "__main__":
165
+ fire.Fire(main)
166
+ ```