Spaces:
Running
Running
| """ | |
| utils.py | |
| """ | |
| # Standard imports | |
| import os | |
| from typing import List, Tuple | |
| # Third party imports | |
| import numpy as np | |
| from google import genai | |
| from openai import OpenAI | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModel | |
| client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| # Maximum tokens for text-embedding-3-small | |
| MAX_TOKENS = 8191 # We don't have access to the tokenizer for text-embedding-3-small, and just assume 1 character = 1 token here | |
| def get_embeddings( | |
| texts: List[str], model: str = "text-embedding-3-large" | |
| ) -> List[List[float]]: | |
| """ | |
| Generate embeddings for a list of texts using OpenAI API synchronously. | |
| Args: | |
| texts: List of strings to embed. | |
| model: OpenAI embedding model to use (default: text-embedding-3-large). | |
| Returns: | |
| A list of embeddings (each embedding is a list of floats). | |
| Raises: | |
| Exception: If the OpenAI API call fails. | |
| """ | |
| # Truncate texts to max token limit | |
| truncated_texts = [text[:MAX_TOKENS] for text in texts] | |
| # Make the API call | |
| response = client.embeddings.create(input=truncated_texts, model=model) | |
| # Extract embeddings from response | |
| embeddings = np.array([data.embedding for data in response.data]) | |
| return embeddings | |
| MODEL_CONFIGS = { | |
| "lionguard-2": { | |
| "label": "LionGuard 2", | |
| "repo_id": "govtech/lionguard-2", | |
| "embedding_strategy": "openai", | |
| "embedding_model": "text-embedding-3-large", | |
| }, | |
| "lionguard-2-lite": { | |
| "label": "LionGuard 2 Lite", | |
| "repo_id": "govtech/lionguard-2-lite", | |
| "embedding_strategy": "sentence_transformer", | |
| "embedding_model": "google/embeddinggemma-300m", | |
| }, | |
| "lionguard-2.1": { | |
| "label": "LionGuard 2.1", | |
| "repo_id": "govtech/lionguard-2.1", | |
| "embedding_strategy": "gemini", | |
| "embedding_model": "gemini-embedding-001", | |
| }, | |
| } | |
| DEFAULT_MODEL_KEY = "lionguard-2.1" | |
| MODEL_CACHE = {} | |
| EMBEDDING_MODEL_CACHE = {} | |
| current_model_choice = DEFAULT_MODEL_KEY | |
| GEMINI_CLIENT = None | |
| def resolve_model_key(model_key: str = None) -> str: | |
| key = model_key or current_model_choice | |
| if key not in MODEL_CONFIGS: | |
| raise ValueError(f"Unknown model selection: {key}") | |
| return key | |
| def load_model_instance(model_key: str): | |
| key = resolve_model_key(model_key) | |
| if key not in MODEL_CACHE: | |
| repo_id = MODEL_CONFIGS[key]["repo_id"] | |
| MODEL_CACHE[key] = AutoModel.from_pretrained(repo_id, trust_remote_code=True) | |
| return MODEL_CACHE[key] | |
| def get_sentence_transformer(model_name: str): | |
| if model_name not in EMBEDDING_MODEL_CACHE: | |
| EMBEDDING_MODEL_CACHE[model_name] = SentenceTransformer(model_name) | |
| return EMBEDDING_MODEL_CACHE[model_name] | |
| def get_gemini_client(): | |
| global GEMINI_CLIENT | |
| if GEMINI_CLIENT is None: | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise EnvironmentError( | |
| "GEMINI_API_KEY environment variable is required for LionGuard 2.1." | |
| ) | |
| GEMINI_CLIENT = genai.Client(api_key=api_key) | |
| return GEMINI_CLIENT | |
| def get_model_embeddings(model_key: str, texts: List[str]) -> np.ndarray: | |
| key = resolve_model_key(model_key) | |
| config = MODEL_CONFIGS[key] | |
| strategy = config["embedding_strategy"] | |
| model_name = config.get("embedding_model") | |
| if strategy == "openai": | |
| return get_embeddings(texts, model=model_name) | |
| if strategy == "sentence_transformer": | |
| embedder = get_sentence_transformer(model_name) | |
| formatted_texts = [f"task: classification | query: {text}" for text in texts] | |
| embeddings = embedder.encode(formatted_texts) | |
| return np.array(embeddings) | |
| if strategy == "gemini": | |
| client = get_gemini_client() | |
| result = client.models.embed_content(model=model_name, contents=texts) | |
| return np.array([embedding.values for embedding in result.embeddings]) | |
| raise ValueError(f"Unsupported embedding strategy: {strategy}") | |
| def predict_with_model(texts: List[str], model_key: str = None) -> Tuple[dict, str]: | |
| key = resolve_model_key(model_key) | |
| embeddings = get_model_embeddings(key, texts) | |
| model = load_model_instance(key) | |
| return model.predict(embeddings), key | |
| def set_active_model(model_key: str) -> str: | |
| if model_key not in MODEL_CONFIGS: | |
| return f"⚠️ Unknown model {model_key}" | |
| global current_model_choice | |
| current_model_choice = model_key | |
| load_model_instance(model_key) | |
| label = MODEL_CONFIGS[model_key]["label"] | |
| return f"🦁 Using {label} ({model_key})" | |