File size: 2,671 Bytes
3ae84a3
 
 
f7ed480
650e271
3ae84a3
 
 
ff697e1
b8b9bf6
3ae84a3
 
45effd2
 
3ae84a3
 
45effd2
f7ed480
3ae84a3
 
0c9c8ed
e6fd470
b8b9bf6
f7ed480
0c9c8ed
e647581
44b4666
90bd6cb
ff697e1
45effd2
 
 
650e271
45effd2
 
 
 
 
650e271
45effd2
 
 
 
a92ca3d
3ae84a3
 
44b4666
45effd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44b4666
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia


# Only runs once when the script is first run.
with open("./insectarium_768.pickle", "rb") as handle:
    index = pickle.load(handle)

# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/vit-base-butterflies")
model = AutoModel.from_pretrained("sasha/vit-base-butterflies")

# Candidate images.
dataset = load_dataset("sasha/insectarium-butterflies")
ds = dataset["train"]


def query(image, top_k=4):
    inputs = feature_extractor(image, return_tensors="pt")
    model_output = model(**inputs)
    embedding = model_output.pooler_output.detach()
    results = index.query(embedding, k=top_k)
    inx = results[0][0].tolist()
    logits = results[1][0].tolist()
    images = ds.select(inx)["image"]
    captions = ds.select(inx)["label"]
    images_with_captions = [(i, c) for i, c in zip(images, captions)]
    labels_with_probs = dict(zip(captions, logits))
    labels_with_probs = {k: 1 - v for k, v in labels_with_probs.items()}
    try:
        description = wikipedia.summary(captions[0], sentences=1)
        description = "### " + description
        url = wikipedia.page(captions[0]).url
        url = " You can learn more about your butterfly [here](" + str(url) + ")!"
        description = description + url
    except:
        description = "### Butterflies are insects in the order Lepidoptera, which also includes moths. Adult butterflies have large, often brightly coloured wings."
        url = "https://en.wikipedia.org/wiki/Butterfly"
        url = " You can learn more about butterflies [here](" + str(url) + ")!"
        description = description + url
    return images_with_captions, labels_with_probs


with gr.Blocks() as demo:
    gr.Markdown("# Find my Butterfly 🦋")
    gr.Markdown(
        "## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!"
    )
    with gr.Row():
        with gr.Column(scale=1):
            inputs = gr.Image()
            btn = gr.Button("Find my butterfly!")
            description = gr.Markdown()

        with gr.Column(scale=2):
            outputs = gr.Gallery(rows=2)
            labels = gr.Label()

    gr.Markdown("### Image Examples")
    gr.Examples(
        examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
        inputs=inputs,
        outputs=[outputs, labels],
        fn=query,
        cache_examples=True,
    )
    btn.click(query, inputs, [outputs, labels, description])

demo.launch()