import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch # Load model and tokenizer model_name = "baidu/ERNIE-4.5-0.3B-PT" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # Define stopping criteria — stop at end of assistant turn stop_tokens = ["User:", "Assistant:", "\nUser", "\nAssistant"] pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1, max_new_tokens=64, # More conservative do_sample=True, temperature=0.7, top_p=0.92, pad_token_id=tokenizer.eos_token_id, ) def chat_function(message, history): # Build prompt with only last 3 exchanges to avoid confusion conversation = "" for human, assistant in history[-3:]: # Only keep last 3 turns conversation += f"User: {human}\nAssistant: {assistant}\n" conversation += f"User: {message}\nAssistant:" # Generate outputs = pipe( conversation, return_full_text=False, max_new_tokens=64, temperature=0.7, top_p=0.92, pad_token_id=tokenizer.eos_token_id, ) response = outputs[0]['generated_text'].strip() # Aggressive cleanup: stop at any unwanted token for stop in stop_tokens: if stop in response: response = response.split(stop)[0].strip() # Remove trailing punctuation or colons response = response.rstrip(":").strip() return response # Gradio Interface with gr.Blocks(title="baidu/ERNIE-4.5-0.3B-PT Chat") as demo: gr.Markdown("# 🤖 baidu/ERNIE-4.5-0.3B-PT Simple Chat") gr.Markdown("A minimal chat interface using `baidu/ERNIE-4.5-0.3B-PT`. Optimized for clean single-turn responses.") chatbot = gr.Chatbot(height=400) msg = gr.Textbox(label="Type your message", placeholder="Say something...") clear = gr.Button("Clear") def respond(message, chat_history): bot_message = chat_function(message, chat_history) chat_history.append((message, bot_message)) return "", chat_history msg.submit(respond, [msg, chatbot], [msg, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.launch()