|
|
import sys |
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
|
|
|
import time |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import onnxruntime as ort |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
class StopJudgmentONNXInference: |
|
|
def __init__(self, onnx_model_path: str, tokenizer_path: str, device: str = 'auto'): |
|
|
""" |
|
|
判停模型ONNX推理类 |
|
|
|
|
|
Args: |
|
|
onnx_model_path: ONNX模型路径 |
|
|
tokenizer_path: tokenizer路径 |
|
|
device: 设备类型 ('auto', 'cuda', 'cpu') |
|
|
""" |
|
|
self.onnx_model_path = onnx_model_path |
|
|
self.tokenizer_path = tokenizer_path |
|
|
self.setup_logging() |
|
|
self.load_model_and_tokenizer() |
|
|
|
|
|
def setup_logging(self): |
|
|
"""设置日志""" |
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_model_and_tokenizer(self): |
|
|
"""加载ONNX模型和tokenizer""" |
|
|
|
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, local_files_only=True) |
|
|
self.logger.info("Tokenizer loaded successfully") |
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to load tokenizer: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
providers = [] |
|
|
|
|
|
|
|
|
available_providers = ort.get_available_providers() |
|
|
if 'CUDAExecutionProvider' in available_providers: |
|
|
providers.append('CUDAExecutionProvider') |
|
|
self.logger.info("CUDA provider is available and will be used") |
|
|
|
|
|
providers.append('CPUExecutionProvider') |
|
|
|
|
|
try: |
|
|
self.ort_session = ort.InferenceSession(self.onnx_model_path, providers=providers) |
|
|
self.logger.info(f"ONNX model loaded successfully with providers: {self.ort_session.get_providers()}") |
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to load ONNX model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
self.input_names = [input.name for input in self.ort_session.get_inputs()] |
|
|
self.output_names = [output.name for output in self.ort_session.get_outputs()] |
|
|
|
|
|
self.logger.info(f"Input names: {self.input_names}") |
|
|
self.logger.info(f"Output names: {self.output_names}") |
|
|
|
|
|
def preprocess_text(self, texts: List[str], max_length: int = 128) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
预处理文本数据 |
|
|
|
|
|
Args: |
|
|
texts: 文本列表 |
|
|
max_length: 最大长度 |
|
|
|
|
|
Returns: |
|
|
包含input_ids和attention_mask的字典 |
|
|
""" |
|
|
encoding = self.tokenizer( |
|
|
texts, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=max_length, |
|
|
return_tensors='np' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].astype(np.int64), |
|
|
'attention_mask': encoding['attention_mask'].astype(np.int64) |
|
|
} |
|
|
|
|
|
def predict_single(self, text: str, max_length: int = 128) -> Tuple[int, float]: |
|
|
"""单个文本预测""" |
|
|
inputs = self.preprocess_text([text], max_length) |
|
|
|
|
|
|
|
|
ort_inputs = { |
|
|
self.input_names[0]: inputs['input_ids'], |
|
|
self.input_names[1]: inputs['attention_mask'] |
|
|
} |
|
|
|
|
|
ort_outputs = self.ort_session.run(self.output_names, ort_inputs) |
|
|
logits = ort_outputs[0] |
|
|
|
|
|
|
|
|
probabilities = self.softmax(logits) |
|
|
prediction = np.argmax(probabilities[0]) |
|
|
confidence = probabilities[0][prediction] |
|
|
|
|
|
return int(prediction), float(confidence) |
|
|
|
|
|
def predict_batch(self, texts: List[str], max_length: int = 128, |
|
|
batch_size: int = 32) -> Tuple[List[int], List[float]]: |
|
|
"""批量预测""" |
|
|
all_predictions = [] |
|
|
all_confidences = [] |
|
|
|
|
|
for i in tqdm(range(0, len(texts), batch_size), desc="ONNX Predicting"): |
|
|
batch_texts = texts[i:i + batch_size] |
|
|
inputs = self.preprocess_text(batch_texts, max_length) |
|
|
|
|
|
|
|
|
ort_inputs = { |
|
|
self.input_names[0]: inputs['input_ids'], |
|
|
self.input_names[1]: inputs['attention_mask'] |
|
|
} |
|
|
|
|
|
ort_outputs = self.ort_session.run(self.output_names, ort_inputs) |
|
|
logits = ort_outputs[0] |
|
|
|
|
|
|
|
|
probabilities = self.softmax(logits) |
|
|
predictions = np.argmax(probabilities, axis=1) |
|
|
confidences = [probabilities[j][pred] for j, pred in enumerate(predictions)] |
|
|
|
|
|
all_predictions.extend(predictions.tolist()) |
|
|
all_confidences.extend(confidences) |
|
|
|
|
|
return all_predictions, all_confidences |
|
|
|
|
|
@staticmethod |
|
|
def softmax(x): |
|
|
"""Softmax函数""" |
|
|
exp_x = np.exp(x - np.max(x, axis=1, keepdims=True)) |
|
|
return exp_x / np.sum(exp_x, axis=1, keepdims=True) |
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
if len(sys.argv) < 3: |
|
|
print("Usage: python validate_onnx.py <tokenizer_path> <onnx_model_path> [test_sentence]") |
|
|
sys.exit(1) |
|
|
|
|
|
tokenizer_path = sys.argv[1] |
|
|
onnx_model_path = sys.argv[2] |
|
|
test_sentence = sys.argv[3] if len(sys.argv) > 3 else "欢迎测试本判停模型有修正建议请随时提出" |
|
|
|
|
|
print("\n ONNX Model Inference...") |
|
|
onnx_inferencer = StopJudgmentONNXInference(onnx_model_path, tokenizer_path) |
|
|
prediction, confidence = onnx_inferencer.predict_single( |
|
|
test_sentence, max_length=128 |
|
|
) |
|
|
print(prediction, confidence) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|