Commit
·
e53e879
1
Parent(s):
7e8ce88
Fix logging clear behavior
Browse files- chat_interface_preference.py +15 -11
chat_interface_preference.py
CHANGED
|
@@ -39,6 +39,7 @@ from huggingface_hub import CommitScheduler
|
|
| 39 |
pattern = re.compile(r'<div class="message-identifier">(.*?)</div>', re.DOTALL)
|
| 40 |
|
| 41 |
PREFERENCE_TECHNIQUE_MAPPING = {"sft": "prompt", "dpo": "preference", "kto": "vibes"}
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
@document()
|
|
@@ -684,7 +685,7 @@ class ChatInterface(Blocks):
|
|
| 684 |
else:
|
| 685 |
response = None
|
| 686 |
if self._check_if_two_responses(response):
|
| 687 |
-
Info(
|
| 688 |
return history, history
|
| 689 |
else:
|
| 690 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
|
@@ -729,7 +730,7 @@ class ChatInterface(Blocks):
|
|
| 729 |
else:
|
| 730 |
response = None
|
| 731 |
if self._check_if_two_responses(response):
|
| 732 |
-
Info(
|
| 733 |
yield history, history
|
| 734 |
else:
|
| 735 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
|
@@ -814,19 +815,18 @@ class ChatInterface(Blocks):
|
|
| 814 |
list[list[str | tuple | None]],
|
| 815 |
]:
|
| 816 |
self._check_num_turns(history, generate=False)
|
|
|
|
| 817 |
history_as_openai_format = self._get_conversation_in_openai_format(history)
|
| 818 |
feedback = {"prompt": history_as_openai_format}
|
| 819 |
-
|
| 820 |
-
prompt, response = history[-1]
|
| 821 |
matches = self._check_if_two_responses(response)
|
|
|
|
| 822 |
if matches and log != "prompt":
|
| 823 |
option_a, option_b = matches[0], matches[1]
|
| 824 |
if log == "a":
|
| 825 |
chosen, rejected = option_a, option_b
|
| 826 |
-
Info("Logged preference: a")
|
| 827 |
elif log == "b":
|
| 828 |
chosen, rejected = option_b, option_a
|
| 829 |
-
Info("Logged preference: b")
|
| 830 |
elif log == "ab":
|
| 831 |
options = [option_a, option_b]
|
| 832 |
chosen, rejected = random.choice([options])
|
|
@@ -838,23 +838,26 @@ class ChatInterface(Blocks):
|
|
| 838 |
"rejected": [{"content": rejected, "role": "assistant"}],
|
| 839 |
}
|
| 840 |
)
|
|
|
|
| 841 |
self._save_feedback(feedback)
|
|
|
|
| 842 |
elif log == "ab":
|
| 843 |
self._save_feedback(feedback)
|
| 844 |
history[-1] = [prompt, chosen]
|
| 845 |
return history, message or "", history
|
| 846 |
elif log in ["conversation", "good", "bad"]:
|
| 847 |
-
feedback.update({"response": response})
|
| 848 |
if log == "good":
|
| 849 |
feedback.update({"label": True})
|
| 850 |
elif log == "bad":
|
| 851 |
feedback.update({"label": False})
|
| 852 |
-
Info("Logged conversation")
|
| 853 |
-
self._save_feedback(feedback)
|
| 854 |
|
|
|
|
|
|
|
| 855 |
return history, "", history
|
| 856 |
else:
|
| 857 |
-
|
|
|
|
|
|
|
| 858 |
|
| 859 |
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
| 860 |
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
|
@@ -906,7 +909,8 @@ class ChatInterface(Blocks):
|
|
| 906 |
if history:
|
| 907 |
_, response = history[-1]
|
| 908 |
if self._check_if_two_responses(response):
|
| 909 |
-
|
|
|
|
| 910 |
else:
|
| 911 |
await self._log_fn(message=message, history=history, log="prompt")
|
| 912 |
self._set_conversation_id()
|
|
|
|
| 39 |
pattern = re.compile(r'<div class="message-identifier">(.*?)</div>', re.DOTALL)
|
| 40 |
|
| 41 |
PREFERENCE_TECHNIQUE_MAPPING = {"sft": "prompt", "dpo": "preference", "kto": "vibes"}
|
| 42 |
+
_LOG_REQUIRED_MESSAGE = "First, provide preference, undo or clear to continue conversation."
|
| 43 |
|
| 44 |
|
| 45 |
@document()
|
|
|
|
| 685 |
else:
|
| 686 |
response = None
|
| 687 |
if self._check_if_two_responses(response):
|
| 688 |
+
Info(_LOG_REQUIRED_MESSAGE)
|
| 689 |
return history, history
|
| 690 |
else:
|
| 691 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
|
|
|
| 730 |
else:
|
| 731 |
response = None
|
| 732 |
if self._check_if_two_responses(response):
|
| 733 |
+
Info(_LOG_REQUIRED_MESSAGE)
|
| 734 |
yield history, history
|
| 735 |
else:
|
| 736 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
|
|
|
| 815 |
list[list[str | tuple | None]],
|
| 816 |
]:
|
| 817 |
self._check_num_turns(history, generate=False)
|
| 818 |
+
prompt, response = history[-1]
|
| 819 |
history_as_openai_format = self._get_conversation_in_openai_format(history)
|
| 820 |
feedback = {"prompt": history_as_openai_format}
|
| 821 |
+
feedback.update({"response": [{"content": response, "role": "assistant"}]})
|
|
|
|
| 822 |
matches = self._check_if_two_responses(response)
|
| 823 |
+
|
| 824 |
if matches and log != "prompt":
|
| 825 |
option_a, option_b = matches[0], matches[1]
|
| 826 |
if log == "a":
|
| 827 |
chosen, rejected = option_a, option_b
|
|
|
|
| 828 |
elif log == "b":
|
| 829 |
chosen, rejected = option_b, option_a
|
|
|
|
| 830 |
elif log == "ab":
|
| 831 |
options = [option_a, option_b]
|
| 832 |
chosen, rejected = random.choice([options])
|
|
|
|
| 838 |
"rejected": [{"content": rejected, "role": "assistant"}],
|
| 839 |
}
|
| 840 |
)
|
| 841 |
+
feedback["response"] = [{"content": chosen, "role": "assistant"}]
|
| 842 |
self._save_feedback(feedback)
|
| 843 |
+
Info("Logged succesfully")
|
| 844 |
elif log == "ab":
|
| 845 |
self._save_feedback(feedback)
|
| 846 |
history[-1] = [prompt, chosen]
|
| 847 |
return history, message or "", history
|
| 848 |
elif log in ["conversation", "good", "bad"]:
|
|
|
|
| 849 |
if log == "good":
|
| 850 |
feedback.update({"label": True})
|
| 851 |
elif log == "bad":
|
| 852 |
feedback.update({"label": False})
|
|
|
|
|
|
|
| 853 |
|
| 854 |
+
self._save_feedback(feedback)
|
| 855 |
+
Info("Logged succesfully")
|
| 856 |
return history, "", history
|
| 857 |
else:
|
| 858 |
+
self._save_feedback(feedback)
|
| 859 |
+
Info("Logged succesfully")
|
| 860 |
+
return history, "", history
|
| 861 |
|
| 862 |
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
| 863 |
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
|
|
|
| 909 |
if history:
|
| 910 |
_, response = history[-1]
|
| 911 |
if self._check_if_two_responses(response):
|
| 912 |
+
Info(_LOG_REQUIRED_MESSAGE)
|
| 913 |
+
return history, message or "", history
|
| 914 |
else:
|
| 915 |
await self._log_fn(message=message, history=history, log="prompt")
|
| 916 |
self._set_conversation_id()
|