Cam-Bin / app.py
CamBinDEV's picture
Update app.py
e91dc29 verified
import gradio as gr
from transformers import pipeline
from PIL import Image
from datasets import Dataset, load_dataset
import huggingface_hub
# Authenticate using your HF token (set as secret in Space)
huggingface_hub.login("HFTOKEN")
# Load or create dataset for global usage
dataset_name = "CamBinDEV/global-usage"
try:
ds = load_dataset(dataset_name, split="train")
except:
ds = Dataset.from_dict({"count": [0]})
ds.push_to_hub(dataset_name)
# -------------------------
# Categories & Keyword Map
# -------------------------
categories = ["Recycle", "Compost", "Trash"]
keyword_map = {
"compost": ["banana", "apple", "orange", "peel", "fruit", "vegetable", "coffee grounds"],
"recycle": ["can", "bottle", "paper", "plastic", "aluminum", "cardboard"],
"trash": ["chip bag", "styrofoam", "napkin", "candy wrapper"]
}
# -------------------------
# Lazy Model Loading
# -------------------------
image_model = None
text_model = None
def load_models():
global image_model, text_model
if image_model is None:
image_model = pipeline("image-classification", model="microsoft/resnet-50")
if text_model is None:
text_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# -------------------------
# Text Classification
# -------------------------
def classify_text(description):
desc = description.lower()
for category, keywords in keyword_map.items():
if any(word in desc for word in keywords):
return category.capitalize()
load_models()
text_pred = text_model(description, candidate_labels=categories)
return text_pred["labels"][0]
# -------------------------
# Main Function
# -------------------------
def classify_trash(image, description):
global ds
load_models()
# Image prediction
image_pred = image_model(image)[0]["label"].lower() if image else ""
text_pred_label = classify_text(description)
# Increment global usage
count = ds["count"][0] + 1
ds = Dataset.from_dict({"count": [count]})
ds.push_to_hub(dataset_name)
# Result text
if image_pred and image_pred == text_pred_label.lower():
classification_result = f"βœ… Confident: {text_pred_label}"
else:
classification_result = f"{text_pred_label}"
# Display global usage
classification_result += f"\n\n🌎 Total items tested globally: {count}"
return classification_result
# -------------------------
# Gradio Interface
# -------------------------
app = gr.Interface(
fn=classify_trash,
inputs=[
gr.Image(type="pil", label="Take or Upload a Photo"),
gr.Textbox(label="Describe the Item (e.g. banana peel, soda can)")
],
outputs=gr.Textbox(label="Sorting Suggestion"),
title="♻️ Cam-Bin",
description="Upload an item β€” see where it should go and track global usage!"
)
if __name__ == "__main__":
app.launch()