FireRedChat-turn-detector / inference_onnx.py
FireRedTeam's picture
Upload folder using huggingface_hub
614e9c1 verified
import sys
import os
import json
import logging
from typing import List, Dict, Tuple, Optional
import time
import numpy as np
from tqdm import tqdm
import onnxruntime as ort
from transformers import AutoTokenizer
class StopJudgmentONNXInference:
def __init__(self, onnx_model_path: str, tokenizer_path: str, device: str = 'auto'):
"""
判停模型ONNX推理类
Args:
onnx_model_path: ONNX模型路径
tokenizer_path: tokenizer路径
device: 设备类型 ('auto', 'cuda', 'cpu')
"""
self.onnx_model_path = onnx_model_path
self.tokenizer_path = tokenizer_path
self.setup_logging()
self.load_model_and_tokenizer()
def setup_logging(self):
"""设置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def load_model_and_tokenizer(self):
"""加载ONNX模型和tokenizer"""
# 加载tokenizer
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, local_files_only=True)
self.logger.info("Tokenizer loaded successfully")
except Exception as e:
self.logger.error(f"Failed to load tokenizer: {e}")
raise
# 修复providers配置
providers = []
# 检查CUDA是否可用
available_providers = ort.get_available_providers()
if 'CUDAExecutionProvider' in available_providers:
providers.append('CUDAExecutionProvider')
self.logger.info("CUDA provider is available and will be used")
providers.append('CPUExecutionProvider') # 始终添加CPU作为备选
try:
self.ort_session = ort.InferenceSession(self.onnx_model_path, providers=providers)
self.logger.info(f"ONNX model loaded successfully with providers: {self.ort_session.get_providers()}")
except Exception as e:
self.logger.error(f"Failed to load ONNX model: {e}")
raise
# 获取输入输出信息
self.input_names = [input.name for input in self.ort_session.get_inputs()]
self.output_names = [output.name for output in self.ort_session.get_outputs()]
self.logger.info(f"Input names: {self.input_names}")
self.logger.info(f"Output names: {self.output_names}")
def preprocess_text(self, texts: List[str], max_length: int = 128) -> Dict[str, np.ndarray]:
"""
预处理文本数据
Args:
texts: 文本列表
max_length: 最大长度
Returns:
包含input_ids和attention_mask的字典
"""
encoding = self.tokenizer(
texts,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='np' # 返回numpy数组
)
return {
'input_ids': encoding['input_ids'].astype(np.int64),
'attention_mask': encoding['attention_mask'].astype(np.int64)
}
def predict_single(self, text: str, max_length: int = 128) -> Tuple[int, float]:
"""单个文本预测"""
inputs = self.preprocess_text([text], max_length)
# ONNX推理
ort_inputs = {
self.input_names[0]: inputs['input_ids'],
self.input_names[1]: inputs['attention_mask']
}
ort_outputs = self.ort_session.run(self.output_names, ort_inputs)
logits = ort_outputs[0]
# 计算概率和预测
probabilities = self.softmax(logits)
prediction = np.argmax(probabilities[0])
confidence = probabilities[0][prediction]
return int(prediction), float(confidence)
def predict_batch(self, texts: List[str], max_length: int = 128,
batch_size: int = 32) -> Tuple[List[int], List[float]]:
"""批量预测"""
all_predictions = []
all_confidences = []
for i in tqdm(range(0, len(texts), batch_size), desc="ONNX Predicting"):
batch_texts = texts[i:i + batch_size]
inputs = self.preprocess_text(batch_texts, max_length)
# ONNX推理
ort_inputs = {
self.input_names[0]: inputs['input_ids'],
self.input_names[1]: inputs['attention_mask']
}
ort_outputs = self.ort_session.run(self.output_names, ort_inputs)
logits = ort_outputs[0]
# 计算概率和预测
probabilities = self.softmax(logits)
predictions = np.argmax(probabilities, axis=1)
confidences = [probabilities[j][pred] for j, pred in enumerate(predictions)]
all_predictions.extend(predictions.tolist())
all_confidences.extend(confidences)
return all_predictions, all_confidences
@staticmethod
def softmax(x):
"""Softmax函数"""
exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
return exp_x / np.sum(exp_x, axis=1, keepdims=True)
def main():
"""主函数"""
if len(sys.argv) < 3:
print("Usage: python validate_onnx.py <tokenizer_path> <onnx_model_path> [test_sentence]")
sys.exit(1)
tokenizer_path = sys.argv[1]
onnx_model_path = sys.argv[2]
test_sentence = sys.argv[3] if len(sys.argv) > 3 else "欢迎测试本判停模型有修正建议请随时提出"
print("\n ONNX Model Inference...")
onnx_inferencer = StopJudgmentONNXInference(onnx_model_path, tokenizer_path)
prediction, confidence = onnx_inferencer.predict_single(
test_sentence, max_length=128
)
print(prediction, confidence)
if __name__ == "__main__":
main()