update weights
Browse files- __pycache__/app.cpython-311.pyc +0 -0
- app.py +12 -2
- birdvec.py +95 -0
- fetch_img.py +0 -3
__pycache__/app.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
|
|
|
app.py
CHANGED
|
@@ -23,7 +23,7 @@ from fetch_img import download_images, scientific_to_species_code
|
|
| 23 |
from audio_class_predictor import predict_class
|
| 24 |
from bird_ast_model import birdast_preprocess, birdast_inference
|
| 25 |
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference
|
| 26 |
-
|
| 27 |
from utils import plot_wave, plot_mel, download_model, bandpass_filter
|
| 28 |
|
| 29 |
# Define the default parameters
|
|
@@ -60,10 +60,20 @@ birdast_seq_assets = {
|
|
| 60 |
"inference_fn": birdast_seq_inference,
|
| 61 |
}
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
# maintain a dictionary of assets
|
| 64 |
ASSET_DICT = {
|
| 65 |
"BirdAST": birdast_assets,
|
| 66 |
"BirdAST_Seq": birdast_seq_assets,
|
|
|
|
| 67 |
}
|
| 68 |
|
| 69 |
|
|
@@ -251,7 +261,7 @@ with gr.Blocks(theme = seafoam, css = css, js = js) as demo:
|
|
| 251 |
gr.Markdown(DESCRIPTION)
|
| 252 |
|
| 253 |
# add dropdown for model selection
|
| 254 |
-
model_names = ['BirdAST', 'BirdAST_Seq'] #, 'EfficientNet']
|
| 255 |
model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names)
|
| 256 |
download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) # Non-interactive textbox for status
|
| 257 |
model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status)
|
|
|
|
| 23 |
from audio_class_predictor import predict_class
|
| 24 |
from bird_ast_model import birdast_preprocess, birdast_inference
|
| 25 |
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference
|
| 26 |
+
from birdvec import birdvec_preprocess, birdvec_inference
|
| 27 |
from utils import plot_wave, plot_mel, download_model, bandpass_filter
|
| 28 |
|
| 29 |
# Define the default parameters
|
|
|
|
| 60 |
"inference_fn": birdast_seq_inference,
|
| 61 |
}
|
| 62 |
|
| 63 |
+
birdvec_assets = {
|
| 64 |
+
"model_weights": [
|
| 65 |
+
f"https://huggingface.co/amroa/BirdVec/resolve/main/fold{i}/best-model{i}.ckpt" for i in range(3)
|
| 66 |
+
],
|
| 67 |
+
"label_mapping": "https://huggingface.co/amroa/BirdVec/resolve/main/new_label_map.csv",
|
| 68 |
+
"preprocess_fn": birdvec_preprocess,
|
| 69 |
+
"inference_fn": birdvec_inference,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
# maintain a dictionary of assets
|
| 73 |
ASSET_DICT = {
|
| 74 |
"BirdAST": birdast_assets,
|
| 75 |
"BirdAST_Seq": birdast_seq_assets,
|
| 76 |
+
"BirdWav2Vec": birdvec_assets,
|
| 77 |
}
|
| 78 |
|
| 79 |
|
|
|
|
| 261 |
gr.Markdown(DESCRIPTION)
|
| 262 |
|
| 263 |
# add dropdown for model selection
|
| 264 |
+
model_names = ['BirdAST', 'BirdAST_Seq', 'BirdWav2Vec'] #, 'EfficientNet']
|
| 265 |
model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names)
|
| 266 |
download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) # Non-interactive textbox for status
|
| 267 |
model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status)
|
birdvec.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification
|
| 6 |
+
|
| 7 |
+
DEFAULT_SR = 16_000
|
| 8 |
+
DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
| 9 |
+
DEFAULT_N_CLASSES = 728
|
| 10 |
+
MODEL_STR = "dima806/bird_sounds_classification" #"facebook/wav2vec2-base-960h"
|
| 11 |
+
RATE_HZ = 16000
|
| 12 |
+
# Define the maximum audio interval length to consider in seconds
|
| 13 |
+
MAX_SECONDS = 10
|
| 14 |
+
# Calculate the maximum audio interval length in samples by multiplying the rate and seconds
|
| 15 |
+
MAX_LENGTH = RATE_HZ * MAX_SECONDS
|
| 16 |
+
|
| 17 |
+
# Create an instance of the feature extractor for audio.
|
| 18 |
+
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(MODEL_STR)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def birdvec_preprocess(audio_array, sr=DEFAULT_SR):
|
| 23 |
+
"""
|
| 24 |
+
Preprocess audio array for BirdAST model
|
| 25 |
+
audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
|
| 26 |
+
sr: int, sampling rate of the audio array (default: 16_000)
|
| 27 |
+
|
| 28 |
+
Note:
|
| 29 |
+
1. The audio array should be normalized to [-1, 1].
|
| 30 |
+
2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
|
| 31 |
+
"""
|
| 32 |
+
# Extract features
|
| 33 |
+
features = FEATURE_EXTRACTOR(audio_array, sampling_rate=DEFAULT_SR, max_length=MAX_LENGTH, truncation=True, return_tensors="pt")
|
| 34 |
+
|
| 35 |
+
return features.input_values
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def birdvec_inference(
|
| 39 |
+
model_weights,
|
| 40 |
+
spectrogram,
|
| 41 |
+
device = 'cpu',
|
| 42 |
+
backbone_name=None,
|
| 43 |
+
n_classes=728,
|
| 44 |
+
activation=None,
|
| 45 |
+
n_mlp_layers=None
|
| 46 |
+
):
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
Perform inference on BirdAST model
|
| 50 |
+
model_weights: list, list of model weights
|
| 51 |
+
spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
|
| 52 |
+
device: str, device to run inference (default: 'cpu')
|
| 53 |
+
backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
|
| 54 |
+
n_classes: int, number of classes (default: 728)
|
| 55 |
+
activation: str, activation function (default: 'silu')
|
| 56 |
+
n_mlp_layers: int, number of MLP layers (default: 1)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
predict_collects = []
|
| 65 |
+
for _weights in model_weights:
|
| 66 |
+
#model.load_state_dict(torch.load(_weights, map_location=device)['state_dict'])
|
| 67 |
+
model = BirdSongClassifier.load_from_checkpoint(_weights, map_location=device, class_weights = None)
|
| 68 |
+
if device != 'cpu': model.to(device)
|
| 69 |
+
model.eval()
|
| 70 |
+
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
if device != 'cpu': spectrogram = spectrogram.to(device)
|
| 73 |
+
|
| 74 |
+
output = model(spectrogram)
|
| 75 |
+
logits = output['logits']
|
| 76 |
+
probs = F.softmax(logits, dim=-1)
|
| 77 |
+
predict_collects.append(probs)
|
| 78 |
+
|
| 79 |
+
if device != 'cpu':
|
| 80 |
+
predict_collects = [pred.cpu() for pred in predict_collects]
|
| 81 |
+
|
| 82 |
+
predict_collects = torch.cat(predict_collects, dim=0).numpy()
|
| 83 |
+
|
| 84 |
+
return predict_collects
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BirdSongClassifier(pl.LightningModule):
|
| 88 |
+
def __init__(self, class_weights):
|
| 89 |
+
super().__init__()
|
| 90 |
+
config = AutoConfig.from_pretrained("dima806/bird_sounds_classification")
|
| 91 |
+
config.num_labels = 728
|
| 92 |
+
self.model = AutoModelForAudioClassification.from_config(config)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
return self.model(x)
|
fetch_img.py
CHANGED
|
@@ -13,9 +13,6 @@ REQ_FMT = {
|
|
| 13 |
"url": 'https://api.ebird.org/v2/ref/taxonomy/ebird',
|
| 14 |
"params" : {
|
| 15 |
'species': 'CHANGE THIS TO SPECIES CODE'
|
| 16 |
-
},
|
| 17 |
-
"headers" : {
|
| 18 |
-
'X-eBirdApiToken': 'id1a0e3q2lt3'
|
| 19 |
}
|
| 20 |
}
|
| 21 |
bird_df = pd.read_csv("ebird_taxonomy_v2023.csv")
|
|
|
|
| 13 |
"url": 'https://api.ebird.org/v2/ref/taxonomy/ebird',
|
| 14 |
"params" : {
|
| 15 |
'species': 'CHANGE THIS TO SPECIES CODE'
|
|
|
|
|
|
|
|
|
|
| 16 |
}
|
| 17 |
}
|
| 18 |
bird_df = pd.read_csv("ebird_taxonomy_v2023.csv")
|