File size: 4,548 Bytes
1df071f
 
f90c852
1df071f
 
 
 
 
f90c852
1df071f
f90c852
1df071f
 
 
f90c852
1df071f
 
f90c852
1df071f
 
 
f90c852
1df071f
f90c852
1df071f
 
 
 
 
 
 
 
 
f90c852
1df071f
 
 
 
 
 
 
 
 
fc90fb4
 
 
 
 
1df071f
 
98ba77e
1df071f
 
 
 
 
 
 
f2bf604
 
 
 
 
 
 
 
 
 
 
1df071f
 
 
 
f2bf604
1df071f
 
 
085b239
 
1df071f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f90c852
1df071f
 
 
 
 
f90c852
 
 
1df071f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# app.py
import os
import gradio as gr
import torch
import spaces
from rxlm.rxt.models import RxTBeta
from rxlm.llm.models import DecoderOnlyTransformer
from rxlm.training.tokenizer import load_tokenizer_from_hf_hub

HF_TOKEN = os.environ.get("HF_TOKEN")

tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro', token=HF_TOKEN)
model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised', token=HF_TOKEN, tokenizer=tokenizer)
model.share_components()

llm_tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/rc-RxT-Beta-Base', token=HF_TOKEN)
llm_model = DecoderOnlyTransformer.from_pretrained('ReactiveAI/SQA-Transformer-Beta-SFT', token=HF_TOKEN)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
llm_model.to(device)

initial_stm = model.export_stm_state().cpu()

seq_len = 1024
llm_seq_len = 4096

@spaces.GPU
def chat(message: str, history: list, stm_state: torch.Tensor, llm_history: list, temperature: float, top_p: float):
    tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)

    model.load_stm_state(stm_state)
    
    response = ""
    llm_response = ""

    with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
        for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
            response += model.stringify_token(token_id, show_memory_update=True)
            yield history + [[message, response]], stm_state, llm_history

    llm_chat_history = llm_model.tokenize_chat_template(llm_tokenizer, llm_history, message, max_seq_len=llm_seq_len, use_simplified_format=True)

    llm_chat_history = {
        'input_ids': llm_chat_history['input_ids'].to(device),
        'attention_mask': llm_chat_history['attention_mask'].to(device)
    }
    
    with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
        for token_id in llm_model.generate(**llm_chat_history, max_seq_len=llm_seq_len, temperature=temperature, top_p=top_p):
            llm_response += llm_model.stringify_token(llm_tokenizer, token_id)
            yield history + [[message, response]], stm_state, llm_history + [[message, llm_response]]
    
    return history + [[message, response]], model.export_stm_state().cpu(), llm_history + [[message, llm_response]]

with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
    gr.Markdown("""
    # RxT-Beta-Micro-Supervised 290M vs Stateless LLM Reference 275M
    Compare Experimental Reactive Transformer with Stateless LLM Reference, trained on the same limited real-world data.

    Both models were pre-trained on 10B tokens from english wikipedia and FineWeb-edu, then fine-tuned on 1.1M single interactions
    and on 30k filtered multi-turn conversations.

    That's very small amount of pre-training data, compared to 1T/2T tokens in production small LLMs. Experiment is made to prove
    that RxT is learning faster and achieve better results, even after very short training.

    Accuracy (next token prediction) in multi-turn conversation training (validation dataset):
    - RxT 88%
    - LLM 60%

    ## Limitations
    Supervised version of the model is still in intermediate stage and will be further improved
    in Reinforcement Learning stages (demo will be constantly updated), so model could generate
    inaccurate answers and memory retention is weak.
    """)

    with gr.Row():
        chatbot = gr.Chatbot(height=600, label='RxT', type='tuples')
        llm_chatbot = gr.Chatbot(height=600, label='LLM', type='tuples')
    
    with gr.Row():
        msg = gr.Textbox(placeholder="Ask Models...", label="Query", scale=4)
        send_btn = gr.Button("Send", scale=1)
        clear = gr.Button("Clear & Reset STM", scale=1)

    with gr.Row():
        temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
    
    stm_state = gr.State(initial_stm.clone())
    
    msg.submit(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then(
        lambda: gr.update(value=""), outputs=msg
    )

    send_btn.click(chat, [msg, chatbot, stm_state, llm_chatbot, temp, top_p], [chatbot, stm_state, llm_chatbot], queue=True).then(
        lambda: gr.update(value=""), outputs=msg
    )
    
    clear.click(lambda: ([], [], initial_stm.clone()), None, [chatbot, llm_chatbot, stm_state])


if __name__ == "__main__":
    demo.queue()
    demo.launch()