|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import marimo |
|
|
|
|
|
__generated_with = "0.11.14" |
|
|
app = marimo.App() |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo): |
|
|
mo.md("""### Fast labelling demo""") |
|
|
return |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo, use_default_switch): |
|
|
uploaded_file = mo.ui.file(kind="area") if not use_default_switch.value else None |
|
|
uploaded_file |
|
|
return (uploaded_file,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo): |
|
|
use_default_switch = mo.ui.switch(False, label="Use default dataset") |
|
|
use_default_switch |
|
|
return (use_default_switch,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo): |
|
|
pos_label = mo.ui.text("pos", placeholder="positive label name", label="positive class name") |
|
|
neg_label = mo.ui.text("neg", placeholder="negative label name", label="negative class name") |
|
|
return neg_label, pos_label |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(uploaded_file, use_default_switch): |
|
|
should_stop = not use_default_switch.value and len(uploaded_file.value) == 0 |
|
|
return (should_stop,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo, pl, should_stop, uploaded_file, use_default_switch): |
|
|
mo.stop(should_stop , mo.md("**Submit a dataset or use default one to continue.**")) |
|
|
|
|
|
if use_default_switch.value: |
|
|
df = pl.read_csv("spam.csv") |
|
|
else: |
|
|
df = pl.read_csv(uploaded_file.value[0].contents) |
|
|
|
|
|
texts = df["text"].to_list() |
|
|
return df, texts |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(StaticModel, mo): |
|
|
with mo.status.spinner(subtitle="Loading model ...") as _spinner: |
|
|
tfm = StaticModel.from_pretrained("minishlab/potion-retrieval-32M") |
|
|
return (tfm,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo, should_stop): |
|
|
mo.stop(should_stop) |
|
|
|
|
|
text_input = mo.ui.text_area("you will win a free ringtone!", label="Reference sentences") |
|
|
form = mo.md("""{text_input}""").batch(text_input=text_input).form() |
|
|
form |
|
|
return form, text_input |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo, texts, tfm): |
|
|
with mo.status.spinner(subtitle="Creating embeddings ...") as _spinner: |
|
|
X = tfm.encode(texts) |
|
|
return (X,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(add_label, get_example, mo, neg_label, pos_label, undo): |
|
|
btn_spam = mo.ui.button( |
|
|
label=f"Annotate {neg_label.value}", |
|
|
on_click=lambda d: add_label(get_example(), neg_label.value), |
|
|
keyboard_shortcut="Ctrl-L" |
|
|
) |
|
|
btn_ham = mo.ui.button( |
|
|
label=f"Annotate {pos_label.value}", |
|
|
on_click=lambda d: add_label(get_example(), pos_label.value), |
|
|
keyboard_shortcut="Ctrl-K" |
|
|
) |
|
|
btn_undo = mo.ui.button( |
|
|
label="Undo", |
|
|
on_click=lambda d: undo(), |
|
|
keyboard_shortcut="Ctrl-U" |
|
|
) |
|
|
return btn_ham, btn_spam, btn_undo |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(gen, get_label, set_example, set_label): |
|
|
def add_label(text, lab): |
|
|
current_labels = get_label() |
|
|
set_label(current_labels + [{"text": text, "label": lab}]) |
|
|
set_example(next(gen)) |
|
|
|
|
|
def undo(): |
|
|
current_labels = get_label() |
|
|
set_label(current_labels[:-2]) |
|
|
return add_label, undo |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(): |
|
|
from mohtml import br |
|
|
return (br,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(br, btn_ham, btn_spam, btn_undo, example, mo, neg_label, p, pos_label): |
|
|
mo.vstack([ |
|
|
mo.hstack([ |
|
|
pos_label, neg_label |
|
|
]), |
|
|
br(), |
|
|
mo.hstack([ |
|
|
btn_ham, btn_spam, btn_undo |
|
|
]), |
|
|
br(), |
|
|
p("Current example:", klass="font-bold"), |
|
|
example |
|
|
]) |
|
|
return |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo): |
|
|
get_label, set_label = mo.state([]) |
|
|
return get_label, set_label |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(gen, mo): |
|
|
get_example, set_example = mo.state(next(gen)) |
|
|
return get_example, set_example |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(): |
|
|
from mohtml import tailwind_css, div, p |
|
|
|
|
|
tailwind_css() |
|
|
return div, p, tailwind_css |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(get_label, mo): |
|
|
import json |
|
|
|
|
|
data = get_label() |
|
|
|
|
|
json_download = mo.download( |
|
|
data=json.dumps(data).encode("utf-8"), |
|
|
filename="data.json", |
|
|
mimetype="application/json", |
|
|
label="Download JSON", |
|
|
) |
|
|
return data, json, json_download |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(X, cosine_similarity, form, get_label, mo, pl, texts, tfm): |
|
|
mo.stop(not form.value, "Need a text input to fetch example") |
|
|
mo.stop(not form.value.get("text_input", None), "Need a text input to fetch example") |
|
|
|
|
|
df_emb = ( |
|
|
pl.DataFrame({ |
|
|
"index": range(X.shape[0]), |
|
|
"text": texts |
|
|
}).with_columns(sim=pl.lit(1)) |
|
|
) |
|
|
|
|
|
|
|
|
query = tfm.encode([form.value["text_input"]]) |
|
|
similarity = cosine_similarity(query, X)[0] |
|
|
df_emb = df_emb.with_columns(sim=similarity).sort(pl.col("sim"), descending=True) |
|
|
label_texts = [_["text"] for _ in get_label()] |
|
|
gen = (_["text"] for _ in df_emb.head(100).to_dicts() if _["text"] not in label_texts) |
|
|
return df_emb, gen, label_texts, query, similarity |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(div, get_example, p): |
|
|
example = div( |
|
|
p(get_example()), |
|
|
klass="bg-gray-100 p-4 rounded-lg" |
|
|
) |
|
|
return (example,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(get_label, mo, pl, should_stop): |
|
|
mo.stop(should_stop) |
|
|
|
|
|
pl.DataFrame(get_label()).reverse() |
|
|
return |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo): |
|
|
with mo.status.spinner(subtitle="Loading libraries ...") as _spinner: |
|
|
import polars as pl |
|
|
import numpy as np |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
return cosine_similarity, np, pl |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(mo): |
|
|
with mo.status.spinner(subtitle="Loading model2vec ...") as _spinner: |
|
|
from model2vec import StaticModel |
|
|
return (StaticModel,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(): |
|
|
import marimo as mo |
|
|
return (mo,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def _(): |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run() |
|
|
|