pmarmaroli's picture
Upload inference.py
3bbd08e verified
#!/usr/bin/env python3
"""
Simple Emotion Classification Inference
Downloads model from HuggingFace and classifies emotion in audio
Usage:
python inference.py speech.wav
python inference.py path/to/audio.mp3
"""
import sys
import json
import torch
import torch.nn as nn
import pickle
import soundfile as sf
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Model architecture definition
class WeightedLayerPooling(nn.Module):
"""Learnable weighted combination of wav2vec2 layers."""
def __init__(self, num_layers=12, temperature=1.0):
super().__init__()
self.layer_weights = nn.Parameter(torch.zeros(num_layers))
self.temperature = temperature
def forward(self, all_layer_features):
stacked = torch.stack(all_layer_features, dim=0)
scaled_weights = self.layer_weights / self.temperature
attention_weights = torch.softmax(scaled_weights, dim=0)
weights_reshaped = attention_weights.view(-1, 1, 1)
weighted = torch.sum(stacked * weights_reshaped, dim=0)
return weighted, attention_weights
class TransformerEmotionClassifier(nn.Module):
"""Transformer Encoder for emotion classification."""
def __init__(self, input_dim=1024, num_classes=3, patch_size=64, d_model=512,
nhead=8, num_layers=4, dim_feedforward=1024, dropout=0.2,
use_layer_weighting=False, num_wav2vec_layers=12,
layer_attention_temperature=1.0):
super().__init__()
self.input_dim = input_dim
self.patch_size = patch_size
self.num_patches = input_dim // patch_size
self.use_layer_weighting = use_layer_weighting
if use_layer_weighting:
self.layer_pooling = WeightedLayerPooling(
num_layers=num_wav2vec_layers,
temperature=layer_attention_temperature
)
# Patch embedding
self.patch_embed = nn.Linear(patch_size, d_model)
# Learnable positional embeddings
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model) * 0.02)
# CLS token
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
# Dropout
self.pos_drop = nn.Dropout(dropout)
# Create transformer encoder layer (compatible with PyTorch 1.x and 2.x)
try:
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
dropout=dropout, activation='gelu', batch_first=True,
norm_first=True, enable_nested_tensor=False
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers,
norm=nn.LayerNorm(d_model), enable_nested_tensor=False
)
except TypeError:
# PyTorch < 2.0 doesn't support enable_nested_tensor
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
dropout=dropout, activation='gelu', batch_first=True,
norm_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers,
norm=nn.LayerNorm(d_model)
)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model // 2, num_classes)
)
def forward(self, x):
if self.use_layer_weighting:
if isinstance(x, list):
x, _ = self.layer_pooling(x)
else:
raise ValueError("use_layer_weighting=True but input is not a list of layers")
batch_size = x.shape[0]
# Reshape into patches
x = x.view(batch_size, self.num_patches, self.patch_size)
# Patch embedding
x = self.patch_embed(x)
# Add CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Add positional embeddings
x = x + self.pos_embed
x = self.pos_drop(x)
# Transformer encoder
x = self.transformer_encoder(x)
# Use CLS token for classification
cls_output = x[:, 0]
return self.head(cls_output)
def classify_emotion(audio_path):
"""Classify emotion in audio file"""
print("="*60)
print("EMOTION CLASSIFICATION")
print("="*60)
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")
# Download and load model from SafeTensors (safer than pickle/PyTorch)
print("Loading model from HuggingFace...")
model_weights_path = hf_hub_download(
repo_id="vocametrix/speech-emotion-valence-classifier",
filename="emotion_classifier_complete.safetensors"
)
# Load config to recreate model architecture
config_path = hf_hub_download(
repo_id="vocametrix/speech-emotion-valence-classifier",
filename="config.json"
)
with open(config_path) as f:
config = json.load(f)
# Create model and load weights from SafeTensors
model = TransformerEmotionClassifier(
input_dim=config["input_dim"],
num_classes=config["num_classes"],
patch_size=config["patch_size"],
d_model=config["d_model"],
nhead=config["nhead"],
num_layers=config["num_layers"],
dim_feedforward=config["dim_feedforward"],
dropout=config["dropout"],
use_layer_weighting=config["use_layer_weighting"],
num_wav2vec_layers=config["num_wav2vec_layers"],
layer_attention_temperature=config.get("layer_attention_temperature", 1.0)
).to(device)
state_dict = load_file(model_weights_path, device=str(device))
model.load_state_dict(state_dict)
model.eval()
print("✓ Model loaded (SafeTensors format)")
# Download and load scaler
scaler_path = hf_hub_download(
repo_id="vocametrix/speech-emotion-valence-classifier",
filename="emotion_classifier_scaler.pkl"
)
with open(scaler_path, 'rb') as f:
scaler = pickle.load(f)
print("✓ Scaler loaded")
# Load wav2vec2 feature extractor and model (multilingual)
print("Loading Wav2Vec2 models...")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
print("✓ Wav2Vec2 loaded\n")
# Load and process audio using soundfile (compatible with training data extraction)
print(f"Processing audio: {audio_path}")
audio_numpy, sample_rate = sf.read(audio_path)
# Convert to torch tensor and add channel dimension if mono
waveform = torch.FloatTensor(audio_numpy)
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0) # Add channel dimension
elif waveform.ndim == 2 and waveform.shape[1] > 1:
# Stereo to mono
waveform = waveform.mean(dim=1, keepdim=True).transpose(0, 1)
# Resample if needed
if sample_rate != 16000:
import torchaudio
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
audio_numpy = waveform.squeeze().numpy()
duration = len(audio_numpy) / sample_rate
print(f"✓ Audio loaded: {duration:.2f}s\n")
# Extract wav2vec2 features (multi-layer with learned weighting)
print("Extracting features...")
inputs = feature_extractor(
audio_numpy,
sampling_rate=16000,
return_tensors="pt",
padding=True
)
with torch.no_grad():
wav2vec2_model.eval()
outputs = wav2vec2_model(inputs.input_values.to(device), output_hidden_states=True)
# Extract FIRST 13 hidden states and mean pool each (layers 0-12, same as training)
# Training script uses range(13) which is layers 0-12
all_hidden_states = outputs.hidden_states
selected_layers = all_hidden_states[:13] # First 13 layers (0-12)
layer_features = [layer.mean(dim=1) for layer in selected_layers]
print(f"✓ Extracted {len(layer_features)} layers\n")
# Normalize each layer
scaled_layers = []
for i, layer_feat in enumerate(layer_features):
scaled = scaler.transform(layer_feat.cpu().numpy())
scaled_layers.append(torch.FloatTensor(scaled).to(device))
if i == 0: # Debug first layer
print(f"Debug - Layer 0 before scaling: mean={layer_feat.mean().item():.4f}, std={layer_feat.std().item():.4f}")
print(f"Debug - Layer 0 after scaling: mean={scaled.mean():.4f}, std={scaled.std():.4f}")
# Predict (model will apply learned layer weighting)
print("Classifying emotion...")
with torch.no_grad():
emotion_logits = model(scaled_layers)
print(f"Debug - Raw logits: {emotion_logits[0].cpu().numpy()}")
emotion_probs = emotion_logits.softmax(dim=1)
emotion_idx = emotion_logits.argmax(dim=1)
# Display results
classes = ["negative", "neutral", "positive"]
print("\n" + "="*60)
print("RESULTS")
print("="*60)
print(f"Emotion: {classes[emotion_idx.item()].upper()}")
print(f"Confidence: {emotion_probs.max().item():.1%}")
print(f"\nProbabilities:")
print(f" Negative: {emotion_probs[0, 0]:.1%}")
print(f" Neutral: {emotion_probs[0, 1]:.1%}")
print(f" Positive: {emotion_probs[0, 2]:.1%}")
print("="*60)
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python inference.py <audio_file>")
print("Example: python inference.py speech.wav")
sys.exit(1)
audio_file = sys.argv[1]
try:
classify_emotion(audio_file)
except FileNotFoundError:
print(f"Error: File not found: {audio_file}")
sys.exit(1)
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)