Max005 commited on
Commit
151ed35
·
1 Parent(s): 39c3fb8

upload files

Browse files
Deepfake/model/config.json ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-base",
3
+ "activation_dropout": 0.0,
4
+ "adapter_attn_dim": null,
5
+ "adapter_kernel_size": 3,
6
+ "adapter_stride": 2,
7
+ "add_adapter": false,
8
+ "apply_spec_augment": true,
9
+ "architectures": [
10
+ "Wav2Vec2ForSequenceClassification"
11
+ ],
12
+ "attention_dropout": 0.1,
13
+ "bos_token_id": 1,
14
+ "classifier_proj_size": 256,
15
+ "codevector_dim": 256,
16
+ "contrastive_logits_temperature": 0.1,
17
+ "conv_bias": false,
18
+ "conv_dim": [
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512,
25
+ 512
26
+ ],
27
+ "conv_kernel": [
28
+ 10,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 3,
33
+ 2,
34
+ 2
35
+ ],
36
+ "conv_stride": [
37
+ 5,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2,
43
+ 2
44
+ ],
45
+ "ctc_loss_reduction": "sum",
46
+ "ctc_zero_infinity": false,
47
+ "diversity_loss_weight": 0.1,
48
+ "do_stable_layer_norm": false,
49
+ "eos_token_id": 2,
50
+ "feat_extract_activation": "gelu",
51
+ "feat_extract_norm": "group",
52
+ "feat_proj_dropout": 0.1,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "freeze_feat_extract_train": true,
56
+ "hidden_act": "gelu",
57
+ "hidden_dropout": 0.1,
58
+ "hidden_size": 768,
59
+ "id2label": {
60
+ "0": "Fake",
61
+ "1": "Real"
62
+ },
63
+ "initializer_range": 0.02,
64
+ "intermediate_size": 3072,
65
+ "label2id": {
66
+ "Fake": 0,
67
+ "Real": 1
68
+ },
69
+ "layer_norm_eps": 1e-05,
70
+ "layerdrop": 0.0,
71
+ "mask_channel_length": 10,
72
+ "mask_channel_min_space": 1,
73
+ "mask_channel_other": 0.0,
74
+ "mask_channel_prob": 0.0,
75
+ "mask_channel_selection": "static",
76
+ "mask_feature_length": 10,
77
+ "mask_feature_min_masks": 0,
78
+ "mask_feature_prob": 0.0,
79
+ "mask_time_length": 10,
80
+ "mask_time_min_masks": 2,
81
+ "mask_time_min_space": 1,
82
+ "mask_time_other": 0.0,
83
+ "mask_time_prob": 0.05,
84
+ "mask_time_selection": "static",
85
+ "model_type": "wav2vec2",
86
+ "no_mask_channel_overlap": false,
87
+ "no_mask_time_overlap": false,
88
+ "num_adapter_layers": 3,
89
+ "num_attention_heads": 12,
90
+ "num_codevector_groups": 2,
91
+ "num_codevectors_per_group": 320,
92
+ "num_conv_pos_embedding_groups": 16,
93
+ "num_conv_pos_embeddings": 128,
94
+ "num_feat_extract_layers": 7,
95
+ "num_hidden_layers": 12,
96
+ "num_negatives": 100,
97
+ "output_hidden_size": 768,
98
+ "pad_token_id": 0,
99
+ "proj_codevector_dim": 256,
100
+ "tdnn_dilation": [
101
+ 1,
102
+ 2,
103
+ 3,
104
+ 1,
105
+ 1
106
+ ],
107
+ "tdnn_dim": [
108
+ 512,
109
+ 512,
110
+ 512,
111
+ 512,
112
+ 1500
113
+ ],
114
+ "tdnn_kernel": [
115
+ 5,
116
+ 3,
117
+ 3,
118
+ 1,
119
+ 1
120
+ ],
121
+ "torch_dtype": "float32",
122
+ "transformers_version": "4.45.2",
123
+ "use_weighted_layer_sum": false,
124
+ "vocab_size": 32,
125
+ "xvector_output_dim": 512
126
+ }
Deepfake/model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4265db7332aaff6149687072b860aaea1767c9a1756dbbecc23647e130724bbc
3
+ size 378302360
Deepfake/model/preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "processor_class": "Wav2Vec2Processor",
8
+ "return_attention_mask": false,
9
+ "sampling_rate": 16000
10
+ }
Deepfake/model/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "<unk>"
6
+ }
Deepfake/model/tokenizer_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": false
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": true,
14
+ "normalized": false,
15
+ "rstrip": true,
16
+ "single_word": false,
17
+ "special": false
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": true,
22
+ "normalized": false,
23
+ "rstrip": true,
24
+ "single_word": false,
25
+ "special": false
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": true,
30
+ "normalized": false,
31
+ "rstrip": true,
32
+ "single_word": false,
33
+ "special": false
34
+ }
35
+ },
36
+ "bos_token": "<s>",
37
+ "clean_up_tokenization_spaces": false,
38
+ "do_lower_case": false,
39
+ "do_normalize": true,
40
+ "eos_token": "</s>",
41
+ "model_max_length": 1000000000000000019884624838656,
42
+ "pad_token": "<pad>",
43
+ "processor_class": "Wav2Vec2Processor",
44
+ "replace_word_delimiter_char": " ",
45
+ "return_attention_mask": false,
46
+ "target_lang": null,
47
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
48
+ "unk_token": "<unk>",
49
+ "word_delimiter_token": "|"
50
+ }
Deepfake/model/vocab.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "'": 27,
3
+ "</s>": 2,
4
+ "<pad>": 0,
5
+ "<s>": 1,
6
+ "<unk>": 3,
7
+ "A": 7,
8
+ "B": 24,
9
+ "C": 19,
10
+ "D": 14,
11
+ "E": 5,
12
+ "F": 20,
13
+ "G": 21,
14
+ "H": 11,
15
+ "I": 10,
16
+ "J": 29,
17
+ "K": 26,
18
+ "L": 15,
19
+ "M": 17,
20
+ "N": 9,
21
+ "O": 8,
22
+ "P": 23,
23
+ "Q": 30,
24
+ "R": 13,
25
+ "S": 12,
26
+ "T": 6,
27
+ "U": 16,
28
+ "V": 25,
29
+ "W": 18,
30
+ "X": 28,
31
+ "Y": 22,
32
+ "Z": 31,
33
+ "|": 4
34
+ }
DeepfakeModel.ipynb ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Requirement already satisfied: torch in c:\\users\\asus\\anaconda3\\lib\\site-packages (2.5.1)Note: you may need to restart the kernel to use updated packages.\n",
13
+ "\n",
14
+ "Requirement already satisfied: torchvision in c:\\users\\asus\\anaconda3\\lib\\site-packages (0.20.1)\n",
15
+ "Requirement already satisfied: torchaudio in c:\\users\\asus\\anaconda3\\lib\\site-packages (2.5.1)\n",
16
+ "Requirement already satisfied: PySoundFile in c:\\users\\asus\\anaconda3\\lib\\site-packages (0.9.0.post1)\n",
17
+ "Requirement already satisfied: filelock in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (3.13.1)\n",
18
+ "Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (4.11.0)\n",
19
+ "Requirement already satisfied: networkx in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (3.3)\n",
20
+ "Requirement already satisfied: jinja2 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (3.1.4)\n",
21
+ "Requirement already satisfied: fsspec in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (2024.6.1)\n",
22
+ "Requirement already satisfied: setuptools in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (75.1.0)\n",
23
+ "Requirement already satisfied: sympy==1.13.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (1.13.1)\n",
24
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from sympy==1.13.1->torch) (1.3.0)\n",
25
+ "Requirement already satisfied: numpy in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torchvision) (1.26.4)\n",
26
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torchvision) (10.4.0)\n",
27
+ "Requirement already satisfied: cffi>=0.6 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from PySoundFile) (1.17.1)\n",
28
+ "Requirement already satisfied: pycparser in c:\\users\\asus\\anaconda3\\lib\\site-packages (from cffi>=0.6->PySoundFile) (2.21)\n",
29
+ "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from jinja2->torch) (2.1.3)\n"
30
+ ]
31
+ }
32
+ ],
33
+ "source": [
34
+ "pip install torch torchvision torchaudio PySoundFile ffmpeg-python"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 4,
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "['soundfile']\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "import torchaudio\n",
52
+ "print(str(torchaudio.list_audio_backends()))"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 12,
58
+ "metadata": {},
59
+ "outputs": [
60
+ {
61
+ "name": "stderr",
62
+ "output_type": "stream",
63
+ "text": [
64
+ "<>:13: SyntaxWarning: invalid escape sequence '\\m'\n",
65
+ "<>:17: SyntaxWarning: invalid escape sequence '\\H'\n",
66
+ "<>:13: SyntaxWarning: invalid escape sequence '\\m'\n",
67
+ "<>:17: SyntaxWarning: invalid escape sequence '\\H'\n",
68
+ "C:\\Users\\Asus\\AppData\\Local\\Temp\\ipykernel_18220\\208613059.py:13: SyntaxWarning: invalid escape sequence '\\m'\n",
69
+ " model_path = \"Deepfake\\model\"\n",
70
+ "C:\\Users\\Asus\\AppData\\Local\\Temp\\ipykernel_18220\\208613059.py:17: SyntaxWarning: invalid escape sequence '\\H'\n",
71
+ " cache_dir=\"D:\\HuggingFace\",\n"
72
+ ]
73
+ }
74
+ ],
75
+ "source": [
76
+ "from transformers import pipeline\n",
77
+ "from transformers import AutoProcessor, AutoModelForAudioClassification\n",
78
+ "from fastapi import FastAPI\n",
79
+ "from pydantic import BaseModel\n",
80
+ "import uvicorn\n",
81
+ "import torchaudio\n",
82
+ "import torch\n",
83
+ "\n",
84
+ "# Define the input schema\n",
85
+ "class InputData(BaseModel):\n",
86
+ " input: str\n",
87
+ "\n",
88
+ "model_path = \"Deepfake\\model\"\n",
89
+ "processor = AutoProcessor.from_pretrained(model_path)\n",
90
+ "# Instantiate the model\n",
91
+ "model = AutoModelForAudioClassification.from_pretrained(pretrained_model_name_or_path=model_path,\n",
92
+ " cache_dir=\"D:\\HuggingFace\",\n",
93
+ " local_files_only=True,\n",
94
+ " )\n"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "metadata": {},
100
+ "source": [
101
+ "Functions"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 29,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": []
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 6,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "def prepare_audio(file_path, sampling_rate=16000, duration=10):\n",
118
+ " \"\"\"\n",
119
+ " Prepares audio by loading, resampling, and returning it in manageable chunks.\n",
120
+ " \n",
121
+ " Parameters:\n",
122
+ " - file_path: Path to the audio file.\n",
123
+ " - sampling_rate: Target sampling rate for the audio.\n",
124
+ " - duration: Duration in seconds for each chunk.\n",
125
+ " \n",
126
+ " Returns:\n",
127
+ " - A list of audio chunks, each as a numpy array.\n",
128
+ " \"\"\"\n",
129
+ " # Load and resample the audio file\n",
130
+ " waveform, original_sampling_rate = torchaudio.load(file_path)\n",
131
+ " \n",
132
+ " # Convert stereo to mono if necessary\n",
133
+ " if waveform.shape[0] > 1: # More than 1 channel\n",
134
+ " waveform = torch.mean(waveform, dim=0, keepdim=True)\n",
135
+ " \n",
136
+ " # Resample if needed\n",
137
+ " if original_sampling_rate != sampling_rate:\n",
138
+ " resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)\n",
139
+ " waveform = resampler(waveform)\n",
140
+ " \n",
141
+ " # Calculate chunk size in samples\n",
142
+ " chunk_size = sampling_rate * duration\n",
143
+ " audio_chunks = []\n",
144
+ "\n",
145
+ " # Split the audio into chunks\n",
146
+ " for start in range(0, waveform.shape[1], chunk_size):\n",
147
+ " chunk = waveform[:, start:start + chunk_size]\n",
148
+ " \n",
149
+ " # Pad the last chunk if it's shorter than the chunk size\n",
150
+ " if chunk.shape[1] < chunk_size:\n",
151
+ " padding = chunk_size - chunk.shape[1]\n",
152
+ " chunk = torch.nn.functional.pad(chunk, (0, padding))\n",
153
+ " \n",
154
+ " audio_chunks.append(chunk.squeeze().numpy())\n",
155
+ " \n",
156
+ " return audio_chunks\n"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 14,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "import torch.nn.functional as F\n",
166
+ "\n",
167
+ "def predict_audio(file_path):\n",
168
+ " \"\"\"\n",
169
+ " Predicts the class of an audio file by aggregating predictions from chunks and calculates confidence.\n",
170
+ " \n",
171
+ " Args:\n",
172
+ " file_path (str): Path to the audio file.\n",
173
+ "\n",
174
+ " Returns:\n",
175
+ " dict: Contains the predicted class label and average confidence score.\n",
176
+ " \"\"\"\n",
177
+ " # Prepare audio chunks\n",
178
+ " audio_chunks = prepare_audio(file_path)\n",
179
+ " predictions = []\n",
180
+ " confidences = []\n",
181
+ "\n",
182
+ " for i, chunk in enumerate(audio_chunks):\n",
183
+ " # Prepare input for the model\n",
184
+ " inputs = processor(\n",
185
+ " chunk, sampling_rate=16000, return_tensors=\"pt\", padding=True\n",
186
+ " )\n",
187
+ " \n",
188
+ " # Perform inference\n",
189
+ " with torch.no_grad():\n",
190
+ " outputs = model(**inputs)\n",
191
+ " logits = outputs.logits\n",
192
+ " \n",
193
+ " # Apply softmax to calculate probabilities\n",
194
+ " probabilities = F.softmax(logits, dim=1)\n",
195
+ " \n",
196
+ " # Get the predicted class and its confidence\n",
197
+ " confidence, predicted_class = torch.max(probabilities, dim=1)\n",
198
+ " predictions.append(predicted_class.item())\n",
199
+ " confidences.append(confidence.item())\n",
200
+ " \n",
201
+ " # Aggregate predictions (majority voting)\n",
202
+ " aggregated_prediction_id = max(set(predictions), key=predictions.count)\n",
203
+ " predicted_label = model.config.id2label[aggregated_prediction_id]\n",
204
+ " \n",
205
+ " # Calculate average confidence across chunks\n",
206
+ " average_confidence = sum(confidences) / len(confidences)\n",
207
+ "\n",
208
+ " return {\n",
209
+ " \"predicted_label\": predicted_label,\n",
210
+ " \"average_confidence\": average_confidence\n",
211
+ " }\n",
212
+ "\n"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 15,
218
+ "metadata": {},
219
+ "outputs": [
220
+ {
221
+ "name": "stdout",
222
+ "output_type": "stream",
223
+ "text": [
224
+ "Chunk shape: (160000,)\n",
225
+ "Chunk shape: (160000,)\n",
226
+ "Chunk shape: (160000,)\n",
227
+ "Chunk shape: (160000,)\n",
228
+ "Chunk shape: (160000,)\n",
229
+ "Chunk shape: (160000,)\n",
230
+ "Chunk shape: (160000,)\n",
231
+ "Chunk shape: (160000,)\n",
232
+ "Chunk shape: (160000,)\n",
233
+ "Chunk shape: (160000,)\n",
234
+ "Predicted Class: {'predicted_label': 'Real', 'average_confidence': 0.9984144032001495}\n"
235
+ ]
236
+ },
237
+ {
238
+ "ename": "",
239
+ "evalue": "",
240
+ "output_type": "error",
241
+ "traceback": [
242
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
243
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
244
+ "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
245
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
246
+ ]
247
+ }
248
+ ],
249
+ "source": [
250
+ "# Example: Test a single audio file\n",
251
+ "file_path = r\"D:\\repos\\GODAM\\audioFiles\\test.wav\" # Replace with your audio file path\n",
252
+ "predicted_class = predict_audio(file_path)\n",
253
+ "print(f\"Predicted Class: {predicted_class}\")"
254
+ ]
255
+ }
256
+ ],
257
+ "metadata": {
258
+ "kernelspec": {
259
+ "display_name": "base",
260
+ "language": "python",
261
+ "name": "python3"
262
+ },
263
+ "language_info": {
264
+ "codemirror_mode": {
265
+ "name": "ipython",
266
+ "version": 3
267
+ },
268
+ "file_extension": ".py",
269
+ "mimetype": "text/x-python",
270
+ "name": "python",
271
+ "nbconvert_exporter": "python",
272
+ "pygments_lexer": "ipython3",
273
+ "version": "3.12.7"
274
+ }
275
+ },
276
+ "nbformat": 4,
277
+ "nbformat_minor": 2
278
+ }
DeepfakeModel.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from pydantic import BaseModel
3
+ import uvicorn
4
+ import os
5
+ import torchaudio
6
+ import torch.nn.functional as F
7
+ import torch
8
+ from transformers import AutoProcessor, AutoModelForAudioClassification
9
+
10
+ # Model setup
11
+ model_path = r"Deepfake\model"
12
+ processor = AutoProcessor.from_pretrained(model_path)
13
+ model = AutoModelForAudioClassification.from_pretrained(
14
+ pretrained_model_name_or_path=model_path,
15
+ local_files_only=True,
16
+ )
17
+
18
+ def prepare_audio(file_path, sampling_rate=16000, duration=10):
19
+ """
20
+ Prepares audio by loading, resampling, and returning it in manageable chunks.
21
+ """
22
+ # Load and resample the audio file
23
+ waveform, original_sampling_rate = torchaudio.load(file_path)
24
+
25
+ # Convert stereo to mono if necessary
26
+ if waveform.shape[0] > 1: # More than 1 channel
27
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
28
+
29
+ # Resample if needed
30
+ if original_sampling_rate != sampling_rate:
31
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)
32
+ waveform = resampler(waveform)
33
+
34
+ # Calculate chunk size in samples
35
+ chunk_size = sampling_rate * duration
36
+ audio_chunks = []
37
+
38
+ # Split the audio into chunks
39
+ for start in range(0, waveform.shape[1], chunk_size):
40
+ chunk = waveform[:, start:start + chunk_size]
41
+
42
+ # Pad the last chunk if it's shorter than the chunk size
43
+ if chunk.shape[1] < chunk_size:
44
+ padding = chunk_size - chunk.shape[1]
45
+ chunk = torch.nn.functional.pad(chunk, (0, padding))
46
+
47
+ audio_chunks.append(chunk.squeeze().numpy())
48
+
49
+ return audio_chunks
50
+
51
+ def predict_audio(file_path):
52
+ """
53
+ Predicts the class of an audio file by aggregating predictions from chunks and calculates confidence.
54
+ """
55
+ # Prepare audio chunks
56
+ audio_chunks = prepare_audio(file_path)
57
+ predictions = []
58
+ confidences = []
59
+
60
+ for i, chunk in enumerate(audio_chunks):
61
+ # Prepare input for the model
62
+ inputs = processor(
63
+ chunk, sampling_rate=16000, return_tensors="pt", padding=True
64
+ )
65
+
66
+ # Perform inference
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ logits = outputs.logits
70
+
71
+ # Apply softmax to calculate probabilities
72
+ probabilities = F.softmax(logits, dim=1)
73
+
74
+ # Get the predicted class and its confidence
75
+ confidence, predicted_class = torch.max(probabilities, dim=1)
76
+ predictions.append(predicted_class.item())
77
+ confidences.append(confidence.item())
78
+
79
+ # Aggregate predictions (majority voting)
80
+ aggregated_prediction_id = max(set(predictions), key=predictions.count)
81
+ predicted_label = model.config.id2label[aggregated_prediction_id]
82
+
83
+ # Calculate average confidence across chunks
84
+ average_confidence = sum(confidences) / len(confidences)
85
+
86
+ return {
87
+ "predicted_label": predicted_label,
88
+ "average_confidence": average_confidence
89
+ }
90
+
91
+ # Initialize FastAPI
92
+ app = FastAPI()
93
+
94
+ @app.post("/infer")
95
+ async def infer(file: UploadFile = File(...)):
96
+ """
97
+ Accepts an audio file and returns the prediction and confidence.
98
+ """
99
+ # Save the uploaded file to a temporary location
100
+ temp_file_path = f"temp_{file.filename}"
101
+ with open(temp_file_path, "wb") as temp_file:
102
+ temp_file.write(await file.read())
103
+
104
+ try:
105
+ # Perform inference
106
+ predictions = predict_audio(temp_file_path)
107
+ finally:
108
+ # Clean up the temporary file
109
+ os.remove(temp_file_path)
110
+
111
+ return predictions
112
+
113
+ @app.get("/health")
114
+ async def health():
115
+ return {"message": "ok"}
116
+
117
+ if __name__ == "__main__":
118
+ uvicorn.run(app, host="0.0.0.0", port=8000)
ModelTest.ipynb ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import requests\n",
10
+ "\n",
11
+ "def infer_text(api_url, input_text):\n",
12
+ " url = f\"{api_url}/infer\"\n",
13
+ " try:\n",
14
+ " # Send the input as a JSON object\n",
15
+ " response = requests.post(url, json={\"input\": input_text})\n",
16
+ " response.raise_for_status()\n",
17
+ " return response.json()\n",
18
+ " except requests.exceptions.RequestException as e:\n",
19
+ " print(f\"Error during API call: {e}\")\n",
20
+ " return None\n",
21
+ "\n",
22
+ "def check_health(api_url):\n",
23
+ " url = f\"{api_url}/health\"\n",
24
+ " try:\n",
25
+ " response = requests.get(url)\n",
26
+ " response.raise_for_status()\n",
27
+ " return response.json()\n",
28
+ " except requests.exceptions.RequestException as e:\n",
29
+ " print(f\"Error during API health check: {e}\")\n",
30
+ " return None"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 5,
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "API Health Check: {'message': 'ok'}\n",
43
+ "Predictions: [{'label': 'LABEL_0', 'score': 0.9927427768707275}]\n"
44
+ ]
45
+ }
46
+ ],
47
+ "source": [
48
+ "api_url = \"http://localhost:8000\"\n",
49
+ "\n",
50
+ "# Check the API health status\n",
51
+ "health_status = check_health(api_url)\n",
52
+ "if health_status:\n",
53
+ " print(\"API Health Check:\", health_status)\n",
54
+ "else:\n",
55
+ " print(\"Failed to connect to the API.\")\n",
56
+ "\n",
57
+ "# Example input text\n",
58
+ "input_text = \"Congratulations! You've won a prize. Click the link to claim your reward.\"\n",
59
+ "\n",
60
+ "# Call the /infer endpoint\n",
61
+ "predictions = infer_text(api_url, input_text)\n",
62
+ "if predictions:\n",
63
+ " print(\"Predictions:\", predictions)\n",
64
+ "else:\n",
65
+ " print(\"Failed to get predictions from the API.\")"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {},
71
+ "source": [
72
+ "DeepFakeModel Test"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 4,
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "Response JSON: {'predicted_label': 'Real', 'average_confidence': 0.9984144032001495}\n"
85
+ ]
86
+ }
87
+ ],
88
+ "source": [
89
+ "import requests\n",
90
+ "\n",
91
+ "# Define the API endpoint\n",
92
+ "url = \"http://127.0.0.1:8000/infer\"\n",
93
+ "\n",
94
+ "# Path to the audio file you want to test\n",
95
+ "file_path = r\"D:\\repos\\GODAM\\audioFiles\\test.wav\" # Replace with the path to your audio file\n",
96
+ "\n",
97
+ "# Open the file in binary mode\n",
98
+ "with open(file_path, \"rb\") as audio_file:\n",
99
+ " # Prepare the file payload\n",
100
+ " files = {\"file\": (\"audio.wav\", audio_file, \"audio/wav\")}\n",
101
+ " \n",
102
+ " # Send the POST request\n",
103
+ " response = requests.post(url, files=files)\n",
104
+ "\n",
105
+ "# Print the response from the API\n",
106
+ "if response.status_code == 200:\n",
107
+ " print(\"Response JSON:\", response.json())\n",
108
+ "else:\n",
109
+ " print(f\"Error {response.status_code}: {response.text}\")"
110
+ ]
111
+ }
112
+ ],
113
+ "metadata": {
114
+ "kernelspec": {
115
+ "display_name": "base",
116
+ "language": "python",
117
+ "name": "python3"
118
+ },
119
+ "language_info": {
120
+ "codemirror_mode": {
121
+ "name": "ipython",
122
+ "version": 3
123
+ },
124
+ "file_extension": ".py",
125
+ "mimetype": "text/x-python",
126
+ "name": "python",
127
+ "nbconvert_exporter": "python",
128
+ "pygments_lexer": "ipython3",
129
+ "version": "3.12.7"
130
+ }
131
+ },
132
+ "nbformat": 4,
133
+ "nbformat_minor": 2
134
+ }
ScamTextModel.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ import uvicorn
5
+
6
+ # Define the input schema
7
+ class InputData(BaseModel):
8
+ input: str
9
+
10
+ # Initialize the pipeline
11
+ pipe = pipeline("text-classification", model="phishbot/ScamLLM")
12
+
13
+ app = FastAPI()
14
+
15
+ # Define API endpoints
16
+ @app.post("/infer")
17
+ async def infer(data: InputData):
18
+ predictions = pipe(data.input)
19
+ return predictions
20
+
21
+ @app.get("/health")
22
+ async def health():
23
+ return {"message": "ok"}
24
+
25
+ if __name__ == "__main__":
26
+ uvicorn.run(app, host="0.0.0.0", port=8000)
notebook.ipynb ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Import data"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 2,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "Collecting torchaudio\n",
20
+ " Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl.metadata (6.5 kB)\n",
21
+ "Requirement already satisfied: torch==2.5.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torchaudio) (2.5.1)\n",
22
+ "Requirement already satisfied: filelock in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.13.1)\n",
23
+ "Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (4.11.0)\n",
24
+ "Requirement already satisfied: networkx in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.3)\n",
25
+ "Requirement already satisfied: jinja2 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.1.4)\n",
26
+ "Requirement already satisfied: fsspec in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (2024.6.1)\n",
27
+ "Requirement already satisfied: setuptools in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (75.1.0)\n",
28
+ "Requirement already satisfied: sympy==1.13.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (1.13.1)\n",
29
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from sympy==1.13.1->torch==2.5.1->torchaudio) (1.3.0)\n",
30
+ "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from jinja2->torch==2.5.1->torchaudio) (2.1.3)\n",
31
+ "Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl (2.4 MB)\n",
32
+ " ---------------------------------------- 0.0/2.4 MB ? eta -:--:--\n",
33
+ " ---------------------------------------- 2.4/2.4 MB 11.6 MB/s eta 0:00:00\n",
34
+ "Installing collected packages: torchaudio\n",
35
+ "Successfully installed torchaudio-2.5.1\n",
36
+ "Note: you may need to restart the kernel to use updated packages.\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "pip install torchaudio"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "ename": "ModuleNotFoundError",
51
+ "evalue": "No module named 'datasets'",
52
+ "output_type": "error",
53
+ "traceback": [
54
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
55
+ "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
56
+ "Cell \u001b[1;32mIn[3], line 5\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorchaudio\u001b[39;00m\n\u001b[1;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DatasetDict, load_dataset\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprepare_dataset\u001b[39m(directory):\n\u001b[0;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpath\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabel\u001b[39m\u001b[38;5;124m\"\u001b[39m: []}\n",
57
+ "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'datasets'"
58
+ ]
59
+ },
60
+ {
61
+ "ename": "",
62
+ "evalue": "",
63
+ "output_type": "error",
64
+ "traceback": [
65
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
66
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
67
+ "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
68
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "## The data is not pushed to repo, only model and training logs etc are uploaded\n",
74
+ "\n",
75
+ "import os\n",
76
+ "import torchaudio\n",
77
+ "from datasets import DatasetDict, load_dataset\n",
78
+ "\n",
79
+ "def prepare_dataset(directory):\n",
80
+ " data = {\"path\": [], \"label\": []}\n",
81
+ " labels = {\"fake\": 0, \"real\": 1} # Map fake to 0 and real to 1\n",
82
+ "\n",
83
+ " for label, label_id in labels.items():\n",
84
+ " folder_path = os.path.join(directory, label)\n",
85
+ " for file in os.listdir(folder_path):\n",
86
+ " if file.endswith(\".wav\"):\n",
87
+ " data[\"path\"].append(os.path.join(folder_path, file))\n",
88
+ " data[\"label\"].append(label_id)\n",
89
+ " return data\n",
90
+ "\n",
91
+ "# Prepare train, validation, and test datasets\n",
92
+ "train_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n",
93
+ "val_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n",
94
+ "test_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 4,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "from datasets import Dataset\n",
104
+ "\n",
105
+ "train_dataset = Dataset.from_dict(train_data)\n",
106
+ "val_dataset = Dataset.from_dict(val_data)\n",
107
+ "test_dataset = Dataset.from_dict(test_data)\n",
108
+ "\n",
109
+ "dataset = DatasetDict({\"train\": train_dataset, \"validation\": val_dataset, \"test\": test_dataset})\n"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "metadata": {},
115
+ "source": [
116
+ "## Import Model"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 5,
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "name": "stderr",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "c:\\Users\\60165\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\configuration_utils.py:302: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n",
129
+ " warnings.warn(\n",
130
+ "Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']\n",
131
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
132
+ ]
133
+ }
134
+ ],
135
+ "source": [
136
+ "from transformers import AutoProcessor\n",
137
+ "from transformers import AutoModelForAudioClassification\n",
138
+ "\n",
139
+ "# Initialize processor\n",
140
+ "model_name = \"facebook/wav2vec2-base\" # Replace with your model if different\n",
141
+ "model = AutoModelForAudioClassification.from_pretrained(model_name, num_labels=2) # Adjust `num_labels` based on your dataset\n",
142
+ "processor = AutoProcessor.from_pretrained(model_name)\n"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "metadata": {},
148
+ "source": [
149
+ "## Preprocess Data"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "data": {
159
+ "application/vnd.jupyter.widget-view+json": {
160
+ "model_id": "af2c5f31f0db43ee9975023b28e2c57c",
161
+ "version_major": 2,
162
+ "version_minor": 0
163
+ },
164
+ "text/plain": [
165
+ "Map: 0%| | 0/4634 [00:00<?, ? examples/s]"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ },
171
+ {
172
+ "data": {
173
+ "application/vnd.jupyter.widget-view+json": {
174
+ "model_id": "051c3226b30145f3adbc708cf8afd6a5",
175
+ "version_major": 2,
176
+ "version_minor": 0
177
+ },
178
+ "text/plain": [
179
+ "Map: 0%| | 0/4634 [00:00<?, ? examples/s]"
180
+ ]
181
+ },
182
+ "metadata": {},
183
+ "output_type": "display_data"
184
+ },
185
+ {
186
+ "data": {
187
+ "application/vnd.jupyter.widget-view+json": {
188
+ "model_id": "86de10f37b7441e3a9c9f0ee88bf5149",
189
+ "version_major": 2,
190
+ "version_minor": 0
191
+ },
192
+ "text/plain": [
193
+ "Map: 0%| | 0/4634 [00:00<?, ? examples/s]"
194
+ ]
195
+ },
196
+ "metadata": {},
197
+ "output_type": "display_data"
198
+ },
199
+ {
200
+ "name": "stdout",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "tensor(0) <class 'torch.Tensor'>\n",
204
+ "torch.int64\n"
205
+ ]
206
+ }
207
+ ],
208
+ "source": [
209
+ "import torch\n",
210
+ "\n",
211
+ "\n",
212
+ "def preprocess_function(batch):\n",
213
+ " audio = torchaudio.load(batch[\"path\"])[0].squeeze().numpy()\n",
214
+ " inputs = processor(\n",
215
+ " audio,\n",
216
+ " sampling_rate=16000,\n",
217
+ " padding=True,\n",
218
+ " truncation=True,\n",
219
+ " max_length=32000, \n",
220
+ " return_tensors=\"pt\"\n",
221
+ " )\n",
222
+ " batch[\"input_values\"] = inputs.input_values[0]\n",
223
+ " # Ensure labels are converted to LongTensor\n",
224
+ " batch[\"label\"] = torch.tensor(batch[\"label\"], dtype=torch.long) # Convert label to LongTensor\n",
225
+ " return batch\n",
226
+ "\n",
227
+ "processed_dataset = dataset.map(preprocess_function, remove_columns=[\"path\"], batched=False)\n",
228
+ "# Set format to torch tensors for compatibility with PyTorch\n",
229
+ "processed_dataset.set_format(type=\"torch\", columns=[\"input_values\", \"label\"])\n",
230
+ "\n",
231
+ "# Double-check the label type again\n",
232
+ "print(processed_dataset[\"train\"][0][\"label\"], type(processed_dataset[\"train\"][0][\"label\"]))\n",
233
+ "print(processed_dataset[\"train\"][0][\"label\"].dtype) # Should print torch.int64\n"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {},
239
+ "source": [
240
+ "## Map Training Labels"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 12,
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "name": "stdout",
250
+ "output_type": "stream",
251
+ "text": [
252
+ "Labels: {0: 'Fake', 1: 'Real'}\n",
253
+ "Labels: {0: 'Fake', 1: 'Real'}\n"
254
+ ]
255
+ }
256
+ ],
257
+ "source": [
258
+ "# Ensure labels are in numerical format (e.g., 0, 1)\n",
259
+ "id2label = {0: \"Fake\", 1: \"Real\"} # Define the mapping based on your dataset\n",
260
+ "label2id = {v: k for k, v in id2label.items()} # Reverse mapping\n",
261
+ "\n",
262
+ "\n",
263
+ "print(\"Labels:\", id2label)\n",
264
+ "\n",
265
+ "# Update the model's configuration with labels\n",
266
+ "model.config.id2label = id2label\n",
267
+ "model.config.label2id = label2id\n",
268
+ "\n",
269
+ "print(\"Labels:\", model.config.id2label) # Verify\n"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "from transformers import DataCollatorWithPadding\n",
279
+ "\n",
280
+ "# Use the processor's tokenizer for padding\n",
281
+ "data_collator = DataCollatorWithPadding(tokenizer=processor, padding=True)\n"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "markdown",
286
+ "metadata": {},
287
+ "source": [
288
+ "## Initialize Training Arguments"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 8,
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "TrainingArguments initialized successfully!\n"
301
+ ]
302
+ },
303
+ {
304
+ "name": "stderr",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "c:\\Users\\60165\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
308
+ " warnings.warn(\n"
309
+ ]
310
+ }
311
+ ],
312
+ "source": [
313
+ "from transformers import TrainingArguments\n",
314
+ "\n",
315
+ "training_args = TrainingArguments(\n",
316
+ " output_dir=\"./results\",\n",
317
+ " evaluation_strategy=\"epoch\",\n",
318
+ " save_strategy=\"epoch\",\n",
319
+ " learning_rate=5e-5,\n",
320
+ " per_device_train_batch_size=8,\n",
321
+ " per_device_eval_batch_size=8,\n",
322
+ " num_train_epochs=3,\n",
323
+ " weight_decay=0.01,\n",
324
+ " logging_dir=\"./logs\",\n",
325
+ " logging_steps=10,\n",
326
+ " save_total_limit=2,\n",
327
+ " fp16=True, \n",
328
+ " push_to_hub=False,\n",
329
+ ")\n",
330
+ "print(\"TrainingArguments initialized successfully!\")\n"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 9,
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "from transformers import Trainer\n",
340
+ "\n",
341
+ "trainer = Trainer(\n",
342
+ " model=model,\n",
343
+ " args=training_args,\n",
344
+ " train_dataset=processed_dataset[\"train\"],\n",
345
+ " eval_dataset=processed_dataset[\"validation\"],\n",
346
+ " tokenizer=processor, # Required for the data collator\n",
347
+ " data_collator=data_collator,\n",
348
+ ")\n"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "markdown",
353
+ "metadata": {},
354
+ "source": [
355
+ "## Start Training"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": 33,
361
+ "metadata": {},
362
+ "outputs": [
363
+ {
364
+ "data": {
365
+ "application/vnd.jupyter.widget-view+json": {
366
+ "model_id": "53044eac53174227a4eb19becbeb0bbf",
367
+ "version_major": 2,
368
+ "version_minor": 0
369
+ },
370
+ "text/plain": [
371
+ " 0%| | 0/1740 [00:00<?, ?it/s]"
372
+ ]
373
+ },
374
+ "metadata": {},
375
+ "output_type": "display_data"
376
+ },
377
+ {
378
+ "name": "stdout",
379
+ "output_type": "stream",
380
+ "text": [
381
+ "{'loss': 0.6607, 'grad_norm': 2.5936832427978516, 'learning_rate': 4.971264367816092e-05, 'epoch': 0.02}\n",
382
+ "{'loss': 0.4545, 'grad_norm': 5.04218864440918, 'learning_rate': 4.9425287356321845e-05, 'epoch': 0.03}\n",
383
+ "{'loss': 0.1779, 'grad_norm': 0.8874927163124084, 'learning_rate': 4.913793103448276e-05, 'epoch': 0.05}\n",
384
+ "{'loss': 0.0833, 'grad_norm': 0.40262681245803833, 'learning_rate': 4.885057471264368e-05, 'epoch': 0.07}\n",
385
+ "{'loss': 0.0948, 'grad_norm': 0.5579108595848083, 'learning_rate': 4.85632183908046e-05, 'epoch': 0.09}\n",
386
+ "{'loss': 0.0128, 'grad_norm': 0.1585635393857956, 'learning_rate': 4.827586206896552e-05, 'epoch': 0.1}\n",
387
+ "{'loss': 0.0696, 'grad_norm': 0.12149885296821594, 'learning_rate': 4.798850574712644e-05, 'epoch': 0.12}\n",
388
+ "{'loss': 0.0065, 'grad_norm': 0.09655608981847763, 'learning_rate': 4.770114942528736e-05, 'epoch': 0.14}\n",
389
+ "{'loss': 0.0052, 'grad_norm': 0.08148041367530823, 'learning_rate': 4.741379310344828e-05, 'epoch': 0.16}\n",
390
+ "{'loss': 0.0041, 'grad_norm': 0.07030971348285675, 'learning_rate': 4.7126436781609195e-05, 'epoch': 0.17}\n",
391
+ "{'loss': 0.0034, 'grad_norm': 0.05784648284316063, 'learning_rate': 4.6839080459770116e-05, 'epoch': 0.19}\n",
392
+ "{'loss': 0.0029, 'grad_norm': 0.04742524400353432, 'learning_rate': 4.655172413793104e-05, 'epoch': 0.21}\n",
393
+ "{'loss': 0.0024, 'grad_norm': 0.04222293198108673, 'learning_rate': 4.626436781609196e-05, 'epoch': 0.22}\n",
394
+ "{'loss': 0.0021, 'grad_norm': 0.039586298167705536, 'learning_rate': 4.597701149425287e-05, 'epoch': 0.24}\n",
395
+ "{'loss': 0.0018, 'grad_norm': 0.034121569246053696, 'learning_rate': 4.5689655172413794e-05, 'epoch': 0.26}\n",
396
+ "{'loss': 0.0016, 'grad_norm': 0.031423550099134445, 'learning_rate': 4.5402298850574716e-05, 'epoch': 0.28}\n",
397
+ "{'loss': 0.0418, 'grad_norm': 0.02912888117134571, 'learning_rate': 4.511494252873563e-05, 'epoch': 0.29}\n",
398
+ "{'loss': 0.0831, 'grad_norm': 0.027829233556985855, 'learning_rate': 4.482758620689655e-05, 'epoch': 0.31}\n",
399
+ "{'loss': 0.2431, 'grad_norm': 0.10419639199972153, 'learning_rate': 4.454022988505747e-05, 'epoch': 0.33}\n",
400
+ "{'loss': 0.0024, 'grad_norm': 0.046179670840501785, 'learning_rate': 4.4252873563218394e-05, 'epoch': 0.34}\n",
401
+ "{'loss': 0.0729, 'grad_norm': 0.03943876922130585, 'learning_rate': 4.396551724137931e-05, 'epoch': 0.36}\n",
402
+ "{'loss': 0.0806, 'grad_norm': 0.05223412811756134, 'learning_rate': 4.367816091954024e-05, 'epoch': 0.38}\n",
403
+ "{'loss': 0.0023, 'grad_norm': 0.041366685181856155, 'learning_rate': 4.339080459770115e-05, 'epoch': 0.4}\n",
404
+ "{'loss': 0.0019, 'grad_norm': 0.03711611405014992, 'learning_rate': 4.3103448275862066e-05, 'epoch': 0.41}\n",
405
+ "{'loss': 0.0016, 'grad_norm': 0.030888166278600693, 'learning_rate': 4.2816091954022994e-05, 'epoch': 0.43}\n",
406
+ "{'loss': 0.0015, 'grad_norm': 0.027964089065790176, 'learning_rate': 4.252873563218391e-05, 'epoch': 0.45}\n",
407
+ "{'loss': 0.0013, 'grad_norm': 0.025312568992376328, 'learning_rate': 4.224137931034483e-05, 'epoch': 0.47}\n",
408
+ "{'loss': 0.0012, 'grad_norm': 0.02383635751903057, 'learning_rate': 4.195402298850575e-05, 'epoch': 0.48}\n",
409
+ "{'loss': 0.0011, 'grad_norm': 0.021753674373030663, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.5}\n",
410
+ "{'loss': 0.001, 'grad_norm': 0.019886162132024765, 'learning_rate': 4.1379310344827587e-05, 'epoch': 0.52}\n",
411
+ "{'loss': 0.0009, 'grad_norm': 0.018876325339078903, 'learning_rate': 4.109195402298851e-05, 'epoch': 0.53}\n",
412
+ "{'loss': 0.0008, 'grad_norm': 0.017580321058630943, 'learning_rate': 4.080459770114943e-05, 'epoch': 0.55}\n",
413
+ "{'loss': 0.0008, 'grad_norm': 0.015544956550002098, 'learning_rate': 4.0517241379310344e-05, 'epoch': 0.57}\n",
414
+ "{'loss': 0.0007, 'grad_norm': 0.015221121720969677, 'learning_rate': 4.0229885057471265e-05, 'epoch': 0.59}\n",
415
+ "{'loss': 0.0007, 'grad_norm': 0.014960471540689468, 'learning_rate': 3.9942528735632186e-05, 'epoch': 0.6}\n",
416
+ "{'loss': 0.0006, 'grad_norm': 0.013560828752815723, 'learning_rate': 3.965517241379311e-05, 'epoch': 0.62}\n",
417
+ "{'loss': 0.0006, 'grad_norm': 0.013700570911169052, 'learning_rate': 3.936781609195402e-05, 'epoch': 0.64}\n",
418
+ "{'loss': 0.0006, 'grad_norm': 0.011968374252319336, 'learning_rate': 3.908045977011495e-05, 'epoch': 0.66}\n",
419
+ "{'loss': 0.0005, 'grad_norm': 0.011644795536994934, 'learning_rate': 3.8793103448275865e-05, 'epoch': 0.67}\n",
420
+ "{'loss': 0.0005, 'grad_norm': 0.011345883831381798, 'learning_rate': 3.850574712643678e-05, 'epoch': 0.69}\n",
421
+ "{'loss': 0.0005, 'grad_norm': 0.010393058881163597, 'learning_rate': 3.82183908045977e-05, 'epoch': 0.71}\n",
422
+ "{'loss': 0.0005, 'grad_norm': 0.010386484675109386, 'learning_rate': 3.793103448275862e-05, 'epoch': 0.72}\n",
423
+ "{'loss': 0.0004, 'grad_norm': 0.009744665585458279, 'learning_rate': 3.764367816091954e-05, 'epoch': 0.74}\n",
424
+ "{'loss': 0.0004, 'grad_norm': 0.009590468369424343, 'learning_rate': 3.735632183908046e-05, 'epoch': 0.76}\n",
425
+ "{'loss': 0.0004, 'grad_norm': 0.009154150262475014, 'learning_rate': 3.7068965517241385e-05, 'epoch': 0.78}\n",
426
+ "{'loss': 0.0004, 'grad_norm': 0.008997919037938118, 'learning_rate': 3.67816091954023e-05, 'epoch': 0.79}\n",
427
+ "{'loss': 0.0004, 'grad_norm': 0.008509515784680843, 'learning_rate': 3.649425287356322e-05, 'epoch': 0.81}\n",
428
+ "{'loss': 0.0004, 'grad_norm': 0.008223678916692734, 'learning_rate': 3.620689655172414e-05, 'epoch': 0.83}\n",
429
+ "{'loss': 0.0003, 'grad_norm': 0.00758435670286417, 'learning_rate': 3.591954022988506e-05, 'epoch': 0.84}\n",
430
+ "{'loss': 0.0003, 'grad_norm': 0.0074744271114468575, 'learning_rate': 3.563218390804598e-05, 'epoch': 0.86}\n",
431
+ "{'loss': 0.0003, 'grad_norm': 0.007454875390976667, 'learning_rate': 3.53448275862069e-05, 'epoch': 0.88}\n",
432
+ "{'loss': 0.0003, 'grad_norm': 0.007157924585044384, 'learning_rate': 3.505747126436782e-05, 'epoch': 0.9}\n",
433
+ "{'loss': 0.0003, 'grad_norm': 0.006946589332073927, 'learning_rate': 3.4770114942528735e-05, 'epoch': 0.91}\n",
434
+ "{'loss': 0.0003, 'grad_norm': 0.0067284563556313515, 'learning_rate': 3.4482758620689657e-05, 'epoch': 0.93}\n",
435
+ "{'loss': 0.0003, 'grad_norm': 0.00652291439473629, 'learning_rate': 3.419540229885058e-05, 'epoch': 0.95}\n",
436
+ "{'loss': 0.0003, 'grad_norm': 0.006468599662184715, 'learning_rate': 3.390804597701149e-05, 'epoch': 0.97}\n",
437
+ "{'loss': 0.0003, 'grad_norm': 0.0061700050719082355, 'learning_rate': 3.3620689655172414e-05, 'epoch': 0.98}\n",
438
+ "{'loss': 0.0002, 'grad_norm': 0.005980886053293943, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}\n"
439
+ ]
440
+ },
441
+ {
442
+ "data": {
443
+ "application/vnd.jupyter.widget-view+json": {
444
+ "model_id": "9fe6f453d3dc4644b7e7adf483e8064b",
445
+ "version_major": 2,
446
+ "version_minor": 0
447
+ },
448
+ "text/plain": [
449
+ " 0%| | 0/580 [00:00<?, ?it/s]"
450
+ ]
451
+ },
452
+ "metadata": {},
453
+ "output_type": "display_data"
454
+ },
455
+ {
456
+ "name": "stdout",
457
+ "output_type": "stream",
458
+ "text": [
459
+ "{'eval_loss': 0.00017967642634175718, 'eval_runtime': 566.6055, 'eval_samples_per_second': 8.179, 'eval_steps_per_second': 1.024, 'epoch': 1.0}\n",
460
+ "{'loss': 0.0002, 'grad_norm': 0.005443067755550146, 'learning_rate': 3.3045977011494256e-05, 'epoch': 1.02}\n",
461
+ "{'loss': 0.0002, 'grad_norm': 0.005919346585869789, 'learning_rate': 3.275862068965517e-05, 'epoch': 1.03}\n",
462
+ "{'loss': 0.0002, 'grad_norm': 0.00538916140794754, 'learning_rate': 3.24712643678161e-05, 'epoch': 1.05}\n",
463
+ "{'loss': 0.0002, 'grad_norm': 0.00514333276078105, 'learning_rate': 3.218390804597701e-05, 'epoch': 1.07}\n",
464
+ "{'loss': 0.0002, 'grad_norm': 0.005011783912777901, 'learning_rate': 3.1896551724137935e-05, 'epoch': 1.09}\n",
465
+ "{'loss': 0.0002, 'grad_norm': 0.005112846381962299, 'learning_rate': 3.160919540229885e-05, 'epoch': 1.1}\n",
466
+ "{'loss': 0.0002, 'grad_norm': 0.004895139951258898, 'learning_rate': 3.132183908045977e-05, 'epoch': 1.12}\n",
467
+ "{'loss': 0.0002, 'grad_norm': 0.004565018694847822, 'learning_rate': 3.103448275862069e-05, 'epoch': 1.14}\n",
468
+ "{'loss': 0.0002, 'grad_norm': 0.00477330107241869, 'learning_rate': 3.0747126436781606e-05, 'epoch': 1.16}\n",
469
+ "{'loss': 0.0002, 'grad_norm': 0.004563583992421627, 'learning_rate': 3.045977011494253e-05, 'epoch': 1.17}\n",
470
+ "{'loss': 0.0002, 'grad_norm': 0.004568100906908512, 'learning_rate': 3.017241379310345e-05, 'epoch': 1.19}\n",
471
+ "{'loss': 0.0002, 'grad_norm': 0.0046887630596756935, 'learning_rate': 2.988505747126437e-05, 'epoch': 1.21}\n",
472
+ "{'loss': 0.0002, 'grad_norm': 0.004261981230229139, 'learning_rate': 2.9597701149425288e-05, 'epoch': 1.22}\n",
473
+ "{'loss': 0.0002, 'grad_norm': 0.004290647804737091, 'learning_rate': 2.9310344827586206e-05, 'epoch': 1.24}\n",
474
+ "{'loss': 0.0002, 'grad_norm': 0.004188802093267441, 'learning_rate': 2.9022988505747127e-05, 'epoch': 1.26}\n",
475
+ "{'loss': 0.0002, 'grad_norm': 0.003917949739843607, 'learning_rate': 2.8735632183908045e-05, 'epoch': 1.28}\n",
476
+ "{'loss': 0.0002, 'grad_norm': 0.003940439783036709, 'learning_rate': 2.844827586206897e-05, 'epoch': 1.29}\n",
477
+ "{'loss': 0.0002, 'grad_norm': 0.004128696396946907, 'learning_rate': 2.8160919540229884e-05, 'epoch': 1.31}\n",
478
+ "{'loss': 0.0002, 'grad_norm': 0.004070794675499201, 'learning_rate': 2.787356321839081e-05, 'epoch': 1.33}\n",
479
+ "{'loss': 0.0002, 'grad_norm': 0.0037206218112260103, 'learning_rate': 2.7586206896551727e-05, 'epoch': 1.34}\n",
480
+ "{'loss': 0.0001, 'grad_norm': 0.00400934973731637, 'learning_rate': 2.7298850574712648e-05, 'epoch': 1.36}\n",
481
+ "{'loss': 0.0001, 'grad_norm': 0.0037558332551270723, 'learning_rate': 2.7011494252873566e-05, 'epoch': 1.38}\n",
482
+ "{'loss': 0.0001, 'grad_norm': 0.0035916264168918133, 'learning_rate': 2.672413793103448e-05, 'epoch': 1.4}\n",
483
+ "{'loss': 0.0001, 'grad_norm': 0.003707454539835453, 'learning_rate': 2.6436781609195405e-05, 'epoch': 1.41}\n",
484
+ "{'loss': 0.0001, 'grad_norm': 0.0034801277797669172, 'learning_rate': 2.6149425287356323e-05, 'epoch': 1.43}\n",
485
+ "{'loss': 0.0001, 'grad_norm': 0.003501839004456997, 'learning_rate': 2.5862068965517244e-05, 'epoch': 1.45}\n",
486
+ "{'loss': 0.0001, 'grad_norm': 0.003458078484982252, 'learning_rate': 2.5574712643678162e-05, 'epoch': 1.47}\n",
487
+ "{'loss': 0.0001, 'grad_norm': 0.0031666585709899664, 'learning_rate': 2.5287356321839083e-05, 'epoch': 1.48}\n",
488
+ "{'loss': 0.0001, 'grad_norm': 0.0033736126497387886, 'learning_rate': 2.5e-05, 'epoch': 1.5}\n",
489
+ "{'loss': 0.0001, 'grad_norm': 0.003289664164185524, 'learning_rate': 2.4712643678160922e-05, 'epoch': 1.52}\n",
490
+ "{'loss': 0.0001, 'grad_norm': 0.0031710772309452295, 'learning_rate': 2.442528735632184e-05, 'epoch': 1.53}\n",
491
+ "{'loss': 0.0001, 'grad_norm': 0.0029777924064546824, 'learning_rate': 2.413793103448276e-05, 'epoch': 1.55}\n",
492
+ "{'loss': 0.0001, 'grad_norm': 0.003144340356811881, 'learning_rate': 2.385057471264368e-05, 'epoch': 1.57}\n",
493
+ "{'loss': 0.0001, 'grad_norm': 0.0029524925630539656, 'learning_rate': 2.3563218390804597e-05, 'epoch': 1.59}\n",
494
+ "{'loss': 0.0001, 'grad_norm': 0.002912016585469246, 'learning_rate': 2.327586206896552e-05, 'epoch': 1.6}\n",
495
+ "{'loss': 0.0001, 'grad_norm': 0.002869528019800782, 'learning_rate': 2.2988505747126437e-05, 'epoch': 1.62}\n",
496
+ "{'loss': 0.0001, 'grad_norm': 0.002966447500512004, 'learning_rate': 2.2701149425287358e-05, 'epoch': 1.64}\n",
497
+ "{'loss': 0.0001, 'grad_norm': 0.0027959959115833044, 'learning_rate': 2.2413793103448276e-05, 'epoch': 1.66}\n",
498
+ "{'loss': 0.0001, 'grad_norm': 0.0030403095297515392, 'learning_rate': 2.2126436781609197e-05, 'epoch': 1.67}\n",
499
+ "{'loss': 0.0001, 'grad_norm': 0.0026587999891489744, 'learning_rate': 2.183908045977012e-05, 'epoch': 1.69}\n",
500
+ "{'loss': 0.0001, 'grad_norm': 0.002689346671104431, 'learning_rate': 2.1551724137931033e-05, 'epoch': 1.71}\n",
501
+ "{'loss': 0.0001, 'grad_norm': 0.002710141707211733, 'learning_rate': 2.1264367816091954e-05, 'epoch': 1.72}\n",
502
+ "{'loss': 0.0001, 'grad_norm': 0.002674366347491741, 'learning_rate': 2.0977011494252875e-05, 'epoch': 1.74}\n",
503
+ "{'loss': 0.0001, 'grad_norm': 0.0026578502729535103, 'learning_rate': 2.0689655172413793e-05, 'epoch': 1.76}\n",
504
+ "{'loss': 0.0001, 'grad_norm': 0.00243232655338943, 'learning_rate': 2.0402298850574715e-05, 'epoch': 1.78}\n",
505
+ "{'loss': 0.0001, 'grad_norm': 0.0025773164816200733, 'learning_rate': 2.0114942528735632e-05, 'epoch': 1.79}\n",
506
+ "{'loss': 0.0001, 'grad_norm': 0.0024439615663141012, 'learning_rate': 1.9827586206896554e-05, 'epoch': 1.81}\n",
507
+ "{'loss': 0.0001, 'grad_norm': 0.0024733347818255424, 'learning_rate': 1.9540229885057475e-05, 'epoch': 1.83}\n",
508
+ "{'loss': 0.0001, 'grad_norm': 0.002439699834212661, 'learning_rate': 1.925287356321839e-05, 'epoch': 1.84}\n",
509
+ "{'loss': 0.0001, 'grad_norm': 0.0025980097707360983, 'learning_rate': 1.896551724137931e-05, 'epoch': 1.86}\n",
510
+ "{'loss': 0.0001, 'grad_norm': 0.002387199318036437, 'learning_rate': 1.867816091954023e-05, 'epoch': 1.88}\n",
511
+ "{'loss': 0.0001, 'grad_norm': 0.0023106117732822895, 'learning_rate': 1.839080459770115e-05, 'epoch': 1.9}\n",
512
+ "{'loss': 0.0001, 'grad_norm': 0.0023344189394265413, 'learning_rate': 1.810344827586207e-05, 'epoch': 1.91}\n",
513
+ "{'loss': 0.0001, 'grad_norm': 0.0023740960750728846, 'learning_rate': 1.781609195402299e-05, 'epoch': 1.93}\n",
514
+ "{'loss': 0.0001, 'grad_norm': 0.002346088644117117, 'learning_rate': 1.752873563218391e-05, 'epoch': 1.95}\n",
515
+ "{'loss': 0.0001, 'grad_norm': 0.002391340211033821, 'learning_rate': 1.7241379310344828e-05, 'epoch': 1.97}\n",
516
+ "{'loss': 0.0001, 'grad_norm': 0.002250733319669962, 'learning_rate': 1.6954022988505746e-05, 'epoch': 1.98}\n",
517
+ "{'loss': 0.0001, 'grad_norm': 0.002164299599826336, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}\n"
518
+ ]
519
+ },
520
+ {
521
+ "data": {
522
+ "application/vnd.jupyter.widget-view+json": {
523
+ "model_id": "ed9239dd5b4346509d243cbc378bcf44",
524
+ "version_major": 2,
525
+ "version_minor": 0
526
+ },
527
+ "text/plain": [
528
+ " 0%| | 0/580 [00:00<?, ?it/s]"
529
+ ]
530
+ },
531
+ "metadata": {},
532
+ "output_type": "display_data"
533
+ },
534
+ {
535
+ "name": "stdout",
536
+ "output_type": "stream",
537
+ "text": [
538
+ "{'eval_loss': 6.0841484810225666e-05, 'eval_runtime': 581.8939, 'eval_samples_per_second': 7.964, 'eval_steps_per_second': 0.997, 'epoch': 2.0}\n",
539
+ "{'loss': 0.0001, 'grad_norm': 0.0022026619408279657, 'learning_rate': 1.6379310344827585e-05, 'epoch': 2.02}\n",
540
+ "{'loss': 0.0001, 'grad_norm': 0.002242449903860688, 'learning_rate': 1.6091954022988507e-05, 'epoch': 2.03}\n",
541
+ "{'loss': 0.0001, 'grad_norm': 0.002464097458869219, 'learning_rate': 1.5804597701149425e-05, 'epoch': 2.05}\n",
542
+ "{'loss': 0.0001, 'grad_norm': 0.0022003022022545338, 'learning_rate': 1.5517241379310346e-05, 'epoch': 2.07}\n",
543
+ "{'loss': 0.0001, 'grad_norm': 0.0021785416174679995, 'learning_rate': 1.5229885057471265e-05, 'epoch': 2.09}\n",
544
+ "{'loss': 0.0001, 'grad_norm': 0.0021638190373778343, 'learning_rate': 1.4942528735632185e-05, 'epoch': 2.1}\n",
545
+ "{'loss': 0.0001, 'grad_norm': 0.0022439502645283937, 'learning_rate': 1.4655172413793103e-05, 'epoch': 2.12}\n",
546
+ "{'loss': 0.0001, 'grad_norm': 0.0020717435982078314, 'learning_rate': 1.4367816091954022e-05, 'epoch': 2.14}\n",
547
+ "{'loss': 0.0001, 'grad_norm': 0.0020531516056507826, 'learning_rate': 1.4080459770114942e-05, 'epoch': 2.16}\n",
548
+ "{'loss': 0.0001, 'grad_norm': 0.0019899189937859774, 'learning_rate': 1.3793103448275863e-05, 'epoch': 2.17}\n",
549
+ "{'loss': 0.0001, 'grad_norm': 0.0020303332712501287, 'learning_rate': 1.3505747126436783e-05, 'epoch': 2.19}\n",
550
+ "{'loss': 0.0001, 'grad_norm': 0.0021102086175233126, 'learning_rate': 1.3218390804597702e-05, 'epoch': 2.21}\n",
551
+ "{'loss': 0.0001, 'grad_norm': 0.0019932142458856106, 'learning_rate': 1.2931034482758622e-05, 'epoch': 2.22}\n",
552
+ "{'loss': 0.0001, 'grad_norm': 0.002080111298710108, 'learning_rate': 1.2643678160919542e-05, 'epoch': 2.24}\n",
553
+ "{'loss': 0.0001, 'grad_norm': 0.0020179597195237875, 'learning_rate': 1.2356321839080461e-05, 'epoch': 2.26}\n",
554
+ "{'loss': 0.0001, 'grad_norm': 0.0019549003336578608, 'learning_rate': 1.206896551724138e-05, 'epoch': 2.28}\n",
555
+ "{'loss': 0.0001, 'grad_norm': 0.0020865327678620815, 'learning_rate': 1.1781609195402299e-05, 'epoch': 2.29}\n",
556
+ "{'loss': 0.0001, 'grad_norm': 0.0018828624160960317, 'learning_rate': 1.1494252873563218e-05, 'epoch': 2.31}\n",
557
+ "{'loss': 0.0001, 'grad_norm': 0.0018662698566913605, 'learning_rate': 1.1206896551724138e-05, 'epoch': 2.33}\n",
558
+ "{'loss': 0.0001, 'grad_norm': 0.001857285387814045, 'learning_rate': 1.091954022988506e-05, 'epoch': 2.34}\n",
559
+ "{'loss': 0.0001, 'grad_norm': 0.001844724640250206, 'learning_rate': 1.0632183908045977e-05, 'epoch': 2.36}\n",
560
+ "{'loss': 0.0001, 'grad_norm': 0.0017886353889480233, 'learning_rate': 1.0344827586206897e-05, 'epoch': 2.38}\n",
561
+ "{'loss': 0.0001, 'grad_norm': 0.0019668207969516516, 'learning_rate': 1.0057471264367816e-05, 'epoch': 2.4}\n",
562
+ "{'loss': 0.0001, 'grad_norm': 0.0018605877412483096, 'learning_rate': 9.770114942528738e-06, 'epoch': 2.41}\n",
563
+ "{'loss': 0.0001, 'grad_norm': 0.0018027386395260692, 'learning_rate': 9.482758620689655e-06, 'epoch': 2.43}\n",
564
+ "{'loss': 0.0001, 'grad_norm': 0.0018370413454249501, 'learning_rate': 9.195402298850575e-06, 'epoch': 2.45}\n",
565
+ "{'loss': 0.0001, 'grad_norm': 0.0019249517936259508, 'learning_rate': 8.908045977011495e-06, 'epoch': 2.47}\n",
566
+ "{'loss': 0.0001, 'grad_norm': 0.0019102703081443906, 'learning_rate': 8.620689655172414e-06, 'epoch': 2.48}\n",
567
+ "{'loss': 0.0001, 'grad_norm': 0.0019130830187350512, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}\n",
568
+ "{'loss': 0.0001, 'grad_norm': 0.0019449306419119239, 'learning_rate': 8.045977011494253e-06, 'epoch': 2.52}\n",
569
+ "{'loss': 0.0001, 'grad_norm': 0.001796119031496346, 'learning_rate': 7.758620689655173e-06, 'epoch': 2.53}\n",
570
+ "{'loss': 0.0001, 'grad_norm': 0.0017440468072891235, 'learning_rate': 7.4712643678160925e-06, 'epoch': 2.55}\n",
571
+ "{'loss': 0.0001, 'grad_norm': 0.0017786856042221189, 'learning_rate': 7.183908045977011e-06, 'epoch': 2.57}\n",
572
+ "{'loss': 0.0001, 'grad_norm': 0.0018597355810925364, 'learning_rate': 6.896551724137932e-06, 'epoch': 2.59}\n",
573
+ "{'loss': 0.0001, 'grad_norm': 0.0017648187931627035, 'learning_rate': 6.609195402298851e-06, 'epoch': 2.6}\n",
574
+ "{'loss': 0.0001, 'grad_norm': 0.0017601278377696872, 'learning_rate': 6.321839080459771e-06, 'epoch': 2.62}\n",
575
+ "{'loss': 0.0001, 'grad_norm': 0.0017502185655757785, 'learning_rate': 6.03448275862069e-06, 'epoch': 2.64}\n",
576
+ "{'loss': 0.0001, 'grad_norm': 0.001753892400301993, 'learning_rate': 5.747126436781609e-06, 'epoch': 2.66}\n",
577
+ "{'loss': 0.0001, 'grad_norm': 0.0016946644755080342, 'learning_rate': 5.45977011494253e-06, 'epoch': 2.67}\n",
578
+ "{'loss': 0.0001, 'grad_norm': 0.001783599378541112, 'learning_rate': 5.172413793103448e-06, 'epoch': 2.69}\n",
579
+ "{'loss': 0.0001, 'grad_norm': 0.0017759180627763271, 'learning_rate': 4.885057471264369e-06, 'epoch': 2.71}\n",
580
+ "{'loss': 0.0001, 'grad_norm': 0.0017218819120898843, 'learning_rate': 4.5977011494252875e-06, 'epoch': 2.72}\n",
581
+ "{'loss': 0.0001, 'grad_norm': 0.0016811942914500833, 'learning_rate': 4.310344827586207e-06, 'epoch': 2.74}\n",
582
+ "{'loss': 0.0001, 'grad_norm': 0.0017582618165761232, 'learning_rate': 4.022988505747127e-06, 'epoch': 2.76}\n",
583
+ "{'loss': 0.0001, 'grad_norm': 0.001848816522397101, 'learning_rate': 3.7356321839080462e-06, 'epoch': 2.78}\n",
584
+ "{'loss': 0.0001, 'grad_norm': 0.0017523870337754488, 'learning_rate': 3.448275862068966e-06, 'epoch': 2.79}\n",
585
+ "{'loss': 0.0001, 'grad_norm': 0.001707645715214312, 'learning_rate': 3.1609195402298854e-06, 'epoch': 2.81}\n",
586
+ "{'loss': 0.0001, 'grad_norm': 0.0017925987485796213, 'learning_rate': 2.8735632183908046e-06, 'epoch': 2.83}\n",
587
+ "{'loss': 0.0001, 'grad_norm': 0.001785592525266111, 'learning_rate': 2.586206896551724e-06, 'epoch': 2.84}\n",
588
+ "{'loss': 0.0001, 'grad_norm': 0.0017369745764881372, 'learning_rate': 2.2988505747126437e-06, 'epoch': 2.86}\n",
589
+ "{'loss': 0.0001, 'grad_norm': 0.0017363271908834577, 'learning_rate': 2.0114942528735633e-06, 'epoch': 2.88}\n",
590
+ "{'loss': 0.0001, 'grad_norm': 0.0017762900097295642, 'learning_rate': 1.724137931034483e-06, 'epoch': 2.9}\n",
591
+ "{'loss': 0.0001, 'grad_norm': 0.0017800360219553113, 'learning_rate': 1.4367816091954023e-06, 'epoch': 2.91}\n",
592
+ "{'loss': 0.0001, 'grad_norm': 0.0016894094878807664, 'learning_rate': 1.1494252873563219e-06, 'epoch': 2.93}\n",
593
+ "{'loss': 0.0001, 'grad_norm': 0.0016883889911696315, 'learning_rate': 8.620689655172415e-07, 'epoch': 2.95}\n",
594
+ "{'loss': 0.0001, 'grad_norm': 0.0017332383431494236, 'learning_rate': 5.747126436781609e-07, 'epoch': 2.97}\n",
595
+ "{'loss': 0.0001, 'grad_norm': 0.0018206291133537889, 'learning_rate': 2.8735632183908047e-07, 'epoch': 2.98}\n",
596
+ "{'loss': 0.0001, 'grad_norm': 0.0017392894951626658, 'learning_rate': 0.0, 'epoch': 3.0}\n"
597
+ ]
598
+ },
599
+ {
600
+ "data": {
601
+ "application/vnd.jupyter.widget-view+json": {
602
+ "model_id": "de595459d7b642babd74e21ce354d064",
603
+ "version_major": 2,
604
+ "version_minor": 0
605
+ },
606
+ "text/plain": [
607
+ " 0%| | 0/580 [00:00<?, ?it/s]"
608
+ ]
609
+ },
610
+ "metadata": {},
611
+ "output_type": "display_data"
612
+ },
613
+ {
614
+ "name": "stdout",
615
+ "output_type": "stream",
616
+ "text": [
617
+ "{'eval_loss': 4.472154250834137e-05, 'eval_runtime': 541.083, 'eval_samples_per_second': 8.564, 'eval_steps_per_second': 1.072, 'epoch': 3.0}\n",
618
+ "{'train_runtime': 10729.5411, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'train_loss': 0.01232232698998762, 'epoch': 3.0}\n"
619
+ ]
620
+ },
621
+ {
622
+ "data": {
623
+ "text/plain": [
624
+ "TrainOutput(global_step=1740, training_loss=0.01232232698998762, metrics={'train_runtime': 10729.5411, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'total_flos': 2.5228134820702045e+17, 'train_loss': 0.01232232698998762, 'epoch': 3.0})"
625
+ ]
626
+ },
627
+ "execution_count": 33,
628
+ "metadata": {},
629
+ "output_type": "execute_result"
630
+ }
631
+ ],
632
+ "source": [
633
+ "trainer.train()\n"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "markdown",
638
+ "metadata": {},
639
+ "source": [
640
+ "## Save The Model"
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "execution_count": 34,
646
+ "metadata": {},
647
+ "outputs": [
648
+ {
649
+ "data": {
650
+ "text/plain": [
651
+ "[]"
652
+ ]
653
+ },
654
+ "execution_count": 34,
655
+ "metadata": {},
656
+ "output_type": "execute_result"
657
+ }
658
+ ],
659
+ "source": [
660
+ "# Save the trained model and processor\n",
661
+ "trainer.save_model(\"./trained_model\") # Saves the model to the specified directory\n",
662
+ "processor.save_pretrained(\"./trained_model\") # Saves the processor as well\n"
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "execution_count": 35,
668
+ "metadata": {},
669
+ "outputs": [
670
+ {
671
+ "data": {
672
+ "application/vnd.jupyter.widget-view+json": {
673
+ "model_id": "ae19dee506db46f4b39d402e43d16276",
674
+ "version_major": 2,
675
+ "version_minor": 0
676
+ },
677
+ "text/plain": [
678
+ " 0%| | 0/580 [00:00<?, ?it/s]"
679
+ ]
680
+ },
681
+ "metadata": {},
682
+ "output_type": "display_data"
683
+ },
684
+ {
685
+ "name": "stdout",
686
+ "output_type": "stream",
687
+ "text": [
688
+ "Evaluation Metrics:\n",
689
+ "eval_loss: 4.472154250834137e-05\n",
690
+ "eval_runtime: 527.4045\n",
691
+ "eval_samples_per_second: 8.786\n",
692
+ "eval_steps_per_second: 1.1\n",
693
+ "epoch: 3.0\n"
694
+ ]
695
+ }
696
+ ],
697
+ "source": [
698
+ "# Evaluate the model on the validation dataset\n",
699
+ "evaluation_metrics = trainer.evaluate()\n",
700
+ "\n",
701
+ "# Print the evaluation metrics\n",
702
+ "print(\"Evaluation Metrics:\")\n",
703
+ "for metric, value in evaluation_metrics.items():\n",
704
+ " print(f\"{metric}: {value}\")\n"
705
+ ]
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "execution_count": 10,
710
+ "metadata": {},
711
+ "outputs": [],
712
+ "source": [
713
+ "from transformers import AutoProcessor, AutoModelForAudioClassification\n",
714
+ "\n",
715
+ "# Load the trained model and processor\n",
716
+ "model_path = \"./trained_model\" # Path to your saved model\n",
717
+ "model = AutoModelForAudioClassification.from_pretrained(model_path)\n",
718
+ "processor = AutoProcessor.from_pretrained(model_path)\n"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "markdown",
723
+ "metadata": {},
724
+ "source": [
725
+ "## Single Audio Testing"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": 34,
731
+ "metadata": {},
732
+ "outputs": [],
733
+ "source": [
734
+ "def prepare_audio(file_path, sampling_rate=16000, duration=10):\n",
735
+ " \"\"\"\n",
736
+ " Prepares audio by loading, resampling, and returning it in manageable chunks.\n",
737
+ " \n",
738
+ " Parameters:\n",
739
+ " - file_path: Path to the audio file.\n",
740
+ " - sampling_rate: Target sampling rate for the audio.\n",
741
+ " - duration: Duration in seconds for each chunk.\n",
742
+ " \n",
743
+ " Returns:\n",
744
+ " - A list of audio chunks, each as a numpy array.\n",
745
+ " \"\"\"\n",
746
+ " # Load and resample the audio file\n",
747
+ " waveform, original_sampling_rate = torchaudio.load(file_path)\n",
748
+ " \n",
749
+ " # Convert stereo to mono if necessary\n",
750
+ " if waveform.shape[0] > 1: # More than 1 channel\n",
751
+ " waveform = torch.mean(waveform, dim=0, keepdim=True)\n",
752
+ " \n",
753
+ " # Resample if needed\n",
754
+ " if original_sampling_rate != sampling_rate:\n",
755
+ " resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)\n",
756
+ " waveform = resampler(waveform)\n",
757
+ " \n",
758
+ " # Calculate chunk size in samples\n",
759
+ " chunk_size = sampling_rate * duration\n",
760
+ " audio_chunks = []\n",
761
+ "\n",
762
+ " # Split the audio into chunks\n",
763
+ " for start in range(0, waveform.shape[1], chunk_size):\n",
764
+ " chunk = waveform[:, start:start + chunk_size]\n",
765
+ " \n",
766
+ " # Pad the last chunk if it's shorter than the chunk size\n",
767
+ " if chunk.shape[1] < chunk_size:\n",
768
+ " padding = chunk_size - chunk.shape[1]\n",
769
+ " chunk = torch.nn.functional.pad(chunk, (0, padding))\n",
770
+ " \n",
771
+ " audio_chunks.append(chunk.squeeze().numpy())\n",
772
+ " \n",
773
+ " return audio_chunks\n"
774
+ ]
775
+ },
776
+ {
777
+ "cell_type": "code",
778
+ "execution_count": 35,
779
+ "metadata": {},
780
+ "outputs": [
781
+ {
782
+ "name": "stdout",
783
+ "output_type": "stream",
784
+ "text": [
785
+ "Chunk shape: (160000,)\n",
786
+ "Logits for chunk 1: tensor([[ 4.6742, -5.1778]])\n",
787
+ "Chunk shape: (160000,)\n",
788
+ "Logits for chunk 2: tensor([[ 4.7219, -5.2332]])\n",
789
+ "Chunk shape: (160000,)\n",
790
+ "Logits for chunk 3: tensor([[ 4.7545, -5.2641]])\n",
791
+ "Chunk shape: (160000,)\n",
792
+ "Logits for chunk 4: tensor([[ 4.6714, -5.1740]])\n",
793
+ "Chunk shape: (160000,)\n",
794
+ "Logits for chunk 5: tensor([[ 4.7660, -5.2743]])\n",
795
+ "Chunk shape: (160000,)\n",
796
+ "Logits for chunk 6: tensor([[ 4.7724, -5.2836]])\n",
797
+ "Chunk shape: (160000,)\n",
798
+ "Logits for chunk 7: tensor([[ 4.7268, -5.2362]])\n",
799
+ "Chunk shape: (160000,)\n",
800
+ "Logits for chunk 8: tensor([[ 4.6898, -5.1898]])\n",
801
+ "Chunk shape: (160000,)\n",
802
+ "Logits for chunk 9: tensor([[ 4.6646, -5.1708]])\n",
803
+ "Chunk shape: (160000,)\n",
804
+ "Logits for chunk 10: tensor([[ 4.5948, -5.0867]])\n",
805
+ "Chunk shape: (160000,)\n",
806
+ "Logits for chunk 11: tensor([[ 4.7512, -5.2579]])\n",
807
+ "Chunk shape: (160000,)\n",
808
+ "Logits for chunk 12: tensor([[-4.5599, 5.0363]])\n",
809
+ "Chunk shape: (160000,)\n",
810
+ "Logits for chunk 13: tensor([[-0.4980, 0.5546]])\n",
811
+ "Chunk shape: (160000,)\n",
812
+ "Logits for chunk 14: tensor([[ 4.7295, -5.2358]])\n",
813
+ "Chunk shape: (160000,)\n",
814
+ "Logits for chunk 15: tensor([[ 4.7426, -5.2534]])\n",
815
+ "Chunk shape: (160000,)\n",
816
+ "Logits for chunk 16: tensor([[ 1.9405, -2.1493]])\n",
817
+ "Chunk shape: (160000,)\n",
818
+ "Logits for chunk 17: tensor([[ 4.7168, -5.2235]])\n",
819
+ "Chunk shape: (160000,)\n",
820
+ "Logits for chunk 18: tensor([[ 4.6801, -5.1907]])\n",
821
+ "Chunk shape: (160000,)\n",
822
+ "Logits for chunk 19: tensor([[ 4.7454, -5.2568]])\n",
823
+ "Chunk shape: (160000,)\n",
824
+ "Logits for chunk 20: tensor([[ 4.7642, -5.2723]])\n",
825
+ "Chunk shape: (160000,)\n",
826
+ "Logits for chunk 21: tensor([[ 4.7868, -5.2969]])\n",
827
+ "Chunk shape: (160000,)\n",
828
+ "Logits for chunk 22: tensor([[ 4.7600, -5.2690]])\n",
829
+ "Chunk shape: (160000,)\n",
830
+ "Logits for chunk 23: tensor([[ 4.7337, -5.2411]])\n",
831
+ "Chunk shape: (160000,)\n",
832
+ "Logits for chunk 24: tensor([[ 4.7835, -5.2943]])\n",
833
+ "Chunk shape: (160000,)\n",
834
+ "Logits for chunk 25: tensor([[ 4.7572, -5.2647]])\n",
835
+ "Chunk shape: (160000,)\n",
836
+ "Logits for chunk 26: tensor([[ 4.7485, -5.2581]])\n",
837
+ "Chunk shape: (160000,)\n",
838
+ "Logits for chunk 27: tensor([[ 4.6874, -5.2023]])\n",
839
+ "Chunk shape: (160000,)\n",
840
+ "Logits for chunk 28: tensor([[ 4.6877, -5.1922]])\n",
841
+ "Chunk shape: (160000,)\n",
842
+ "Logits for chunk 29: tensor([[ 4.7474, -5.2561]])\n",
843
+ "Chunk shape: (160000,)\n",
844
+ "Logits for chunk 30: tensor([[-4.3064, 4.7629]])\n",
845
+ "Chunk shape: (160000,)\n",
846
+ "Logits for chunk 31: tensor([[-3.8067, 4.2312]])\n",
847
+ "Chunk shape: (160000,)\n",
848
+ "Logits for chunk 32: tensor([[ 4.7217, -5.2325]])\n",
849
+ "Chunk shape: (160000,)\n",
850
+ "Logits for chunk 33: tensor([[ 4.7798, -5.2913]])\n",
851
+ "Chunk shape: (160000,)\n",
852
+ "Logits for chunk 34: tensor([[ 4.7214, -5.2355]])\n",
853
+ "Chunk shape: (160000,)\n",
854
+ "Logits for chunk 35: tensor([[ 4.7116, -5.2192]])\n",
855
+ "Chunk shape: (160000,)\n",
856
+ "Logits for chunk 36: tensor([[ 4.6687, -5.1812]])\n",
857
+ "Chunk shape: (160000,)\n",
858
+ "Logits for chunk 37: tensor([[-0.8128, 0.9402]])\n",
859
+ "Chunk shape: (160000,)\n",
860
+ "Logits for chunk 38: tensor([[ 4.7259, -5.2333]])\n",
861
+ "Chunk shape: (160000,)\n",
862
+ "Logits for chunk 39: tensor([[ 4.5698, -5.0731]])\n",
863
+ "Chunk shape: (160000,)\n",
864
+ "Logits for chunk 40: tensor([[ 4.7467, -5.2544]])\n",
865
+ "Chunk shape: (160000,)\n",
866
+ "Logits for chunk 41: tensor([[ 4.7781, -5.2884]])\n",
867
+ "Chunk shape: (160000,)\n",
868
+ "Logits for chunk 42: tensor([[ 4.7243, -5.2365]])\n",
869
+ "Chunk shape: (160000,)\n",
870
+ "Logits for chunk 43: tensor([[ 3.9325, -4.3570]])\n",
871
+ "Chunk shape: (160000,)\n",
872
+ "Logits for chunk 44: tensor([[-3.8786, 4.3105]])\n",
873
+ "Chunk shape: (160000,)\n",
874
+ "Logits for chunk 45: tensor([[ 3.3633, -3.6958]])\n",
875
+ "Chunk shape: (160000,)\n",
876
+ "Logits for chunk 46: tensor([[ 4.7127, -5.2213]])\n",
877
+ "Chunk shape: (160000,)\n",
878
+ "Logits for chunk 47: tensor([[ 0.0519, -0.0359]])\n",
879
+ "Chunk shape: (160000,)\n",
880
+ "Logits for chunk 48: tensor([[ 4.7457, -5.2535]])\n",
881
+ "Chunk shape: (160000,)\n",
882
+ "Logits for chunk 49: tensor([[ 3.4856, -3.8528]])\n",
883
+ "Chunk shape: (160000,)\n",
884
+ "Logits for chunk 50: tensor([[ 4.6485, -5.1538]])\n",
885
+ "Chunk shape: (160000,)\n",
886
+ "Logits for chunk 51: tensor([[ 4.6274, -5.1355]])\n",
887
+ "Chunk shape: (160000,)\n",
888
+ "Logits for chunk 52: tensor([[ 4.6852, -5.1872]])\n",
889
+ "Chunk shape: (160000,)\n",
890
+ "Logits for chunk 53: tensor([[ 4.7341, -5.2452]])\n",
891
+ "Chunk shape: (160000,)\n",
892
+ "Logits for chunk 54: tensor([[-4.5378, 5.0152]])\n",
893
+ "Chunk shape: (160000,)\n",
894
+ "Logits for chunk 55: tensor([[ 4.6822, -5.1887]])\n",
895
+ "Chunk shape: (160000,)\n",
896
+ "Logits for chunk 56: tensor([[ 4.7186, -5.2252]])\n",
897
+ "Chunk shape: (160000,)\n",
898
+ "Logits for chunk 57: tensor([[ 4.7688, -5.2787]])\n",
899
+ "Chunk shape: (160000,)\n",
900
+ "Logits for chunk 58: tensor([[ 4.7285, -5.2342]])\n",
901
+ "Chunk shape: (160000,)\n",
902
+ "Logits for chunk 59: tensor([[ 4.7447, -5.2550]])\n",
903
+ "Chunk shape: (160000,)\n",
904
+ "Logits for chunk 60: tensor([[ 4.5292, -5.0253]])\n",
905
+ "Predicted Class: Fake\n"
906
+ ]
907
+ }
908
+ ],
909
+ "source": [
910
+ "def predict_audio(file_path):\n",
911
+ " \"\"\"\n",
912
+ " Predicts the class of an audio file by aggregating predictions from chunks.\n",
913
+ " \n",
914
+ " Args:\n",
915
+ " file_path (str): Path to the audio file.\n",
916
+ "\n",
917
+ " Returns:\n",
918
+ " str: Predicted class label.\n",
919
+ " \"\"\"\n",
920
+ " # Prepare audio chunks\n",
921
+ " audio_chunks = prepare_audio(file_path)\n",
922
+ " predictions = []\n",
923
+ "\n",
924
+ " for i, chunk in enumerate(audio_chunks):\n",
925
+ " # Prepare input for the model\n",
926
+ " print(f\"Chunk shape: {chunk.shape}\")\n",
927
+ " inputs = processor(\n",
928
+ " chunk, sampling_rate=16000, return_tensors=\"pt\", padding=True\n",
929
+ " )\n",
930
+ " \n",
931
+ " # Perform inference\n",
932
+ " with torch.no_grad():\n",
933
+ " outputs = model(**inputs)\n",
934
+ " logits = outputs.logits\n",
935
+ " print(f\"Logits for chunk {i + 1}: {logits}\") # Print the logits\n",
936
+ " predicted_class = torch.argmax(logits, dim=1).item()\n",
937
+ " predictions.append(predicted_class)\n",
938
+ " \n",
939
+ " # Aggregate predictions (e.g., majority voting)\n",
940
+ " aggregated_prediction = max(set(predictions), key=predictions.count)\n",
941
+ " \n",
942
+ " # Convert class ID to label\n",
943
+ " return model.config.id2label[aggregated_prediction]\n",
944
+ "\n",
945
+ "# Example: Test a single audio file\n",
946
+ "file_path = r\"D:\\Year 3 Sem 2\\Godamlah\\Deepfake\\deepfake model ver3\\data\\KAGGLE\\AUDIO\\FAKE\\biden-to-linus.wav\" # Replace with your audio file path\n",
947
+ "predicted_class = predict_audio(file_path)\n",
948
+ "print(f\"Predicted Class: {predicted_class}\")\n"
949
+ ]
950
+ },
951
+ {
952
+ "cell_type": "markdown",
953
+ "metadata": {},
954
+ "source": [
955
+ "## Batch Testing"
956
+ ]
957
+ },
958
+ {
959
+ "cell_type": "code",
960
+ "execution_count": 36,
961
+ "metadata": {},
962
+ "outputs": [
963
+ {
964
+ "name": "stdout",
965
+ "output_type": "stream",
966
+ "text": [
967
+ "Chunk shape: (160000,)\n",
968
+ "Logits for chunk 1: tensor([[-3.3933, 3.7590]])\n",
969
+ "Chunk shape: (160000,)\n",
970
+ "Logits for chunk 1: tensor([[-3.3933, 3.7590]])\n",
971
+ "Chunk shape: (160000,)\n",
972
+ "Logits for chunk 1: tensor([[-1.5531, 1.7190]])\n",
973
+ "Chunk shape: (160000,)\n",
974
+ "Logits for chunk 1: tensor([[-1.5917, 1.7620]])\n",
975
+ "Chunk shape: (160000,)\n",
976
+ "Logits for chunk 1: tensor([[ 4.7569, -5.2631]])\n",
977
+ "Chunk shape: (160000,)\n",
978
+ "Logits for chunk 1: tensor([[ 4.7569, -5.2630]])\n",
979
+ "Chunk shape: (160000,)\n",
980
+ "Logits for chunk 1: tensor([[-4.5033, 4.9768]])\n",
981
+ "Chunk shape: (160000,)\n",
982
+ "Logits for chunk 1: tensor([[-4.5029, 4.9765]])\n",
983
+ "Chunk shape: (160000,)\n",
984
+ "Logits for chunk 1: tensor([[ 4.7639, -5.2653]])\n",
985
+ "Chunk shape: (160000,)\n",
986
+ "Logits for chunk 1: tensor([[ 4.7639, -5.2653]])\n",
987
+ "{'file': 'human voice 1 to mr beast.mp3', 'predicted_class': 'Real'}\n",
988
+ "{'file': 'human voice 1 to mr beast.wav', 'predicted_class': 'Real'}\n",
989
+ "{'file': 'human voice 1.mp3', 'predicted_class': 'Real'}\n",
990
+ "{'file': 'human voice 1.wav', 'predicted_class': 'Real'}\n",
991
+ "{'file': 'human voice 2 to Jett.mp3', 'predicted_class': 'Fake'}\n",
992
+ "{'file': 'human voice 2 to Jett.wav', 'predicted_class': 'Fake'}\n",
993
+ "{'file': 'human voice 2.mp3', 'predicted_class': 'Real'}\n",
994
+ "{'file': 'human voice 2.wav', 'predicted_class': 'Real'}\n",
995
+ "{'file': 'text to audio jett.mp3', 'predicted_class': 'Fake'}\n",
996
+ "{'file': 'text to audio jett.wav', 'predicted_class': 'Fake'}\n"
997
+ ]
998
+ }
999
+ ],
1000
+ "source": [
1001
+ "import os\n",
1002
+ "\n",
1003
+ "def batch_predict(test_folder, limit=10):\n",
1004
+ " \"\"\"\n",
1005
+ " Batch processes audio files for predictions.\n",
1006
+ "\n",
1007
+ " Args:\n",
1008
+ " test_folder (str): Path to the folder containing audio files.\n",
1009
+ " limit (int): Maximum number of files to process. Set to None for all files.\n",
1010
+ "\n",
1011
+ " Returns:\n",
1012
+ " list: A list of dictionaries containing file names and predicted classes.\n",
1013
+ " \"\"\"\n",
1014
+ " results = []\n",
1015
+ " files = os.listdir(test_folder)\n",
1016
+ "\n",
1017
+ " # Limit the number of files processed if a limit is provided\n",
1018
+ " if limit is not None:\n",
1019
+ " files = files[:limit]\n",
1020
+ "\n",
1021
+ " # Process each file in the folder\n",
1022
+ " for file_name in files:\n",
1023
+ " file_path = os.path.join(test_folder, file_name)\n",
1024
+ " try:\n",
1025
+ " predicted_class = predict_audio(file_path) # Use the predict_audio function\n",
1026
+ " results.append({\"file\": file_name, \"predicted_class\": predicted_class})\n",
1027
+ " except Exception as e:\n",
1028
+ " print(f\"Error processing {file_name}: {e}\")\n",
1029
+ " \n",
1030
+ " return results\n",
1031
+ "\n",
1032
+ "# Specify the folder path and limit\n",
1033
+ "test_folder = r\"D:\\Year 3 Sem 2\\Godamlah\\Deepfake\\deepfake model ver3\\data\\real life test audio\" # Replace with your test folder path\n",
1034
+ "results = batch_predict(test_folder, limit=10)\n",
1035
+ "\n",
1036
+ "# Print results\n",
1037
+ "for result in results:\n",
1038
+ " print(result)\n"
1039
+ ]
1040
+ },
1041
+ {
1042
+ "cell_type": "code",
1043
+ "execution_count": 14,
1044
+ "metadata": {},
1045
+ "outputs": [
1046
+ {
1047
+ "data": {
1048
+ "application/vnd.jupyter.widget-view+json": {
1049
+ "model_id": "479899631f95453e9f82355f7511cff3",
1050
+ "version_major": 2,
1051
+ "version_minor": 0
1052
+ },
1053
+ "text/plain": [
1054
+ " 0%| | 0/580 [00:00<?, ?it/s]"
1055
+ ]
1056
+ },
1057
+ "metadata": {},
1058
+ "output_type": "display_data"
1059
+ },
1060
+ {
1061
+ "name": "stdout",
1062
+ "output_type": "stream",
1063
+ "text": [
1064
+ "{'eval_loss': 4.472154250834137e-05, 'eval_model_preparation_time': 0.0032, 'eval_accuracy': 1.0, 'eval_runtime': 1168.3696, 'eval_samples_per_second': 3.966, 'eval_steps_per_second': 0.496}\n"
1065
+ ]
1066
+ }
1067
+ ],
1068
+ "source": [
1069
+ "import evaluate\n",
1070
+ "\n",
1071
+ "# Load the accuracy metric\n",
1072
+ "accuracy_metric = evaluate.load(\"accuracy\")\n",
1073
+ "\n",
1074
+ "def compute_metrics(eval_pred):\n",
1075
+ " logits, labels = eval_pred\n",
1076
+ " predictions = logits.argmax(axis=-1) # Get the predicted class\n",
1077
+ " accuracy = accuracy_metric.compute(predictions=predictions, references=labels)\n",
1078
+ " return {\"accuracy\": accuracy[\"accuracy\"]}\n",
1079
+ "\n",
1080
+ "from transformers import Trainer\n",
1081
+ "\n",
1082
+ "trainer = Trainer(\n",
1083
+ " model=model,\n",
1084
+ " args=training_args,\n",
1085
+ " train_dataset=processed_dataset[\"train\"],\n",
1086
+ " eval_dataset=processed_dataset[\"validation\"],\n",
1087
+ " tokenizer=processor, # Required for padding\n",
1088
+ " data_collator=data_collator,\n",
1089
+ " compute_metrics=compute_metrics, # Add this line\n",
1090
+ ")\n",
1091
+ "\n",
1092
+ "# Evaluate the model\n",
1093
+ "metrics = trainer.evaluate()\n",
1094
+ "print(metrics)\n"
1095
+ ]
1096
+ }
1097
+ ],
1098
+ "metadata": {
1099
+ "kernelspec": {
1100
+ "display_name": "base",
1101
+ "language": "python",
1102
+ "name": "python3"
1103
+ },
1104
+ "language_info": {
1105
+ "codemirror_mode": {
1106
+ "name": "ipython",
1107
+ "version": 3
1108
+ },
1109
+ "file_extension": ".py",
1110
+ "mimetype": "text/x-python",
1111
+ "name": "python",
1112
+ "nbconvert_exporter": "python",
1113
+ "pygments_lexer": "ipython3",
1114
+ "version": "3.12.7"
1115
+ }
1116
+ },
1117
+ "nbformat": 4,
1118
+ "nbformat_minor": 2
1119
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torch
5
+ torchaudio
6
+ pydantic
7
+ numpy
8
+ python-multipart
vercel.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "builds": [
3
+ {
4
+ "src": "DeepfakeModel.py",
5
+ "use": "@vercel/python",
6
+ "config": {
7
+ "maxLambdaSize": "450mb"
8
+ }
9
+ }
10
+ ],
11
+ "routes": [
12
+ {
13
+ "src": "/(.*)",
14
+ "dest": "DeepfakeModel.py"
15
+ }
16
+ ]
17
+
18
+ }