File size: 8,261 Bytes
dd6543b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0bb1e3
dd6543b
 
 
 
 
a0bb1e3
dd6543b
 
 
 
 
 
a0bb1e3
dd6543b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0bb1e3
dd6543b
a0bb1e3
 
dd6543b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0bb1e3
dd6543b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a8bc39
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# -*- coding: utf-8 -*-
"""notebook9de6b64a65 (4).ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1c2KnGGRls-UP7uEOg2uV-FX5EjJkhtte
"""

# Commented out IPython magic to ensure Python compatibility.
# %%capture output
# %pip install unsloth
# %pip install -qU "langchain-chroma>=0.1.2" langchain-huggingface langchain-core
# %pip install -U gradio pillow datasets
# %pip install assemblyai PyMuPDF

from huggingface_hub import snapshot_download
import gradio as gr
import assemblyai as aai
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import chromadb
from unsloth import FastVisionModel # FastLanguageModel for LLMs
import torch
from langchain_core.vectorstores import InMemoryVectorStore
from transformers import TextIteratorStreamer
from PIL import Image


snapshot_download(repo_id="pranavupadhyaya52/lavita-MedQuAD-embeddings", repo_type="dataset", local_dir="./chroma_langchain_db")

aai.settings.api_key = "c50e769cd99c43509c13bd6226645a2c"

config = aai.TranscriptionConfig(speech_model=aai.SpeechModel.best)

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2",  model_kwargs={"device": "cuda:0"})

vector_store = Chroma(
    collection_name="mediwiki_rag",
    embedding_function=embeddings,
    persist_directory="./chroma_langchain_db",  # Where to save data locally, remove if not necessary
)

persistent_client = chromadb.PersistentClient()
collection = persistent_client.get_or_create_collection("collection_name")

retriever = vector_store.as_retriever(
)

"""
from datasets import load_dataset
from langchain_core.documents import Document

data = load_dataset("lavita/MedQuAD")

vector_store.add_documents(documents=[Document(page_content=str(i['answer']),metadata={"source":i['document_source']}, ) for k, i in  zip(range(41000), data["train"])])
"""

model, tokenizer = FastVisionModel.from_pretrained(
    "pranavupadhyaya52/llama-3.2-11b-vision-instruct-mediwiki3",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth",
)

similarity_store = InMemoryVectorStore(embeddings)


class unsloth_agent:
    def __init__(self, model, tokenizer, sys_prompt, name):
        self.name = name
        self.model, self.tokenizer = model, tokenizer
        self.sys_prompt = sys_prompt

    def no_image_prompt(self, prompt):
            message = tokenizer.apply_chat_template([
                {"role": "user", "content": [
                {"type": "text", "text": self.sys_prompt}
                ]},
                {"role": "user", "content": [
                {"type": "text", "text": prompt}
                ]}
                 ], add_generation_prompt = True)
            return message

    def yes_image_prompt(self, prompt):
            message = tokenizer.apply_chat_template([
                {"role": "user", "content": [
                {"type": "text", "text": self.sys_prompt}
                ]},
                {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": prompt}
            ]}
            ], add_generation_prompt = True)
            return message

class unsloth_agent_supervisor:

    def __init__(self, agents:list, user_input:str, file=None):
        if user_input == None and file == None:
            user_input = "prompt user for input"
        elif type(file).__name__ == "list":
            user_input = "Tell user that only one multimedia object is allowed."
        else:
            pass

        _agents, _agent_name = [str(x.name) for x in agents], ""
        current_agent = None

        if file!= None:
            _agent_name = self.similarity_finder(_agents, file[len(file)-3:])
        else:
            _agent_name = self.similarity_finder(_agents, "text")


        for i in agents:
            if i.name == _agent_name:
                current_agent = i
            else:
                pass


        if file == None and user_input != None:
            image = None
            message = current_agent.no_image_prompt(user_input)
        elif str(file[(len(file)-3):]) in ["mp3", "wav"]:
            image = None
            input_text = aai.Transcriber(config=config).transcribe(file)
            message = current_agent.no_image_prompt(input_text.text)
        elif str(file[(len(file)-3):]) in ["jpg", "peg", "png", "bmp"]:
            image = Image.open(file)
            message = current_agent.yes_image_prompt(user_input)
        else:
            image = None
            message = current_agent.no_image_prompt("Prompt the user to enter atleast one input.")


        inputs = current_agent.tokenizer(
            image,
            message,
            add_special_tokens = False,
            return_tensors = "pt",
        ).to("cuda")

        text_streamer = TextIteratorStreamer(current_agent.tokenizer, skip_prompt = True)
        _ = current_agent.model.generate(**inputs, streamer=text_streamer, max_new_tokens=128, use_cache=True, temperature=1.5, min_p=0.1)
        self.streamer = text_streamer

    def similarity_finder(self, keywords: list, sentence: str):

        return_keyword = ""
        similarity_store.add_texts(keywords)
        return_keyword = similarity_store.similarity_search(sentence, k=1)
        similarity_store.delete()
        return return_keyword[0].page_content

text_agent = unsloth_agent(model=model,
                           tokenizer=tokenizer,
                           sys_prompt="You are a medical assistant. Answer the query in two sentences or less. Also, please put a disclaimer in the end that this does not construe medical", name="text_agent")

image_agent = unsloth_agent(model=model,
                           tokenizer=tokenizer,
                           sys_prompt="You are a medical assistant. Describe the image in five sentences or less. Also, please put a disclaimer in the end that this does not constitute medical advice",name="image_agent")
audio_agent = unsloth_agent(model=model,
                           tokenizer=tokenizer,
                           sys_prompt="You are a medical assistant. Answer the query in two sentences or less. Also, please put a disclaimer in the end that this does not constitute medical advice",name="audio_agent")

def gradio_chat(messages, history):
   if len(messages["files"]) == 0 and len(messages["text"]) == 0:
        return "Please enter a valid input"

   elif len(messages["files"]) == 0 or messages["files"][0][(len(messages["files"][0])-3):] in ["mp3", "wav"]:
        output_text, input_prompt, input_file = "", "", None
        if len(messages["files"])==0:
          input_prompt = f"""{messages["text"]}, context : {retriever.invoke(messages["text"])}"""
          input_file = None
        elif len(messages["text"]) == 0:
          input_prompt = None
          input_file = messages["files"][0]
        else:
          input_prompt = f"""{messages["text"]}, context : {retriever.invoke(messages["text"])}"""
          input_file = messages["files"][0]
        supervisor_agent = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], input_prompt, input_file)
        for chat in supervisor_agent.streamer:
            output_text += chat
        context = retriever.invoke(output_text)
        return output_text

   elif len(messages["text"]) == 0 or messages["files"][0][(len(messages["files"][0])-3):] in ["jpg", "peg", "png", "bmp"]:
        output_text, final_text = "", ""
        supervisor_agent = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], messages["text"], file=messages["files"][0])
        for chat in supervisor_agent.streamer:
            output_text += chat
        context = retriever.invoke(output_text)
        final_supervisor = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], f"{output_text} context={context[0].page_content}")
        for final_chat in final_supervisor.streamer:
            final_text += final_chat
        return final_text

   else:
    return "Invalid Input"


app = gr.ChatInterface(fn=gradio_chat,  type="messages", title="Medical Assistant", multimodal=True)
if __name__ == "__main__":
    app.launch()