Spaces:
Runtime error
Runtime error
upload files
Browse files- Deepfake/model/config.json +126 -0
- Deepfake/model/model.safetensors +3 -0
- Deepfake/model/preprocessor_config.json +10 -0
- Deepfake/model/special_tokens_map.json +6 -0
- Deepfake/model/tokenizer_config.json +50 -0
- Deepfake/model/vocab.json +34 -0
- DeepfakeModel.ipynb +278 -0
- DeepfakeModel.py +118 -0
- ModelTest.ipynb +134 -0
- ScamTextModel.py +26 -0
- notebook.ipynb +1119 -0
- requirements.txt +8 -0
- vercel.json +18 -0
Deepfake/model/config.json
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "facebook/wav2vec2-base",
|
| 3 |
+
"activation_dropout": 0.0,
|
| 4 |
+
"adapter_attn_dim": null,
|
| 5 |
+
"adapter_kernel_size": 3,
|
| 6 |
+
"adapter_stride": 2,
|
| 7 |
+
"add_adapter": false,
|
| 8 |
+
"apply_spec_augment": true,
|
| 9 |
+
"architectures": [
|
| 10 |
+
"Wav2Vec2ForSequenceClassification"
|
| 11 |
+
],
|
| 12 |
+
"attention_dropout": 0.1,
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"classifier_proj_size": 256,
|
| 15 |
+
"codevector_dim": 256,
|
| 16 |
+
"contrastive_logits_temperature": 0.1,
|
| 17 |
+
"conv_bias": false,
|
| 18 |
+
"conv_dim": [
|
| 19 |
+
512,
|
| 20 |
+
512,
|
| 21 |
+
512,
|
| 22 |
+
512,
|
| 23 |
+
512,
|
| 24 |
+
512,
|
| 25 |
+
512
|
| 26 |
+
],
|
| 27 |
+
"conv_kernel": [
|
| 28 |
+
10,
|
| 29 |
+
3,
|
| 30 |
+
3,
|
| 31 |
+
3,
|
| 32 |
+
3,
|
| 33 |
+
2,
|
| 34 |
+
2
|
| 35 |
+
],
|
| 36 |
+
"conv_stride": [
|
| 37 |
+
5,
|
| 38 |
+
2,
|
| 39 |
+
2,
|
| 40 |
+
2,
|
| 41 |
+
2,
|
| 42 |
+
2,
|
| 43 |
+
2
|
| 44 |
+
],
|
| 45 |
+
"ctc_loss_reduction": "sum",
|
| 46 |
+
"ctc_zero_infinity": false,
|
| 47 |
+
"diversity_loss_weight": 0.1,
|
| 48 |
+
"do_stable_layer_norm": false,
|
| 49 |
+
"eos_token_id": 2,
|
| 50 |
+
"feat_extract_activation": "gelu",
|
| 51 |
+
"feat_extract_norm": "group",
|
| 52 |
+
"feat_proj_dropout": 0.1,
|
| 53 |
+
"feat_quantizer_dropout": 0.0,
|
| 54 |
+
"final_dropout": 0.0,
|
| 55 |
+
"freeze_feat_extract_train": true,
|
| 56 |
+
"hidden_act": "gelu",
|
| 57 |
+
"hidden_dropout": 0.1,
|
| 58 |
+
"hidden_size": 768,
|
| 59 |
+
"id2label": {
|
| 60 |
+
"0": "Fake",
|
| 61 |
+
"1": "Real"
|
| 62 |
+
},
|
| 63 |
+
"initializer_range": 0.02,
|
| 64 |
+
"intermediate_size": 3072,
|
| 65 |
+
"label2id": {
|
| 66 |
+
"Fake": 0,
|
| 67 |
+
"Real": 1
|
| 68 |
+
},
|
| 69 |
+
"layer_norm_eps": 1e-05,
|
| 70 |
+
"layerdrop": 0.0,
|
| 71 |
+
"mask_channel_length": 10,
|
| 72 |
+
"mask_channel_min_space": 1,
|
| 73 |
+
"mask_channel_other": 0.0,
|
| 74 |
+
"mask_channel_prob": 0.0,
|
| 75 |
+
"mask_channel_selection": "static",
|
| 76 |
+
"mask_feature_length": 10,
|
| 77 |
+
"mask_feature_min_masks": 0,
|
| 78 |
+
"mask_feature_prob": 0.0,
|
| 79 |
+
"mask_time_length": 10,
|
| 80 |
+
"mask_time_min_masks": 2,
|
| 81 |
+
"mask_time_min_space": 1,
|
| 82 |
+
"mask_time_other": 0.0,
|
| 83 |
+
"mask_time_prob": 0.05,
|
| 84 |
+
"mask_time_selection": "static",
|
| 85 |
+
"model_type": "wav2vec2",
|
| 86 |
+
"no_mask_channel_overlap": false,
|
| 87 |
+
"no_mask_time_overlap": false,
|
| 88 |
+
"num_adapter_layers": 3,
|
| 89 |
+
"num_attention_heads": 12,
|
| 90 |
+
"num_codevector_groups": 2,
|
| 91 |
+
"num_codevectors_per_group": 320,
|
| 92 |
+
"num_conv_pos_embedding_groups": 16,
|
| 93 |
+
"num_conv_pos_embeddings": 128,
|
| 94 |
+
"num_feat_extract_layers": 7,
|
| 95 |
+
"num_hidden_layers": 12,
|
| 96 |
+
"num_negatives": 100,
|
| 97 |
+
"output_hidden_size": 768,
|
| 98 |
+
"pad_token_id": 0,
|
| 99 |
+
"proj_codevector_dim": 256,
|
| 100 |
+
"tdnn_dilation": [
|
| 101 |
+
1,
|
| 102 |
+
2,
|
| 103 |
+
3,
|
| 104 |
+
1,
|
| 105 |
+
1
|
| 106 |
+
],
|
| 107 |
+
"tdnn_dim": [
|
| 108 |
+
512,
|
| 109 |
+
512,
|
| 110 |
+
512,
|
| 111 |
+
512,
|
| 112 |
+
1500
|
| 113 |
+
],
|
| 114 |
+
"tdnn_kernel": [
|
| 115 |
+
5,
|
| 116 |
+
3,
|
| 117 |
+
3,
|
| 118 |
+
1,
|
| 119 |
+
1
|
| 120 |
+
],
|
| 121 |
+
"torch_dtype": "float32",
|
| 122 |
+
"transformers_version": "4.45.2",
|
| 123 |
+
"use_weighted_layer_sum": false,
|
| 124 |
+
"vocab_size": 32,
|
| 125 |
+
"xvector_output_dim": 512
|
| 126 |
+
}
|
Deepfake/model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4265db7332aaff6149687072b860aaea1767c9a1756dbbecc23647e130724bbc
|
| 3 |
+
size 378302360
|
Deepfake/model/preprocessor_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_normalize": true,
|
| 3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
| 4 |
+
"feature_size": 1,
|
| 5 |
+
"padding_side": "right",
|
| 6 |
+
"padding_value": 0.0,
|
| 7 |
+
"processor_class": "Wav2Vec2Processor",
|
| 8 |
+
"return_attention_mask": false,
|
| 9 |
+
"sampling_rate": 16000
|
| 10 |
+
}
|
Deepfake/model/special_tokens_map.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"eos_token": "</s>",
|
| 4 |
+
"pad_token": "<pad>",
|
| 5 |
+
"unk_token": "<unk>"
|
| 6 |
+
}
|
Deepfake/model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<pad>",
|
| 5 |
+
"lstrip": true,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": true,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": false
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<s>",
|
| 13 |
+
"lstrip": true,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": true,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": false
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": true,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": true,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": false
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": true,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": true,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": false
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"bos_token": "<s>",
|
| 37 |
+
"clean_up_tokenization_spaces": false,
|
| 38 |
+
"do_lower_case": false,
|
| 39 |
+
"do_normalize": true,
|
| 40 |
+
"eos_token": "</s>",
|
| 41 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 42 |
+
"pad_token": "<pad>",
|
| 43 |
+
"processor_class": "Wav2Vec2Processor",
|
| 44 |
+
"replace_word_delimiter_char": " ",
|
| 45 |
+
"return_attention_mask": false,
|
| 46 |
+
"target_lang": null,
|
| 47 |
+
"tokenizer_class": "Wav2Vec2CTCTokenizer",
|
| 48 |
+
"unk_token": "<unk>",
|
| 49 |
+
"word_delimiter_token": "|"
|
| 50 |
+
}
|
Deepfake/model/vocab.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"'": 27,
|
| 3 |
+
"</s>": 2,
|
| 4 |
+
"<pad>": 0,
|
| 5 |
+
"<s>": 1,
|
| 6 |
+
"<unk>": 3,
|
| 7 |
+
"A": 7,
|
| 8 |
+
"B": 24,
|
| 9 |
+
"C": 19,
|
| 10 |
+
"D": 14,
|
| 11 |
+
"E": 5,
|
| 12 |
+
"F": 20,
|
| 13 |
+
"G": 21,
|
| 14 |
+
"H": 11,
|
| 15 |
+
"I": 10,
|
| 16 |
+
"J": 29,
|
| 17 |
+
"K": 26,
|
| 18 |
+
"L": 15,
|
| 19 |
+
"M": 17,
|
| 20 |
+
"N": 9,
|
| 21 |
+
"O": 8,
|
| 22 |
+
"P": 23,
|
| 23 |
+
"Q": 30,
|
| 24 |
+
"R": 13,
|
| 25 |
+
"S": 12,
|
| 26 |
+
"T": 6,
|
| 27 |
+
"U": 16,
|
| 28 |
+
"V": 25,
|
| 29 |
+
"W": 18,
|
| 30 |
+
"X": 28,
|
| 31 |
+
"Y": 22,
|
| 32 |
+
"Z": 31,
|
| 33 |
+
"|": 4
|
| 34 |
+
}
|
DeepfakeModel.ipynb
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stdout",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"Requirement already satisfied: torch in c:\\users\\asus\\anaconda3\\lib\\site-packages (2.5.1)Note: you may need to restart the kernel to use updated packages.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"Requirement already satisfied: torchvision in c:\\users\\asus\\anaconda3\\lib\\site-packages (0.20.1)\n",
|
| 15 |
+
"Requirement already satisfied: torchaudio in c:\\users\\asus\\anaconda3\\lib\\site-packages (2.5.1)\n",
|
| 16 |
+
"Requirement already satisfied: PySoundFile in c:\\users\\asus\\anaconda3\\lib\\site-packages (0.9.0.post1)\n",
|
| 17 |
+
"Requirement already satisfied: filelock in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (3.13.1)\n",
|
| 18 |
+
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (4.11.0)\n",
|
| 19 |
+
"Requirement already satisfied: networkx in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (3.3)\n",
|
| 20 |
+
"Requirement already satisfied: jinja2 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (3.1.4)\n",
|
| 21 |
+
"Requirement already satisfied: fsspec in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (2024.6.1)\n",
|
| 22 |
+
"Requirement already satisfied: setuptools in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (75.1.0)\n",
|
| 23 |
+
"Requirement already satisfied: sympy==1.13.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch) (1.13.1)\n",
|
| 24 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
| 25 |
+
"Requirement already satisfied: numpy in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torchvision) (1.26.4)\n",
|
| 26 |
+
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torchvision) (10.4.0)\n",
|
| 27 |
+
"Requirement already satisfied: cffi>=0.6 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from PySoundFile) (1.17.1)\n",
|
| 28 |
+
"Requirement already satisfied: pycparser in c:\\users\\asus\\anaconda3\\lib\\site-packages (from cffi>=0.6->PySoundFile) (2.21)\n",
|
| 29 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from jinja2->torch) (2.1.3)\n"
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
],
|
| 33 |
+
"source": [
|
| 34 |
+
"pip install torch torchvision torchaudio PySoundFile ffmpeg-python"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": 4,
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [
|
| 42 |
+
{
|
| 43 |
+
"name": "stdout",
|
| 44 |
+
"output_type": "stream",
|
| 45 |
+
"text": [
|
| 46 |
+
"['soundfile']\n"
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
],
|
| 50 |
+
"source": [
|
| 51 |
+
"import torchaudio\n",
|
| 52 |
+
"print(str(torchaudio.list_audio_backends()))"
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"cell_type": "code",
|
| 57 |
+
"execution_count": 12,
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"outputs": [
|
| 60 |
+
{
|
| 61 |
+
"name": "stderr",
|
| 62 |
+
"output_type": "stream",
|
| 63 |
+
"text": [
|
| 64 |
+
"<>:13: SyntaxWarning: invalid escape sequence '\\m'\n",
|
| 65 |
+
"<>:17: SyntaxWarning: invalid escape sequence '\\H'\n",
|
| 66 |
+
"<>:13: SyntaxWarning: invalid escape sequence '\\m'\n",
|
| 67 |
+
"<>:17: SyntaxWarning: invalid escape sequence '\\H'\n",
|
| 68 |
+
"C:\\Users\\Asus\\AppData\\Local\\Temp\\ipykernel_18220\\208613059.py:13: SyntaxWarning: invalid escape sequence '\\m'\n",
|
| 69 |
+
" model_path = \"Deepfake\\model\"\n",
|
| 70 |
+
"C:\\Users\\Asus\\AppData\\Local\\Temp\\ipykernel_18220\\208613059.py:17: SyntaxWarning: invalid escape sequence '\\H'\n",
|
| 71 |
+
" cache_dir=\"D:\\HuggingFace\",\n"
|
| 72 |
+
]
|
| 73 |
+
}
|
| 74 |
+
],
|
| 75 |
+
"source": [
|
| 76 |
+
"from transformers import pipeline\n",
|
| 77 |
+
"from transformers import AutoProcessor, AutoModelForAudioClassification\n",
|
| 78 |
+
"from fastapi import FastAPI\n",
|
| 79 |
+
"from pydantic import BaseModel\n",
|
| 80 |
+
"import uvicorn\n",
|
| 81 |
+
"import torchaudio\n",
|
| 82 |
+
"import torch\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"# Define the input schema\n",
|
| 85 |
+
"class InputData(BaseModel):\n",
|
| 86 |
+
" input: str\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"model_path = \"Deepfake\\model\"\n",
|
| 89 |
+
"processor = AutoProcessor.from_pretrained(model_path)\n",
|
| 90 |
+
"# Instantiate the model\n",
|
| 91 |
+
"model = AutoModelForAudioClassification.from_pretrained(pretrained_model_name_or_path=model_path,\n",
|
| 92 |
+
" cache_dir=\"D:\\HuggingFace\",\n",
|
| 93 |
+
" local_files_only=True,\n",
|
| 94 |
+
" )\n"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "markdown",
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"source": [
|
| 101 |
+
"Functions"
|
| 102 |
+
]
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "code",
|
| 106 |
+
"execution_count": 29,
|
| 107 |
+
"metadata": {},
|
| 108 |
+
"outputs": [],
|
| 109 |
+
"source": []
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "code",
|
| 113 |
+
"execution_count": 6,
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"outputs": [],
|
| 116 |
+
"source": [
|
| 117 |
+
"def prepare_audio(file_path, sampling_rate=16000, duration=10):\n",
|
| 118 |
+
" \"\"\"\n",
|
| 119 |
+
" Prepares audio by loading, resampling, and returning it in manageable chunks.\n",
|
| 120 |
+
" \n",
|
| 121 |
+
" Parameters:\n",
|
| 122 |
+
" - file_path: Path to the audio file.\n",
|
| 123 |
+
" - sampling_rate: Target sampling rate for the audio.\n",
|
| 124 |
+
" - duration: Duration in seconds for each chunk.\n",
|
| 125 |
+
" \n",
|
| 126 |
+
" Returns:\n",
|
| 127 |
+
" - A list of audio chunks, each as a numpy array.\n",
|
| 128 |
+
" \"\"\"\n",
|
| 129 |
+
" # Load and resample the audio file\n",
|
| 130 |
+
" waveform, original_sampling_rate = torchaudio.load(file_path)\n",
|
| 131 |
+
" \n",
|
| 132 |
+
" # Convert stereo to mono if necessary\n",
|
| 133 |
+
" if waveform.shape[0] > 1: # More than 1 channel\n",
|
| 134 |
+
" waveform = torch.mean(waveform, dim=0, keepdim=True)\n",
|
| 135 |
+
" \n",
|
| 136 |
+
" # Resample if needed\n",
|
| 137 |
+
" if original_sampling_rate != sampling_rate:\n",
|
| 138 |
+
" resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)\n",
|
| 139 |
+
" waveform = resampler(waveform)\n",
|
| 140 |
+
" \n",
|
| 141 |
+
" # Calculate chunk size in samples\n",
|
| 142 |
+
" chunk_size = sampling_rate * duration\n",
|
| 143 |
+
" audio_chunks = []\n",
|
| 144 |
+
"\n",
|
| 145 |
+
" # Split the audio into chunks\n",
|
| 146 |
+
" for start in range(0, waveform.shape[1], chunk_size):\n",
|
| 147 |
+
" chunk = waveform[:, start:start + chunk_size]\n",
|
| 148 |
+
" \n",
|
| 149 |
+
" # Pad the last chunk if it's shorter than the chunk size\n",
|
| 150 |
+
" if chunk.shape[1] < chunk_size:\n",
|
| 151 |
+
" padding = chunk_size - chunk.shape[1]\n",
|
| 152 |
+
" chunk = torch.nn.functional.pad(chunk, (0, padding))\n",
|
| 153 |
+
" \n",
|
| 154 |
+
" audio_chunks.append(chunk.squeeze().numpy())\n",
|
| 155 |
+
" \n",
|
| 156 |
+
" return audio_chunks\n"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "code",
|
| 161 |
+
"execution_count": 14,
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"outputs": [],
|
| 164 |
+
"source": [
|
| 165 |
+
"import torch.nn.functional as F\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"def predict_audio(file_path):\n",
|
| 168 |
+
" \"\"\"\n",
|
| 169 |
+
" Predicts the class of an audio file by aggregating predictions from chunks and calculates confidence.\n",
|
| 170 |
+
" \n",
|
| 171 |
+
" Args:\n",
|
| 172 |
+
" file_path (str): Path to the audio file.\n",
|
| 173 |
+
"\n",
|
| 174 |
+
" Returns:\n",
|
| 175 |
+
" dict: Contains the predicted class label and average confidence score.\n",
|
| 176 |
+
" \"\"\"\n",
|
| 177 |
+
" # Prepare audio chunks\n",
|
| 178 |
+
" audio_chunks = prepare_audio(file_path)\n",
|
| 179 |
+
" predictions = []\n",
|
| 180 |
+
" confidences = []\n",
|
| 181 |
+
"\n",
|
| 182 |
+
" for i, chunk in enumerate(audio_chunks):\n",
|
| 183 |
+
" # Prepare input for the model\n",
|
| 184 |
+
" inputs = processor(\n",
|
| 185 |
+
" chunk, sampling_rate=16000, return_tensors=\"pt\", padding=True\n",
|
| 186 |
+
" )\n",
|
| 187 |
+
" \n",
|
| 188 |
+
" # Perform inference\n",
|
| 189 |
+
" with torch.no_grad():\n",
|
| 190 |
+
" outputs = model(**inputs)\n",
|
| 191 |
+
" logits = outputs.logits\n",
|
| 192 |
+
" \n",
|
| 193 |
+
" # Apply softmax to calculate probabilities\n",
|
| 194 |
+
" probabilities = F.softmax(logits, dim=1)\n",
|
| 195 |
+
" \n",
|
| 196 |
+
" # Get the predicted class and its confidence\n",
|
| 197 |
+
" confidence, predicted_class = torch.max(probabilities, dim=1)\n",
|
| 198 |
+
" predictions.append(predicted_class.item())\n",
|
| 199 |
+
" confidences.append(confidence.item())\n",
|
| 200 |
+
" \n",
|
| 201 |
+
" # Aggregate predictions (majority voting)\n",
|
| 202 |
+
" aggregated_prediction_id = max(set(predictions), key=predictions.count)\n",
|
| 203 |
+
" predicted_label = model.config.id2label[aggregated_prediction_id]\n",
|
| 204 |
+
" \n",
|
| 205 |
+
" # Calculate average confidence across chunks\n",
|
| 206 |
+
" average_confidence = sum(confidences) / len(confidences)\n",
|
| 207 |
+
"\n",
|
| 208 |
+
" return {\n",
|
| 209 |
+
" \"predicted_label\": predicted_label,\n",
|
| 210 |
+
" \"average_confidence\": average_confidence\n",
|
| 211 |
+
" }\n",
|
| 212 |
+
"\n"
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "code",
|
| 217 |
+
"execution_count": 15,
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [
|
| 220 |
+
{
|
| 221 |
+
"name": "stdout",
|
| 222 |
+
"output_type": "stream",
|
| 223 |
+
"text": [
|
| 224 |
+
"Chunk shape: (160000,)\n",
|
| 225 |
+
"Chunk shape: (160000,)\n",
|
| 226 |
+
"Chunk shape: (160000,)\n",
|
| 227 |
+
"Chunk shape: (160000,)\n",
|
| 228 |
+
"Chunk shape: (160000,)\n",
|
| 229 |
+
"Chunk shape: (160000,)\n",
|
| 230 |
+
"Chunk shape: (160000,)\n",
|
| 231 |
+
"Chunk shape: (160000,)\n",
|
| 232 |
+
"Chunk shape: (160000,)\n",
|
| 233 |
+
"Chunk shape: (160000,)\n",
|
| 234 |
+
"Predicted Class: {'predicted_label': 'Real', 'average_confidence': 0.9984144032001495}\n"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"ename": "",
|
| 239 |
+
"evalue": "",
|
| 240 |
+
"output_type": "error",
|
| 241 |
+
"traceback": [
|
| 242 |
+
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
|
| 243 |
+
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
|
| 244 |
+
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
|
| 245 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
| 246 |
+
]
|
| 247 |
+
}
|
| 248 |
+
],
|
| 249 |
+
"source": [
|
| 250 |
+
"# Example: Test a single audio file\n",
|
| 251 |
+
"file_path = r\"D:\\repos\\GODAM\\audioFiles\\test.wav\" # Replace with your audio file path\n",
|
| 252 |
+
"predicted_class = predict_audio(file_path)\n",
|
| 253 |
+
"print(f\"Predicted Class: {predicted_class}\")"
|
| 254 |
+
]
|
| 255 |
+
}
|
| 256 |
+
],
|
| 257 |
+
"metadata": {
|
| 258 |
+
"kernelspec": {
|
| 259 |
+
"display_name": "base",
|
| 260 |
+
"language": "python",
|
| 261 |
+
"name": "python3"
|
| 262 |
+
},
|
| 263 |
+
"language_info": {
|
| 264 |
+
"codemirror_mode": {
|
| 265 |
+
"name": "ipython",
|
| 266 |
+
"version": 3
|
| 267 |
+
},
|
| 268 |
+
"file_extension": ".py",
|
| 269 |
+
"mimetype": "text/x-python",
|
| 270 |
+
"name": "python",
|
| 271 |
+
"nbconvert_exporter": "python",
|
| 272 |
+
"pygments_lexer": "ipython3",
|
| 273 |
+
"version": "3.12.7"
|
| 274 |
+
}
|
| 275 |
+
},
|
| 276 |
+
"nbformat": 4,
|
| 277 |
+
"nbformat_minor": 2
|
| 278 |
+
}
|
DeepfakeModel.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
import uvicorn
|
| 4 |
+
import os
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoProcessor, AutoModelForAudioClassification
|
| 9 |
+
|
| 10 |
+
# Model setup
|
| 11 |
+
model_path = r"Deepfake\model"
|
| 12 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
| 13 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
| 14 |
+
pretrained_model_name_or_path=model_path,
|
| 15 |
+
local_files_only=True,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def prepare_audio(file_path, sampling_rate=16000, duration=10):
|
| 19 |
+
"""
|
| 20 |
+
Prepares audio by loading, resampling, and returning it in manageable chunks.
|
| 21 |
+
"""
|
| 22 |
+
# Load and resample the audio file
|
| 23 |
+
waveform, original_sampling_rate = torchaudio.load(file_path)
|
| 24 |
+
|
| 25 |
+
# Convert stereo to mono if necessary
|
| 26 |
+
if waveform.shape[0] > 1: # More than 1 channel
|
| 27 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 28 |
+
|
| 29 |
+
# Resample if needed
|
| 30 |
+
if original_sampling_rate != sampling_rate:
|
| 31 |
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)
|
| 32 |
+
waveform = resampler(waveform)
|
| 33 |
+
|
| 34 |
+
# Calculate chunk size in samples
|
| 35 |
+
chunk_size = sampling_rate * duration
|
| 36 |
+
audio_chunks = []
|
| 37 |
+
|
| 38 |
+
# Split the audio into chunks
|
| 39 |
+
for start in range(0, waveform.shape[1], chunk_size):
|
| 40 |
+
chunk = waveform[:, start:start + chunk_size]
|
| 41 |
+
|
| 42 |
+
# Pad the last chunk if it's shorter than the chunk size
|
| 43 |
+
if chunk.shape[1] < chunk_size:
|
| 44 |
+
padding = chunk_size - chunk.shape[1]
|
| 45 |
+
chunk = torch.nn.functional.pad(chunk, (0, padding))
|
| 46 |
+
|
| 47 |
+
audio_chunks.append(chunk.squeeze().numpy())
|
| 48 |
+
|
| 49 |
+
return audio_chunks
|
| 50 |
+
|
| 51 |
+
def predict_audio(file_path):
|
| 52 |
+
"""
|
| 53 |
+
Predicts the class of an audio file by aggregating predictions from chunks and calculates confidence.
|
| 54 |
+
"""
|
| 55 |
+
# Prepare audio chunks
|
| 56 |
+
audio_chunks = prepare_audio(file_path)
|
| 57 |
+
predictions = []
|
| 58 |
+
confidences = []
|
| 59 |
+
|
| 60 |
+
for i, chunk in enumerate(audio_chunks):
|
| 61 |
+
# Prepare input for the model
|
| 62 |
+
inputs = processor(
|
| 63 |
+
chunk, sampling_rate=16000, return_tensors="pt", padding=True
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Perform inference
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
outputs = model(**inputs)
|
| 69 |
+
logits = outputs.logits
|
| 70 |
+
|
| 71 |
+
# Apply softmax to calculate probabilities
|
| 72 |
+
probabilities = F.softmax(logits, dim=1)
|
| 73 |
+
|
| 74 |
+
# Get the predicted class and its confidence
|
| 75 |
+
confidence, predicted_class = torch.max(probabilities, dim=1)
|
| 76 |
+
predictions.append(predicted_class.item())
|
| 77 |
+
confidences.append(confidence.item())
|
| 78 |
+
|
| 79 |
+
# Aggregate predictions (majority voting)
|
| 80 |
+
aggregated_prediction_id = max(set(predictions), key=predictions.count)
|
| 81 |
+
predicted_label = model.config.id2label[aggregated_prediction_id]
|
| 82 |
+
|
| 83 |
+
# Calculate average confidence across chunks
|
| 84 |
+
average_confidence = sum(confidences) / len(confidences)
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"predicted_label": predicted_label,
|
| 88 |
+
"average_confidence": average_confidence
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Initialize FastAPI
|
| 92 |
+
app = FastAPI()
|
| 93 |
+
|
| 94 |
+
@app.post("/infer")
|
| 95 |
+
async def infer(file: UploadFile = File(...)):
|
| 96 |
+
"""
|
| 97 |
+
Accepts an audio file and returns the prediction and confidence.
|
| 98 |
+
"""
|
| 99 |
+
# Save the uploaded file to a temporary location
|
| 100 |
+
temp_file_path = f"temp_{file.filename}"
|
| 101 |
+
with open(temp_file_path, "wb") as temp_file:
|
| 102 |
+
temp_file.write(await file.read())
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
# Perform inference
|
| 106 |
+
predictions = predict_audio(temp_file_path)
|
| 107 |
+
finally:
|
| 108 |
+
# Clean up the temporary file
|
| 109 |
+
os.remove(temp_file_path)
|
| 110 |
+
|
| 111 |
+
return predictions
|
| 112 |
+
|
| 113 |
+
@app.get("/health")
|
| 114 |
+
async def health():
|
| 115 |
+
return {"message": "ok"}
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
ModelTest.ipynb
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 4,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import requests\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"def infer_text(api_url, input_text):\n",
|
| 12 |
+
" url = f\"{api_url}/infer\"\n",
|
| 13 |
+
" try:\n",
|
| 14 |
+
" # Send the input as a JSON object\n",
|
| 15 |
+
" response = requests.post(url, json={\"input\": input_text})\n",
|
| 16 |
+
" response.raise_for_status()\n",
|
| 17 |
+
" return response.json()\n",
|
| 18 |
+
" except requests.exceptions.RequestException as e:\n",
|
| 19 |
+
" print(f\"Error during API call: {e}\")\n",
|
| 20 |
+
" return None\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"def check_health(api_url):\n",
|
| 23 |
+
" url = f\"{api_url}/health\"\n",
|
| 24 |
+
" try:\n",
|
| 25 |
+
" response = requests.get(url)\n",
|
| 26 |
+
" response.raise_for_status()\n",
|
| 27 |
+
" return response.json()\n",
|
| 28 |
+
" except requests.exceptions.RequestException as e:\n",
|
| 29 |
+
" print(f\"Error during API health check: {e}\")\n",
|
| 30 |
+
" return None"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": 5,
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [
|
| 38 |
+
{
|
| 39 |
+
"name": "stdout",
|
| 40 |
+
"output_type": "stream",
|
| 41 |
+
"text": [
|
| 42 |
+
"API Health Check: {'message': 'ok'}\n",
|
| 43 |
+
"Predictions: [{'label': 'LABEL_0', 'score': 0.9927427768707275}]\n"
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"source": [
|
| 48 |
+
"api_url = \"http://localhost:8000\"\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"# Check the API health status\n",
|
| 51 |
+
"health_status = check_health(api_url)\n",
|
| 52 |
+
"if health_status:\n",
|
| 53 |
+
" print(\"API Health Check:\", health_status)\n",
|
| 54 |
+
"else:\n",
|
| 55 |
+
" print(\"Failed to connect to the API.\")\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"# Example input text\n",
|
| 58 |
+
"input_text = \"Congratulations! You've won a prize. Click the link to claim your reward.\"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# Call the /infer endpoint\n",
|
| 61 |
+
"predictions = infer_text(api_url, input_text)\n",
|
| 62 |
+
"if predictions:\n",
|
| 63 |
+
" print(\"Predictions:\", predictions)\n",
|
| 64 |
+
"else:\n",
|
| 65 |
+
" print(\"Failed to get predictions from the API.\")"
|
| 66 |
+
]
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "markdown",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"source": [
|
| 72 |
+
"DeepFakeModel Test"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "code",
|
| 77 |
+
"execution_count": 4,
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [
|
| 80 |
+
{
|
| 81 |
+
"name": "stdout",
|
| 82 |
+
"output_type": "stream",
|
| 83 |
+
"text": [
|
| 84 |
+
"Response JSON: {'predicted_label': 'Real', 'average_confidence': 0.9984144032001495}\n"
|
| 85 |
+
]
|
| 86 |
+
}
|
| 87 |
+
],
|
| 88 |
+
"source": [
|
| 89 |
+
"import requests\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# Define the API endpoint\n",
|
| 92 |
+
"url = \"http://127.0.0.1:8000/infer\"\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# Path to the audio file you want to test\n",
|
| 95 |
+
"file_path = r\"D:\\repos\\GODAM\\audioFiles\\test.wav\" # Replace with the path to your audio file\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"# Open the file in binary mode\n",
|
| 98 |
+
"with open(file_path, \"rb\") as audio_file:\n",
|
| 99 |
+
" # Prepare the file payload\n",
|
| 100 |
+
" files = {\"file\": (\"audio.wav\", audio_file, \"audio/wav\")}\n",
|
| 101 |
+
" \n",
|
| 102 |
+
" # Send the POST request\n",
|
| 103 |
+
" response = requests.post(url, files=files)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"# Print the response from the API\n",
|
| 106 |
+
"if response.status_code == 200:\n",
|
| 107 |
+
" print(\"Response JSON:\", response.json())\n",
|
| 108 |
+
"else:\n",
|
| 109 |
+
" print(f\"Error {response.status_code}: {response.text}\")"
|
| 110 |
+
]
|
| 111 |
+
}
|
| 112 |
+
],
|
| 113 |
+
"metadata": {
|
| 114 |
+
"kernelspec": {
|
| 115 |
+
"display_name": "base",
|
| 116 |
+
"language": "python",
|
| 117 |
+
"name": "python3"
|
| 118 |
+
},
|
| 119 |
+
"language_info": {
|
| 120 |
+
"codemirror_mode": {
|
| 121 |
+
"name": "ipython",
|
| 122 |
+
"version": 3
|
| 123 |
+
},
|
| 124 |
+
"file_extension": ".py",
|
| 125 |
+
"mimetype": "text/x-python",
|
| 126 |
+
"name": "python",
|
| 127 |
+
"nbconvert_exporter": "python",
|
| 128 |
+
"pygments_lexer": "ipython3",
|
| 129 |
+
"version": "3.12.7"
|
| 130 |
+
}
|
| 131 |
+
},
|
| 132 |
+
"nbformat": 4,
|
| 133 |
+
"nbformat_minor": 2
|
| 134 |
+
}
|
ScamTextModel.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline
|
| 2 |
+
from fastapi import FastAPI
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
import uvicorn
|
| 5 |
+
|
| 6 |
+
# Define the input schema
|
| 7 |
+
class InputData(BaseModel):
|
| 8 |
+
input: str
|
| 9 |
+
|
| 10 |
+
# Initialize the pipeline
|
| 11 |
+
pipe = pipeline("text-classification", model="phishbot/ScamLLM")
|
| 12 |
+
|
| 13 |
+
app = FastAPI()
|
| 14 |
+
|
| 15 |
+
# Define API endpoints
|
| 16 |
+
@app.post("/infer")
|
| 17 |
+
async def infer(data: InputData):
|
| 18 |
+
predictions = pipe(data.input)
|
| 19 |
+
return predictions
|
| 20 |
+
|
| 21 |
+
@app.get("/health")
|
| 22 |
+
async def health():
|
| 23 |
+
return {"message": "ok"}
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
notebook.ipynb
ADDED
|
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"## Import data"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": 2,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [
|
| 15 |
+
{
|
| 16 |
+
"name": "stdout",
|
| 17 |
+
"output_type": "stream",
|
| 18 |
+
"text": [
|
| 19 |
+
"Collecting torchaudio\n",
|
| 20 |
+
" Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl.metadata (6.5 kB)\n",
|
| 21 |
+
"Requirement already satisfied: torch==2.5.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torchaudio) (2.5.1)\n",
|
| 22 |
+
"Requirement already satisfied: filelock in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.13.1)\n",
|
| 23 |
+
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (4.11.0)\n",
|
| 24 |
+
"Requirement already satisfied: networkx in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.3)\n",
|
| 25 |
+
"Requirement already satisfied: jinja2 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.1.4)\n",
|
| 26 |
+
"Requirement already satisfied: fsspec in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (2024.6.1)\n",
|
| 27 |
+
"Requirement already satisfied: setuptools in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (75.1.0)\n",
|
| 28 |
+
"Requirement already satisfied: sympy==1.13.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (1.13.1)\n",
|
| 29 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from sympy==1.13.1->torch==2.5.1->torchaudio) (1.3.0)\n",
|
| 30 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from jinja2->torch==2.5.1->torchaudio) (2.1.3)\n",
|
| 31 |
+
"Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl (2.4 MB)\n",
|
| 32 |
+
" ---------------------------------------- 0.0/2.4 MB ? eta -:--:--\n",
|
| 33 |
+
" ---------------------------------------- 2.4/2.4 MB 11.6 MB/s eta 0:00:00\n",
|
| 34 |
+
"Installing collected packages: torchaudio\n",
|
| 35 |
+
"Successfully installed torchaudio-2.5.1\n",
|
| 36 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
| 37 |
+
]
|
| 38 |
+
}
|
| 39 |
+
],
|
| 40 |
+
"source": [
|
| 41 |
+
"pip install torchaudio"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [
|
| 49 |
+
{
|
| 50 |
+
"ename": "ModuleNotFoundError",
|
| 51 |
+
"evalue": "No module named 'datasets'",
|
| 52 |
+
"output_type": "error",
|
| 53 |
+
"traceback": [
|
| 54 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 55 |
+
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
| 56 |
+
"Cell \u001b[1;32mIn[3], line 5\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorchaudio\u001b[39;00m\n\u001b[1;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DatasetDict, load_dataset\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprepare_dataset\u001b[39m(directory):\n\u001b[0;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpath\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabel\u001b[39m\u001b[38;5;124m\"\u001b[39m: []}\n",
|
| 57 |
+
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'datasets'"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"ename": "",
|
| 62 |
+
"evalue": "",
|
| 63 |
+
"output_type": "error",
|
| 64 |
+
"traceback": [
|
| 65 |
+
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
|
| 66 |
+
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
|
| 67 |
+
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
|
| 68 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
| 69 |
+
]
|
| 70 |
+
}
|
| 71 |
+
],
|
| 72 |
+
"source": [
|
| 73 |
+
"## The data is not pushed to repo, only model and training logs etc are uploaded\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"import os\n",
|
| 76 |
+
"import torchaudio\n",
|
| 77 |
+
"from datasets import DatasetDict, load_dataset\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"def prepare_dataset(directory):\n",
|
| 80 |
+
" data = {\"path\": [], \"label\": []}\n",
|
| 81 |
+
" labels = {\"fake\": 0, \"real\": 1} # Map fake to 0 and real to 1\n",
|
| 82 |
+
"\n",
|
| 83 |
+
" for label, label_id in labels.items():\n",
|
| 84 |
+
" folder_path = os.path.join(directory, label)\n",
|
| 85 |
+
" for file in os.listdir(folder_path):\n",
|
| 86 |
+
" if file.endswith(\".wav\"):\n",
|
| 87 |
+
" data[\"path\"].append(os.path.join(folder_path, file))\n",
|
| 88 |
+
" data[\"label\"].append(label_id)\n",
|
| 89 |
+
" return data\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# Prepare train, validation, and test datasets\n",
|
| 92 |
+
"train_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n",
|
| 93 |
+
"val_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n",
|
| 94 |
+
"test_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": 4,
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"from datasets import Dataset\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"train_dataset = Dataset.from_dict(train_data)\n",
|
| 106 |
+
"val_dataset = Dataset.from_dict(val_data)\n",
|
| 107 |
+
"test_dataset = Dataset.from_dict(test_data)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"dataset = DatasetDict({\"train\": train_dataset, \"validation\": val_dataset, \"test\": test_dataset})\n"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"cell_type": "markdown",
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"source": [
|
| 116 |
+
"## Import Model"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "code",
|
| 121 |
+
"execution_count": 5,
|
| 122 |
+
"metadata": {},
|
| 123 |
+
"outputs": [
|
| 124 |
+
{
|
| 125 |
+
"name": "stderr",
|
| 126 |
+
"output_type": "stream",
|
| 127 |
+
"text": [
|
| 128 |
+
"c:\\Users\\60165\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\configuration_utils.py:302: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n",
|
| 129 |
+
" warnings.warn(\n",
|
| 130 |
+
"Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']\n",
|
| 131 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 132 |
+
]
|
| 133 |
+
}
|
| 134 |
+
],
|
| 135 |
+
"source": [
|
| 136 |
+
"from transformers import AutoProcessor\n",
|
| 137 |
+
"from transformers import AutoModelForAudioClassification\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"# Initialize processor\n",
|
| 140 |
+
"model_name = \"facebook/wav2vec2-base\" # Replace with your model if different\n",
|
| 141 |
+
"model = AutoModelForAudioClassification.from_pretrained(model_name, num_labels=2) # Adjust `num_labels` based on your dataset\n",
|
| 142 |
+
"processor = AutoProcessor.from_pretrained(model_name)\n"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "markdown",
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"source": [
|
| 149 |
+
"## Preprocess Data"
|
| 150 |
+
]
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"cell_type": "code",
|
| 154 |
+
"execution_count": null,
|
| 155 |
+
"metadata": {},
|
| 156 |
+
"outputs": [
|
| 157 |
+
{
|
| 158 |
+
"data": {
|
| 159 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 160 |
+
"model_id": "af2c5f31f0db43ee9975023b28e2c57c",
|
| 161 |
+
"version_major": 2,
|
| 162 |
+
"version_minor": 0
|
| 163 |
+
},
|
| 164 |
+
"text/plain": [
|
| 165 |
+
"Map: 0%| | 0/4634 [00:00<?, ? examples/s]"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"output_type": "display_data"
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"data": {
|
| 173 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 174 |
+
"model_id": "051c3226b30145f3adbc708cf8afd6a5",
|
| 175 |
+
"version_major": 2,
|
| 176 |
+
"version_minor": 0
|
| 177 |
+
},
|
| 178 |
+
"text/plain": [
|
| 179 |
+
"Map: 0%| | 0/4634 [00:00<?, ? examples/s]"
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"output_type": "display_data"
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"data": {
|
| 187 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 188 |
+
"model_id": "86de10f37b7441e3a9c9f0ee88bf5149",
|
| 189 |
+
"version_major": 2,
|
| 190 |
+
"version_minor": 0
|
| 191 |
+
},
|
| 192 |
+
"text/plain": [
|
| 193 |
+
"Map: 0%| | 0/4634 [00:00<?, ? examples/s]"
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
"metadata": {},
|
| 197 |
+
"output_type": "display_data"
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"name": "stdout",
|
| 201 |
+
"output_type": "stream",
|
| 202 |
+
"text": [
|
| 203 |
+
"tensor(0) <class 'torch.Tensor'>\n",
|
| 204 |
+
"torch.int64\n"
|
| 205 |
+
]
|
| 206 |
+
}
|
| 207 |
+
],
|
| 208 |
+
"source": [
|
| 209 |
+
"import torch\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"def preprocess_function(batch):\n",
|
| 213 |
+
" audio = torchaudio.load(batch[\"path\"])[0].squeeze().numpy()\n",
|
| 214 |
+
" inputs = processor(\n",
|
| 215 |
+
" audio,\n",
|
| 216 |
+
" sampling_rate=16000,\n",
|
| 217 |
+
" padding=True,\n",
|
| 218 |
+
" truncation=True,\n",
|
| 219 |
+
" max_length=32000, \n",
|
| 220 |
+
" return_tensors=\"pt\"\n",
|
| 221 |
+
" )\n",
|
| 222 |
+
" batch[\"input_values\"] = inputs.input_values[0]\n",
|
| 223 |
+
" # Ensure labels are converted to LongTensor\n",
|
| 224 |
+
" batch[\"label\"] = torch.tensor(batch[\"label\"], dtype=torch.long) # Convert label to LongTensor\n",
|
| 225 |
+
" return batch\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"processed_dataset = dataset.map(preprocess_function, remove_columns=[\"path\"], batched=False)\n",
|
| 228 |
+
"# Set format to torch tensors for compatibility with PyTorch\n",
|
| 229 |
+
"processed_dataset.set_format(type=\"torch\", columns=[\"input_values\", \"label\"])\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"# Double-check the label type again\n",
|
| 232 |
+
"print(processed_dataset[\"train\"][0][\"label\"], type(processed_dataset[\"train\"][0][\"label\"]))\n",
|
| 233 |
+
"print(processed_dataset[\"train\"][0][\"label\"].dtype) # Should print torch.int64\n"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "markdown",
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"source": [
|
| 240 |
+
"## Map Training Labels"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "code",
|
| 245 |
+
"execution_count": 12,
|
| 246 |
+
"metadata": {},
|
| 247 |
+
"outputs": [
|
| 248 |
+
{
|
| 249 |
+
"name": "stdout",
|
| 250 |
+
"output_type": "stream",
|
| 251 |
+
"text": [
|
| 252 |
+
"Labels: {0: 'Fake', 1: 'Real'}\n",
|
| 253 |
+
"Labels: {0: 'Fake', 1: 'Real'}\n"
|
| 254 |
+
]
|
| 255 |
+
}
|
| 256 |
+
],
|
| 257 |
+
"source": [
|
| 258 |
+
"# Ensure labels are in numerical format (e.g., 0, 1)\n",
|
| 259 |
+
"id2label = {0: \"Fake\", 1: \"Real\"} # Define the mapping based on your dataset\n",
|
| 260 |
+
"label2id = {v: k for k, v in id2label.items()} # Reverse mapping\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"print(\"Labels:\", id2label)\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"# Update the model's configuration with labels\n",
|
| 266 |
+
"model.config.id2label = id2label\n",
|
| 267 |
+
"model.config.label2id = label2id\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"print(\"Labels:\", model.config.id2label) # Verify\n"
|
| 270 |
+
]
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"cell_type": "code",
|
| 274 |
+
"execution_count": null,
|
| 275 |
+
"metadata": {},
|
| 276 |
+
"outputs": [],
|
| 277 |
+
"source": [
|
| 278 |
+
"from transformers import DataCollatorWithPadding\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"# Use the processor's tokenizer for padding\n",
|
| 281 |
+
"data_collator = DataCollatorWithPadding(tokenizer=processor, padding=True)\n"
|
| 282 |
+
]
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"cell_type": "markdown",
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"source": [
|
| 288 |
+
"## Initialize Training Arguments"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"execution_count": 8,
|
| 294 |
+
"metadata": {},
|
| 295 |
+
"outputs": [
|
| 296 |
+
{
|
| 297 |
+
"name": "stdout",
|
| 298 |
+
"output_type": "stream",
|
| 299 |
+
"text": [
|
| 300 |
+
"TrainingArguments initialized successfully!\n"
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"name": "stderr",
|
| 305 |
+
"output_type": "stream",
|
| 306 |
+
"text": [
|
| 307 |
+
"c:\\Users\\60165\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
|
| 308 |
+
" warnings.warn(\n"
|
| 309 |
+
]
|
| 310 |
+
}
|
| 311 |
+
],
|
| 312 |
+
"source": [
|
| 313 |
+
"from transformers import TrainingArguments\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"training_args = TrainingArguments(\n",
|
| 316 |
+
" output_dir=\"./results\",\n",
|
| 317 |
+
" evaluation_strategy=\"epoch\",\n",
|
| 318 |
+
" save_strategy=\"epoch\",\n",
|
| 319 |
+
" learning_rate=5e-5,\n",
|
| 320 |
+
" per_device_train_batch_size=8,\n",
|
| 321 |
+
" per_device_eval_batch_size=8,\n",
|
| 322 |
+
" num_train_epochs=3,\n",
|
| 323 |
+
" weight_decay=0.01,\n",
|
| 324 |
+
" logging_dir=\"./logs\",\n",
|
| 325 |
+
" logging_steps=10,\n",
|
| 326 |
+
" save_total_limit=2,\n",
|
| 327 |
+
" fp16=True, \n",
|
| 328 |
+
" push_to_hub=False,\n",
|
| 329 |
+
")\n",
|
| 330 |
+
"print(\"TrainingArguments initialized successfully!\")\n"
|
| 331 |
+
]
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"cell_type": "code",
|
| 335 |
+
"execution_count": 9,
|
| 336 |
+
"metadata": {},
|
| 337 |
+
"outputs": [],
|
| 338 |
+
"source": [
|
| 339 |
+
"from transformers import Trainer\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"trainer = Trainer(\n",
|
| 342 |
+
" model=model,\n",
|
| 343 |
+
" args=training_args,\n",
|
| 344 |
+
" train_dataset=processed_dataset[\"train\"],\n",
|
| 345 |
+
" eval_dataset=processed_dataset[\"validation\"],\n",
|
| 346 |
+
" tokenizer=processor, # Required for the data collator\n",
|
| 347 |
+
" data_collator=data_collator,\n",
|
| 348 |
+
")\n"
|
| 349 |
+
]
|
| 350 |
+
},
|
| 351 |
+
{
|
| 352 |
+
"cell_type": "markdown",
|
| 353 |
+
"metadata": {},
|
| 354 |
+
"source": [
|
| 355 |
+
"## Start Training"
|
| 356 |
+
]
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"cell_type": "code",
|
| 360 |
+
"execution_count": 33,
|
| 361 |
+
"metadata": {},
|
| 362 |
+
"outputs": [
|
| 363 |
+
{
|
| 364 |
+
"data": {
|
| 365 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 366 |
+
"model_id": "53044eac53174227a4eb19becbeb0bbf",
|
| 367 |
+
"version_major": 2,
|
| 368 |
+
"version_minor": 0
|
| 369 |
+
},
|
| 370 |
+
"text/plain": [
|
| 371 |
+
" 0%| | 0/1740 [00:00<?, ?it/s]"
|
| 372 |
+
]
|
| 373 |
+
},
|
| 374 |
+
"metadata": {},
|
| 375 |
+
"output_type": "display_data"
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"name": "stdout",
|
| 379 |
+
"output_type": "stream",
|
| 380 |
+
"text": [
|
| 381 |
+
"{'loss': 0.6607, 'grad_norm': 2.5936832427978516, 'learning_rate': 4.971264367816092e-05, 'epoch': 0.02}\n",
|
| 382 |
+
"{'loss': 0.4545, 'grad_norm': 5.04218864440918, 'learning_rate': 4.9425287356321845e-05, 'epoch': 0.03}\n",
|
| 383 |
+
"{'loss': 0.1779, 'grad_norm': 0.8874927163124084, 'learning_rate': 4.913793103448276e-05, 'epoch': 0.05}\n",
|
| 384 |
+
"{'loss': 0.0833, 'grad_norm': 0.40262681245803833, 'learning_rate': 4.885057471264368e-05, 'epoch': 0.07}\n",
|
| 385 |
+
"{'loss': 0.0948, 'grad_norm': 0.5579108595848083, 'learning_rate': 4.85632183908046e-05, 'epoch': 0.09}\n",
|
| 386 |
+
"{'loss': 0.0128, 'grad_norm': 0.1585635393857956, 'learning_rate': 4.827586206896552e-05, 'epoch': 0.1}\n",
|
| 387 |
+
"{'loss': 0.0696, 'grad_norm': 0.12149885296821594, 'learning_rate': 4.798850574712644e-05, 'epoch': 0.12}\n",
|
| 388 |
+
"{'loss': 0.0065, 'grad_norm': 0.09655608981847763, 'learning_rate': 4.770114942528736e-05, 'epoch': 0.14}\n",
|
| 389 |
+
"{'loss': 0.0052, 'grad_norm': 0.08148041367530823, 'learning_rate': 4.741379310344828e-05, 'epoch': 0.16}\n",
|
| 390 |
+
"{'loss': 0.0041, 'grad_norm': 0.07030971348285675, 'learning_rate': 4.7126436781609195e-05, 'epoch': 0.17}\n",
|
| 391 |
+
"{'loss': 0.0034, 'grad_norm': 0.05784648284316063, 'learning_rate': 4.6839080459770116e-05, 'epoch': 0.19}\n",
|
| 392 |
+
"{'loss': 0.0029, 'grad_norm': 0.04742524400353432, 'learning_rate': 4.655172413793104e-05, 'epoch': 0.21}\n",
|
| 393 |
+
"{'loss': 0.0024, 'grad_norm': 0.04222293198108673, 'learning_rate': 4.626436781609196e-05, 'epoch': 0.22}\n",
|
| 394 |
+
"{'loss': 0.0021, 'grad_norm': 0.039586298167705536, 'learning_rate': 4.597701149425287e-05, 'epoch': 0.24}\n",
|
| 395 |
+
"{'loss': 0.0018, 'grad_norm': 0.034121569246053696, 'learning_rate': 4.5689655172413794e-05, 'epoch': 0.26}\n",
|
| 396 |
+
"{'loss': 0.0016, 'grad_norm': 0.031423550099134445, 'learning_rate': 4.5402298850574716e-05, 'epoch': 0.28}\n",
|
| 397 |
+
"{'loss': 0.0418, 'grad_norm': 0.02912888117134571, 'learning_rate': 4.511494252873563e-05, 'epoch': 0.29}\n",
|
| 398 |
+
"{'loss': 0.0831, 'grad_norm': 0.027829233556985855, 'learning_rate': 4.482758620689655e-05, 'epoch': 0.31}\n",
|
| 399 |
+
"{'loss': 0.2431, 'grad_norm': 0.10419639199972153, 'learning_rate': 4.454022988505747e-05, 'epoch': 0.33}\n",
|
| 400 |
+
"{'loss': 0.0024, 'grad_norm': 0.046179670840501785, 'learning_rate': 4.4252873563218394e-05, 'epoch': 0.34}\n",
|
| 401 |
+
"{'loss': 0.0729, 'grad_norm': 0.03943876922130585, 'learning_rate': 4.396551724137931e-05, 'epoch': 0.36}\n",
|
| 402 |
+
"{'loss': 0.0806, 'grad_norm': 0.05223412811756134, 'learning_rate': 4.367816091954024e-05, 'epoch': 0.38}\n",
|
| 403 |
+
"{'loss': 0.0023, 'grad_norm': 0.041366685181856155, 'learning_rate': 4.339080459770115e-05, 'epoch': 0.4}\n",
|
| 404 |
+
"{'loss': 0.0019, 'grad_norm': 0.03711611405014992, 'learning_rate': 4.3103448275862066e-05, 'epoch': 0.41}\n",
|
| 405 |
+
"{'loss': 0.0016, 'grad_norm': 0.030888166278600693, 'learning_rate': 4.2816091954022994e-05, 'epoch': 0.43}\n",
|
| 406 |
+
"{'loss': 0.0015, 'grad_norm': 0.027964089065790176, 'learning_rate': 4.252873563218391e-05, 'epoch': 0.45}\n",
|
| 407 |
+
"{'loss': 0.0013, 'grad_norm': 0.025312568992376328, 'learning_rate': 4.224137931034483e-05, 'epoch': 0.47}\n",
|
| 408 |
+
"{'loss': 0.0012, 'grad_norm': 0.02383635751903057, 'learning_rate': 4.195402298850575e-05, 'epoch': 0.48}\n",
|
| 409 |
+
"{'loss': 0.0011, 'grad_norm': 0.021753674373030663, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.5}\n",
|
| 410 |
+
"{'loss': 0.001, 'grad_norm': 0.019886162132024765, 'learning_rate': 4.1379310344827587e-05, 'epoch': 0.52}\n",
|
| 411 |
+
"{'loss': 0.0009, 'grad_norm': 0.018876325339078903, 'learning_rate': 4.109195402298851e-05, 'epoch': 0.53}\n",
|
| 412 |
+
"{'loss': 0.0008, 'grad_norm': 0.017580321058630943, 'learning_rate': 4.080459770114943e-05, 'epoch': 0.55}\n",
|
| 413 |
+
"{'loss': 0.0008, 'grad_norm': 0.015544956550002098, 'learning_rate': 4.0517241379310344e-05, 'epoch': 0.57}\n",
|
| 414 |
+
"{'loss': 0.0007, 'grad_norm': 0.015221121720969677, 'learning_rate': 4.0229885057471265e-05, 'epoch': 0.59}\n",
|
| 415 |
+
"{'loss': 0.0007, 'grad_norm': 0.014960471540689468, 'learning_rate': 3.9942528735632186e-05, 'epoch': 0.6}\n",
|
| 416 |
+
"{'loss': 0.0006, 'grad_norm': 0.013560828752815723, 'learning_rate': 3.965517241379311e-05, 'epoch': 0.62}\n",
|
| 417 |
+
"{'loss': 0.0006, 'grad_norm': 0.013700570911169052, 'learning_rate': 3.936781609195402e-05, 'epoch': 0.64}\n",
|
| 418 |
+
"{'loss': 0.0006, 'grad_norm': 0.011968374252319336, 'learning_rate': 3.908045977011495e-05, 'epoch': 0.66}\n",
|
| 419 |
+
"{'loss': 0.0005, 'grad_norm': 0.011644795536994934, 'learning_rate': 3.8793103448275865e-05, 'epoch': 0.67}\n",
|
| 420 |
+
"{'loss': 0.0005, 'grad_norm': 0.011345883831381798, 'learning_rate': 3.850574712643678e-05, 'epoch': 0.69}\n",
|
| 421 |
+
"{'loss': 0.0005, 'grad_norm': 0.010393058881163597, 'learning_rate': 3.82183908045977e-05, 'epoch': 0.71}\n",
|
| 422 |
+
"{'loss': 0.0005, 'grad_norm': 0.010386484675109386, 'learning_rate': 3.793103448275862e-05, 'epoch': 0.72}\n",
|
| 423 |
+
"{'loss': 0.0004, 'grad_norm': 0.009744665585458279, 'learning_rate': 3.764367816091954e-05, 'epoch': 0.74}\n",
|
| 424 |
+
"{'loss': 0.0004, 'grad_norm': 0.009590468369424343, 'learning_rate': 3.735632183908046e-05, 'epoch': 0.76}\n",
|
| 425 |
+
"{'loss': 0.0004, 'grad_norm': 0.009154150262475014, 'learning_rate': 3.7068965517241385e-05, 'epoch': 0.78}\n",
|
| 426 |
+
"{'loss': 0.0004, 'grad_norm': 0.008997919037938118, 'learning_rate': 3.67816091954023e-05, 'epoch': 0.79}\n",
|
| 427 |
+
"{'loss': 0.0004, 'grad_norm': 0.008509515784680843, 'learning_rate': 3.649425287356322e-05, 'epoch': 0.81}\n",
|
| 428 |
+
"{'loss': 0.0004, 'grad_norm': 0.008223678916692734, 'learning_rate': 3.620689655172414e-05, 'epoch': 0.83}\n",
|
| 429 |
+
"{'loss': 0.0003, 'grad_norm': 0.00758435670286417, 'learning_rate': 3.591954022988506e-05, 'epoch': 0.84}\n",
|
| 430 |
+
"{'loss': 0.0003, 'grad_norm': 0.0074744271114468575, 'learning_rate': 3.563218390804598e-05, 'epoch': 0.86}\n",
|
| 431 |
+
"{'loss': 0.0003, 'grad_norm': 0.007454875390976667, 'learning_rate': 3.53448275862069e-05, 'epoch': 0.88}\n",
|
| 432 |
+
"{'loss': 0.0003, 'grad_norm': 0.007157924585044384, 'learning_rate': 3.505747126436782e-05, 'epoch': 0.9}\n",
|
| 433 |
+
"{'loss': 0.0003, 'grad_norm': 0.006946589332073927, 'learning_rate': 3.4770114942528735e-05, 'epoch': 0.91}\n",
|
| 434 |
+
"{'loss': 0.0003, 'grad_norm': 0.0067284563556313515, 'learning_rate': 3.4482758620689657e-05, 'epoch': 0.93}\n",
|
| 435 |
+
"{'loss': 0.0003, 'grad_norm': 0.00652291439473629, 'learning_rate': 3.419540229885058e-05, 'epoch': 0.95}\n",
|
| 436 |
+
"{'loss': 0.0003, 'grad_norm': 0.006468599662184715, 'learning_rate': 3.390804597701149e-05, 'epoch': 0.97}\n",
|
| 437 |
+
"{'loss': 0.0003, 'grad_norm': 0.0061700050719082355, 'learning_rate': 3.3620689655172414e-05, 'epoch': 0.98}\n",
|
| 438 |
+
"{'loss': 0.0002, 'grad_norm': 0.005980886053293943, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}\n"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"data": {
|
| 443 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 444 |
+
"model_id": "9fe6f453d3dc4644b7e7adf483e8064b",
|
| 445 |
+
"version_major": 2,
|
| 446 |
+
"version_minor": 0
|
| 447 |
+
},
|
| 448 |
+
"text/plain": [
|
| 449 |
+
" 0%| | 0/580 [00:00<?, ?it/s]"
|
| 450 |
+
]
|
| 451 |
+
},
|
| 452 |
+
"metadata": {},
|
| 453 |
+
"output_type": "display_data"
|
| 454 |
+
},
|
| 455 |
+
{
|
| 456 |
+
"name": "stdout",
|
| 457 |
+
"output_type": "stream",
|
| 458 |
+
"text": [
|
| 459 |
+
"{'eval_loss': 0.00017967642634175718, 'eval_runtime': 566.6055, 'eval_samples_per_second': 8.179, 'eval_steps_per_second': 1.024, 'epoch': 1.0}\n",
|
| 460 |
+
"{'loss': 0.0002, 'grad_norm': 0.005443067755550146, 'learning_rate': 3.3045977011494256e-05, 'epoch': 1.02}\n",
|
| 461 |
+
"{'loss': 0.0002, 'grad_norm': 0.005919346585869789, 'learning_rate': 3.275862068965517e-05, 'epoch': 1.03}\n",
|
| 462 |
+
"{'loss': 0.0002, 'grad_norm': 0.00538916140794754, 'learning_rate': 3.24712643678161e-05, 'epoch': 1.05}\n",
|
| 463 |
+
"{'loss': 0.0002, 'grad_norm': 0.00514333276078105, 'learning_rate': 3.218390804597701e-05, 'epoch': 1.07}\n",
|
| 464 |
+
"{'loss': 0.0002, 'grad_norm': 0.005011783912777901, 'learning_rate': 3.1896551724137935e-05, 'epoch': 1.09}\n",
|
| 465 |
+
"{'loss': 0.0002, 'grad_norm': 0.005112846381962299, 'learning_rate': 3.160919540229885e-05, 'epoch': 1.1}\n",
|
| 466 |
+
"{'loss': 0.0002, 'grad_norm': 0.004895139951258898, 'learning_rate': 3.132183908045977e-05, 'epoch': 1.12}\n",
|
| 467 |
+
"{'loss': 0.0002, 'grad_norm': 0.004565018694847822, 'learning_rate': 3.103448275862069e-05, 'epoch': 1.14}\n",
|
| 468 |
+
"{'loss': 0.0002, 'grad_norm': 0.00477330107241869, 'learning_rate': 3.0747126436781606e-05, 'epoch': 1.16}\n",
|
| 469 |
+
"{'loss': 0.0002, 'grad_norm': 0.004563583992421627, 'learning_rate': 3.045977011494253e-05, 'epoch': 1.17}\n",
|
| 470 |
+
"{'loss': 0.0002, 'grad_norm': 0.004568100906908512, 'learning_rate': 3.017241379310345e-05, 'epoch': 1.19}\n",
|
| 471 |
+
"{'loss': 0.0002, 'grad_norm': 0.0046887630596756935, 'learning_rate': 2.988505747126437e-05, 'epoch': 1.21}\n",
|
| 472 |
+
"{'loss': 0.0002, 'grad_norm': 0.004261981230229139, 'learning_rate': 2.9597701149425288e-05, 'epoch': 1.22}\n",
|
| 473 |
+
"{'loss': 0.0002, 'grad_norm': 0.004290647804737091, 'learning_rate': 2.9310344827586206e-05, 'epoch': 1.24}\n",
|
| 474 |
+
"{'loss': 0.0002, 'grad_norm': 0.004188802093267441, 'learning_rate': 2.9022988505747127e-05, 'epoch': 1.26}\n",
|
| 475 |
+
"{'loss': 0.0002, 'grad_norm': 0.003917949739843607, 'learning_rate': 2.8735632183908045e-05, 'epoch': 1.28}\n",
|
| 476 |
+
"{'loss': 0.0002, 'grad_norm': 0.003940439783036709, 'learning_rate': 2.844827586206897e-05, 'epoch': 1.29}\n",
|
| 477 |
+
"{'loss': 0.0002, 'grad_norm': 0.004128696396946907, 'learning_rate': 2.8160919540229884e-05, 'epoch': 1.31}\n",
|
| 478 |
+
"{'loss': 0.0002, 'grad_norm': 0.004070794675499201, 'learning_rate': 2.787356321839081e-05, 'epoch': 1.33}\n",
|
| 479 |
+
"{'loss': 0.0002, 'grad_norm': 0.0037206218112260103, 'learning_rate': 2.7586206896551727e-05, 'epoch': 1.34}\n",
|
| 480 |
+
"{'loss': 0.0001, 'grad_norm': 0.00400934973731637, 'learning_rate': 2.7298850574712648e-05, 'epoch': 1.36}\n",
|
| 481 |
+
"{'loss': 0.0001, 'grad_norm': 0.0037558332551270723, 'learning_rate': 2.7011494252873566e-05, 'epoch': 1.38}\n",
|
| 482 |
+
"{'loss': 0.0001, 'grad_norm': 0.0035916264168918133, 'learning_rate': 2.672413793103448e-05, 'epoch': 1.4}\n",
|
| 483 |
+
"{'loss': 0.0001, 'grad_norm': 0.003707454539835453, 'learning_rate': 2.6436781609195405e-05, 'epoch': 1.41}\n",
|
| 484 |
+
"{'loss': 0.0001, 'grad_norm': 0.0034801277797669172, 'learning_rate': 2.6149425287356323e-05, 'epoch': 1.43}\n",
|
| 485 |
+
"{'loss': 0.0001, 'grad_norm': 0.003501839004456997, 'learning_rate': 2.5862068965517244e-05, 'epoch': 1.45}\n",
|
| 486 |
+
"{'loss': 0.0001, 'grad_norm': 0.003458078484982252, 'learning_rate': 2.5574712643678162e-05, 'epoch': 1.47}\n",
|
| 487 |
+
"{'loss': 0.0001, 'grad_norm': 0.0031666585709899664, 'learning_rate': 2.5287356321839083e-05, 'epoch': 1.48}\n",
|
| 488 |
+
"{'loss': 0.0001, 'grad_norm': 0.0033736126497387886, 'learning_rate': 2.5e-05, 'epoch': 1.5}\n",
|
| 489 |
+
"{'loss': 0.0001, 'grad_norm': 0.003289664164185524, 'learning_rate': 2.4712643678160922e-05, 'epoch': 1.52}\n",
|
| 490 |
+
"{'loss': 0.0001, 'grad_norm': 0.0031710772309452295, 'learning_rate': 2.442528735632184e-05, 'epoch': 1.53}\n",
|
| 491 |
+
"{'loss': 0.0001, 'grad_norm': 0.0029777924064546824, 'learning_rate': 2.413793103448276e-05, 'epoch': 1.55}\n",
|
| 492 |
+
"{'loss': 0.0001, 'grad_norm': 0.003144340356811881, 'learning_rate': 2.385057471264368e-05, 'epoch': 1.57}\n",
|
| 493 |
+
"{'loss': 0.0001, 'grad_norm': 0.0029524925630539656, 'learning_rate': 2.3563218390804597e-05, 'epoch': 1.59}\n",
|
| 494 |
+
"{'loss': 0.0001, 'grad_norm': 0.002912016585469246, 'learning_rate': 2.327586206896552e-05, 'epoch': 1.6}\n",
|
| 495 |
+
"{'loss': 0.0001, 'grad_norm': 0.002869528019800782, 'learning_rate': 2.2988505747126437e-05, 'epoch': 1.62}\n",
|
| 496 |
+
"{'loss': 0.0001, 'grad_norm': 0.002966447500512004, 'learning_rate': 2.2701149425287358e-05, 'epoch': 1.64}\n",
|
| 497 |
+
"{'loss': 0.0001, 'grad_norm': 0.0027959959115833044, 'learning_rate': 2.2413793103448276e-05, 'epoch': 1.66}\n",
|
| 498 |
+
"{'loss': 0.0001, 'grad_norm': 0.0030403095297515392, 'learning_rate': 2.2126436781609197e-05, 'epoch': 1.67}\n",
|
| 499 |
+
"{'loss': 0.0001, 'grad_norm': 0.0026587999891489744, 'learning_rate': 2.183908045977012e-05, 'epoch': 1.69}\n",
|
| 500 |
+
"{'loss': 0.0001, 'grad_norm': 0.002689346671104431, 'learning_rate': 2.1551724137931033e-05, 'epoch': 1.71}\n",
|
| 501 |
+
"{'loss': 0.0001, 'grad_norm': 0.002710141707211733, 'learning_rate': 2.1264367816091954e-05, 'epoch': 1.72}\n",
|
| 502 |
+
"{'loss': 0.0001, 'grad_norm': 0.002674366347491741, 'learning_rate': 2.0977011494252875e-05, 'epoch': 1.74}\n",
|
| 503 |
+
"{'loss': 0.0001, 'grad_norm': 0.0026578502729535103, 'learning_rate': 2.0689655172413793e-05, 'epoch': 1.76}\n",
|
| 504 |
+
"{'loss': 0.0001, 'grad_norm': 0.00243232655338943, 'learning_rate': 2.0402298850574715e-05, 'epoch': 1.78}\n",
|
| 505 |
+
"{'loss': 0.0001, 'grad_norm': 0.0025773164816200733, 'learning_rate': 2.0114942528735632e-05, 'epoch': 1.79}\n",
|
| 506 |
+
"{'loss': 0.0001, 'grad_norm': 0.0024439615663141012, 'learning_rate': 1.9827586206896554e-05, 'epoch': 1.81}\n",
|
| 507 |
+
"{'loss': 0.0001, 'grad_norm': 0.0024733347818255424, 'learning_rate': 1.9540229885057475e-05, 'epoch': 1.83}\n",
|
| 508 |
+
"{'loss': 0.0001, 'grad_norm': 0.002439699834212661, 'learning_rate': 1.925287356321839e-05, 'epoch': 1.84}\n",
|
| 509 |
+
"{'loss': 0.0001, 'grad_norm': 0.0025980097707360983, 'learning_rate': 1.896551724137931e-05, 'epoch': 1.86}\n",
|
| 510 |
+
"{'loss': 0.0001, 'grad_norm': 0.002387199318036437, 'learning_rate': 1.867816091954023e-05, 'epoch': 1.88}\n",
|
| 511 |
+
"{'loss': 0.0001, 'grad_norm': 0.0023106117732822895, 'learning_rate': 1.839080459770115e-05, 'epoch': 1.9}\n",
|
| 512 |
+
"{'loss': 0.0001, 'grad_norm': 0.0023344189394265413, 'learning_rate': 1.810344827586207e-05, 'epoch': 1.91}\n",
|
| 513 |
+
"{'loss': 0.0001, 'grad_norm': 0.0023740960750728846, 'learning_rate': 1.781609195402299e-05, 'epoch': 1.93}\n",
|
| 514 |
+
"{'loss': 0.0001, 'grad_norm': 0.002346088644117117, 'learning_rate': 1.752873563218391e-05, 'epoch': 1.95}\n",
|
| 515 |
+
"{'loss': 0.0001, 'grad_norm': 0.002391340211033821, 'learning_rate': 1.7241379310344828e-05, 'epoch': 1.97}\n",
|
| 516 |
+
"{'loss': 0.0001, 'grad_norm': 0.002250733319669962, 'learning_rate': 1.6954022988505746e-05, 'epoch': 1.98}\n",
|
| 517 |
+
"{'loss': 0.0001, 'grad_norm': 0.002164299599826336, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}\n"
|
| 518 |
+
]
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"data": {
|
| 522 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 523 |
+
"model_id": "ed9239dd5b4346509d243cbc378bcf44",
|
| 524 |
+
"version_major": 2,
|
| 525 |
+
"version_minor": 0
|
| 526 |
+
},
|
| 527 |
+
"text/plain": [
|
| 528 |
+
" 0%| | 0/580 [00:00<?, ?it/s]"
|
| 529 |
+
]
|
| 530 |
+
},
|
| 531 |
+
"metadata": {},
|
| 532 |
+
"output_type": "display_data"
|
| 533 |
+
},
|
| 534 |
+
{
|
| 535 |
+
"name": "stdout",
|
| 536 |
+
"output_type": "stream",
|
| 537 |
+
"text": [
|
| 538 |
+
"{'eval_loss': 6.0841484810225666e-05, 'eval_runtime': 581.8939, 'eval_samples_per_second': 7.964, 'eval_steps_per_second': 0.997, 'epoch': 2.0}\n",
|
| 539 |
+
"{'loss': 0.0001, 'grad_norm': 0.0022026619408279657, 'learning_rate': 1.6379310344827585e-05, 'epoch': 2.02}\n",
|
| 540 |
+
"{'loss': 0.0001, 'grad_norm': 0.002242449903860688, 'learning_rate': 1.6091954022988507e-05, 'epoch': 2.03}\n",
|
| 541 |
+
"{'loss': 0.0001, 'grad_norm': 0.002464097458869219, 'learning_rate': 1.5804597701149425e-05, 'epoch': 2.05}\n",
|
| 542 |
+
"{'loss': 0.0001, 'grad_norm': 0.0022003022022545338, 'learning_rate': 1.5517241379310346e-05, 'epoch': 2.07}\n",
|
| 543 |
+
"{'loss': 0.0001, 'grad_norm': 0.0021785416174679995, 'learning_rate': 1.5229885057471265e-05, 'epoch': 2.09}\n",
|
| 544 |
+
"{'loss': 0.0001, 'grad_norm': 0.0021638190373778343, 'learning_rate': 1.4942528735632185e-05, 'epoch': 2.1}\n",
|
| 545 |
+
"{'loss': 0.0001, 'grad_norm': 0.0022439502645283937, 'learning_rate': 1.4655172413793103e-05, 'epoch': 2.12}\n",
|
| 546 |
+
"{'loss': 0.0001, 'grad_norm': 0.0020717435982078314, 'learning_rate': 1.4367816091954022e-05, 'epoch': 2.14}\n",
|
| 547 |
+
"{'loss': 0.0001, 'grad_norm': 0.0020531516056507826, 'learning_rate': 1.4080459770114942e-05, 'epoch': 2.16}\n",
|
| 548 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019899189937859774, 'learning_rate': 1.3793103448275863e-05, 'epoch': 2.17}\n",
|
| 549 |
+
"{'loss': 0.0001, 'grad_norm': 0.0020303332712501287, 'learning_rate': 1.3505747126436783e-05, 'epoch': 2.19}\n",
|
| 550 |
+
"{'loss': 0.0001, 'grad_norm': 0.0021102086175233126, 'learning_rate': 1.3218390804597702e-05, 'epoch': 2.21}\n",
|
| 551 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019932142458856106, 'learning_rate': 1.2931034482758622e-05, 'epoch': 2.22}\n",
|
| 552 |
+
"{'loss': 0.0001, 'grad_norm': 0.002080111298710108, 'learning_rate': 1.2643678160919542e-05, 'epoch': 2.24}\n",
|
| 553 |
+
"{'loss': 0.0001, 'grad_norm': 0.0020179597195237875, 'learning_rate': 1.2356321839080461e-05, 'epoch': 2.26}\n",
|
| 554 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019549003336578608, 'learning_rate': 1.206896551724138e-05, 'epoch': 2.28}\n",
|
| 555 |
+
"{'loss': 0.0001, 'grad_norm': 0.0020865327678620815, 'learning_rate': 1.1781609195402299e-05, 'epoch': 2.29}\n",
|
| 556 |
+
"{'loss': 0.0001, 'grad_norm': 0.0018828624160960317, 'learning_rate': 1.1494252873563218e-05, 'epoch': 2.31}\n",
|
| 557 |
+
"{'loss': 0.0001, 'grad_norm': 0.0018662698566913605, 'learning_rate': 1.1206896551724138e-05, 'epoch': 2.33}\n",
|
| 558 |
+
"{'loss': 0.0001, 'grad_norm': 0.001857285387814045, 'learning_rate': 1.091954022988506e-05, 'epoch': 2.34}\n",
|
| 559 |
+
"{'loss': 0.0001, 'grad_norm': 0.001844724640250206, 'learning_rate': 1.0632183908045977e-05, 'epoch': 2.36}\n",
|
| 560 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017886353889480233, 'learning_rate': 1.0344827586206897e-05, 'epoch': 2.38}\n",
|
| 561 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019668207969516516, 'learning_rate': 1.0057471264367816e-05, 'epoch': 2.4}\n",
|
| 562 |
+
"{'loss': 0.0001, 'grad_norm': 0.0018605877412483096, 'learning_rate': 9.770114942528738e-06, 'epoch': 2.41}\n",
|
| 563 |
+
"{'loss': 0.0001, 'grad_norm': 0.0018027386395260692, 'learning_rate': 9.482758620689655e-06, 'epoch': 2.43}\n",
|
| 564 |
+
"{'loss': 0.0001, 'grad_norm': 0.0018370413454249501, 'learning_rate': 9.195402298850575e-06, 'epoch': 2.45}\n",
|
| 565 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019249517936259508, 'learning_rate': 8.908045977011495e-06, 'epoch': 2.47}\n",
|
| 566 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019102703081443906, 'learning_rate': 8.620689655172414e-06, 'epoch': 2.48}\n",
|
| 567 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019130830187350512, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}\n",
|
| 568 |
+
"{'loss': 0.0001, 'grad_norm': 0.0019449306419119239, 'learning_rate': 8.045977011494253e-06, 'epoch': 2.52}\n",
|
| 569 |
+
"{'loss': 0.0001, 'grad_norm': 0.001796119031496346, 'learning_rate': 7.758620689655173e-06, 'epoch': 2.53}\n",
|
| 570 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017440468072891235, 'learning_rate': 7.4712643678160925e-06, 'epoch': 2.55}\n",
|
| 571 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017786856042221189, 'learning_rate': 7.183908045977011e-06, 'epoch': 2.57}\n",
|
| 572 |
+
"{'loss': 0.0001, 'grad_norm': 0.0018597355810925364, 'learning_rate': 6.896551724137932e-06, 'epoch': 2.59}\n",
|
| 573 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017648187931627035, 'learning_rate': 6.609195402298851e-06, 'epoch': 2.6}\n",
|
| 574 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017601278377696872, 'learning_rate': 6.321839080459771e-06, 'epoch': 2.62}\n",
|
| 575 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017502185655757785, 'learning_rate': 6.03448275862069e-06, 'epoch': 2.64}\n",
|
| 576 |
+
"{'loss': 0.0001, 'grad_norm': 0.001753892400301993, 'learning_rate': 5.747126436781609e-06, 'epoch': 2.66}\n",
|
| 577 |
+
"{'loss': 0.0001, 'grad_norm': 0.0016946644755080342, 'learning_rate': 5.45977011494253e-06, 'epoch': 2.67}\n",
|
| 578 |
+
"{'loss': 0.0001, 'grad_norm': 0.001783599378541112, 'learning_rate': 5.172413793103448e-06, 'epoch': 2.69}\n",
|
| 579 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017759180627763271, 'learning_rate': 4.885057471264369e-06, 'epoch': 2.71}\n",
|
| 580 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017218819120898843, 'learning_rate': 4.5977011494252875e-06, 'epoch': 2.72}\n",
|
| 581 |
+
"{'loss': 0.0001, 'grad_norm': 0.0016811942914500833, 'learning_rate': 4.310344827586207e-06, 'epoch': 2.74}\n",
|
| 582 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017582618165761232, 'learning_rate': 4.022988505747127e-06, 'epoch': 2.76}\n",
|
| 583 |
+
"{'loss': 0.0001, 'grad_norm': 0.001848816522397101, 'learning_rate': 3.7356321839080462e-06, 'epoch': 2.78}\n",
|
| 584 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017523870337754488, 'learning_rate': 3.448275862068966e-06, 'epoch': 2.79}\n",
|
| 585 |
+
"{'loss': 0.0001, 'grad_norm': 0.001707645715214312, 'learning_rate': 3.1609195402298854e-06, 'epoch': 2.81}\n",
|
| 586 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017925987485796213, 'learning_rate': 2.8735632183908046e-06, 'epoch': 2.83}\n",
|
| 587 |
+
"{'loss': 0.0001, 'grad_norm': 0.001785592525266111, 'learning_rate': 2.586206896551724e-06, 'epoch': 2.84}\n",
|
| 588 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017369745764881372, 'learning_rate': 2.2988505747126437e-06, 'epoch': 2.86}\n",
|
| 589 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017363271908834577, 'learning_rate': 2.0114942528735633e-06, 'epoch': 2.88}\n",
|
| 590 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017762900097295642, 'learning_rate': 1.724137931034483e-06, 'epoch': 2.9}\n",
|
| 591 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017800360219553113, 'learning_rate': 1.4367816091954023e-06, 'epoch': 2.91}\n",
|
| 592 |
+
"{'loss': 0.0001, 'grad_norm': 0.0016894094878807664, 'learning_rate': 1.1494252873563219e-06, 'epoch': 2.93}\n",
|
| 593 |
+
"{'loss': 0.0001, 'grad_norm': 0.0016883889911696315, 'learning_rate': 8.620689655172415e-07, 'epoch': 2.95}\n",
|
| 594 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017332383431494236, 'learning_rate': 5.747126436781609e-07, 'epoch': 2.97}\n",
|
| 595 |
+
"{'loss': 0.0001, 'grad_norm': 0.0018206291133537889, 'learning_rate': 2.8735632183908047e-07, 'epoch': 2.98}\n",
|
| 596 |
+
"{'loss': 0.0001, 'grad_norm': 0.0017392894951626658, 'learning_rate': 0.0, 'epoch': 3.0}\n"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"data": {
|
| 601 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 602 |
+
"model_id": "de595459d7b642babd74e21ce354d064",
|
| 603 |
+
"version_major": 2,
|
| 604 |
+
"version_minor": 0
|
| 605 |
+
},
|
| 606 |
+
"text/plain": [
|
| 607 |
+
" 0%| | 0/580 [00:00<?, ?it/s]"
|
| 608 |
+
]
|
| 609 |
+
},
|
| 610 |
+
"metadata": {},
|
| 611 |
+
"output_type": "display_data"
|
| 612 |
+
},
|
| 613 |
+
{
|
| 614 |
+
"name": "stdout",
|
| 615 |
+
"output_type": "stream",
|
| 616 |
+
"text": [
|
| 617 |
+
"{'eval_loss': 4.472154250834137e-05, 'eval_runtime': 541.083, 'eval_samples_per_second': 8.564, 'eval_steps_per_second': 1.072, 'epoch': 3.0}\n",
|
| 618 |
+
"{'train_runtime': 10729.5411, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'train_loss': 0.01232232698998762, 'epoch': 3.0}\n"
|
| 619 |
+
]
|
| 620 |
+
},
|
| 621 |
+
{
|
| 622 |
+
"data": {
|
| 623 |
+
"text/plain": [
|
| 624 |
+
"TrainOutput(global_step=1740, training_loss=0.01232232698998762, metrics={'train_runtime': 10729.5411, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'total_flos': 2.5228134820702045e+17, 'train_loss': 0.01232232698998762, 'epoch': 3.0})"
|
| 625 |
+
]
|
| 626 |
+
},
|
| 627 |
+
"execution_count": 33,
|
| 628 |
+
"metadata": {},
|
| 629 |
+
"output_type": "execute_result"
|
| 630 |
+
}
|
| 631 |
+
],
|
| 632 |
+
"source": [
|
| 633 |
+
"trainer.train()\n"
|
| 634 |
+
]
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"cell_type": "markdown",
|
| 638 |
+
"metadata": {},
|
| 639 |
+
"source": [
|
| 640 |
+
"## Save The Model"
|
| 641 |
+
]
|
| 642 |
+
},
|
| 643 |
+
{
|
| 644 |
+
"cell_type": "code",
|
| 645 |
+
"execution_count": 34,
|
| 646 |
+
"metadata": {},
|
| 647 |
+
"outputs": [
|
| 648 |
+
{
|
| 649 |
+
"data": {
|
| 650 |
+
"text/plain": [
|
| 651 |
+
"[]"
|
| 652 |
+
]
|
| 653 |
+
},
|
| 654 |
+
"execution_count": 34,
|
| 655 |
+
"metadata": {},
|
| 656 |
+
"output_type": "execute_result"
|
| 657 |
+
}
|
| 658 |
+
],
|
| 659 |
+
"source": [
|
| 660 |
+
"# Save the trained model and processor\n",
|
| 661 |
+
"trainer.save_model(\"./trained_model\") # Saves the model to the specified directory\n",
|
| 662 |
+
"processor.save_pretrained(\"./trained_model\") # Saves the processor as well\n"
|
| 663 |
+
]
|
| 664 |
+
},
|
| 665 |
+
{
|
| 666 |
+
"cell_type": "code",
|
| 667 |
+
"execution_count": 35,
|
| 668 |
+
"metadata": {},
|
| 669 |
+
"outputs": [
|
| 670 |
+
{
|
| 671 |
+
"data": {
|
| 672 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 673 |
+
"model_id": "ae19dee506db46f4b39d402e43d16276",
|
| 674 |
+
"version_major": 2,
|
| 675 |
+
"version_minor": 0
|
| 676 |
+
},
|
| 677 |
+
"text/plain": [
|
| 678 |
+
" 0%| | 0/580 [00:00<?, ?it/s]"
|
| 679 |
+
]
|
| 680 |
+
},
|
| 681 |
+
"metadata": {},
|
| 682 |
+
"output_type": "display_data"
|
| 683 |
+
},
|
| 684 |
+
{
|
| 685 |
+
"name": "stdout",
|
| 686 |
+
"output_type": "stream",
|
| 687 |
+
"text": [
|
| 688 |
+
"Evaluation Metrics:\n",
|
| 689 |
+
"eval_loss: 4.472154250834137e-05\n",
|
| 690 |
+
"eval_runtime: 527.4045\n",
|
| 691 |
+
"eval_samples_per_second: 8.786\n",
|
| 692 |
+
"eval_steps_per_second: 1.1\n",
|
| 693 |
+
"epoch: 3.0\n"
|
| 694 |
+
]
|
| 695 |
+
}
|
| 696 |
+
],
|
| 697 |
+
"source": [
|
| 698 |
+
"# Evaluate the model on the validation dataset\n",
|
| 699 |
+
"evaluation_metrics = trainer.evaluate()\n",
|
| 700 |
+
"\n",
|
| 701 |
+
"# Print the evaluation metrics\n",
|
| 702 |
+
"print(\"Evaluation Metrics:\")\n",
|
| 703 |
+
"for metric, value in evaluation_metrics.items():\n",
|
| 704 |
+
" print(f\"{metric}: {value}\")\n"
|
| 705 |
+
]
|
| 706 |
+
},
|
| 707 |
+
{
|
| 708 |
+
"cell_type": "code",
|
| 709 |
+
"execution_count": 10,
|
| 710 |
+
"metadata": {},
|
| 711 |
+
"outputs": [],
|
| 712 |
+
"source": [
|
| 713 |
+
"from transformers import AutoProcessor, AutoModelForAudioClassification\n",
|
| 714 |
+
"\n",
|
| 715 |
+
"# Load the trained model and processor\n",
|
| 716 |
+
"model_path = \"./trained_model\" # Path to your saved model\n",
|
| 717 |
+
"model = AutoModelForAudioClassification.from_pretrained(model_path)\n",
|
| 718 |
+
"processor = AutoProcessor.from_pretrained(model_path)\n"
|
| 719 |
+
]
|
| 720 |
+
},
|
| 721 |
+
{
|
| 722 |
+
"cell_type": "markdown",
|
| 723 |
+
"metadata": {},
|
| 724 |
+
"source": [
|
| 725 |
+
"## Single Audio Testing"
|
| 726 |
+
]
|
| 727 |
+
},
|
| 728 |
+
{
|
| 729 |
+
"cell_type": "code",
|
| 730 |
+
"execution_count": 34,
|
| 731 |
+
"metadata": {},
|
| 732 |
+
"outputs": [],
|
| 733 |
+
"source": [
|
| 734 |
+
"def prepare_audio(file_path, sampling_rate=16000, duration=10):\n",
|
| 735 |
+
" \"\"\"\n",
|
| 736 |
+
" Prepares audio by loading, resampling, and returning it in manageable chunks.\n",
|
| 737 |
+
" \n",
|
| 738 |
+
" Parameters:\n",
|
| 739 |
+
" - file_path: Path to the audio file.\n",
|
| 740 |
+
" - sampling_rate: Target sampling rate for the audio.\n",
|
| 741 |
+
" - duration: Duration in seconds for each chunk.\n",
|
| 742 |
+
" \n",
|
| 743 |
+
" Returns:\n",
|
| 744 |
+
" - A list of audio chunks, each as a numpy array.\n",
|
| 745 |
+
" \"\"\"\n",
|
| 746 |
+
" # Load and resample the audio file\n",
|
| 747 |
+
" waveform, original_sampling_rate = torchaudio.load(file_path)\n",
|
| 748 |
+
" \n",
|
| 749 |
+
" # Convert stereo to mono if necessary\n",
|
| 750 |
+
" if waveform.shape[0] > 1: # More than 1 channel\n",
|
| 751 |
+
" waveform = torch.mean(waveform, dim=0, keepdim=True)\n",
|
| 752 |
+
" \n",
|
| 753 |
+
" # Resample if needed\n",
|
| 754 |
+
" if original_sampling_rate != sampling_rate:\n",
|
| 755 |
+
" resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)\n",
|
| 756 |
+
" waveform = resampler(waveform)\n",
|
| 757 |
+
" \n",
|
| 758 |
+
" # Calculate chunk size in samples\n",
|
| 759 |
+
" chunk_size = sampling_rate * duration\n",
|
| 760 |
+
" audio_chunks = []\n",
|
| 761 |
+
"\n",
|
| 762 |
+
" # Split the audio into chunks\n",
|
| 763 |
+
" for start in range(0, waveform.shape[1], chunk_size):\n",
|
| 764 |
+
" chunk = waveform[:, start:start + chunk_size]\n",
|
| 765 |
+
" \n",
|
| 766 |
+
" # Pad the last chunk if it's shorter than the chunk size\n",
|
| 767 |
+
" if chunk.shape[1] < chunk_size:\n",
|
| 768 |
+
" padding = chunk_size - chunk.shape[1]\n",
|
| 769 |
+
" chunk = torch.nn.functional.pad(chunk, (0, padding))\n",
|
| 770 |
+
" \n",
|
| 771 |
+
" audio_chunks.append(chunk.squeeze().numpy())\n",
|
| 772 |
+
" \n",
|
| 773 |
+
" return audio_chunks\n"
|
| 774 |
+
]
|
| 775 |
+
},
|
| 776 |
+
{
|
| 777 |
+
"cell_type": "code",
|
| 778 |
+
"execution_count": 35,
|
| 779 |
+
"metadata": {},
|
| 780 |
+
"outputs": [
|
| 781 |
+
{
|
| 782 |
+
"name": "stdout",
|
| 783 |
+
"output_type": "stream",
|
| 784 |
+
"text": [
|
| 785 |
+
"Chunk shape: (160000,)\n",
|
| 786 |
+
"Logits for chunk 1: tensor([[ 4.6742, -5.1778]])\n",
|
| 787 |
+
"Chunk shape: (160000,)\n",
|
| 788 |
+
"Logits for chunk 2: tensor([[ 4.7219, -5.2332]])\n",
|
| 789 |
+
"Chunk shape: (160000,)\n",
|
| 790 |
+
"Logits for chunk 3: tensor([[ 4.7545, -5.2641]])\n",
|
| 791 |
+
"Chunk shape: (160000,)\n",
|
| 792 |
+
"Logits for chunk 4: tensor([[ 4.6714, -5.1740]])\n",
|
| 793 |
+
"Chunk shape: (160000,)\n",
|
| 794 |
+
"Logits for chunk 5: tensor([[ 4.7660, -5.2743]])\n",
|
| 795 |
+
"Chunk shape: (160000,)\n",
|
| 796 |
+
"Logits for chunk 6: tensor([[ 4.7724, -5.2836]])\n",
|
| 797 |
+
"Chunk shape: (160000,)\n",
|
| 798 |
+
"Logits for chunk 7: tensor([[ 4.7268, -5.2362]])\n",
|
| 799 |
+
"Chunk shape: (160000,)\n",
|
| 800 |
+
"Logits for chunk 8: tensor([[ 4.6898, -5.1898]])\n",
|
| 801 |
+
"Chunk shape: (160000,)\n",
|
| 802 |
+
"Logits for chunk 9: tensor([[ 4.6646, -5.1708]])\n",
|
| 803 |
+
"Chunk shape: (160000,)\n",
|
| 804 |
+
"Logits for chunk 10: tensor([[ 4.5948, -5.0867]])\n",
|
| 805 |
+
"Chunk shape: (160000,)\n",
|
| 806 |
+
"Logits for chunk 11: tensor([[ 4.7512, -5.2579]])\n",
|
| 807 |
+
"Chunk shape: (160000,)\n",
|
| 808 |
+
"Logits for chunk 12: tensor([[-4.5599, 5.0363]])\n",
|
| 809 |
+
"Chunk shape: (160000,)\n",
|
| 810 |
+
"Logits for chunk 13: tensor([[-0.4980, 0.5546]])\n",
|
| 811 |
+
"Chunk shape: (160000,)\n",
|
| 812 |
+
"Logits for chunk 14: tensor([[ 4.7295, -5.2358]])\n",
|
| 813 |
+
"Chunk shape: (160000,)\n",
|
| 814 |
+
"Logits for chunk 15: tensor([[ 4.7426, -5.2534]])\n",
|
| 815 |
+
"Chunk shape: (160000,)\n",
|
| 816 |
+
"Logits for chunk 16: tensor([[ 1.9405, -2.1493]])\n",
|
| 817 |
+
"Chunk shape: (160000,)\n",
|
| 818 |
+
"Logits for chunk 17: tensor([[ 4.7168, -5.2235]])\n",
|
| 819 |
+
"Chunk shape: (160000,)\n",
|
| 820 |
+
"Logits for chunk 18: tensor([[ 4.6801, -5.1907]])\n",
|
| 821 |
+
"Chunk shape: (160000,)\n",
|
| 822 |
+
"Logits for chunk 19: tensor([[ 4.7454, -5.2568]])\n",
|
| 823 |
+
"Chunk shape: (160000,)\n",
|
| 824 |
+
"Logits for chunk 20: tensor([[ 4.7642, -5.2723]])\n",
|
| 825 |
+
"Chunk shape: (160000,)\n",
|
| 826 |
+
"Logits for chunk 21: tensor([[ 4.7868, -5.2969]])\n",
|
| 827 |
+
"Chunk shape: (160000,)\n",
|
| 828 |
+
"Logits for chunk 22: tensor([[ 4.7600, -5.2690]])\n",
|
| 829 |
+
"Chunk shape: (160000,)\n",
|
| 830 |
+
"Logits for chunk 23: tensor([[ 4.7337, -5.2411]])\n",
|
| 831 |
+
"Chunk shape: (160000,)\n",
|
| 832 |
+
"Logits for chunk 24: tensor([[ 4.7835, -5.2943]])\n",
|
| 833 |
+
"Chunk shape: (160000,)\n",
|
| 834 |
+
"Logits for chunk 25: tensor([[ 4.7572, -5.2647]])\n",
|
| 835 |
+
"Chunk shape: (160000,)\n",
|
| 836 |
+
"Logits for chunk 26: tensor([[ 4.7485, -5.2581]])\n",
|
| 837 |
+
"Chunk shape: (160000,)\n",
|
| 838 |
+
"Logits for chunk 27: tensor([[ 4.6874, -5.2023]])\n",
|
| 839 |
+
"Chunk shape: (160000,)\n",
|
| 840 |
+
"Logits for chunk 28: tensor([[ 4.6877, -5.1922]])\n",
|
| 841 |
+
"Chunk shape: (160000,)\n",
|
| 842 |
+
"Logits for chunk 29: tensor([[ 4.7474, -5.2561]])\n",
|
| 843 |
+
"Chunk shape: (160000,)\n",
|
| 844 |
+
"Logits for chunk 30: tensor([[-4.3064, 4.7629]])\n",
|
| 845 |
+
"Chunk shape: (160000,)\n",
|
| 846 |
+
"Logits for chunk 31: tensor([[-3.8067, 4.2312]])\n",
|
| 847 |
+
"Chunk shape: (160000,)\n",
|
| 848 |
+
"Logits for chunk 32: tensor([[ 4.7217, -5.2325]])\n",
|
| 849 |
+
"Chunk shape: (160000,)\n",
|
| 850 |
+
"Logits for chunk 33: tensor([[ 4.7798, -5.2913]])\n",
|
| 851 |
+
"Chunk shape: (160000,)\n",
|
| 852 |
+
"Logits for chunk 34: tensor([[ 4.7214, -5.2355]])\n",
|
| 853 |
+
"Chunk shape: (160000,)\n",
|
| 854 |
+
"Logits for chunk 35: tensor([[ 4.7116, -5.2192]])\n",
|
| 855 |
+
"Chunk shape: (160000,)\n",
|
| 856 |
+
"Logits for chunk 36: tensor([[ 4.6687, -5.1812]])\n",
|
| 857 |
+
"Chunk shape: (160000,)\n",
|
| 858 |
+
"Logits for chunk 37: tensor([[-0.8128, 0.9402]])\n",
|
| 859 |
+
"Chunk shape: (160000,)\n",
|
| 860 |
+
"Logits for chunk 38: tensor([[ 4.7259, -5.2333]])\n",
|
| 861 |
+
"Chunk shape: (160000,)\n",
|
| 862 |
+
"Logits for chunk 39: tensor([[ 4.5698, -5.0731]])\n",
|
| 863 |
+
"Chunk shape: (160000,)\n",
|
| 864 |
+
"Logits for chunk 40: tensor([[ 4.7467, -5.2544]])\n",
|
| 865 |
+
"Chunk shape: (160000,)\n",
|
| 866 |
+
"Logits for chunk 41: tensor([[ 4.7781, -5.2884]])\n",
|
| 867 |
+
"Chunk shape: (160000,)\n",
|
| 868 |
+
"Logits for chunk 42: tensor([[ 4.7243, -5.2365]])\n",
|
| 869 |
+
"Chunk shape: (160000,)\n",
|
| 870 |
+
"Logits for chunk 43: tensor([[ 3.9325, -4.3570]])\n",
|
| 871 |
+
"Chunk shape: (160000,)\n",
|
| 872 |
+
"Logits for chunk 44: tensor([[-3.8786, 4.3105]])\n",
|
| 873 |
+
"Chunk shape: (160000,)\n",
|
| 874 |
+
"Logits for chunk 45: tensor([[ 3.3633, -3.6958]])\n",
|
| 875 |
+
"Chunk shape: (160000,)\n",
|
| 876 |
+
"Logits for chunk 46: tensor([[ 4.7127, -5.2213]])\n",
|
| 877 |
+
"Chunk shape: (160000,)\n",
|
| 878 |
+
"Logits for chunk 47: tensor([[ 0.0519, -0.0359]])\n",
|
| 879 |
+
"Chunk shape: (160000,)\n",
|
| 880 |
+
"Logits for chunk 48: tensor([[ 4.7457, -5.2535]])\n",
|
| 881 |
+
"Chunk shape: (160000,)\n",
|
| 882 |
+
"Logits for chunk 49: tensor([[ 3.4856, -3.8528]])\n",
|
| 883 |
+
"Chunk shape: (160000,)\n",
|
| 884 |
+
"Logits for chunk 50: tensor([[ 4.6485, -5.1538]])\n",
|
| 885 |
+
"Chunk shape: (160000,)\n",
|
| 886 |
+
"Logits for chunk 51: tensor([[ 4.6274, -5.1355]])\n",
|
| 887 |
+
"Chunk shape: (160000,)\n",
|
| 888 |
+
"Logits for chunk 52: tensor([[ 4.6852, -5.1872]])\n",
|
| 889 |
+
"Chunk shape: (160000,)\n",
|
| 890 |
+
"Logits for chunk 53: tensor([[ 4.7341, -5.2452]])\n",
|
| 891 |
+
"Chunk shape: (160000,)\n",
|
| 892 |
+
"Logits for chunk 54: tensor([[-4.5378, 5.0152]])\n",
|
| 893 |
+
"Chunk shape: (160000,)\n",
|
| 894 |
+
"Logits for chunk 55: tensor([[ 4.6822, -5.1887]])\n",
|
| 895 |
+
"Chunk shape: (160000,)\n",
|
| 896 |
+
"Logits for chunk 56: tensor([[ 4.7186, -5.2252]])\n",
|
| 897 |
+
"Chunk shape: (160000,)\n",
|
| 898 |
+
"Logits for chunk 57: tensor([[ 4.7688, -5.2787]])\n",
|
| 899 |
+
"Chunk shape: (160000,)\n",
|
| 900 |
+
"Logits for chunk 58: tensor([[ 4.7285, -5.2342]])\n",
|
| 901 |
+
"Chunk shape: (160000,)\n",
|
| 902 |
+
"Logits for chunk 59: tensor([[ 4.7447, -5.2550]])\n",
|
| 903 |
+
"Chunk shape: (160000,)\n",
|
| 904 |
+
"Logits for chunk 60: tensor([[ 4.5292, -5.0253]])\n",
|
| 905 |
+
"Predicted Class: Fake\n"
|
| 906 |
+
]
|
| 907 |
+
}
|
| 908 |
+
],
|
| 909 |
+
"source": [
|
| 910 |
+
"def predict_audio(file_path):\n",
|
| 911 |
+
" \"\"\"\n",
|
| 912 |
+
" Predicts the class of an audio file by aggregating predictions from chunks.\n",
|
| 913 |
+
" \n",
|
| 914 |
+
" Args:\n",
|
| 915 |
+
" file_path (str): Path to the audio file.\n",
|
| 916 |
+
"\n",
|
| 917 |
+
" Returns:\n",
|
| 918 |
+
" str: Predicted class label.\n",
|
| 919 |
+
" \"\"\"\n",
|
| 920 |
+
" # Prepare audio chunks\n",
|
| 921 |
+
" audio_chunks = prepare_audio(file_path)\n",
|
| 922 |
+
" predictions = []\n",
|
| 923 |
+
"\n",
|
| 924 |
+
" for i, chunk in enumerate(audio_chunks):\n",
|
| 925 |
+
" # Prepare input for the model\n",
|
| 926 |
+
" print(f\"Chunk shape: {chunk.shape}\")\n",
|
| 927 |
+
" inputs = processor(\n",
|
| 928 |
+
" chunk, sampling_rate=16000, return_tensors=\"pt\", padding=True\n",
|
| 929 |
+
" )\n",
|
| 930 |
+
" \n",
|
| 931 |
+
" # Perform inference\n",
|
| 932 |
+
" with torch.no_grad():\n",
|
| 933 |
+
" outputs = model(**inputs)\n",
|
| 934 |
+
" logits = outputs.logits\n",
|
| 935 |
+
" print(f\"Logits for chunk {i + 1}: {logits}\") # Print the logits\n",
|
| 936 |
+
" predicted_class = torch.argmax(logits, dim=1).item()\n",
|
| 937 |
+
" predictions.append(predicted_class)\n",
|
| 938 |
+
" \n",
|
| 939 |
+
" # Aggregate predictions (e.g., majority voting)\n",
|
| 940 |
+
" aggregated_prediction = max(set(predictions), key=predictions.count)\n",
|
| 941 |
+
" \n",
|
| 942 |
+
" # Convert class ID to label\n",
|
| 943 |
+
" return model.config.id2label[aggregated_prediction]\n",
|
| 944 |
+
"\n",
|
| 945 |
+
"# Example: Test a single audio file\n",
|
| 946 |
+
"file_path = r\"D:\\Year 3 Sem 2\\Godamlah\\Deepfake\\deepfake model ver3\\data\\KAGGLE\\AUDIO\\FAKE\\biden-to-linus.wav\" # Replace with your audio file path\n",
|
| 947 |
+
"predicted_class = predict_audio(file_path)\n",
|
| 948 |
+
"print(f\"Predicted Class: {predicted_class}\")\n"
|
| 949 |
+
]
|
| 950 |
+
},
|
| 951 |
+
{
|
| 952 |
+
"cell_type": "markdown",
|
| 953 |
+
"metadata": {},
|
| 954 |
+
"source": [
|
| 955 |
+
"## Batch Testing"
|
| 956 |
+
]
|
| 957 |
+
},
|
| 958 |
+
{
|
| 959 |
+
"cell_type": "code",
|
| 960 |
+
"execution_count": 36,
|
| 961 |
+
"metadata": {},
|
| 962 |
+
"outputs": [
|
| 963 |
+
{
|
| 964 |
+
"name": "stdout",
|
| 965 |
+
"output_type": "stream",
|
| 966 |
+
"text": [
|
| 967 |
+
"Chunk shape: (160000,)\n",
|
| 968 |
+
"Logits for chunk 1: tensor([[-3.3933, 3.7590]])\n",
|
| 969 |
+
"Chunk shape: (160000,)\n",
|
| 970 |
+
"Logits for chunk 1: tensor([[-3.3933, 3.7590]])\n",
|
| 971 |
+
"Chunk shape: (160000,)\n",
|
| 972 |
+
"Logits for chunk 1: tensor([[-1.5531, 1.7190]])\n",
|
| 973 |
+
"Chunk shape: (160000,)\n",
|
| 974 |
+
"Logits for chunk 1: tensor([[-1.5917, 1.7620]])\n",
|
| 975 |
+
"Chunk shape: (160000,)\n",
|
| 976 |
+
"Logits for chunk 1: tensor([[ 4.7569, -5.2631]])\n",
|
| 977 |
+
"Chunk shape: (160000,)\n",
|
| 978 |
+
"Logits for chunk 1: tensor([[ 4.7569, -5.2630]])\n",
|
| 979 |
+
"Chunk shape: (160000,)\n",
|
| 980 |
+
"Logits for chunk 1: tensor([[-4.5033, 4.9768]])\n",
|
| 981 |
+
"Chunk shape: (160000,)\n",
|
| 982 |
+
"Logits for chunk 1: tensor([[-4.5029, 4.9765]])\n",
|
| 983 |
+
"Chunk shape: (160000,)\n",
|
| 984 |
+
"Logits for chunk 1: tensor([[ 4.7639, -5.2653]])\n",
|
| 985 |
+
"Chunk shape: (160000,)\n",
|
| 986 |
+
"Logits for chunk 1: tensor([[ 4.7639, -5.2653]])\n",
|
| 987 |
+
"{'file': 'human voice 1 to mr beast.mp3', 'predicted_class': 'Real'}\n",
|
| 988 |
+
"{'file': 'human voice 1 to mr beast.wav', 'predicted_class': 'Real'}\n",
|
| 989 |
+
"{'file': 'human voice 1.mp3', 'predicted_class': 'Real'}\n",
|
| 990 |
+
"{'file': 'human voice 1.wav', 'predicted_class': 'Real'}\n",
|
| 991 |
+
"{'file': 'human voice 2 to Jett.mp3', 'predicted_class': 'Fake'}\n",
|
| 992 |
+
"{'file': 'human voice 2 to Jett.wav', 'predicted_class': 'Fake'}\n",
|
| 993 |
+
"{'file': 'human voice 2.mp3', 'predicted_class': 'Real'}\n",
|
| 994 |
+
"{'file': 'human voice 2.wav', 'predicted_class': 'Real'}\n",
|
| 995 |
+
"{'file': 'text to audio jett.mp3', 'predicted_class': 'Fake'}\n",
|
| 996 |
+
"{'file': 'text to audio jett.wav', 'predicted_class': 'Fake'}\n"
|
| 997 |
+
]
|
| 998 |
+
}
|
| 999 |
+
],
|
| 1000 |
+
"source": [
|
| 1001 |
+
"import os\n",
|
| 1002 |
+
"\n",
|
| 1003 |
+
"def batch_predict(test_folder, limit=10):\n",
|
| 1004 |
+
" \"\"\"\n",
|
| 1005 |
+
" Batch processes audio files for predictions.\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
" Args:\n",
|
| 1008 |
+
" test_folder (str): Path to the folder containing audio files.\n",
|
| 1009 |
+
" limit (int): Maximum number of files to process. Set to None for all files.\n",
|
| 1010 |
+
"\n",
|
| 1011 |
+
" Returns:\n",
|
| 1012 |
+
" list: A list of dictionaries containing file names and predicted classes.\n",
|
| 1013 |
+
" \"\"\"\n",
|
| 1014 |
+
" results = []\n",
|
| 1015 |
+
" files = os.listdir(test_folder)\n",
|
| 1016 |
+
"\n",
|
| 1017 |
+
" # Limit the number of files processed if a limit is provided\n",
|
| 1018 |
+
" if limit is not None:\n",
|
| 1019 |
+
" files = files[:limit]\n",
|
| 1020 |
+
"\n",
|
| 1021 |
+
" # Process each file in the folder\n",
|
| 1022 |
+
" for file_name in files:\n",
|
| 1023 |
+
" file_path = os.path.join(test_folder, file_name)\n",
|
| 1024 |
+
" try:\n",
|
| 1025 |
+
" predicted_class = predict_audio(file_path) # Use the predict_audio function\n",
|
| 1026 |
+
" results.append({\"file\": file_name, \"predicted_class\": predicted_class})\n",
|
| 1027 |
+
" except Exception as e:\n",
|
| 1028 |
+
" print(f\"Error processing {file_name}: {e}\")\n",
|
| 1029 |
+
" \n",
|
| 1030 |
+
" return results\n",
|
| 1031 |
+
"\n",
|
| 1032 |
+
"# Specify the folder path and limit\n",
|
| 1033 |
+
"test_folder = r\"D:\\Year 3 Sem 2\\Godamlah\\Deepfake\\deepfake model ver3\\data\\real life test audio\" # Replace with your test folder path\n",
|
| 1034 |
+
"results = batch_predict(test_folder, limit=10)\n",
|
| 1035 |
+
"\n",
|
| 1036 |
+
"# Print results\n",
|
| 1037 |
+
"for result in results:\n",
|
| 1038 |
+
" print(result)\n"
|
| 1039 |
+
]
|
| 1040 |
+
},
|
| 1041 |
+
{
|
| 1042 |
+
"cell_type": "code",
|
| 1043 |
+
"execution_count": 14,
|
| 1044 |
+
"metadata": {},
|
| 1045 |
+
"outputs": [
|
| 1046 |
+
{
|
| 1047 |
+
"data": {
|
| 1048 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1049 |
+
"model_id": "479899631f95453e9f82355f7511cff3",
|
| 1050 |
+
"version_major": 2,
|
| 1051 |
+
"version_minor": 0
|
| 1052 |
+
},
|
| 1053 |
+
"text/plain": [
|
| 1054 |
+
" 0%| | 0/580 [00:00<?, ?it/s]"
|
| 1055 |
+
]
|
| 1056 |
+
},
|
| 1057 |
+
"metadata": {},
|
| 1058 |
+
"output_type": "display_data"
|
| 1059 |
+
},
|
| 1060 |
+
{
|
| 1061 |
+
"name": "stdout",
|
| 1062 |
+
"output_type": "stream",
|
| 1063 |
+
"text": [
|
| 1064 |
+
"{'eval_loss': 4.472154250834137e-05, 'eval_model_preparation_time': 0.0032, 'eval_accuracy': 1.0, 'eval_runtime': 1168.3696, 'eval_samples_per_second': 3.966, 'eval_steps_per_second': 0.496}\n"
|
| 1065 |
+
]
|
| 1066 |
+
}
|
| 1067 |
+
],
|
| 1068 |
+
"source": [
|
| 1069 |
+
"import evaluate\n",
|
| 1070 |
+
"\n",
|
| 1071 |
+
"# Load the accuracy metric\n",
|
| 1072 |
+
"accuracy_metric = evaluate.load(\"accuracy\")\n",
|
| 1073 |
+
"\n",
|
| 1074 |
+
"def compute_metrics(eval_pred):\n",
|
| 1075 |
+
" logits, labels = eval_pred\n",
|
| 1076 |
+
" predictions = logits.argmax(axis=-1) # Get the predicted class\n",
|
| 1077 |
+
" accuracy = accuracy_metric.compute(predictions=predictions, references=labels)\n",
|
| 1078 |
+
" return {\"accuracy\": accuracy[\"accuracy\"]}\n",
|
| 1079 |
+
"\n",
|
| 1080 |
+
"from transformers import Trainer\n",
|
| 1081 |
+
"\n",
|
| 1082 |
+
"trainer = Trainer(\n",
|
| 1083 |
+
" model=model,\n",
|
| 1084 |
+
" args=training_args,\n",
|
| 1085 |
+
" train_dataset=processed_dataset[\"train\"],\n",
|
| 1086 |
+
" eval_dataset=processed_dataset[\"validation\"],\n",
|
| 1087 |
+
" tokenizer=processor, # Required for padding\n",
|
| 1088 |
+
" data_collator=data_collator,\n",
|
| 1089 |
+
" compute_metrics=compute_metrics, # Add this line\n",
|
| 1090 |
+
")\n",
|
| 1091 |
+
"\n",
|
| 1092 |
+
"# Evaluate the model\n",
|
| 1093 |
+
"metrics = trainer.evaluate()\n",
|
| 1094 |
+
"print(metrics)\n"
|
| 1095 |
+
]
|
| 1096 |
+
}
|
| 1097 |
+
],
|
| 1098 |
+
"metadata": {
|
| 1099 |
+
"kernelspec": {
|
| 1100 |
+
"display_name": "base",
|
| 1101 |
+
"language": "python",
|
| 1102 |
+
"name": "python3"
|
| 1103 |
+
},
|
| 1104 |
+
"language_info": {
|
| 1105 |
+
"codemirror_mode": {
|
| 1106 |
+
"name": "ipython",
|
| 1107 |
+
"version": 3
|
| 1108 |
+
},
|
| 1109 |
+
"file_extension": ".py",
|
| 1110 |
+
"mimetype": "text/x-python",
|
| 1111 |
+
"name": "python",
|
| 1112 |
+
"nbconvert_exporter": "python",
|
| 1113 |
+
"pygments_lexer": "ipython3",
|
| 1114 |
+
"version": "3.12.7"
|
| 1115 |
+
}
|
| 1116 |
+
},
|
| 1117 |
+
"nbformat": 4,
|
| 1118 |
+
"nbformat_minor": 2
|
| 1119 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
transformers
|
| 4 |
+
torch
|
| 5 |
+
torchaudio
|
| 6 |
+
pydantic
|
| 7 |
+
numpy
|
| 8 |
+
python-multipart
|
vercel.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"builds": [
|
| 3 |
+
{
|
| 4 |
+
"src": "DeepfakeModel.py",
|
| 5 |
+
"use": "@vercel/python",
|
| 6 |
+
"config": {
|
| 7 |
+
"maxLambdaSize": "450mb"
|
| 8 |
+
}
|
| 9 |
+
}
|
| 10 |
+
],
|
| 11 |
+
"routes": [
|
| 12 |
+
{
|
| 13 |
+
"src": "/(.*)",
|
| 14 |
+
"dest": "DeepfakeModel.py"
|
| 15 |
+
}
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
}
|