Spaces:
Sleeping
Sleeping
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer | |
| TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.1</center></h1>" | |
| DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/hiyouga/PaliGemma-3B-Chat-v0.1' target='_blank'>our model page</a> for details.</center></h3>" | |
| CSS = """ | |
| .duplicate-button { | |
| margin: auto !important; | |
| color: white !important; | |
| background: black !important; | |
| border-radius: 100vh !important; | |
| } | |
| """ | |
| model_id = "hiyouga/PaliGemma-3B-Chat-v0.1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto") | |
| def stream_chat(message: Dict[str, str], history: list): | |
| print(message) | |
| conversation = [] | |
| for prompt, answer in history: | |
| conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) | |
| conversation.append({"role": "user", "content": message}) | |
| input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to( | |
| model.device | |
| ) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| ) | |
| if temperature == 0: | |
| generate_kwargs["do_sample"] = False | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| output = "" | |
| for new_token in streamer: | |
| output += new_token | |
| yield output | |
| chatbot = gr.Chatbot(height=450) | |
| with gr.Blocks(css=CSS) as demo: | |
| gr.HTML(TITLE) | |
| gr.HTML(DESCRIPTION) | |
| gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
| gr.ChatInterface( | |
| fn=stream_chat, | |
| multimodal=True, | |
| chatbot=chatbot, | |
| fill_height=True, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |