File size: 974 Bytes
087f7cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import json
import requests
def predict(text):
repo_id = "logasanjeev/goemotions-bert"
model = BertForSequenceClassification.from_pretrained(repo_id)
tokenizer = BertTokenizer.from_pretrained(repo_id)
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
thresholds_data = json.loads(requests.get(thresholds_url).text)
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
predictions = [{"label": emotion_labels[i], "score": float(logit)} for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
return sorted(predictions, key=lambda x: x["score"], reverse=True) |