Max005 commited on
Commit
3562f68
·
1 Parent(s): 21657a1
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. main.py +95 -0
Dockerfile CHANGED
@@ -13,4 +13,4 @@ COPY --chown=user ./requirements.txt requirements.txt
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
 
15
  COPY --chown=user . /app
16
- CMD ["uvicorn", "DeepfakeModel:app", "--host", "0.0.0.0", "--port", "7860"]
 
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
 
15
  COPY --chown=user . /app
16
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from pydantic import BaseModel
3
+ import os
4
+ import torchaudio
5
+ import torch.nn.functional as F
6
+ import torch
7
+ from transformers import AutoProcessor, AutoModelForAudioClassification, pipeline
8
+ from pathlib import Path
9
+
10
+ app_dir = Path(__file__).parent
11
+
12
+ # Deepfake model setup
13
+ deepfake_model_path = app_dir / "Deepfake" / "model"
14
+ deepfake_processor = AutoProcessor.from_pretrained(deepfake_model_path)
15
+ deepfake_model = AutoModelForAudioClassification.from_pretrained(
16
+ pretrained_model_name_or_path=deepfake_model_path,
17
+ local_files_only=True,
18
+ )
19
+
20
+ def prepare_audio(file_path, sampling_rate=16000, duration=10):
21
+ waveform, original_sampling_rate = torchaudio.load(file_path)
22
+ if waveform.shape[0] > 1:
23
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
24
+ if original_sampling_rate != sampling_rate:
25
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)
26
+ waveform = resampler(waveform)
27
+ chunk_size = sampling_rate * duration
28
+ audio_chunks = []
29
+ for start in range(0, waveform.shape[1], chunk_size):
30
+ chunk = waveform[:, start:start + chunk_size]
31
+ if chunk.shape[1] < chunk_size:
32
+ padding = chunk_size - chunk.shape[1]
33
+ chunk = torch.nn.functional.pad(chunk, (0, padding))
34
+ audio_chunks.append(chunk.squeeze().numpy())
35
+ return audio_chunks
36
+
37
+ def predict_audio(file_path):
38
+ audio_chunks = prepare_audio(file_path)
39
+ predictions = []
40
+ confidences = []
41
+ for chunk in audio_chunks:
42
+ inputs = deepfake_processor(
43
+ chunk, sampling_rate=16000, return_tensors="pt", padding=True
44
+ )
45
+ with torch.no_grad():
46
+ outputs = deepfake_model(**inputs)
47
+ logits = outputs.logits
48
+ probabilities = F.softmax(logits, dim=1)
49
+ confidence, predicted_class = torch.max(probabilities, dim=1)
50
+ predictions.append(predicted_class.item())
51
+ confidences.append(confidence.item())
52
+ aggregated_prediction_id = max(set(predictions), key=predictions.count)
53
+ predicted_label = deepfake_model.config.id2label[aggregated_prediction_id]
54
+ average_confidence = sum(confidences) / len(confidences)
55
+ return {
56
+ "predicted_label": predicted_label,
57
+ "average_confidence": average_confidence
58
+ }
59
+
60
+ # ScamText model setup
61
+ scamtext_pipe = pipeline("text-classification", model="phishbot/ScamLLM")
62
+
63
+ # Initialize FastAPI
64
+ app = FastAPI()
65
+
66
+ @app.post("/deepfake/infer")
67
+ async def deepfake_infer(file: UploadFile = File(...)):
68
+ temp_file_path = f"temp_{file.filename}"
69
+ with open(temp_file_path, "wb") as temp_file:
70
+ temp_file.write(await file.read())
71
+ try:
72
+ predictions = predict_audio(temp_file_path)
73
+ finally:
74
+ os.remove(temp_file_path)
75
+ return predictions
76
+
77
+ @app.post("/scamtext/infer")
78
+ async def scamtext_infer(data: BaseModel):
79
+ predictions = scamtext_pipe(data.input)
80
+ return predictions
81
+
82
+ @app.get("/deepfake/health")
83
+ async def deepfake_health():
84
+ return {
85
+ "message": "ok",
86
+ "Sound": str(torchaudio.list_audio_backends())
87
+ }
88
+
89
+ @app.get("/scamtext/health")
90
+ async def scamtext_health():
91
+ return {"message": "ok"}
92
+
93
+ if __name__ == "__main__":
94
+ import uvicorn
95
+ uvicorn.run(app, host="0.0.0.0", port=8000)