|
|
|
|
|
import os |
|
|
from typing import List, Dict, Optional |
|
|
import google.generativeai as genai |
|
|
|
|
|
class GeminiProvider: |
|
|
_inited = False |
|
|
|
|
|
@classmethod |
|
|
def _ensure_init(cls): |
|
|
if cls._inited: |
|
|
return |
|
|
api_key = os.getenv("GEMINI_API_KEY") |
|
|
if not api_key: |
|
|
raise RuntimeError("Missing GEMINI_API_KEY (set it in Spaces Secrets).") |
|
|
genai.configure(api_key=api_key) |
|
|
cls._inited = True |
|
|
|
|
|
@staticmethod |
|
|
def _messages_to_history(messages: List[Dict[str, str]], system_prompt: Optional[str]): |
|
|
history = [] |
|
|
if system_prompt: |
|
|
history.append({"role": "user", "parts": system_prompt}) |
|
|
for m in messages: |
|
|
role = m.get("role", "user") |
|
|
content = m.get("content", "") |
|
|
if role == "assistant": |
|
|
history.append({"role": "model", "parts": content}) |
|
|
else: |
|
|
history.append({"role": "user", "parts": content}) |
|
|
return history |
|
|
|
|
|
@classmethod |
|
|
def generate( |
|
|
cls, |
|
|
model_name: str, |
|
|
messages: List[Dict[str, str]], |
|
|
system_prompt: Optional[str] = None, |
|
|
max_tokens: int = 8192, |
|
|
temperature: float = 0.7, |
|
|
) -> str: |
|
|
cls._ensure_init() |
|
|
model = genai.GenerativeModel(model_name) |
|
|
history = cls._messages_to_history(messages, system_prompt) |
|
|
chat = model.start_chat(history=history[:-1] if len(history) > 1 else []) |
|
|
user_turn = history[-1]["parts"] if history else "" |
|
|
resp = chat.send_message( |
|
|
user_turn, |
|
|
generation_config=genai.types.GenerationConfig( |
|
|
max_output_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
), |
|
|
) |
|
|
return resp.text or "" |