| from transformers import BertForSequenceClassification, BertTokenizer | |
| import torch | |
| import json | |
| import requests | |
| def predict(text): | |
| repo_id = "logasanjeev/goemotions-bert" | |
| model = BertForSequenceClassification.from_pretrained(repo_id) | |
| tokenizer = BertTokenizer.from_pretrained(repo_id) | |
| thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json" | |
| thresholds_data = json.loads(requests.get(thresholds_url).text) | |
| emotion_labels = thresholds_data["emotion_labels"] | |
| thresholds = thresholds_data["thresholds"] | |
| encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
| with torch.no_grad(): | |
| logits = torch.sigmoid(model(**encodings).logits).numpy()[0] | |
| predictions = [{"label": emotion_labels[i], "score": float(logit)} for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh] | |
| return sorted(predictions, key=lambda x: x["score"], reverse=True) |