File size: 974 Bytes
087f7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import json
import requests

def predict(text):
    repo_id = "logasanjeev/goemotions-bert"
    model = BertForSequenceClassification.from_pretrained(repo_id)
    tokenizer = BertTokenizer.from_pretrained(repo_id)
    thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
    thresholds_data = json.loads(requests.get(thresholds_url).text)
    emotion_labels = thresholds_data["emotion_labels"]
    thresholds = thresholds_data["thresholds"]
    encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    with torch.no_grad():
        logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
    predictions = [{"label": emotion_labels[i], "score": float(logit)} for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
    return sorted(predictions, key=lambda x: x["score"], reverse=True)