import os, json, re from typing import List, Tuple import numpy as np import gradio as gr import faiss from sentence_transformers import SentenceTransformer HF_TOKEN = os.getenv("HF_TOKEN") # ---------- Paths (expects files committed under ./assets) ---------- APP_DIR = os.path.dirname(__file__) ASSETS_DIR = os.path.join(APP_DIR, "assets") CACHE_DIR = "/mnt/data/eg_space_cache" # runtime cache os.makedirs(CACHE_DIR, exist_ok=True) CORPUS_JSON = os.path.join(ASSETS_DIR, "corpus.json") EMB_FP32 = os.path.join(ASSETS_DIR, "doc_embs_fp32.npy") EMB_FP16 = os.path.join(ASSETS_DIR, "doc_embs_fp16.npy") FAISS_MAIN = os.path.join(ASSETS_DIR, "faiss_ip_768.index") # ---------- Matryoshka dims ---------- MATRYOSHKA_DIMS = [768, 512, 256, 128] DEFAULT_DIMS = 768 # ---------- Load corpus ---------- with open(CORPUS_JSON, "r", encoding="utf-8") as f: corpus = json.load(f) # list of {"title","text"} in EXACT same order as embeddings # ---------- Load embeddings ---------- if os.path.exists(EMB_FP32): doc_embs = np.load(EMB_FP32).astype(np.float32, copy=False) elif os.path.exists(EMB_FP16): doc_embs = np.load(EMB_FP16).astype(np.float32) # cast back for FAISS else: raise FileNotFoundError("Expected assets/doc_embs_fp32.npy or assets/doc_embs_fp16.npy") if doc_embs.ndim != 2 or doc_embs.shape[0] != len(corpus): raise ValueError("Embeddings shape mismatch vs corpus length.") EMB_DIM = doc_embs.shape[1] # should be 768 # ---------- Model (for queries + sentence-level ops) ---------- model = SentenceTransformer("google/embeddinggemma-300m", token=HF_TOKEN) # CPU is fine for queries # ---------- FAISS indexes ---------- if os.path.exists(FAISS_MAIN): base_index_768 = faiss.read_index(FAISS_MAIN) else: base_index_768 = faiss.IndexFlatIP(EMB_DIM) base_index_768.add(doc_embs.astype(np.float32, copy=False)) # Build per-dimension flat IP indexes from the loaded embeddings class MultiDimFaiss: def __init__(self, doc_embs_full: np.ndarray): self.full = doc_embs_full self.indexes = {} for d in MATRYOSHKA_DIMS: if d == 768 and FAISS_MAIN and os.path.exists(FAISS_MAIN): self.indexes[d] = base_index_768 else: view = self.full[:, :d].astype(np.float32, copy=False) idx = faiss.IndexFlatIP(d) idx.add(view) self.indexes[d] = idx def search(self, q_vec: np.ndarray, top_k: int, dims: int) -> Tuple[np.ndarray, np.ndarray]: q = q_vec[:dims].astype(np.float32, copy=False)[None, :] idx = self.indexes[dims] return idx.search(q, top_k) faiss_md = MultiDimFaiss(doc_embs) # ---------- Core ops ---------- def _format_snippet(text: str, max_len: int = 380) -> str: return text[:max_len] + ("…" if len(text) > max_len else "") def do_search(query: str, top_k: int = 5, dims: int = DEFAULT_DIMS) -> List[List[str]]: if not query or not query.strip(): return [] q_emb = model.encode_query( query.strip(), normalize_embeddings=True, convert_to_numpy=True ) scores, idxs = faiss_md.search(q_emb, top_k=top_k, dims=dims) rows = [] for s, i in zip(scores[0].tolist(), idxs[0].tolist()): if i == -1: continue title = corpus[i]["title"] snippet = _format_snippet(corpus[i]["text"]) rows.append([f"{s:.4f}", title, snippet]) return rows def do_similarity(text_a: str, text_b: str, dims: int = DEFAULT_DIMS) -> float: if not text_a or not text_b: return 0.0 a = model.encode_document([text_a], normalize_embeddings=True, convert_to_numpy=True)[0][:dims] b = model.encode_document([text_b], normalize_embeddings=True, convert_to_numpy=True)[0][:dims] return float(np.dot(a, b)) # ---------- Gradio UI ---------- with gr.Blocks(title="EmbeddingGemma × Wikipedia (EN corpus)") as demo: gr.Markdown( """ # Demo: EmbeddingGemma × Wikipedia (EN corpus) This Space showcases [Google DeepMind’s EmbeddingGemma models](https://huggingface.co/collections/google/embeddinggemma-68b9ae3a72a82f0562a80dc4), on a pre-indexed **random 10k sample** of [English Wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia). You can try: - **Semantic search** (English queries) - **Cross-lingual search** (queries in other languages → English articles) - **Sentence similarity** (compare two texts) 🔗 Learn more in the [EmbeddingGemma blog post](https://huggingface.co/blog/embeddinggemma). """ ) with gr.Tabs(): # 1) Semantic Search (EN-only corpus) with gr.TabItem("Semantic Search (EN corpus)"): with gr.Row(): q = gr.Textbox(label="Query", value="Who discovered penicillin?") topk = gr.Slider(1, 20, value=5, step=1, label="Top-K") dims = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims") run = gr.Button("Search") out = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True) run.click(lambda query, k, d: do_search(query, int(k), int(d)), [q, topk, dims], out) # 2) Cross-Lingual (queries in FR/ES/etc → EN corpus) with gr.TabItem("Cross-Lingual (EN corpus)"): gr.Markdown("Type your query in **French/Spanish/Arabic**. Results come from the **English-only** corpus.") with gr.Row(): qx = gr.Textbox(label="Query", value="¿Quién descubrió la penicilina?") topkx = gr.Slider(1, 20, value=5, step=1, label="Top-K") dimsx = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims") runx = gr.Button("Search") outx = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True) runx.click(lambda query, k, d: do_search(query, int(k), int(d)), [qx, topkx, dimsx], outx) # 3) Similarity with gr.TabItem("Similarity"): with gr.Row(): a = gr.Textbox(lines=5, label="Text A", value="Alexander Fleming observed a mold that killed bacteria in 1928.") b = gr.Textbox(lines=5, label="Text B", value="La penicilina fue descubierta por Alexander Fleming en 1928.") dims2 = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims") sim_btn = gr.Button("Compute Similarity") sim_out = gr.Number(label="Cosine similarity (-1..1)") sim_btn.click(lambda x, y, d: do_similarity(x, y, int(d)), [a, b, dims2], sim_out) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)