Spaces:
Sleeping
Sleeping
Update tasks/audio.py
Browse files- tasks/audio.py +31 -12
tasks/audio.py
CHANGED
|
@@ -5,6 +5,9 @@ from sklearn.metrics import accuracy_score
|
|
| 5 |
import random
|
| 6 |
import os
|
| 7 |
import torch
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from .utils.evaluation import AudioEvaluationRequest
|
| 10 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
@@ -14,7 +17,7 @@ load_dotenv()
|
|
| 14 |
|
| 15 |
router = APIRouter()
|
| 16 |
|
| 17 |
-
DESCRIPTION = "
|
| 18 |
ROUTE = "/audio"
|
| 19 |
|
| 20 |
|
|
@@ -55,30 +58,46 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
| 55 |
#--------------------------------------------------------------------------------------------
|
| 56 |
|
| 57 |
# Make random predictions (placeholder for actual model inference)
|
| 58 |
-
def preprocess_audio(example):
|
| 59 |
-
"""Convert dataset into tensors."""
|
| 60 |
-
waveform = torch.tensor(example["audio"]["array"], dtype=torch.float32).unsqueeze(0) # Add batch dim
|
| 61 |
-
label = torch.tensor(example["label"], dtype=torch.long) # Ensure labels are `int64`
|
| 62 |
-
return waveform, label
|
| 63 |
-
|
| 64 |
model_path = "quantized_teacher_m5_static.pth"
|
| 65 |
model, device = load_model(model_path)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
train_test = train_test.map(preprocess_audio)
|
| 69 |
test_dataset = train_test.map(preprocess_audio)
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
|
| 74 |
-
true_labels =
|
| 75 |
predictions = []
|
| 76 |
|
| 77 |
with torch.no_grad():
|
| 78 |
-
for waveforms, labels in
|
| 79 |
waveforms, labels = waveforms.to(device), labels.to(device)
|
| 80 |
|
| 81 |
-
# Run Model
|
| 82 |
outputs = model(waveforms)
|
| 83 |
predicted_label = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
| 84 |
true_labels.extend(labels.cpu().numpy())
|
|
|
|
| 5 |
import random
|
| 6 |
import os
|
| 7 |
import torch
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from Model_Loader import load_model
|
| 10 |
+
|
| 11 |
|
| 12 |
from .utils.evaluation import AudioEvaluationRequest
|
| 13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
|
|
| 17 |
|
| 18 |
router = APIRouter()
|
| 19 |
|
| 20 |
+
DESCRIPTION = "Quantized M5"
|
| 21 |
ROUTE = "/audio"
|
| 22 |
|
| 23 |
|
|
|
|
| 58 |
#--------------------------------------------------------------------------------------------
|
| 59 |
|
| 60 |
# Make random predictions (placeholder for actual model inference)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
model_path = "quantized_teacher_m5_static.pth"
|
| 62 |
model, device = load_model(model_path)
|
| 63 |
|
| 64 |
+
def preprocess_audio(example, target_length=32000):
|
| 65 |
+
"""
|
| 66 |
+
Convert dataset into tensors:
|
| 67 |
+
- Convert to tensor
|
| 68 |
+
- Normalize waveform
|
| 69 |
+
- Pad/truncate to `target_length`
|
| 70 |
+
"""
|
| 71 |
+
waveform = torch.tensor(example["audio"]["array"], dtype=torch.float32).unsqueeze(0) # Add batch dim
|
| 72 |
+
|
| 73 |
+
# Normalize waveform
|
| 74 |
+
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
|
| 75 |
+
|
| 76 |
+
# Pad or truncate to fixed length
|
| 77 |
+
if waveform.shape[1] < target_length:
|
| 78 |
+
pad = torch.zeros(1, target_length - waveform.shape[1])
|
| 79 |
+
waveform = torch.cat((waveform, pad), dim=1) # Pad
|
| 80 |
+
else:
|
| 81 |
+
waveform = waveform[:, :target_length] # Truncate
|
| 82 |
+
|
| 83 |
+
label = torch.tensor(example["label"], dtype=torch.long) # Ensure int64
|
| 84 |
+
return {"waveform": waveform, "label": label}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
|
| 88 |
+
train_test = train_test.map(preprocess_audio, batched=True)
|
| 89 |
test_dataset = train_test.map(preprocess_audio)
|
| 90 |
|
| 91 |
+
train_loader = DataLoader(train_test, batch_size=32, shuffle=True)
|
| 92 |
|
| 93 |
|
| 94 |
+
true_labels = train_dataset["label"]
|
| 95 |
predictions = []
|
| 96 |
|
| 97 |
with torch.no_grad():
|
| 98 |
+
for waveforms, labels in train_loader:
|
| 99 |
waveforms, labels = waveforms.to(device), labels.to(device)
|
| 100 |
|
|
|
|
| 101 |
outputs = model(waveforms)
|
| 102 |
predicted_label = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
| 103 |
true_labels.extend(labels.cpu().numpy())
|