Spaces:
Running
on
Zero
Running
on
Zero
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +7 -0
- README.md +3 -0
- app.py +400 -0
- data/image_embeddings/American_Goldfinch_0123_32505.jpg.pt +3 -0
- data/image_embeddings/Black_Tern_0101_144331.jpg.pt +3 -0
- data/image_embeddings/Brandt_Cormorant_0040_23144.jpg.pt +3 -0
- data/image_embeddings/Brown_Thrasher_0014_155421.jpg.pt +3 -0
- data/image_embeddings/Carolina_Wren_0060_186296.jpg.pt +3 -0
- data/image_embeddings/Cedar_Waxwing_0075_179114.jpg.pt +3 -0
- data/image_embeddings/Clark_Nutcracker_0126_85134.jpg.pt +3 -0
- data/image_embeddings/Gray_Catbird_0071_20974.jpg.pt +3 -0
- data/image_embeddings/Heermann_Gull_0097_45783.jpg.pt +3 -0
- data/image_embeddings/House_Wren_0137_187273.jpg.pt +3 -0
- data/image_embeddings/Ivory_Gull_0004_49019.jpg.pt +3 -0
- data/image_embeddings/Northern_Waterthrush_0038_177027.jpg.pt +3 -0
- data/image_embeddings/Pine_Warbler_0113_172456.jpg.pt +3 -0
- data/image_embeddings/Red_Headed_Woodpecker_0032_182815.jpg.pt +3 -0
- data/image_embeddings/Rufous_Hummingbird_0076_59563.jpg.pt +3 -0
- data/image_embeddings/Sage_Thrasher_0062_796462.jpg.pt +3 -0
- data/image_embeddings/Vesper_Sparrow_0030_125663.jpg.pt +3 -0
- data/image_embeddings/Western_Grebe_0064_36613.jpg.pt +3 -0
- data/image_embeddings/White_Eyed_Vireo_0046_158849.jpg.pt +3 -0
- data/image_embeddings/Winter_Wren_0048_189683.jpg.pt +3 -0
- data/images/boxes/American_Goldfinch_0123_32505_all.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_back.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_beak.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_belly.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_breast.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_crown.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_eyes.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_forehead.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_legs.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_nape.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_tail.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_throat.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_visible.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_wings.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_all.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_back.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_beak.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_belly.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_breast.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_crown.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_eyes.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_forehead.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_legs.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_nape.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_tail.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_throat.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_visible.jpg +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
temp*
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# python temp files
|
| 5 |
+
__pycache__
|
| 6 |
+
*.pyc
|
| 7 |
+
.vscode
|
README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import base64
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from plots import get_pre_define_colors
|
| 14 |
+
from utils.load_model import load_xclip
|
| 15 |
+
from utils.predict import xclip_pred
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DEVICE = "cpu"
|
| 19 |
+
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
|
| 20 |
+
XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
|
| 21 |
+
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
|
| 22 |
+
PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
|
| 23 |
+
IMAGES_FOLDER = "data/images"
|
| 24 |
+
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
| 25 |
+
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
| 26 |
+
|
| 27 |
+
# get the intersection of sachit and xclip (revised)
|
| 28 |
+
# INTERSECTION = []
|
| 29 |
+
# IMAGE_RES = 400 * 400 # minimum resolution
|
| 30 |
+
# TOTAL_SAMPLES = 20
|
| 31 |
+
# for file_name in XCLIP_RESULTS:
|
| 32 |
+
# image = Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB')
|
| 33 |
+
# w, h = image.size
|
| 34 |
+
# if w * h < IMAGE_RES:
|
| 35 |
+
# continue
|
| 36 |
+
# else:
|
| 37 |
+
# INTERSECTION.append(file_name)
|
| 38 |
+
|
| 39 |
+
# IMAGE_FILE_LIST = random.sample(INTERSECTION, TOTAL_SAMPLES)
|
| 40 |
+
IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r"))
|
| 41 |
+
# IMAGE_FILE_LIST = IMAGE_FILE_LIST[:19]
|
| 42 |
+
# IMAGE_FILE_LIST.append('Eastern_Bluebird.jpg')
|
| 43 |
+
IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST]
|
| 44 |
+
|
| 45 |
+
ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
|
| 46 |
+
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']
|
| 47 |
+
COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10'])
|
| 48 |
+
SACHIT_COLOR = "#ADD8E6"
|
| 49 |
+
# CUB_BOXES = json.load(open("data/jsons/cub_boxes_owlvit_large.json", "r"))
|
| 50 |
+
VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r'))
|
| 51 |
+
VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12))
|
| 52 |
+
|
| 53 |
+
# --- Image related functions ---
|
| 54 |
+
def img_to_base64(img):
|
| 55 |
+
img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img
|
| 56 |
+
buffered = io.BytesIO()
|
| 57 |
+
img_pil.save(buffered, format="JPEG")
|
| 58 |
+
img_str = base64.b64encode(buffered.getvalue())
|
| 59 |
+
return img_str.decode()
|
| 60 |
+
|
| 61 |
+
def create_blank_image(width=500, height=500, color=(255, 255, 255)):
|
| 62 |
+
"""Create a blank image of the given size and color."""
|
| 63 |
+
return np.array(Image.new("RGB", (width, height), color))
|
| 64 |
+
|
| 65 |
+
# Convert RGB colors to hex
|
| 66 |
+
def rgb_to_hex(rgb):
|
| 67 |
+
return f"#{''.join(f'{x:02x}' for x in rgb)}"
|
| 68 |
+
|
| 69 |
+
def load_part_images(file_name: str) -> dict:
|
| 70 |
+
part_images = {}
|
| 71 |
+
# start_time = time.time()
|
| 72 |
+
for part_name in ORDERED_PARTS:
|
| 73 |
+
base_name = Path(file_name).stem
|
| 74 |
+
part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg")
|
| 75 |
+
if not Path(part_image_path).exists():
|
| 76 |
+
continue
|
| 77 |
+
image = np.array(Image.open(part_image_path))
|
| 78 |
+
part_images[part_name] = img_to_base64(image)
|
| 79 |
+
# print(f"Time cost to load 12 images: {time.time() - start_time}")
|
| 80 |
+
# This takes less than 0.01 seconds. So the loading time is not the bottleneck.
|
| 81 |
+
return part_images
|
| 82 |
+
|
| 83 |
+
def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))):
|
| 84 |
+
"""
|
| 85 |
+
The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name'
|
| 86 |
+
descriptions: {part_name1: desc_1, part_name2: desc_2, ...}
|
| 87 |
+
pred_scores: {part_name1: score_1, part_name2: score_2, ...}
|
| 88 |
+
file_name: str
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
descriptions = result_dict['descriptions']
|
| 92 |
+
image_name = result_dict['file_name']
|
| 93 |
+
part_images = PART_IMAGES_DICT[image_name]
|
| 94 |
+
MAX_LENGTH = 50
|
| 95 |
+
exp_length = 400
|
| 96 |
+
fontsize = 15
|
| 97 |
+
|
| 98 |
+
# Start the SVG inside a div
|
| 99 |
+
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
|
| 100 |
+
"<svg width=\"100%\" height=\"100%\">"]
|
| 101 |
+
|
| 102 |
+
# Add a row for each visible bird part
|
| 103 |
+
y_offset = 0
|
| 104 |
+
for part in ORDERED_PARTS:
|
| 105 |
+
if visibility[part] and part_mask[part]:
|
| 106 |
+
# Calculate the length of the bar (scaled to fit within the SVG)
|
| 107 |
+
part_score = max(result_dict['pred_scores'][part], 0)
|
| 108 |
+
bar_length = part_score * exp_length
|
| 109 |
+
|
| 110 |
+
# Modify the overlay image's opacity on mouseover and mouseout
|
| 111 |
+
mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;"
|
| 112 |
+
mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;"
|
| 113 |
+
|
| 114 |
+
combined_mouseover = f"javascript: {mouseover_action1};"
|
| 115 |
+
combined_mouseout = f"javascript: {mouseout_action1};"
|
| 116 |
+
|
| 117 |
+
# Add the description
|
| 118 |
+
num_lines = len(descriptions[part]) // MAX_LENGTH + 1
|
| 119 |
+
for line in range(num_lines):
|
| 120 |
+
desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH]
|
| 121 |
+
y_offset += fontsize
|
| 122 |
+
svg_parts.append(f"""
|
| 123 |
+
<text x="0" y="{y_offset}" font-size="{fontsize}"
|
| 124 |
+
onmouseover="{combined_mouseover}"
|
| 125 |
+
onmouseout="{combined_mouseout}">
|
| 126 |
+
{desc_line}
|
| 127 |
+
</text>
|
| 128 |
+
""")
|
| 129 |
+
|
| 130 |
+
# Add the bars
|
| 131 |
+
svg_parts.append(f"""
|
| 132 |
+
<rect x="0" y="{y_offset +3}" width="{bar_length}" height="{fontsize*0.7}" fill="{PART_COLORS[part]}"
|
| 133 |
+
onmouseover="{combined_mouseover}"
|
| 134 |
+
onmouseout="{combined_mouseout}">
|
| 135 |
+
</rect>
|
| 136 |
+
""")
|
| 137 |
+
# Add the scores
|
| 138 |
+
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="{fontsize}" fill="{PART_COLORS[part]}">{part_score:.2f}</text>')
|
| 139 |
+
|
| 140 |
+
y_offset += fontsize + 3
|
| 141 |
+
svg_parts.extend(("</svg>", "</div>"))
|
| 142 |
+
# Join everything into a single string
|
| 143 |
+
html = "".join(svg_parts)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
return html
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def generate_sachit_explanations(result_dict:dict):
|
| 151 |
+
descriptions = result_dict['descriptions']
|
| 152 |
+
scores = result_dict['scores']
|
| 153 |
+
MAX_LENGTH = 50
|
| 154 |
+
exp_length = 400
|
| 155 |
+
fontsize = 15
|
| 156 |
+
|
| 157 |
+
descriptions = zip(scores, descriptions)
|
| 158 |
+
descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True)
|
| 159 |
+
|
| 160 |
+
# Start the SVG inside a div
|
| 161 |
+
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
|
| 162 |
+
"<svg width=\"100%\" height=\"100%\">"]
|
| 163 |
+
|
| 164 |
+
# Add a row for each visible bird part
|
| 165 |
+
y_offset = 0
|
| 166 |
+
for score, desc in descriptions:
|
| 167 |
+
|
| 168 |
+
# Calculate the length of the bar (scaled to fit within the SVG)
|
| 169 |
+
part_score = max(score, 0)
|
| 170 |
+
bar_length = part_score * exp_length
|
| 171 |
+
|
| 172 |
+
# Split the description into two lines if it's too long
|
| 173 |
+
num_lines = len(desc) // MAX_LENGTH + 1
|
| 174 |
+
for line in range(num_lines):
|
| 175 |
+
desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH]
|
| 176 |
+
y_offset += fontsize
|
| 177 |
+
svg_parts.append(f"""
|
| 178 |
+
<text x="0" y="{y_offset}" font-size="{fontsize}" fill="black">
|
| 179 |
+
{desc_line}
|
| 180 |
+
</text>
|
| 181 |
+
""")
|
| 182 |
+
|
| 183 |
+
# Add the bar
|
| 184 |
+
svg_parts.append(f"""
|
| 185 |
+
<rect x="0" y="{y_offset+3}" width="{bar_length}" height="{fontsize*0.7}" fill="{SACHIT_COLOR}">
|
| 186 |
+
</rect>
|
| 187 |
+
""")
|
| 188 |
+
|
| 189 |
+
# Add the score
|
| 190 |
+
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="fontsize" fill="{SACHIT_COLOR}">{part_score:.2f}</text>') # Added fill color
|
| 191 |
+
|
| 192 |
+
y_offset += fontsize + 3
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
svg_parts.extend(("</svg>", "</div>"))
|
| 196 |
+
# Join everything into a single string
|
| 197 |
+
html = "".join(svg_parts)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
return html
|
| 201 |
+
|
| 202 |
+
# --- Constants created by the functions above ---
|
| 203 |
+
BLANK_OVERLAY = img_to_base64(create_blank_image())
|
| 204 |
+
PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)}
|
| 205 |
+
blank_image = np.array(Image.open('data/images/final.png').convert('RGB'))
|
| 206 |
+
PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST}
|
| 207 |
+
|
| 208 |
+
# --- Gradio Functions ---
|
| 209 |
+
def update_selected_image(event: gr.SelectData):
|
| 210 |
+
image_height = 400
|
| 211 |
+
index = event.index
|
| 212 |
+
|
| 213 |
+
image_name = IMAGE_FILE_LIST[index]
|
| 214 |
+
current_image.state = image_name
|
| 215 |
+
org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB')
|
| 216 |
+
img_base64 = f"""
|
| 217 |
+
<div style="position: relative; height: {image_height}px; display: inline-block;">
|
| 218 |
+
<img id="birdImage" src="data:image/jpeg;base64,{img_to_base64(org_image)}" style="height: {image_height}px; width: auto;">
|
| 219 |
+
<img id="overlayImage" src="data:image/jpeg;base64,{BLANK_OVERLAY}" style="position:absolute; top:0; left:0; width:auto; height: {image_height}px; opacity: 0;">
|
| 220 |
+
</div>
|
| 221 |
+
"""
|
| 222 |
+
gt_label = XCLIP_RESULTS[image_name]['ground_truth']
|
| 223 |
+
gt_class.state = gt_label
|
| 224 |
+
|
| 225 |
+
# --- for initial value only ---
|
| 226 |
+
out_dict = xclip_pred(new_desc=None, new_part_mask=None, new_class=None, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
|
| 227 |
+
xclip_label = out_dict['pred_class']
|
| 228 |
+
clip_pred_scores = out_dict['pred_score']
|
| 229 |
+
xclip_part_scores = out_dict['pred_desc_scores']
|
| 230 |
+
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
|
| 231 |
+
xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12)))
|
| 232 |
+
# --- end of intial value ---
|
| 233 |
+
|
| 234 |
+
xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red"
|
| 235 |
+
xclip_pred_markdown = f"""
|
| 236 |
+
### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {clip_pred_scores:.4f}</span>
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
gt_label = f"""
|
| 240 |
+
## {gt_label}
|
| 241 |
+
"""
|
| 242 |
+
current_predicted_class.state = xclip_label
|
| 243 |
+
|
| 244 |
+
# Populate the textbox with current descriptions
|
| 245 |
+
custom_class_name = "class name: custom"
|
| 246 |
+
descs = XCLIP_DESC[xclip_label]
|
| 247 |
+
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
|
| 248 |
+
descs = {k: descs[k] for k in ORDERED_PARTS}
|
| 249 |
+
custom_text = [custom_class_name] + list(descs.values())
|
| 250 |
+
descriptions = ";\n".join(custom_text)
|
| 251 |
+
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
| 252 |
+
# modified_exp = gr.HTML().update(value="", visible=True)
|
| 253 |
+
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
|
| 254 |
+
|
| 255 |
+
def on_edit_button_click_xclip():
|
| 256 |
+
empty_exp = gr.HTML.update(visible=False)
|
| 257 |
+
|
| 258 |
+
# Populate the textbox with current descriptions
|
| 259 |
+
descs = XCLIP_DESC[current_predicted_class.state]
|
| 260 |
+
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
|
| 261 |
+
descs = {k: descs[k] for k in ORDERED_PARTS}
|
| 262 |
+
custom_text = ["class name: custom"] + list(descs.values())
|
| 263 |
+
descriptions = ";\n".join(custom_text)
|
| 264 |
+
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
| 265 |
+
|
| 266 |
+
return textbox, empty_exp
|
| 267 |
+
|
| 268 |
+
def convert_input_text_to_xclip_format(textbox_input: str):
|
| 269 |
+
|
| 270 |
+
# Split the descriptions by newline to get individual descriptions for each part
|
| 271 |
+
descriptions_list = textbox_input.split(";\n")
|
| 272 |
+
# the first line should be "class name: xxx"
|
| 273 |
+
class_name_line = descriptions_list[0]
|
| 274 |
+
new_class_name = class_name_line.split(":")[1].strip()
|
| 275 |
+
|
| 276 |
+
descriptions_list = descriptions_list[1:]
|
| 277 |
+
|
| 278 |
+
# construct descripion dict with part name as key
|
| 279 |
+
descriptions_dict = {}
|
| 280 |
+
for desc in descriptions_list:
|
| 281 |
+
if desc.strip() == "":
|
| 282 |
+
continue
|
| 283 |
+
part_name, _ = desc.split(":")
|
| 284 |
+
descriptions_dict[part_name.strip()] = desc
|
| 285 |
+
# fill with empty string if the part is not in the descriptions
|
| 286 |
+
part_mask = {}
|
| 287 |
+
for part in ORDERED_PARTS:
|
| 288 |
+
if part not in descriptions_dict:
|
| 289 |
+
descriptions_dict[part] = ""
|
| 290 |
+
part_mask[part] = 0
|
| 291 |
+
else:
|
| 292 |
+
part_mask[part] = 1
|
| 293 |
+
return descriptions_dict, part_mask, new_class_name
|
| 294 |
+
|
| 295 |
+
def on_predict_button_click_xclip(textbox_input: str):
|
| 296 |
+
descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
|
| 297 |
+
|
| 298 |
+
# Get the new predictions and explanations
|
| 299 |
+
out_dict = xclip_pred(new_desc=descriptions_dict, new_part_mask=part_mask, new_class=new_class_name, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
|
| 300 |
+
xclip_label = out_dict['pred_class']
|
| 301 |
+
xclip_pred_score = out_dict['pred_score']
|
| 302 |
+
xclip_part_scores = out_dict['pred_desc_scores']
|
| 303 |
+
custom_label = out_dict['modified_class']
|
| 304 |
+
custom_pred_score = out_dict['modified_score']
|
| 305 |
+
custom_part_scores = out_dict['modified_desc_scores']
|
| 306 |
+
|
| 307 |
+
# construct a result dict to generate xclip explanations
|
| 308 |
+
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
|
| 309 |
+
xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask)
|
| 310 |
+
modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state}
|
| 311 |
+
modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask)
|
| 312 |
+
|
| 313 |
+
xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red"
|
| 314 |
+
xclip_pred_markdown = f"""
|
| 315 |
+
### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {xclip_pred_score:.4f}</span>
|
| 316 |
+
"""
|
| 317 |
+
custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red"
|
| 318 |
+
custom_pred_markdown = f"""
|
| 319 |
+
### <span style='color:{custom_color}'>XCLIP: {custom_label} {custom_pred_score:.4f}</span>
|
| 320 |
+
"""
|
| 321 |
+
textbox = gr.Textbox.update(visible=False)
|
| 322 |
+
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
|
| 323 |
+
|
| 324 |
+
modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
|
| 325 |
+
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
custom_css = """
|
| 329 |
+
html, body {
|
| 330 |
+
margin: 0;
|
| 331 |
+
padding: 0;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
#container {
|
| 335 |
+
position: relative;
|
| 336 |
+
width: 400px;
|
| 337 |
+
height: 400px;
|
| 338 |
+
border: 1px solid #000;
|
| 339 |
+
margin: 0 auto; /* This will center the container horizontally */
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
#canvas {
|
| 343 |
+
position: absolute;
|
| 344 |
+
top: 0;
|
| 345 |
+
left: 0;
|
| 346 |
+
width: 100%;
|
| 347 |
+
height: 100%;
|
| 348 |
+
object-fit: cover;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
# Define the Gradio interface
|
| 354 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
|
| 355 |
+
current_image = gr.State("")
|
| 356 |
+
current_predicted_class = gr.State("")
|
| 357 |
+
gt_class = gr.State("")
|
| 358 |
+
|
| 359 |
+
with gr.Column():
|
| 360 |
+
title_text = gr.Markdown("# PEEB - demo")
|
| 361 |
+
gr.Markdown(
|
| 362 |
+
"- In this demo, you can edit the descriptions of a class and see how to model react to it."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# display the gallery of images
|
| 366 |
+
with gr.Column():
|
| 367 |
+
|
| 368 |
+
gr.Markdown("## Select an image to start!")
|
| 369 |
+
image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250)
|
| 370 |
+
gr.Markdown("### Custom descritions: \n The first row should be **class name: {some name};**, where you can name your descriptions. \n For the remianing descriptions, please use **;** to separate the descriptions for each part, and use the format **{part name}: {descriptions}**. \n Note that you can delete a part completely, in such cases, all descriptions will remove the corresponding part.")
|
| 371 |
+
|
| 372 |
+
with gr.Row():
|
| 373 |
+
with gr.Column():
|
| 374 |
+
image_label = gr.Markdown("### Class Name")
|
| 375 |
+
org_image = gr.HTML()
|
| 376 |
+
|
| 377 |
+
with gr.Column():
|
| 378 |
+
with gr.Row():
|
| 379 |
+
# xclip_predict_button = gr.Button(label="Predict", value="Predict")
|
| 380 |
+
xclip_predict_button = gr.Button(value="Predict")
|
| 381 |
+
xclip_pred_label = gr.Markdown("### XCLIP:")
|
| 382 |
+
xclip_explanation = gr.HTML()
|
| 383 |
+
|
| 384 |
+
with gr.Column():
|
| 385 |
+
# xclip_edit_button = gr.Button(label="Edit", value="Reset Descriptions")
|
| 386 |
+
xclip_edit_button = gr.Button(value="Reset Descriptions")
|
| 387 |
+
custom_pred_label = gr.Markdown(
|
| 388 |
+
"### Custom Descritpions:"
|
| 389 |
+
)
|
| 390 |
+
xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False)
|
| 391 |
+
# ai_explanation = gr.Image(type="numpy", visible=True, show_label=False, height=500)
|
| 392 |
+
custom_explanation = gr.HTML()
|
| 393 |
+
|
| 394 |
+
gr.HTML("<br>")
|
| 395 |
+
|
| 396 |
+
image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox])
|
| 397 |
+
xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
|
| 398 |
+
xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
|
| 399 |
+
|
| 400 |
+
demo.launch(server_port=5000, share=True)
|
data/image_embeddings/American_Goldfinch_0123_32505.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4405b6dfc87741cf87aa4887f77308aee46209877a7dcf29caacb4dae12459d5
|
| 3 |
+
size 1770910
|
data/image_embeddings/Black_Tern_0101_144331.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:218995c5e9d3256313ead069ff11c89a52ce616221880070d722f27c4227ffe2
|
| 3 |
+
size 1770875
|
data/image_embeddings/Brandt_Cormorant_0040_23144.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c493ed75f6dad68a1336ae3142deea98acb2eec30fbb5345aa1c545660eef4bb
|
| 3 |
+
size 1770900
|
data/image_embeddings/Brown_Thrasher_0014_155421.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c051c80027beeebfabab679b596f5a2b7536c016c2c966a5736b03a980b96a5
|
| 3 |
+
size 1770895
|
data/image_embeddings/Carolina_Wren_0060_186296.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b0c34e05f759b6244ad50ca5529002e26a9370c9db07d22df91e476f827b7724
|
| 3 |
+
size 1770890
|
data/image_embeddings/Cedar_Waxwing_0075_179114.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d91e1fd22664d4dbad771f214ae943b60c26a0e52aeefc156eddbddde8cb0fb
|
| 3 |
+
size 1770890
|
data/image_embeddings/Clark_Nutcracker_0126_85134.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99e85d16d9b4b0d62e92926a7cefce6fbd5298daa1632df02d1d2bc1c812ccf4
|
| 3 |
+
size 1770900
|
data/image_embeddings/Gray_Catbird_0071_20974.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e02ea920306d2a41b2f0a46c3205691e1373d3a443714ba31c67bd46fa0baae8
|
| 3 |
+
size 1770880
|
data/image_embeddings/Heermann_Gull_0097_45783.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:51ecf397a13ffc0ef481b029c7c54498dd9c0dda7db709f9335dba01faebdc65
|
| 3 |
+
size 1770885
|
data/image_embeddings/House_Wren_0137_187273.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3fab5144fff8e0ff975f9064337dc032d39918bf777d149e02e4952a6ed10d8b
|
| 3 |
+
size 1770875
|
data/image_embeddings/Ivory_Gull_0004_49019.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:129b38324da3899caa7182fa0a251c81eba2a8ba8e71995139e269d479456e75
|
| 3 |
+
size 1770870
|
data/image_embeddings/Northern_Waterthrush_0038_177027.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bd735f0756b810b8c74628ca2285311411cb6fb14639277728a60260e64cda9
|
| 3 |
+
size 1770925
|
data/image_embeddings/Pine_Warbler_0113_172456.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c48503ff01eb8af79b86315ab9b6abe7d215c32ab37eb5acc54dd99b9877574
|
| 3 |
+
size 1770885
|
data/image_embeddings/Red_Headed_Woodpecker_0032_182815.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54ec8a9edf3bc0e5e21a989596469efec44815f9ac30a0cdbde4f5d1f1952619
|
| 3 |
+
size 1770930
|
data/image_embeddings/Rufous_Hummingbird_0076_59563.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d02487c6d3b10c2bc193547a3ad863b6b02710071e93b2a99e9be17931c9e785
|
| 3 |
+
size 1770910
|
data/image_embeddings/Sage_Thrasher_0062_796462.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:294ae723107b6cc26f467ef19018f7d0c27befe0ddbf46ea1432a4440cf538c7
|
| 3 |
+
size 1770890
|
data/image_embeddings/Vesper_Sparrow_0030_125663.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c3b58049302546a0f19e1a0da37d85ee3841d1f34674a6263b4972229539806
|
| 3 |
+
size 1770895
|
data/image_embeddings/Western_Grebe_0064_36613.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66a4a4c3d9e8c61c729eef180dca7c06dc19748be507798548bb629fb8283645
|
| 3 |
+
size 1770885
|
data/image_embeddings/White_Eyed_Vireo_0046_158849.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31f5601dd90778785d90da4b079faa4e8082da814b0edb75c46c27f7a59bb0c3
|
| 3 |
+
size 1770905
|
data/image_embeddings/Winter_Wren_0048_189683.jpg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa44fb0827d907160d964837908b8d313bce096d02062be2ea7192e6c2903543
|
| 3 |
+
size 1770880
|
data/images/boxes/American_Goldfinch_0123_32505_all.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_back.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_beak.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_belly.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_breast.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_crown.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_eyes.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_forehead.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_legs.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_nape.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_tail.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_throat.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_visible.jpg
ADDED
|
data/images/boxes/American_Goldfinch_0123_32505_wings.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_all.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_back.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_beak.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_belly.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_breast.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_crown.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_eyes.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_forehead.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_legs.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_nape.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_tail.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_throat.jpg
ADDED
|
data/images/boxes/Black_Tern_0101_144331_visible.jpg
ADDED
|