|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
from PIL import Image |
|
|
from datasets import Dataset, load_dataset |
|
|
import huggingface_hub |
|
|
|
|
|
|
|
|
huggingface_hub.login("HFTOKEN") |
|
|
|
|
|
|
|
|
dataset_name = "CamBinDEV/global-usage" |
|
|
try: |
|
|
ds = load_dataset(dataset_name, split="train") |
|
|
except: |
|
|
ds = Dataset.from_dict({"count": [0]}) |
|
|
ds.push_to_hub(dataset_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
categories = ["Recycle", "Compost", "Trash"] |
|
|
keyword_map = { |
|
|
"compost": ["banana", "apple", "orange", "peel", "fruit", "vegetable", "coffee grounds"], |
|
|
"recycle": ["can", "bottle", "paper", "plastic", "aluminum", "cardboard"], |
|
|
"trash": ["chip bag", "styrofoam", "napkin", "candy wrapper"] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_model = None |
|
|
text_model = None |
|
|
|
|
|
def load_models(): |
|
|
global image_model, text_model |
|
|
if image_model is None: |
|
|
image_model = pipeline("image-classification", model="microsoft/resnet-50") |
|
|
if text_model is None: |
|
|
text_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_text(description): |
|
|
desc = description.lower() |
|
|
for category, keywords in keyword_map.items(): |
|
|
if any(word in desc for word in keywords): |
|
|
return category.capitalize() |
|
|
load_models() |
|
|
text_pred = text_model(description, candidate_labels=categories) |
|
|
return text_pred["labels"][0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_trash(image, description): |
|
|
global ds |
|
|
load_models() |
|
|
|
|
|
|
|
|
image_pred = image_model(image)[0]["label"].lower() if image else "" |
|
|
text_pred_label = classify_text(description) |
|
|
|
|
|
|
|
|
count = ds["count"][0] + 1 |
|
|
ds = Dataset.from_dict({"count": [count]}) |
|
|
ds.push_to_hub(dataset_name) |
|
|
|
|
|
|
|
|
if image_pred and image_pred == text_pred_label.lower(): |
|
|
classification_result = f"β
Confident: {text_pred_label}" |
|
|
else: |
|
|
classification_result = f"{text_pred_label}" |
|
|
|
|
|
|
|
|
classification_result += f"\n\nπ Total items tested globally: {count}" |
|
|
|
|
|
return classification_result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = gr.Interface( |
|
|
fn=classify_trash, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Take or Upload a Photo"), |
|
|
gr.Textbox(label="Describe the Item (e.g. banana peel, soda can)") |
|
|
], |
|
|
outputs=gr.Textbox(label="Sorting Suggestion"), |
|
|
title="β»οΈ Cam-Bin", |
|
|
description="Upload an item β see where it should go and track global usage!" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |
|
|
|
|
|
|