# app.py import os import gradio as gr import torch import spaces from rxlm.rxt.models import RxTBeta from rxlm.llm.models import DecoderOnlyTransformer from rxlm.training.tokenizer import load_tokenizer_from_hf_hub HF_TOKEN = os.environ.get("HF_TOKEN") tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro', token=HF_TOKEN) model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised', token=HF_TOKEN, tokenizer=tokenizer) model.share_components() llm_tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/rc-RxT-Beta-Base', token=HF_TOKEN) llm_model = DecoderOnlyTransformer.from_pretrained('ReactiveAI/SQA-Transformer-Beta-SFT', token=HF_TOKEN) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) llm_model.to(device) initial_stm = model.export_stm_state().cpu() seq_len = 1024 llm_seq_len = 4096 @spaces.GPU def chat(message: str, history: list, stm_state: torch.Tensor, llm_history: list, temperature: float, top_p: float): tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device) model.load_stm_state(stm_state) response = "" llm_response = "" with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p): response += model.stringify_token(token_id, show_memory_update=True) yield history + [[message, response]], stm_state, llm_history llm_chat_history = llm_model.tokenize_chat_template(llm_tokenizer, llm_history, message, max_seq_len=llm_seq_len, use_simplified_format=True) llm_chat_history = { 'input_ids': llm_chat_history['input_ids'].to(device), 'attention_mask': llm_chat_history['attention_mask'].to(device) } with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): for token_id in llm_model.generate(**llm_chat_history, max_seq_len=llm_seq_len, temperature=temperature, top_p=top_p): llm_response += llm_model.stringify_token(llm_tokenizer, token_id) yield history + [[message, response]], stm_state, llm_history + [[message, llm_response]] return history + [[message, response]], model.export_stm_state().cpu(), llm_history + [[message, llm_response]] with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo: gr.Markdown(""" # RxT-Beta-Micro-Supervised 290M vs Stateless LLM Reference 275M Compare Experimental Reactive Transformer with Stateless LLM Reference, trained on the same limited 10B tokens dataset. ## Limitations Supervised version of the model is still in intermediate stage and will be further improved in Reinforcement Learning stages (demo will be constantly updated), so model could generate inaccurate answers and memory retention is weak. However, it should still demonstate the architecture advantages, especially infinite context and no delays (small delays are caused by Spaces ZeroGPU allocation). """) with gr.Row(): chatbot = gr.Chatbot(height=600, label='RxT', type='tuples') llm_chatbot = gr.Chatbot(height=600, label='LLM', type='tuples') with gr.Row(): msg = gr.Textbox(placeholder="Ask Models...", label="Query", scale=4) send_btn = gr.Button("Send", scale=1) clear = gr.Button("Clear & Reset STM", scale=1) with gr.Row(): temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") stm_state = gr.State(initial_stm.clone()) msg.submit(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then( lambda: gr.update(value=""), outputs=msg ) send_btn.click(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then( lambda: gr.update(value=""), outputs=msg ) clear.click(lambda: ([], [], initial_stm.clone()), None, [chatbot, llm_chatbot, stm_state]) if __name__ == "__main__": demo.queue() demo.launch()