import gradio as gr from transformers import pipeline import torch # List of models MODELS = { "econbert": "climatebert/econbert", "controversy-classification": "climatebert/ClimateControversyBERT_classification", "controversy-bert": "climatebert/ClimateControversyBert", "netzero-reduction": "climatebert/netzero-reduction", "transition-physical": "climatebert/transition-physical", "renewable": "climatebert/renewable", "climate-detector": "climatebert/distilroberta-base-climate-detector", "climate-commitment": "climatebert/distilroberta-base-climate-commitment", "climate-tcfd": "climatebert/distilroberta-base-climate-tcfd", "climate-s": "climatebert/distilroberta-base-climate-s", "climate-specificity": "climatebert/distilroberta-base-climate-specificity", "climate-sentiment": "climatebert/distilroberta-base-climate-sentiment", "environmental-claims": "climatebert/environmental-claims", "climate-f": "climatebert/distilroberta-base-climate-f", "climate-d-s": "climatebert/distilroberta-base-climate-d-s", "climate-d": "climatebert/distilroberta-base-climate-d" } # Human-readable label mappings LABEL_MAPS = { "climate-commitment": { "LABEL_0": "Not about climate commitments", "LABEL_1": "About climate commitments", }, "climate-detector": { "LABEL_0": "Not climate-related", "LABEL_1": "Climate-related", }, "climate-sentiment": { "LABEL_0": "Negative", "LABEL_1": "Neutral", "LABEL_2": "Positive", }, "climate-specificity": { "LABEL_0": "Low specificity", "LABEL_1": "Medium specificity", "LABEL_2": "High specificity", }, "netzero-reduction": { "LABEL_0": "No net-zero / reduction commitment", "LABEL_1": "Net-zero / reduction commitment", }, "transition-physical": { "LABEL_0": "Transition risk", "LABEL_1": "Physical risk", }, "renewable": { "LABEL_0": "Not about renewables", "LABEL_1": "About renewables", }, # You can expand mappings for other models after checking their model cards } # Cache for loaded pipelines pipelines = {} def load_model(model_key): """Load pipeline for the selected model with truncation enabled.""" if model_key not in pipelines: repo_id = MODELS[model_key] device = 0 if torch.cuda.is_available() else -1 pipelines[model_key] = pipeline( "text-classification", model=repo_id, device=device, torch_dtype=torch.float16 if device == 0 else None, truncation=True, max_length=512 ) return pipelines[model_key] def predict(model_key, text): """Run inference on selected model with truncation and readable labels.""" if not text.strip(): return "Please enter some text." try: model = load_model(model_key) results = model(text) label_map = LABEL_MAPS.get(model_key, {}) formatted = "\n".join([ f"{label_map.get(r['label'], r['label'])}: {r['score']:.2f}" for r in results ]) return f"Predictions for '{text[:50]}...':\n{formatted}" except Exception as e: return f"Error: {str(e)} (Check input length or model card for details)." # Gradio interface with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo: gr.Markdown("# 🌍 ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), label="Select Model", value=list(MODELS.keys())[0] ) text_input = gr.Textbox( label="Input Text", placeholder="E.g., 'Companies must reduce emissions to net zero by 2050.'", lines=2 ) output = gr.Textbox(label="Output", lines=5) predict_btn = gr.Button("Run Inference") predict_btn.click(predict, inputs=[model_dropdown, text_input], outputs=output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)