AdamF92 commited on
Commit
fc90fb4
·
verified ·
1 Parent(s): 1df071f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -41,6 +41,11 @@ def chat(message: str, history: list, stm_state: torch.Tensor, llm_history: list
41
 
42
  llm_chat_history = llm_model.tokenize_chat_template(llm_tokenizer, llm_history, message, max_seq_len=llm_seq_len, use_simplified_format=True)
43
 
 
 
 
 
 
44
  with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
45
  for token_id in llm_model.generate(**llm_chat_history, max_seq_len=llm_seq_len, temperature=temperature, top_p=top_p):
46
  llm_response += model.stringify_token(token_id, show_memory_update=False)
 
41
 
42
  llm_chat_history = llm_model.tokenize_chat_template(llm_tokenizer, llm_history, message, max_seq_len=llm_seq_len, use_simplified_format=True)
43
 
44
+ llm_chat_history = {
45
+ 'input_ids': llm_chat_history['input_ids'].to(device),
46
+ 'attention_mask': llm_chat_history['attention_mask'].to(device)
47
+ }
48
+
49
  with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
50
  for token_id in llm_model.generate(**llm_chat_history, max_seq_len=llm_seq_len, temperature=temperature, top_p=top_p):
51
  llm_response += model.stringify_token(token_id, show_memory_update=False)