AdamF92's picture
Update app.py
f2bf604 verified
# app.py
import os
import gradio as gr
import torch
import spaces
from rxlm.rxt.models import RxTBeta
from rxlm.llm.models import DecoderOnlyTransformer
from rxlm.training.tokenizer import load_tokenizer_from_hf_hub
HF_TOKEN = os.environ.get("HF_TOKEN")
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro', token=HF_TOKEN)
model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised', token=HF_TOKEN, tokenizer=tokenizer)
model.share_components()
llm_tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/rc-RxT-Beta-Base', token=HF_TOKEN)
llm_model = DecoderOnlyTransformer.from_pretrained('ReactiveAI/SQA-Transformer-Beta-SFT', token=HF_TOKEN)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
llm_model.to(device)
initial_stm = model.export_stm_state().cpu()
seq_len = 1024
llm_seq_len = 4096
@spaces.GPU
def chat(message: str, history: list, stm_state: torch.Tensor, llm_history: list, temperature: float, top_p: float):
tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)
model.load_stm_state(stm_state)
response = ""
llm_response = ""
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
response += model.stringify_token(token_id, show_memory_update=True)
yield history + [[message, response]], stm_state, llm_history
llm_chat_history = llm_model.tokenize_chat_template(llm_tokenizer, llm_history, message, max_seq_len=llm_seq_len, use_simplified_format=True)
llm_chat_history = {
'input_ids': llm_chat_history['input_ids'].to(device),
'attention_mask': llm_chat_history['attention_mask'].to(device)
}
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
for token_id in llm_model.generate(**llm_chat_history, max_seq_len=llm_seq_len, temperature=temperature, top_p=top_p):
llm_response += llm_model.stringify_token(llm_tokenizer, token_id)
yield history + [[message, response]], stm_state, llm_history + [[message, llm_response]]
return history + [[message, response]], model.export_stm_state().cpu(), llm_history + [[message, llm_response]]
with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
gr.Markdown("""
# RxT-Beta-Micro-Supervised 290M vs Stateless LLM Reference 275M
Compare Experimental Reactive Transformer with Stateless LLM Reference, trained on the same limited real-world data.
Both models were pre-trained on 10B tokens from english wikipedia and FineWeb-edu, then fine-tuned on 1.1M single interactions
and on 30k filtered multi-turn conversations.
That's very small amount of pre-training data, compared to 1T/2T tokens in production small LLMs. Experiment is made to prove
that RxT is learning faster and achieve better results, even after very short training.
Accuracy (next token prediction) in multi-turn conversation training (validation dataset):
- RxT 88%
- LLM 60%
## Limitations
Supervised version of the model is still in intermediate stage and will be further improved
in Reinforcement Learning stages (demo will be constantly updated), so model could generate
inaccurate answers and memory retention is weak.
""")
with gr.Row():
chatbot = gr.Chatbot(height=600, label='RxT', type='tuples')
llm_chatbot = gr.Chatbot(height=600, label='LLM', type='tuples')
with gr.Row():
msg = gr.Textbox(placeholder="Ask Models...", label="Query", scale=4)
send_btn = gr.Button("Send", scale=1)
clear = gr.Button("Clear & Reset STM", scale=1)
with gr.Row():
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
stm_state = gr.State(initial_stm.clone())
msg.submit(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
send_btn.click(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
clear.click(lambda: ([], [], initial_stm.clone()), None, [chatbot, llm_chatbot, stm_state])
if __name__ == "__main__":
demo.queue()
demo.launch()