File size: 1,236 Bytes
7e96c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import pipeline
import torch
import spaces

# Load the model pipeline
pipe = pipeline("text-generation", model="google/vaultgemma-1b", device="cuda", torch_dtype=torch.float16)

# Define the chat function
@spaces.GPU(duration=120)
def chat(message, history):
    # Format the conversation history for the model
    prompt = ""
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"
    
    # Generate response
    response = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9)
    generated_text = response[0]['generated_text']
    
    # Extract only the assistant's response
    assistant_response = generated_text.split("Assistant:")[-1].strip()
    
    return assistant_response

# Create the Gradio chat interface
demo = gr.ChatInterface(
    fn=chat,
    title="VaultGemma-1B Chatbot",
    description="A chatbot powered by Google's VaultGemma-1B model.",
    theme="soft",
    examples=[
        "What is the capital of France?",
        "Tell me a joke.",
        "Explain quantum computing in simple terms."
    ],
    concurrency_limit=1
)

# Launch the app
demo.launch()