Update src/processor/message_processor.py
Browse files
src/processor/message_processor.py
CHANGED
|
@@ -1,20 +1,97 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Hadad <[email protected]>
|
| 2 |
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
from typing import List, Dict
|
| 6 |
-
from config import INSTRUCTIONS_START, INSTRUCTIONS_END
|
| 7 |
from src.providers.gemini import GeminiProvider
|
| 8 |
|
| 9 |
-
|
|
|
|
| 10 |
"""
|
| 11 |
-
|
| 12 |
-
messages: [{"role": "user"/"assistant", "content": "..."}]
|
| 13 |
"""
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
reply = GeminiProvider.generate(
|
| 16 |
-
model_name="gemini-1.5-flash"
|
| 17 |
-
messages=messages,
|
| 18 |
-
system_prompt=system_prompt,
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
return reply
|
|
|
|
|
|
|
| 1 |
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Universal message processor that is compatible with Gradio ChatInterface passing:
|
| 4 |
+
- (message: str, history: list[tuple[str, str]] or list[dict]])
|
| 5 |
+
- (messages: list[dict]) when type="messages"
|
| 6 |
+
It normalizes into OpenAI/Gemini-style [{"role":"user"|"assistant","content":"..."}].
|
| 7 |
+
Then it calls the Gemini provider to generate a reply.
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
+
from typing import List, Dict, Tuple, Any, Optional
|
| 11 |
+
from config import MODEL, INSTRUCTIONS_START, INSTRUCTIONS_END, MAX_TOKENS # MODEL 來自 config
|
| 12 |
from src.providers.gemini import GeminiProvider
|
| 13 |
|
| 14 |
+
# --------- helpers ---------
|
| 15 |
+
def _tuple_history_to_messages(history: List[Tuple[Optional[str], Optional[str]]]) -> List[Dict[str, str]]:
|
| 16 |
"""
|
| 17 |
+
Convert ChatInterface tuple history [(user, assistant), ...] -> role/content messages.
|
|
|
|
| 18 |
"""
|
| 19 |
+
msgs: List[Dict[str, str]] = []
|
| 20 |
+
for ua in history or []:
|
| 21 |
+
if not isinstance(ua, (list, tuple)) or len(ua) != 2:
|
| 22 |
+
# already normalized or unexpected; skip
|
| 23 |
+
continue
|
| 24 |
+
user_msg, assistant_msg = ua
|
| 25 |
+
if user_msg:
|
| 26 |
+
msgs.append({"role": "user", "content": str(user_msg)})
|
| 27 |
+
if assistant_msg:
|
| 28 |
+
msgs.append({"role": "assistant", "content": str(assistant_msg)})
|
| 29 |
+
return msgs
|
| 30 |
+
|
| 31 |
+
def _dict_history_to_messages(history: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
| 32 |
+
"""
|
| 33 |
+
Pass-through for ChatInterface type='messages' history.
|
| 34 |
+
"""
|
| 35 |
+
msgs: List[Dict[str, str]] = []
|
| 36 |
+
for m in history or []:
|
| 37 |
+
role = m.get("role")
|
| 38 |
+
content = m.get("content")
|
| 39 |
+
if role and content is not None:
|
| 40 |
+
msgs.append({"role": str(role), "content": str(content)})
|
| 41 |
+
return msgs
|
| 42 |
+
|
| 43 |
+
def _normalize_inputs(*args: Any) -> List[Dict[str, str]]:
|
| 44 |
+
"""
|
| 45 |
+
Accepts either:
|
| 46 |
+
- (messages: list[dict]) # type='messages'
|
| 47 |
+
- (message: str, history: list) # default type
|
| 48 |
+
Returns list[{"role", "content"}]
|
| 49 |
+
"""
|
| 50 |
+
# Case A: single argument
|
| 51 |
+
if len(args) == 1:
|
| 52 |
+
payload = args[0]
|
| 53 |
+
# Already messages list (type="messages")
|
| 54 |
+
if isinstance(payload, list) and (not payload or isinstance(payload[0], dict)):
|
| 55 |
+
return _dict_history_to_messages(payload)
|
| 56 |
+
# Single string -> single-turn chat
|
| 57 |
+
if isinstance(payload, str):
|
| 58 |
+
return [{"role": "user", "content": payload}]
|
| 59 |
+
# Fallback
|
| 60 |
+
return [{"role": "user", "content": str(payload)}]
|
| 61 |
+
|
| 62 |
+
# Case B: two arguments (message, history)
|
| 63 |
+
if len(args) >= 2:
|
| 64 |
+
message, history = args[0], args[1]
|
| 65 |
+
msgs: List[Dict[str, str]] = []
|
| 66 |
+
if isinstance(history, list):
|
| 67 |
+
if history and isinstance(history[0], dict):
|
| 68 |
+
msgs.extend(_dict_history_to_messages(history))
|
| 69 |
+
else:
|
| 70 |
+
msgs.extend(_tuple_history_to_messages(history))
|
| 71 |
+
# append current user message if exists
|
| 72 |
+
if message is not None and message != "":
|
| 73 |
+
msgs.append({"role": "user", "content": str(message)})
|
| 74 |
+
return msgs
|
| 75 |
+
|
| 76 |
+
# Fallback (no args)
|
| 77 |
+
return [{"role": "user", "content": ""}]
|
| 78 |
+
|
| 79 |
+
# --------- main entry called by Gradio ChatInterface (see app.py)---------
|
| 80 |
+
def process_user_request(*args: Any) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Gradio will call with either (message, history) or (messages).
|
| 83 |
+
We normalize to messages and call Gemini to generate a reply.
|
| 84 |
+
"""
|
| 85 |
+
messages = _normalize_inputs(*args)
|
| 86 |
+
|
| 87 |
+
# System prompt:沿用 config 的強制工具/交叉驗證規則
|
| 88 |
+
system_prompt = (INSTRUCTIONS_START or "") + "\n" + (INSTRUCTIONS_END or "")
|
| 89 |
+
|
| 90 |
reply = GeminiProvider.generate(
|
| 91 |
+
model_name=MODEL, # e.g., "gemini-1.5-flash"
|
| 92 |
+
messages=messages, # normalized messages
|
| 93 |
+
system_prompt=system_prompt, # prepend system rules
|
| 94 |
+
max_tokens=MAX_TOKENS, # internal output cap
|
| 95 |
+
temperature=0.7,
|
| 96 |
)
|
| 97 |
return reply
|