File size: 3,276 Bytes
08a6c8d
5f6343c
08a6c8d
f7151f4
eaadaf2
08a6c8d
 
 
f7151f4
08a6c8d
e67e492
08a6c8d
e67e492
08a6c8d
 
 
 
 
 
 
 
f7151f4
 
 
7c83f89
08a6c8d
 
 
 
f7151f4
eaadaf2
08a6c8d
e325f49
659f477
e325f49
 
 
 
 
 
bbbe5a6
a2f5d42
5516e3f
bbbe5a6
 
 
5516e3f
bbbe5a6
 
 
 
e325f49
a2f5d42
e325f49
a2f5d42
e325f49
 
2ca3c06
e325f49
 
659f477
 
 
08a6c8d
 
 
f7151f4
a9aaecd
659f477
 
 
 
 
f7151f4
08a6c8d
f7151f4
08a6c8d
 
bca5e76
08a6c8d
659f477
1ae1376
08a6c8d
f7151f4
08a6c8d
 
f7151f4
08a6c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7151f4
 
 
6b6f47e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from threading import Thread
from typing import Dict

import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer


TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>"

DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>"

CSS = """
.duplicate-button {
  margin: auto !important;
  color: white !important;
  background: black !important;
  border-radius: 100vh !important;
}
"""


model_id = "sam2ai/odia-paligemma-2b-5000-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")


@spaces.GPU
def stream_chat(message: Dict[str, str], history: list):
    # Turn 1:
    # {'text': 'what is this', 'files': ['image-xxx.jpg']}
    # []

    # Turn 2:
    # {'text': 'continue?', 'files': []}
    # [[('image-xxx.jpg',), None], ['what is this', 'a image.']]

    files = message.get('files', [])
    image_path = None
    print(files)
    
    if files:
        image_path = files[0]['path']
        print(image_path)

    # image_path = None
    # if len(message.files) != 0:
    #     image_path = message.files[0].path

    if len(history) != 0 and isinstance(history[0][0], tuple):
        image_path = history[0][0][0]
        history = history[1:]

    if image_path is not None:
        image = Image.open(image_path).convert("RGB")
    else:
        image = Image.new("RGB", (100, 100), (255, 255, 255))

    pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]

    conversation = []
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    conversation.append({"role": "user", "content": message.text})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    image_token_id = tokenizer.convert_tokens_to_ids("<image>")
    image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
    input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        pixel_values=pixel_values,
        streamer=streamer,
        max_new_tokens=256,
        do_sample=True,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    output = ""
    for new_token in streamer:
        output += new_token
        yield output


chatbot = gr.Chatbot(height=450)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        multimodal=True,
        chatbot=chatbot,
        fill_height=True,
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()