Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import gradio as gr
|
|
| 2 |
from transformers import pipeline
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
# List of
|
| 6 |
MODELS = {
|
| 7 |
"econbert": "climatebert/econbert",
|
| 8 |
"controversy-classification": "climatebert/ClimateControversyBERT_classification",
|
|
@@ -13,46 +13,47 @@ MODELS = {
|
|
| 13 |
"climate-detector": "climatebert/distilroberta-base-climate-detector",
|
| 14 |
"climate-commitment": "climatebert/distilroberta-base-climate-commitment",
|
| 15 |
"climate-tcfd": "climatebert/distilroberta-base-climate-tcfd",
|
| 16 |
-
"climate-s": "climatebert/distilroberta-base-climate-s",
|
| 17 |
"climate-specificity": "climatebert/distilroberta-base-climate-specificity",
|
| 18 |
"climate-sentiment": "climatebert/distilroberta-base-climate-sentiment",
|
| 19 |
"environmental-claims": "climatebert/environmental-claims",
|
| 20 |
-
"climate-f": "climatebert/distilroberta-base-climate-f",
|
| 21 |
-
"climate-d-s": "climatebert/distilroberta-base-climate-d-s",
|
| 22 |
-
"climate-d": "climatebert/distilroberta-base-climate-d"
|
| 23 |
}
|
| 24 |
|
| 25 |
-
# Cache for loaded pipelines
|
| 26 |
pipelines = {}
|
| 27 |
|
| 28 |
def load_model(model_key):
|
| 29 |
-
"""Load pipeline for the selected model
|
| 30 |
if model_key not in pipelines:
|
| 31 |
repo_id = MODELS[model_key]
|
| 32 |
-
device = 0 if torch.cuda.is_available() else -1
|
| 33 |
pipelines[model_key] = pipeline(
|
| 34 |
"text-classification",
|
| 35 |
model=repo_id,
|
| 36 |
device=device,
|
| 37 |
-
torch_dtype=torch.float16 if device == 0 else None
|
|
|
|
|
|
|
| 38 |
)
|
| 39 |
return pipelines[model_key]
|
| 40 |
|
| 41 |
def predict(model_key, text):
|
| 42 |
-
"""Run inference on selected model."""
|
| 43 |
if not text.strip():
|
| 44 |
return "Please enter some text."
|
| 45 |
|
| 46 |
try:
|
| 47 |
model = load_model(model_key)
|
| 48 |
results = model(text)
|
| 49 |
-
# Format output (e.g., for classification: label + score)
|
| 50 |
formatted = "\n".join([f"{r['label']}: {r['score']:.2f}" for r in results])
|
| 51 |
-
return f"Predictions for '{text}':\n{formatted}"
|
| 52 |
except Exception as e:
|
| 53 |
-
return f"Error: {str(e)} (Check model card for
|
| 54 |
|
| 55 |
-
# Gradio interface
|
| 56 |
with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo:
|
| 57 |
gr.Markdown("# ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).")
|
| 58 |
|
|
|
|
| 2 |
from transformers import pipeline
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# List of models (unchanged)
|
| 6 |
MODELS = {
|
| 7 |
"econbert": "climatebert/econbert",
|
| 8 |
"controversy-classification": "climatebert/ClimateControversyBERT_classification",
|
|
|
|
| 13 |
"climate-detector": "climatebert/distilroberta-base-climate-detector",
|
| 14 |
"climate-commitment": "climatebert/distilroberta-base-climate-commitment",
|
| 15 |
"climate-tcfd": "climatebert/distilroberta-base-climate-tcfd",
|
| 16 |
+
"climate-s": "climatebert/distilroberta-base-climate-s",
|
| 17 |
"climate-specificity": "climatebert/distilroberta-base-climate-specificity",
|
| 18 |
"climate-sentiment": "climatebert/distilroberta-base-climate-sentiment",
|
| 19 |
"environmental-claims": "climatebert/environmental-claims",
|
| 20 |
+
"climate-f": "climatebert/distilroberta-base-climate-f",
|
| 21 |
+
"climate-d-s": "climatebert/distilroberta-base-climate-d-s",
|
| 22 |
+
"climate-d": "climatebert/distilroberta-base-climate-d"
|
| 23 |
}
|
| 24 |
|
| 25 |
+
# Cache for loaded pipelines
|
| 26 |
pipelines = {}
|
| 27 |
|
| 28 |
def load_model(model_key):
|
| 29 |
+
"""Load pipeline for the selected model with truncation enabled."""
|
| 30 |
if model_key not in pipelines:
|
| 31 |
repo_id = MODELS[model_key]
|
| 32 |
+
device = 0 if torch.cuda.is_available() else -1
|
| 33 |
pipelines[model_key] = pipeline(
|
| 34 |
"text-classification",
|
| 35 |
model=repo_id,
|
| 36 |
device=device,
|
| 37 |
+
torch_dtype=torch.float16 if device == 0 else None,
|
| 38 |
+
truncation=True, # Truncate input to max_length
|
| 39 |
+
max_length=512 # Set to 512 to match BERT/DistilRoBERTa limit
|
| 40 |
)
|
| 41 |
return pipelines[model_key]
|
| 42 |
|
| 43 |
def predict(model_key, text):
|
| 44 |
+
"""Run inference on selected model with truncation."""
|
| 45 |
if not text.strip():
|
| 46 |
return "Please enter some text."
|
| 47 |
|
| 48 |
try:
|
| 49 |
model = load_model(model_key)
|
| 50 |
results = model(text)
|
|
|
|
| 51 |
formatted = "\n".join([f"{r['label']}: {r['score']:.2f}" for r in results])
|
| 52 |
+
return f"Predictions for '{text[:50]}...':\n{formatted}" # Truncate display for readability
|
| 53 |
except Exception as e:
|
| 54 |
+
return f"Error: {str(e)} (Check input length or model card for details)."
|
| 55 |
|
| 56 |
+
# Gradio interface (unchanged)
|
| 57 |
with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo:
|
| 58 |
gr.Markdown("# ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).")
|
| 59 |
|