Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| import os | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from rxlm.rxt.models import RxTBeta | |
| from rxlm.llm.models import DecoderOnlyTransformer | |
| from rxlm.training.tokenizer import load_tokenizer_from_hf_hub | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro', token=HF_TOKEN) | |
| model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised', token=HF_TOKEN, tokenizer=tokenizer) | |
| model.share_components() | |
| llm_tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/rc-RxT-Beta-Base', token=HF_TOKEN) | |
| llm_model = DecoderOnlyTransformer.from_pretrained('ReactiveAI/SQA-Transformer-Beta-SFT', token=HF_TOKEN) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| llm_model.to(device) | |
| initial_stm = model.export_stm_state().cpu() | |
| seq_len = 1024 | |
| llm_seq_len = 4096 | |
| def chat(message: str, history: list, stm_state: torch.Tensor, llm_history: list, temperature: float, top_p: float): | |
| tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device) | |
| model.load_stm_state(stm_state) | |
| response = "" | |
| llm_response = "" | |
| with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): | |
| for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p): | |
| response += model.stringify_token(token_id, show_memory_update=True) | |
| yield history + [[message, response]], stm_state, llm_history | |
| llm_chat_history = llm_model.tokenize_chat_template(llm_tokenizer, llm_history, message, max_seq_len=llm_seq_len, use_simplified_format=True) | |
| llm_chat_history = { | |
| 'input_ids': llm_chat_history['input_ids'].to(device), | |
| 'attention_mask': llm_chat_history['attention_mask'].to(device) | |
| } | |
| with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): | |
| for token_id in llm_model.generate(**llm_chat_history, max_seq_len=llm_seq_len, temperature=temperature, top_p=top_p): | |
| llm_response += llm_model.stringify_token(llm_tokenizer, token_id) | |
| yield history + [[message, response]], stm_state, llm_history + [[message, llm_response]] | |
| return history + [[message, response]], model.export_stm_state().cpu(), llm_history + [[message, llm_response]] | |
| with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo: | |
| gr.Markdown(""" | |
| # RxT-Beta-Micro-Supervised 290M vs Stateless LLM Reference 275M | |
| Compare Experimental Reactive Transformer with Stateless LLM Reference, trained on the same limited real-world data. | |
| Both models were pre-trained on 10B tokens from english wikipedia and FineWeb-edu, then fine-tuned on 1.1M single interactions | |
| and on 30k filtered multi-turn conversations. | |
| That's very small amount of pre-training data, compared to 1T/2T tokens in production small LLMs. Experiment is made to prove | |
| that RxT is learning faster and achieve better results, even after very short training. | |
| Accuracy (next token prediction) in multi-turn conversation training (validation dataset): | |
| - RxT 88% | |
| - LLM 60% | |
| ## Limitations | |
| Supervised version of the model is still in intermediate stage and will be further improved | |
| in Reinforcement Learning stages (demo will be constantly updated), so model could generate | |
| inaccurate answers and memory retention is weak. | |
| """) | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(height=600, label='RxT', type='tuples') | |
| llm_chatbot = gr.Chatbot(height=600, label='LLM', type='tuples') | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Ask Models...", label="Query", scale=4) | |
| send_btn = gr.Button("Send", scale=1) | |
| clear = gr.Button("Clear & Reset STM", scale=1) | |
| with gr.Row(): | |
| temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| stm_state = gr.State(initial_stm.clone()) | |
| msg.submit(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then( | |
| lambda: gr.update(value=""), outputs=msg | |
| ) | |
| send_btn.click(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then( | |
| lambda: gr.update(value=""), outputs=msg | |
| ) | |
| clear.click(lambda: ([], [], initial_stm.clone()), None, [chatbot, llm_chatbot, stm_state]) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |