Update README.md
Browse files
README.md
CHANGED
|
@@ -7,16 +7,19 @@ pipeline_tag: text-generation
|
|
| 7 |
|
| 8 |
# MAMBA (2.8B) 🐍 fine-tuned on H4/no_robots dataset for chat / instruction
|
| 9 |
|
| 10 |
-
|
| 11 |
|
| 12 |
## Usage
|
| 13 |
|
| 14 |
```py
|
|
|
|
| 15 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 16 |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| 17 |
|
| 18 |
CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
|
| 19 |
|
|
|
|
|
|
|
| 20 |
eos_token = "<|endoftext|>"
|
| 21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 22 |
tokenizer.eos_token = eos_token
|
|
@@ -24,7 +27,7 @@ tokenizer.pad_token = tokenizer.eos_token
|
|
| 24 |
tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
|
| 25 |
|
| 26 |
model = MambaLMHeadModel.from_pretrained(
|
| 27 |
-
model_name, device=
|
| 28 |
|
| 29 |
history_dict: list[dict[str, str]] = []
|
| 30 |
prompt = "Tell me 5 sites to visit in Spain"
|
|
@@ -32,7 +35,7 @@ history_dict.append(dict(role="user", content=prompt))
|
|
| 32 |
|
| 33 |
input_ids = tokenizer.apply_chat_template(
|
| 34 |
history_dict, return_tensors="pt", add_generation_prompt=True
|
| 35 |
-
|
| 36 |
|
| 37 |
out = model.generate(
|
| 38 |
input_ids=input_ids,
|
|
|
|
| 7 |
|
| 8 |
# MAMBA (2.8B) 🐍 fine-tuned on H4/no_robots dataset for chat / instruction
|
| 9 |
|
| 10 |
+
Model Card is still WIP!
|
| 11 |
|
| 12 |
## Usage
|
| 13 |
|
| 14 |
```py
|
| 15 |
+
import torch
|
| 16 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 17 |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| 18 |
|
| 19 |
CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
|
| 20 |
|
| 21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
eos_token = "<|endoftext|>"
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
tokenizer.eos_token = eos_token
|
|
|
|
| 27 |
tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
|
| 28 |
|
| 29 |
model = MambaLMHeadModel.from_pretrained(
|
| 30 |
+
model_name, device=device, dtype=torch.float16)
|
| 31 |
|
| 32 |
history_dict: list[dict[str, str]] = []
|
| 33 |
prompt = "Tell me 5 sites to visit in Spain"
|
|
|
|
| 35 |
|
| 36 |
input_ids = tokenizer.apply_chat_template(
|
| 37 |
history_dict, return_tensors="pt", add_generation_prompt=True
|
| 38 |
+
).to(device)
|
| 39 |
|
| 40 |
out = model.generate(
|
| 41 |
input_ids=input_ids,
|