lionguard-demo / utils.py
gabrielchua's picture
update-demo (#1)
833aed9 verified
"""
utils.py
"""
# Standard imports
import os
from typing import List, Tuple
# Third party imports
import numpy as np
from google import genai
from openai import OpenAI
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Maximum tokens for text-embedding-3-small
MAX_TOKENS = 8191 # We don't have access to the tokenizer for text-embedding-3-small, and just assume 1 character = 1 token here
def get_embeddings(
texts: List[str], model: str = "text-embedding-3-large"
) -> List[List[float]]:
"""
Generate embeddings for a list of texts using OpenAI API synchronously.
Args:
texts: List of strings to embed.
model: OpenAI embedding model to use (default: text-embedding-3-large).
Returns:
A list of embeddings (each embedding is a list of floats).
Raises:
Exception: If the OpenAI API call fails.
"""
# Truncate texts to max token limit
truncated_texts = [text[:MAX_TOKENS] for text in texts]
# Make the API call
response = client.embeddings.create(input=truncated_texts, model=model)
# Extract embeddings from response
embeddings = np.array([data.embedding for data in response.data])
return embeddings
MODEL_CONFIGS = {
"lionguard-2": {
"label": "LionGuard 2",
"repo_id": "govtech/lionguard-2",
"embedding_strategy": "openai",
"embedding_model": "text-embedding-3-large",
},
"lionguard-2-lite": {
"label": "LionGuard 2 Lite",
"repo_id": "govtech/lionguard-2-lite",
"embedding_strategy": "sentence_transformer",
"embedding_model": "google/embeddinggemma-300m",
},
"lionguard-2.1": {
"label": "LionGuard 2.1",
"repo_id": "govtech/lionguard-2.1",
"embedding_strategy": "gemini",
"embedding_model": "gemini-embedding-001",
},
}
DEFAULT_MODEL_KEY = "lionguard-2.1"
MODEL_CACHE = {}
EMBEDDING_MODEL_CACHE = {}
current_model_choice = DEFAULT_MODEL_KEY
GEMINI_CLIENT = None
def resolve_model_key(model_key: str = None) -> str:
key = model_key or current_model_choice
if key not in MODEL_CONFIGS:
raise ValueError(f"Unknown model selection: {key}")
return key
def load_model_instance(model_key: str):
key = resolve_model_key(model_key)
if key not in MODEL_CACHE:
repo_id = MODEL_CONFIGS[key]["repo_id"]
MODEL_CACHE[key] = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
return MODEL_CACHE[key]
def get_sentence_transformer(model_name: str):
if model_name not in EMBEDDING_MODEL_CACHE:
EMBEDDING_MODEL_CACHE[model_name] = SentenceTransformer(model_name)
return EMBEDDING_MODEL_CACHE[model_name]
def get_gemini_client():
global GEMINI_CLIENT
if GEMINI_CLIENT is None:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise EnvironmentError(
"GEMINI_API_KEY environment variable is required for LionGuard 2.1."
)
GEMINI_CLIENT = genai.Client(api_key=api_key)
return GEMINI_CLIENT
def get_model_embeddings(model_key: str, texts: List[str]) -> np.ndarray:
key = resolve_model_key(model_key)
config = MODEL_CONFIGS[key]
strategy = config["embedding_strategy"]
model_name = config.get("embedding_model")
if strategy == "openai":
return get_embeddings(texts, model=model_name)
if strategy == "sentence_transformer":
embedder = get_sentence_transformer(model_name)
formatted_texts = [f"task: classification | query: {text}" for text in texts]
embeddings = embedder.encode(formatted_texts)
return np.array(embeddings)
if strategy == "gemini":
client = get_gemini_client()
result = client.models.embed_content(model=model_name, contents=texts)
return np.array([embedding.values for embedding in result.embeddings])
raise ValueError(f"Unsupported embedding strategy: {strategy}")
def predict_with_model(texts: List[str], model_key: str = None) -> Tuple[dict, str]:
key = resolve_model_key(model_key)
embeddings = get_model_embeddings(key, texts)
model = load_model_instance(key)
return model.predict(embeddings), key
def set_active_model(model_key: str) -> str:
if model_key not in MODEL_CONFIGS:
return f"⚠️ Unknown model {model_key}"
global current_model_choice
current_model_choice = model_key
load_model_instance(model_key)
label = MODEL_CONFIGS[model_key]["label"]
return f"🦁 Using {label} ({model_key})"