Update README.md
Browse filesAdd batch inference example.
README.md
CHANGED
|
@@ -78,6 +78,58 @@ with torch.inference_mode():
|
|
| 78 |
print(f'Output:\n{output}')
|
| 79 |
```
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
## Citation
|
| 82 |
If you find Ovis useful, please cite the paper
|
| 83 |
```
|
|
|
|
| 78 |
print(f'Output:\n{output}')
|
| 79 |
```
|
| 80 |
|
| 81 |
+
<details>
|
| 82 |
+
<summary>Batch inference</summary>
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
batch_inputs = [
|
| 86 |
+
('example_image1.jpeg', 'Describe the content of this image.'),
|
| 87 |
+
('example_image2.jpeg', 'What is the equation in the image?')
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
batch_input_ids = []
|
| 91 |
+
batch_attention_mask = []
|
| 92 |
+
batch_pixel_values = []
|
| 93 |
+
|
| 94 |
+
for image_path, text in batch_inputs:
|
| 95 |
+
image = Image.open(image_path)
|
| 96 |
+
query = f'<image>\n{text}'
|
| 97 |
+
prompt, input_ids, pixel_values = model.preprocess_inputs(query, [image])
|
| 98 |
+
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
|
| 99 |
+
input_ids = input_ids.unsqueeze(0).to(device=model.device)
|
| 100 |
+
attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
|
| 101 |
+
pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]
|
| 102 |
+
batch_input_ids.append(input_ids.squeeze())
|
| 103 |
+
batch_attention_mask.append(attention_mask.squeeze())
|
| 104 |
+
batch_pixel_values.append(pixel_values)
|
| 105 |
+
|
| 106 |
+
pad_batch_input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in batch_input_ids],batch_first=True, padding_value=0.0).flip(dims=[1])
|
| 107 |
+
pad_batch_input_ids = pad_batch_input_ids[:,-model.config.multimodal_max_length:]
|
| 108 |
+
pad_batch_attention_mask = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in batch_attention_mask],batch_first=True, padding_value=False).flip(dims=[1])
|
| 109 |
+
pad_batch_attention_mask = pad_batch_attention_mask[:,-model.config.multimodal_max_length:]
|
| 110 |
+
pad_batch_pixel_values = [item for sublist in batch_pixel_values for item in sublist]
|
| 111 |
+
|
| 112 |
+
# generate output
|
| 113 |
+
with torch.inference_mode():
|
| 114 |
+
gen_kwargs = dict(
|
| 115 |
+
max_new_tokens=1024,
|
| 116 |
+
do_sample=False,
|
| 117 |
+
top_p=None,
|
| 118 |
+
top_k=None,
|
| 119 |
+
temperature=None,
|
| 120 |
+
repetition_penalty=None,
|
| 121 |
+
eos_token_id=model.generation_config.eos_token_id,
|
| 122 |
+
pad_token_id=text_tokenizer.pad_token_id,
|
| 123 |
+
use_cache=True
|
| 124 |
+
)
|
| 125 |
+
output_ids = model.generate(pad_batch_input_ids, pixel_values=pad_batch_pixel_values, attention_mask=pad_batch_attention_mask, **gen_kwargs)
|
| 126 |
+
|
| 127 |
+
for i in range(len(batch_input_ids)):
|
| 128 |
+
output = text_tokenizer.decode(output_ids[i], skip_special_tokens=True)
|
| 129 |
+
print(f'Output_{i}:\n{output}')
|
| 130 |
+
```
|
| 131 |
+
</details>
|
| 132 |
+
|
| 133 |
## Citation
|
| 134 |
If you find Ovis useful, please cite the paper
|
| 135 |
```
|