Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import gradio as gr | |
| # 加载指令模型 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "MediaTek-Research/Breeze-7B-Instruct-v1_0", | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # 加载分词器 | |
| tokenizer = AutoTokenizer.from_pretrained("MediaTek-Research/Breeze-7B-Instruct-v1_0") | |
| # 定义SYS_PROMPT | |
| SYS_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan." | |
| # 定义对话历史记录变量 | |
| chat_history = [] | |
| def generate_response(user_input): | |
| global chat_history | |
| # 将用户输入添加到对话历史记录 | |
| chat_history.append({"role": "user", "content": user_input}) | |
| # 应用聊天模板 | |
| prompt = tokenizer.apply_chat_template(chat_history, tokenize=False) | |
| full_prompt = f"<s>{SYS_PROMPT} [INST] {prompt} [/INST]" | |
| # 生成文本 | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| max_new_tokens=128, | |
| top_p=0.95, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| temperature=0.7, | |
| do_sample=True, # 启用 sample-based 生成模式 | |
| attention_mask=inputs["attention_mask"], # 设置 attention mask | |
| ) | |
| # 解码输出 | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # 将生成的文本添加到对话历史记录 | |
| chat_history.append({"role": "assistant", "content": generated_text}) | |
| # 只保留最新的五个问答对话 | |
| if len(chat_history) > 10: | |
| chat_history = chat_history[-10:] | |
| return generated_text | |
| # 创建Gradio界面 | |
| def chat_interface(user_input, history): | |
| response = generate_response(user_input) | |
| history.append({"role": "user", "content": user_input}) | |
| history.append({"role": "assistant", "content": response}) | |
| return "", history | |
| iface = gr.Blocks() | |
| with iface: | |
| gr.Markdown("# 醫療問答助手\n這是一個基於 MediaTek-Research/Breeze-7B-Instruct-v1_0 模型的醫療問答助手。") | |
| chatbot = gr.Chatbot(type="messages") # 使用OpenAI风格的字典格式 | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| show_label=False, | |
| placeholder="請輸入你的問題...", | |
| lines=1, | |
| ) | |
| txt.submit(chat_interface, [txt, chatbot], [txt, chatbot]) | |
| # 启动Gradio界面并共享链接 | |
| iface.launch(share=True) | |