AdamF92's picture
Update app.py
085b239 verified
raw
history blame
4.19 kB
# app.py
import os
import gradio as gr
import torch
import spaces
from rxlm.rxt.models import RxTBeta
from rxlm.llm.models import DecoderOnlyTransformer
from rxlm.training.tokenizer import load_tokenizer_from_hf_hub
HF_TOKEN = os.environ.get("HF_TOKEN")
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro', token=HF_TOKEN)
model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised', token=HF_TOKEN, tokenizer=tokenizer)
model.share_components()
llm_tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/rc-RxT-Beta-Base', token=HF_TOKEN)
llm_model = DecoderOnlyTransformer.from_pretrained('ReactiveAI/SQA-Transformer-Beta-SFT', token=HF_TOKEN)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
llm_model.to(device)
initial_stm = model.export_stm_state().cpu()
seq_len = 1024
llm_seq_len = 4096
@spaces.GPU
def chat(message: str, history: list, stm_state: torch.Tensor, llm_history: list, temperature: float, top_p: float):
tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)
model.load_stm_state(stm_state)
response = ""
llm_response = ""
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
response += model.stringify_token(token_id, show_memory_update=True)
yield history + [[message, response]], stm_state, llm_history
llm_chat_history = llm_model.tokenize_chat_template(llm_tokenizer, llm_history, message, max_seq_len=llm_seq_len, use_simplified_format=True)
llm_chat_history = {
'input_ids': llm_chat_history['input_ids'].to(device),
'attention_mask': llm_chat_history['attention_mask'].to(device)
}
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
for token_id in llm_model.generate(**llm_chat_history, max_seq_len=llm_seq_len, temperature=temperature, top_p=top_p):
llm_response += llm_model.stringify_token(llm_tokenizer, token_id)
yield history + [[message, response]], stm_state, llm_history + [[message, llm_response]]
return history + [[message, response]], model.export_stm_state().cpu(), llm_history + [[message, llm_response]]
with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
gr.Markdown("""
# RxT-Beta-Micro-Supervised 290M vs Stateless LLM Reference 275M
Compare Experimental Reactive Transformer with Stateless LLM Reference, trained on the same limited 10B tokens dataset.
## Limitations
Supervised version of the model is still in intermediate stage and will be further improved
in Reinforcement Learning stages (demo will be constantly updated), so model could generate
inaccurate answers and memory retention is weak. However, it should still demonstate the architecture
advantages, especially infinite context and no delays (small delays are caused by Spaces ZeroGPU allocation).
""")
with gr.Row():
chatbot = gr.Chatbot(height=600, label='RxT', type='tuples')
llm_chatbot = gr.Chatbot(height=600, label='LLM', type='tuples')
with gr.Row():
msg = gr.Textbox(placeholder="Ask Models...", label="Query", scale=4)
send_btn = gr.Button("Send", scale=1)
clear = gr.Button("Clear & Reset STM", scale=1)
with gr.Row():
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
stm_state = gr.State(initial_stm.clone())
msg.submit(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
send_btn.click(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
clear.click(lambda: ([], [], initial_stm.clone()), None, [chatbot, llm_chatbot, stm_state])
if __name__ == "__main__":
demo.queue()
demo.launch()