import os
import gradio as gr
from transformers import pipeline
from PIL import Image
# -------------------------
# Categories & Keyword Map
# -------------------------
categories = ["Recycle", "Compost", "Trash"]
keyword_map = {
"compost": ["banana", "apple", "orange", "peel", "fruit", "vegetable", "coffee grounds"],
"recycle": ["can", "bottle", "paper", "plastic", "aluminum", "cardboard"],
"trash": ["chip bag", "styrofoam", "napkin", "candy wrapper"]
}
# -------------------------
# Lazy Model Loading
# -------------------------
image_model = None
text_model = None
def load_models():
global image_model, text_model
if image_model is None:
image_model = pipeline("image-classification", model="microsoft/resnet-50")
if text_model is None:
text_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# -------------------------
# Global Leaderboard (in-memory)
# -------------------------
leaderboard_data = {"users": [], "counts": []}
def update_leaderboard(username):
users = leaderboard_data["users"]
counts = leaderboard_data["counts"]
if username in users:
idx = users.index(username)
counts[idx] += 1
else:
users.append(username)
counts.append(1)
leaderboard_data["users"] = users
leaderboard_data["counts"] = counts
# Get top 3
sorted_pairs = sorted(zip(users, counts), key=lambda x: x[1], reverse=True)[:3]
# Build HTML table with "Daily Leaderboard" header
leaderboard_html = "
๐ Daily Leaderboard
"
leaderboard_html += ""
leaderboard_html += "| ๐
Rank | ๐ค User | ๐งช Tests |
"
for i, (u, c) in enumerate(sorted_pairs, start=1):
leaderboard_html += f"| {i} | {u} | {c} |
"
leaderboard_html += "
"
return leaderboard_html
# -------------------------
# Text Classification
# -------------------------
def classify_text(description):
desc = description.lower()
for category, keywords in keyword_map.items():
if any(word in desc for word in keywords):
return category.capitalize()
# fallback to zero-shot
load_models()
text_pred = text_model(description, candidate_labels=categories)
return text_pred["labels"][0]
# -------------------------
# Main Function
# -------------------------
def classify_trash(image, description, username):
load_models()
# Image classification
image_pred = image_model(image)[0]["label"].lower() if image else ""
text_pred_label = classify_text(description)
# Decision logic
if image_pred and image_pred == text_pred_label.lower():
classification_result = f"โ
Confident: {text_pred_label}"
else:
classification_result = f"{text_pred_label}"
# Update leaderboard
leaderboard_html = update_leaderboard(username)
return classification_result, leaderboard_html
# -------------------------
# Gradio Interface
# -------------------------
app = gr.Interface(
fn=classify_trash,
inputs=[
gr.Image(type="pil", label="Take or Upload a Photo"),
gr.Textbox(label="Describe the Item (e.g. banana peel, soda can)"),
gr.Textbox(label="Enter Your Name")
],
outputs=[
gr.Textbox(label="Sorting Suggestion"),
gr.HTML(label="๐ Leaderboard (Top 3)")
],
title="โป๏ธ Cam-Bin",
description="""
Upload and describe items โ figure out where to trash within seconds ๐ฑ
"""
)
if __name__ == "__main__":
app.launch()