darisdzakwanhoesien commited on
Commit
c3c791b
·
verified ·
1 Parent(s): 6932b84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from transformers import pipeline
3
  import torch
4
 
5
- # List of your models (using repo IDs for easy loading)
6
  MODELS = {
7
  "econbert": "climatebert/econbert",
8
  "controversy-classification": "climatebert/ClimateControversyBERT_classification",
@@ -13,46 +13,47 @@ MODELS = {
13
  "climate-detector": "climatebert/distilroberta-base-climate-detector",
14
  "climate-commitment": "climatebert/distilroberta-base-climate-commitment",
15
  "climate-tcfd": "climatebert/distilroberta-base-climate-tcfd",
16
- "climate-s": "climatebert/distilroberta-base-climate-s", # Adjust label if truncated
17
  "climate-specificity": "climatebert/distilroberta-base-climate-specificity",
18
  "climate-sentiment": "climatebert/distilroberta-base-climate-sentiment",
19
  "environmental-claims": "climatebert/environmental-claims",
20
- "climate-f": "climatebert/distilroberta-base-climate-f", # Adjust label if truncated
21
- "climate-d-s": "climatebert/distilroberta-base-climate-d-s", # Adjust label if truncated
22
- "climate-d": "climatebert/distilroberta-base-climate-d" # Adjust label if truncated
23
  }
24
 
25
- # Cache for loaded pipelines (lazy loading)
26
  pipelines = {}
27
 
28
  def load_model(model_key):
29
- """Load pipeline for the selected model (text-classification assumed; adjust task if needed)."""
30
  if model_key not in pipelines:
31
  repo_id = MODELS[model_key]
32
- device = 0 if torch.cuda.is_available() else -1 # GPU if available
33
  pipelines[model_key] = pipeline(
34
  "text-classification",
35
  model=repo_id,
36
  device=device,
37
- torch_dtype=torch.float16 if device == 0 else None # Half-precision on GPU for memory savings
 
 
38
  )
39
  return pipelines[model_key]
40
 
41
  def predict(model_key, text):
42
- """Run inference on selected model."""
43
  if not text.strip():
44
  return "Please enter some text."
45
 
46
  try:
47
  model = load_model(model_key)
48
  results = model(text)
49
- # Format output (e.g., for classification: label + score)
50
  formatted = "\n".join([f"{r['label']}: {r['score']:.2f}" for r in results])
51
- return f"Predictions for '{text}':\n{formatted}"
52
  except Exception as e:
53
- return f"Error: {str(e)} (Check model card for task/input format)."
54
 
55
- # Gradio interface
56
  with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo:
57
  gr.Markdown("# ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).")
58
 
 
2
  from transformers import pipeline
3
  import torch
4
 
5
+ # List of models (unchanged)
6
  MODELS = {
7
  "econbert": "climatebert/econbert",
8
  "controversy-classification": "climatebert/ClimateControversyBERT_classification",
 
13
  "climate-detector": "climatebert/distilroberta-base-climate-detector",
14
  "climate-commitment": "climatebert/distilroberta-base-climate-commitment",
15
  "climate-tcfd": "climatebert/distilroberta-base-climate-tcfd",
16
+ "climate-s": "climatebert/distilroberta-base-climate-s",
17
  "climate-specificity": "climatebert/distilroberta-base-climate-specificity",
18
  "climate-sentiment": "climatebert/distilroberta-base-climate-sentiment",
19
  "environmental-claims": "climatebert/environmental-claims",
20
+ "climate-f": "climatebert/distilroberta-base-climate-f",
21
+ "climate-d-s": "climatebert/distilroberta-base-climate-d-s",
22
+ "climate-d": "climatebert/distilroberta-base-climate-d"
23
  }
24
 
25
+ # Cache for loaded pipelines
26
  pipelines = {}
27
 
28
  def load_model(model_key):
29
+ """Load pipeline for the selected model with truncation enabled."""
30
  if model_key not in pipelines:
31
  repo_id = MODELS[model_key]
32
+ device = 0 if torch.cuda.is_available() else -1
33
  pipelines[model_key] = pipeline(
34
  "text-classification",
35
  model=repo_id,
36
  device=device,
37
+ torch_dtype=torch.float16 if device == 0 else None,
38
+ truncation=True, # Truncate input to max_length
39
+ max_length=512 # Set to 512 to match BERT/DistilRoBERTa limit
40
  )
41
  return pipelines[model_key]
42
 
43
  def predict(model_key, text):
44
+ """Run inference on selected model with truncation."""
45
  if not text.strip():
46
  return "Please enter some text."
47
 
48
  try:
49
  model = load_model(model_key)
50
  results = model(text)
 
51
  formatted = "\n".join([f"{r['label']}: {r['score']:.2f}" for r in results])
52
+ return f"Predictions for '{text[:50]}...':\n{formatted}" # Truncate display for readability
53
  except Exception as e:
54
+ return f"Error: {str(e)} (Check input length or model card for details)."
55
 
56
+ # Gradio interface (unchanged)
57
  with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo:
58
  gr.Markdown("# ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).")
59