|
|
class ChatState(): |
|
|
""" |
|
|
Manages the conversation history for a turn-based chatbot |
|
|
Follows the turn-based conversation guidelines for the Gemma family of models |
|
|
documented at https://ai.google.dev/gemma/docs/formatting |
|
|
""" |
|
|
|
|
|
|
|
|
__START_TURN_USER__ = "Instruction:\n" |
|
|
__START_TURN_MODEL__ = "\n\nResponse:\n" |
|
|
__END_TURN__ = "" |
|
|
|
|
|
|
|
|
def __init__(self, model, system=""): |
|
|
""" |
|
|
Initializes the chat state. |
|
|
|
|
|
Args: |
|
|
model: The language model to use for generating responses. |
|
|
system: (Optional) System instructions or bot description. |
|
|
""" |
|
|
self.model = model |
|
|
self.system = system |
|
|
self.history = [] |
|
|
|
|
|
def add_to_history_as_user(self, message): |
|
|
""" |
|
|
Adds a user message to the history with start/end turn markers. |
|
|
""" |
|
|
self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__) |
|
|
|
|
|
def add_to_history_as_model(self, message): |
|
|
""" |
|
|
Adds a model response to the history with the start turn marker. |
|
|
Model will generate end turn marker. |
|
|
""" |
|
|
self.history.append(self.__START_TURN_MODEL__ + message+ "\n") |
|
|
|
|
|
def get_history(self): |
|
|
""" |
|
|
Returns the entire chat history as a single string. |
|
|
""" |
|
|
return "".join([*self.history]) |
|
|
|
|
|
def get_full_prompt(self): |
|
|
""" |
|
|
Builds the prompt for the language model, including history and system description. |
|
|
""" |
|
|
prompt = self.get_history() + self.__START_TURN_MODEL__ |
|
|
if len(self.system)>0: |
|
|
prompt = self.system + "\n" + prompt |
|
|
return prompt |
|
|
|
|
|
def send_message(self, message): |
|
|
""" |
|
|
Handles sending a user message and getting a model response. |
|
|
|
|
|
Args: |
|
|
message: The user's message. |
|
|
|
|
|
Returns: |
|
|
The model's response. |
|
|
""" |
|
|
self.add_to_history_as_user(message) |
|
|
prompt = self.get_full_prompt() |
|
|
response = self.model.generate(prompt, max_length=4096) |
|
|
result = response.replace(prompt, "") |
|
|
self.add_to_history_as_model(result) |
|
|
return result |
|
|
|