import gradio as gr from transformers import pipeline import torch # List of your models (using repo IDs for easy loading) MODELS = { "econbert": "climatebert/econbert", "controversy-classification": "climatebert/ClimateControversyBERT_classification", "controversy-bert": "climatebert/ClimateControversyBert", "netzero-reduction": "climatebert/netzero-reduction", "transition-physical": "climatebert/transition-physical", "renewable": "climatebert/renewable", "climate-detector": "climatebert/distilroberta-base-climate-detector", "climate-commitment": "climatebert/distilroberta-base-climate-commitment", "climate-tcfd": "climatebert/distilroberta-base-climate-tcfd", "climate-s": "climatebert/distilroberta-base-climate-s", # Adjust label if truncated "climate-specificity": "climatebert/distilroberta-base-climate-specificity", "climate-sentiment": "climatebert/distilroberta-base-climate-sentiment", "environmental-claims": "climatebert/environmental-claims", "climate-f": "climatebert/distilroberta-base-climate-f", # Adjust label if truncated "climate-d-s": "climatebert/distilroberta-base-climate-d-s", # Adjust label if truncated "climate-d": "climatebert/distilroberta-base-climate-d" # Adjust label if truncated } # Cache for loaded pipelines (lazy loading) pipelines = {} def load_model(model_key): """Load pipeline for the selected model (text-classification assumed; adjust task if needed).""" if model_key not in pipelines: repo_id = MODELS[model_key] device = 0 if torch.cuda.is_available() else -1 # GPU if available pipelines[model_key] = pipeline( "text-classification", model=repo_id, device=device, torch_dtype=torch.float16 if device == 0 else None # Half-precision on GPU for memory savings ) return pipelines[model_key] def predict(model_key, text): """Run inference on selected model.""" if not text.strip(): return "Please enter some text." try: model = load_model(model_key) results = model(text) # Format output (e.g., for classification: label + score) formatted = "\n".join([f"{r['label']}: {r['score']:.2f}" for r in results]) return f"Predictions for '{text}':\n{formatted}" except Exception as e: return f"Error: {str(e)} (Check model card for task/input format)." # Gradio interface with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo: gr.Markdown("# ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), label="Select Model", value=list(MODELS.keys())[0] ) text_input = gr.Textbox( label="Input Text", placeholder="E.g., 'Companies must reduce emissions to net zero by 2050.'", lines=2 ) output = gr.Textbox(label="Output", lines=5) predict_btn = gr.Button("Run Inference") predict_btn.click(predict, inputs=[model_dropdown, text_input], outputs=output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)