Use model config directly (#1)
Browse files- Use model config directly (7eb2d5255f6968846ee0dae38472cb2405841bcb)
Co-authored-by: Joshua <[email protected]>
README.md
CHANGED
|
@@ -28,24 +28,22 @@ import requests
|
|
| 28 |
import onnxruntime as ort
|
| 29 |
from PIL import Image
|
| 30 |
from io import BytesIO
|
| 31 |
-
from transformers import
|
| 32 |
|
| 33 |
# Command line arguments
|
| 34 |
model_path = sys.argv[1]
|
| 35 |
onnx_path = sys.argv[2]
|
| 36 |
|
| 37 |
-
# Initialize model and tokenizer
|
| 38 |
-
|
| 39 |
-
model_path, torch_dtype=torch.float32, device_map='mps'
|
| 40 |
-
)
|
| 41 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 42 |
|
| 43 |
# Model configuration
|
| 44 |
max_length = 1024
|
| 45 |
-
num_attention_heads =
|
| 46 |
-
num_key_value_heads =
|
| 47 |
-
head_dim =
|
| 48 |
-
num_layers =
|
| 49 |
|
| 50 |
# Setup ONNX sessions
|
| 51 |
session_options = ort.SessionOptions()
|
|
|
|
| 28 |
import onnxruntime as ort
|
| 29 |
from PIL import Image
|
| 30 |
from io import BytesIO
|
| 31 |
+
from transformers import Qwen2VLConfig, AutoTokenizer
|
| 32 |
|
| 33 |
# Command line arguments
|
| 34 |
model_path = sys.argv[1]
|
| 35 |
onnx_path = sys.argv[2]
|
| 36 |
|
| 37 |
+
# Initialize model config and tokenizer
|
| 38 |
+
model_config = Qwen2VLConfig.from_pretrained(model_path)
|
|
|
|
|
|
|
| 39 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 40 |
|
| 41 |
# Model configuration
|
| 42 |
max_length = 1024
|
| 43 |
+
num_attention_heads = model_config.num_attention_heads
|
| 44 |
+
num_key_value_heads = model_config.num_key_value_heads
|
| 45 |
+
head_dim = model_config.hidden_size // num_attention_heads
|
| 46 |
+
num_layers = model_config.num_hidden_layers
|
| 47 |
|
| 48 |
# Setup ONNX sessions
|
| 49 |
session_options = ort.SessionOptions()
|