|
|
|
|
|
"""
|
|
|
Simple Emotion Classification Inference
|
|
|
Downloads model from HuggingFace and classifies emotion in audio
|
|
|
|
|
|
Usage:
|
|
|
python inference.py speech.wav
|
|
|
python inference.py path/to/audio.mp3
|
|
|
"""
|
|
|
import sys
|
|
|
import json
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import pickle
|
|
|
import soundfile as sf
|
|
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
|
|
|
|
|
|
class WeightedLayerPooling(nn.Module):
|
|
|
"""Learnable weighted combination of wav2vec2 layers."""
|
|
|
def __init__(self, num_layers=12, temperature=1.0):
|
|
|
super().__init__()
|
|
|
self.layer_weights = nn.Parameter(torch.zeros(num_layers))
|
|
|
self.temperature = temperature
|
|
|
|
|
|
def forward(self, all_layer_features):
|
|
|
stacked = torch.stack(all_layer_features, dim=0)
|
|
|
scaled_weights = self.layer_weights / self.temperature
|
|
|
attention_weights = torch.softmax(scaled_weights, dim=0)
|
|
|
weights_reshaped = attention_weights.view(-1, 1, 1)
|
|
|
weighted = torch.sum(stacked * weights_reshaped, dim=0)
|
|
|
return weighted, attention_weights
|
|
|
|
|
|
|
|
|
class TransformerEmotionClassifier(nn.Module):
|
|
|
"""Transformer Encoder for emotion classification."""
|
|
|
def __init__(self, input_dim=1024, num_classes=3, patch_size=64, d_model=512,
|
|
|
nhead=8, num_layers=4, dim_feedforward=1024, dropout=0.2,
|
|
|
use_layer_weighting=False, num_wav2vec_layers=12,
|
|
|
layer_attention_temperature=1.0):
|
|
|
super().__init__()
|
|
|
|
|
|
self.input_dim = input_dim
|
|
|
self.patch_size = patch_size
|
|
|
self.num_patches = input_dim // patch_size
|
|
|
self.use_layer_weighting = use_layer_weighting
|
|
|
|
|
|
if use_layer_weighting:
|
|
|
self.layer_pooling = WeightedLayerPooling(
|
|
|
num_layers=num_wav2vec_layers,
|
|
|
temperature=layer_attention_temperature
|
|
|
)
|
|
|
|
|
|
|
|
|
self.patch_embed = nn.Linear(patch_size, d_model)
|
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model) * 0.02)
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
|
|
|
|
|
|
|
|
|
self.pos_drop = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
try:
|
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
|
|
|
dropout=dropout, activation='gelu', batch_first=True,
|
|
|
norm_first=True, enable_nested_tensor=False
|
|
|
)
|
|
|
self.transformer_encoder = nn.TransformerEncoder(
|
|
|
encoder_layer, num_layers=num_layers,
|
|
|
norm=nn.LayerNorm(d_model), enable_nested_tensor=False
|
|
|
)
|
|
|
except TypeError:
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
|
|
|
dropout=dropout, activation='gelu', batch_first=True,
|
|
|
norm_first=True
|
|
|
)
|
|
|
self.transformer_encoder = nn.TransformerEncoder(
|
|
|
encoder_layer, num_layers=num_layers,
|
|
|
norm=nn.LayerNorm(d_model)
|
|
|
)
|
|
|
|
|
|
self.head = nn.Sequential(
|
|
|
nn.LayerNorm(d_model),
|
|
|
nn.Linear(d_model, d_model // 2),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(d_model // 2, num_classes)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
if self.use_layer_weighting:
|
|
|
if isinstance(x, list):
|
|
|
x, _ = self.layer_pooling(x)
|
|
|
else:
|
|
|
raise ValueError("use_layer_weighting=True but input is not a list of layers")
|
|
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
|
|
|
|
|
x = x.view(batch_size, self.num_patches, self.patch_size)
|
|
|
|
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
|
x = torch.cat([cls_tokens, x], dim=1)
|
|
|
|
|
|
|
|
|
x = x + self.pos_embed
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
|
x = self.transformer_encoder(x)
|
|
|
|
|
|
|
|
|
cls_output = x[:, 0]
|
|
|
return self.head(cls_output)
|
|
|
|
|
|
def classify_emotion(audio_path):
|
|
|
"""Classify emotion in audio file"""
|
|
|
|
|
|
print("="*60)
|
|
|
print("EMOTION CLASSIFICATION")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Device: {device}\n")
|
|
|
|
|
|
|
|
|
print("Loading model from HuggingFace...")
|
|
|
model_weights_path = hf_hub_download(
|
|
|
repo_id="vocametrix/speech-emotion-valence-classifier",
|
|
|
filename="emotion_classifier_complete.safetensors"
|
|
|
)
|
|
|
|
|
|
|
|
|
config_path = hf_hub_download(
|
|
|
repo_id="vocametrix/speech-emotion-valence-classifier",
|
|
|
filename="config.json"
|
|
|
)
|
|
|
with open(config_path) as f:
|
|
|
config = json.load(f)
|
|
|
|
|
|
|
|
|
model = TransformerEmotionClassifier(
|
|
|
input_dim=config["input_dim"],
|
|
|
num_classes=config["num_classes"],
|
|
|
patch_size=config["patch_size"],
|
|
|
d_model=config["d_model"],
|
|
|
nhead=config["nhead"],
|
|
|
num_layers=config["num_layers"],
|
|
|
dim_feedforward=config["dim_feedforward"],
|
|
|
dropout=config["dropout"],
|
|
|
use_layer_weighting=config["use_layer_weighting"],
|
|
|
num_wav2vec_layers=config["num_wav2vec_layers"],
|
|
|
layer_attention_temperature=config.get("layer_attention_temperature", 1.0)
|
|
|
).to(device)
|
|
|
|
|
|
state_dict = load_file(model_weights_path, device=str(device))
|
|
|
model.load_state_dict(state_dict)
|
|
|
model.eval()
|
|
|
print("✓ Model loaded (SafeTensors format)")
|
|
|
|
|
|
|
|
|
scaler_path = hf_hub_download(
|
|
|
repo_id="vocametrix/speech-emotion-valence-classifier",
|
|
|
filename="emotion_classifier_scaler.pkl"
|
|
|
)
|
|
|
with open(scaler_path, 'rb') as f:
|
|
|
scaler = pickle.load(f)
|
|
|
print("✓ Scaler loaded")
|
|
|
|
|
|
|
|
|
print("Loading Wav2Vec2 models...")
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
|
|
|
wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
|
|
|
print("✓ Wav2Vec2 loaded\n")
|
|
|
|
|
|
|
|
|
print(f"Processing audio: {audio_path}")
|
|
|
audio_numpy, sample_rate = sf.read(audio_path)
|
|
|
|
|
|
|
|
|
waveform = torch.FloatTensor(audio_numpy)
|
|
|
if waveform.ndim == 1:
|
|
|
waveform = waveform.unsqueeze(0)
|
|
|
elif waveform.ndim == 2 and waveform.shape[1] > 1:
|
|
|
|
|
|
waveform = waveform.mean(dim=1, keepdim=True).transpose(0, 1)
|
|
|
|
|
|
|
|
|
if sample_rate != 16000:
|
|
|
import torchaudio
|
|
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
|
|
waveform = resampler(waveform)
|
|
|
sample_rate = 16000
|
|
|
|
|
|
audio_numpy = waveform.squeeze().numpy()
|
|
|
duration = len(audio_numpy) / sample_rate
|
|
|
print(f"✓ Audio loaded: {duration:.2f}s\n")
|
|
|
|
|
|
|
|
|
print("Extracting features...")
|
|
|
inputs = feature_extractor(
|
|
|
audio_numpy,
|
|
|
sampling_rate=16000,
|
|
|
return_tensors="pt",
|
|
|
padding=True
|
|
|
)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
wav2vec2_model.eval()
|
|
|
outputs = wav2vec2_model(inputs.input_values.to(device), output_hidden_states=True)
|
|
|
|
|
|
|
|
|
|
|
|
all_hidden_states = outputs.hidden_states
|
|
|
selected_layers = all_hidden_states[:13]
|
|
|
layer_features = [layer.mean(dim=1) for layer in selected_layers]
|
|
|
|
|
|
print(f"✓ Extracted {len(layer_features)} layers\n")
|
|
|
|
|
|
|
|
|
scaled_layers = []
|
|
|
for i, layer_feat in enumerate(layer_features):
|
|
|
scaled = scaler.transform(layer_feat.cpu().numpy())
|
|
|
scaled_layers.append(torch.FloatTensor(scaled).to(device))
|
|
|
if i == 0:
|
|
|
print(f"Debug - Layer 0 before scaling: mean={layer_feat.mean().item():.4f}, std={layer_feat.std().item():.4f}")
|
|
|
print(f"Debug - Layer 0 after scaling: mean={scaled.mean():.4f}, std={scaled.std():.4f}")
|
|
|
|
|
|
|
|
|
print("Classifying emotion...")
|
|
|
with torch.no_grad():
|
|
|
emotion_logits = model(scaled_layers)
|
|
|
print(f"Debug - Raw logits: {emotion_logits[0].cpu().numpy()}")
|
|
|
emotion_probs = emotion_logits.softmax(dim=1)
|
|
|
emotion_idx = emotion_logits.argmax(dim=1)
|
|
|
|
|
|
|
|
|
classes = ["negative", "neutral", "positive"]
|
|
|
print("\n" + "="*60)
|
|
|
print("RESULTS")
|
|
|
print("="*60)
|
|
|
print(f"Emotion: {classes[emotion_idx.item()].upper()}")
|
|
|
print(f"Confidence: {emotion_probs.max().item():.1%}")
|
|
|
print(f"\nProbabilities:")
|
|
|
print(f" Negative: {emotion_probs[0, 0]:.1%}")
|
|
|
print(f" Neutral: {emotion_probs[0, 1]:.1%}")
|
|
|
print(f" Positive: {emotion_probs[0, 2]:.1%}")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
if len(sys.argv) != 2:
|
|
|
print("Usage: python inference.py <audio_file>")
|
|
|
print("Example: python inference.py speech.wav")
|
|
|
sys.exit(1)
|
|
|
|
|
|
audio_file = sys.argv[1]
|
|
|
|
|
|
try:
|
|
|
classify_emotion(audio_file)
|
|
|
except FileNotFoundError:
|
|
|
print(f"Error: File not found: {audio_file}")
|
|
|
sys.exit(1)
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
sys.exit(1)
|
|
|
|