""" utils.py """ # Standard imports import os from typing import List, Tuple # Third party imports import numpy as np from google import genai from openai import OpenAI from sentence_transformers import SentenceTransformer from transformers import AutoModel client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) # Maximum tokens for text-embedding-3-small MAX_TOKENS = 8191 # We don't have access to the tokenizer for text-embedding-3-small, and just assume 1 character = 1 token here def get_embeddings( texts: List[str], model: str = "text-embedding-3-large" ) -> List[List[float]]: """ Generate embeddings for a list of texts using OpenAI API synchronously. Args: texts: List of strings to embed. model: OpenAI embedding model to use (default: text-embedding-3-large). Returns: A list of embeddings (each embedding is a list of floats). Raises: Exception: If the OpenAI API call fails. """ # Truncate texts to max token limit truncated_texts = [text[:MAX_TOKENS] for text in texts] # Make the API call response = client.embeddings.create(input=truncated_texts, model=model) # Extract embeddings from response embeddings = np.array([data.embedding for data in response.data]) return embeddings MODEL_CONFIGS = { "lionguard-2": { "label": "LionGuard 2", "repo_id": "govtech/lionguard-2", "embedding_strategy": "openai", "embedding_model": "text-embedding-3-large", }, "lionguard-2-lite": { "label": "LionGuard 2 Lite", "repo_id": "govtech/lionguard-2-lite", "embedding_strategy": "sentence_transformer", "embedding_model": "google/embeddinggemma-300m", }, "lionguard-2.1": { "label": "LionGuard 2.1", "repo_id": "govtech/lionguard-2.1", "embedding_strategy": "gemini", "embedding_model": "gemini-embedding-001", }, } DEFAULT_MODEL_KEY = "lionguard-2.1" MODEL_CACHE = {} EMBEDDING_MODEL_CACHE = {} current_model_choice = DEFAULT_MODEL_KEY GEMINI_CLIENT = None def resolve_model_key(model_key: str = None) -> str: key = model_key or current_model_choice if key not in MODEL_CONFIGS: raise ValueError(f"Unknown model selection: {key}") return key def load_model_instance(model_key: str): key = resolve_model_key(model_key) if key not in MODEL_CACHE: repo_id = MODEL_CONFIGS[key]["repo_id"] MODEL_CACHE[key] = AutoModel.from_pretrained(repo_id, trust_remote_code=True) return MODEL_CACHE[key] def get_sentence_transformer(model_name: str): if model_name not in EMBEDDING_MODEL_CACHE: EMBEDDING_MODEL_CACHE[model_name] = SentenceTransformer(model_name) return EMBEDDING_MODEL_CACHE[model_name] def get_gemini_client(): global GEMINI_CLIENT if GEMINI_CLIENT is None: api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise EnvironmentError( "GEMINI_API_KEY environment variable is required for LionGuard 2.1." ) GEMINI_CLIENT = genai.Client(api_key=api_key) return GEMINI_CLIENT def get_model_embeddings(model_key: str, texts: List[str]) -> np.ndarray: key = resolve_model_key(model_key) config = MODEL_CONFIGS[key] strategy = config["embedding_strategy"] model_name = config.get("embedding_model") if strategy == "openai": return get_embeddings(texts, model=model_name) if strategy == "sentence_transformer": embedder = get_sentence_transformer(model_name) formatted_texts = [f"task: classification | query: {text}" for text in texts] embeddings = embedder.encode(formatted_texts) return np.array(embeddings) if strategy == "gemini": client = get_gemini_client() result = client.models.embed_content(model=model_name, contents=texts) return np.array([embedding.values for embedding in result.embeddings]) raise ValueError(f"Unsupported embedding strategy: {strategy}") def predict_with_model(texts: List[str], model_key: str = None) -> Tuple[dict, str]: key = resolve_model_key(model_key) embeddings = get_model_embeddings(key, texts) model = load_model_instance(key) return model.predict(embeddings), key def set_active_model(model_key: str) -> str: if model_key not in MODEL_CONFIGS: return f"⚠️ Unknown model {model_key}" global current_model_choice current_model_choice = model_key load_model_instance(model_key) label = MODEL_CONFIGS[model_key]["label"] return f"🦁 Using {label} ({model_key})"