Update README.md
Browse files
README.md
CHANGED
|
@@ -35,23 +35,14 @@ from heron.models.git_llm.git_llama import GitLlamaConfig, GitLlamaForCausalLM
|
|
| 35 |
device_id = 0
|
| 36 |
|
| 37 |
# prepare a pretrained model
|
| 38 |
-
MODEL_NAME = 'turing-motors/heron-chat-git-Llama-2-7b-v0'
|
| 39 |
-
|
| 40 |
-
git_config = GitLlamaConfig.from_pretrained(MODEL_NAME)
|
| 41 |
-
git_config.set_vision_configs(
|
| 42 |
-
num_image_with_embedding=1, vision_model_name=git_config.vision_model_name
|
| 43 |
-
)
|
| 44 |
model = GitLlamaForCausalLM.from_pretrained(
|
| 45 |
-
|
| 46 |
)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
model = GitLlamaForCausalLM.from_pretrained(MODEL_NAME)
|
| 50 |
model.eval()
|
| 51 |
model.to(f"cuda:{device_id}")
|
| 52 |
|
| 53 |
# prepare a processor
|
| 54 |
-
processor = AutoProcessor.from_pretrained(
|
| 55 |
|
| 56 |
# prepare inputs
|
| 57 |
url = "https://www.barnorama.com/wp-content/uploads/2016/12/03-Confusing-Pictures.jpg"
|
|
@@ -79,7 +70,7 @@ with torch.no_grad():
|
|
| 79 |
out = model.generate(**inputs, max_length=256, do_sample=False, temperature=0., eos_token_id=eos_token_id_list)
|
| 80 |
|
| 81 |
# print result
|
| 82 |
-
print(processor.tokenizer.batch_decode(out))
|
| 83 |
```
|
| 84 |
|
| 85 |
|
|
|
|
| 35 |
device_id = 0
|
| 36 |
|
| 37 |
# prepare a pretrained model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
model = GitLlamaForCausalLM.from_pretrained(
|
| 39 |
+
'turing-motors/heron-chat-git-Llama-2-7b-v0', torch_dtype=torch.float16
|
| 40 |
)
|
|
|
|
|
|
|
|
|
|
| 41 |
model.eval()
|
| 42 |
model.to(f"cuda:{device_id}")
|
| 43 |
|
| 44 |
# prepare a processor
|
| 45 |
+
processor = AutoProcessor.from_pretrained('turing-motors/heron-chat-git-Llama-2-7b-v0')
|
| 46 |
|
| 47 |
# prepare inputs
|
| 48 |
url = "https://www.barnorama.com/wp-content/uploads/2016/12/03-Confusing-Pictures.jpg"
|
|
|
|
| 70 |
out = model.generate(**inputs, max_length=256, do_sample=False, temperature=0., eos_token_id=eos_token_id_list)
|
| 71 |
|
| 72 |
# print result
|
| 73 |
+
print(processor.tokenizer.batch_decode(out)[0])
|
| 74 |
```
|
| 75 |
|
| 76 |
|