Upload folder using huggingface_hub
Browse files- handler.py +50 -34
- test.py +57 -4
handler.py
CHANGED
|
@@ -7,6 +7,7 @@ import numpy as np
|
|
| 7 |
from queue import Queue, Empty
|
| 8 |
import threading
|
| 9 |
import base64
|
|
|
|
| 10 |
|
| 11 |
class EndpointHandler:
|
| 12 |
def __init__(self, path=""):
|
|
@@ -22,7 +23,7 @@ class EndpointHandler:
|
|
| 22 |
self.parler_tts_handler_kwargs,
|
| 23 |
self.melo_tts_handler_kwargs,
|
| 24 |
self.chat_tts_handler_kwargs,
|
| 25 |
-
) = get_default_arguments(mode='none',
|
| 26 |
setup_logger(self.module_kwargs.log_level)
|
| 27 |
|
| 28 |
prepare_all_args(
|
|
@@ -57,65 +58,80 @@ class EndpointHandler:
|
|
| 57 |
|
| 58 |
# Add a new queue for collecting the final output
|
| 59 |
self.final_output_queue = Queue()
|
|
|
|
| 60 |
|
| 61 |
-
def _collect_output(self):
|
| 62 |
while True:
|
| 63 |
try:
|
| 64 |
-
output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=
|
| 65 |
if isinstance(output, (str, bytes)) and output in (b"END", "END"):
|
| 66 |
-
self.
|
| 67 |
break
|
| 68 |
elif isinstance(output, np.ndarray):
|
| 69 |
-
self.
|
| 70 |
else:
|
| 71 |
-
self.
|
| 72 |
except Empty:
|
| 73 |
-
|
| 74 |
-
self.final_output_queue.put("END")
|
| 75 |
break
|
| 76 |
|
| 77 |
-
def __call__(self, data: Dict[str, Any]) ->
|
| 78 |
-
"""
|
| 79 |
-
Args:
|
| 80 |
-
data (Dict[str, Any]): The input data containing the necessary arguments.
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
input_type = data.get("input_type", "text")
|
| 90 |
input_data = data.get("inputs", "")
|
| 91 |
|
| 92 |
if input_type == "speech":
|
| 93 |
-
# Convert input audio data to numpy array
|
| 94 |
audio_array = np.frombuffer(input_data, dtype=np.int16)
|
| 95 |
-
|
| 96 |
-
# Put audio data into the recv_audio_chunks_queue
|
| 97 |
self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
|
| 98 |
elif input_type == "text":
|
| 99 |
-
# Put text data directly into the text_prompt_queue
|
| 100 |
self.queues_and_events['text_prompt_queue'].put(input_data)
|
| 101 |
else:
|
| 102 |
raise ValueError(f"Unsupported input type: {input_type}")
|
| 103 |
|
| 104 |
-
#
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
if chunk == "END":
|
| 109 |
-
break
|
| 110 |
-
output_chunks.append(chunk)
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def cleanup(self):
|
| 121 |
# Stop the pipeline
|
|
|
|
| 7 |
from queue import Queue, Empty
|
| 8 |
import threading
|
| 9 |
import base64
|
| 10 |
+
import uuid
|
| 11 |
|
| 12 |
class EndpointHandler:
|
| 13 |
def __init__(self, path=""):
|
|
|
|
| 23 |
self.parler_tts_handler_kwargs,
|
| 24 |
self.melo_tts_handler_kwargs,
|
| 25 |
self.chat_tts_handler_kwargs,
|
| 26 |
+
) = get_default_arguments(mode='none', log_level='DEBUG')
|
| 27 |
setup_logger(self.module_kwargs.log_level)
|
| 28 |
|
| 29 |
prepare_all_args(
|
|
|
|
| 58 |
|
| 59 |
# Add a new queue for collecting the final output
|
| 60 |
self.final_output_queue = Queue()
|
| 61 |
+
self.sessions = {} # Store session information
|
| 62 |
|
| 63 |
+
def _collect_output(self, session_id):
|
| 64 |
while True:
|
| 65 |
try:
|
| 66 |
+
output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2)
|
| 67 |
if isinstance(output, (str, bytes)) and output in (b"END", "END"):
|
| 68 |
+
self.sessions[session_id]['status'] = 'completed'
|
| 69 |
break
|
| 70 |
elif isinstance(output, np.ndarray):
|
| 71 |
+
self.sessions[session_id]['chunks'].append(output.tobytes())
|
| 72 |
else:
|
| 73 |
+
self.sessions[session_id]['chunks'].append(output)
|
| 74 |
except Empty:
|
| 75 |
+
self.sessions[session_id]['status'] = 'completed'
|
|
|
|
| 76 |
break
|
| 77 |
|
| 78 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 79 |
+
request_type = data.get("request_type", "start")
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
if request_type == "start":
|
| 82 |
+
return self._handle_start_request(data)
|
| 83 |
+
elif request_type == "continue":
|
| 84 |
+
return self._handle_continue_request(data)
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Unsupported request type: {request_type}")
|
| 87 |
+
|
| 88 |
+
def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 89 |
+
session_id = str(uuid.uuid4())
|
| 90 |
+
self.sessions[session_id] = {
|
| 91 |
+
'status': 'processing',
|
| 92 |
+
'chunks': [],
|
| 93 |
+
'last_sent_index': 0
|
| 94 |
+
}
|
| 95 |
|
| 96 |
input_type = data.get("input_type", "text")
|
| 97 |
input_data = data.get("inputs", "")
|
| 98 |
|
| 99 |
if input_type == "speech":
|
|
|
|
| 100 |
audio_array = np.frombuffer(input_data, dtype=np.int16)
|
|
|
|
|
|
|
| 101 |
self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
|
| 102 |
elif input_type == "text":
|
|
|
|
| 103 |
self.queues_and_events['text_prompt_queue'].put(input_data)
|
| 104 |
else:
|
| 105 |
raise ValueError(f"Unsupported input type: {input_type}")
|
| 106 |
|
| 107 |
+
# Start output collection in a separate thread
|
| 108 |
+
threading.Thread(target=self._collect_output, args=(session_id,)).start()
|
| 109 |
+
|
| 110 |
+
return {"session_id": session_id, "status": "processing"}
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 113 |
+
session_id = data.get("session_id")
|
| 114 |
+
if not session_id or session_id not in self.sessions:
|
| 115 |
+
raise ValueError("Invalid or missing session_id")
|
| 116 |
|
| 117 |
+
session = self.sessions[session_id]
|
| 118 |
+
chunks_to_send = session['chunks'][session['last_sent_index']:]
|
| 119 |
+
session['last_sent_index'] = len(session['chunks'])
|
| 120 |
|
| 121 |
+
if chunks_to_send:
|
| 122 |
+
combined_audio = b''.join(chunks_to_send)
|
| 123 |
+
base64_audio = base64.b64encode(combined_audio).decode('utf-8')
|
| 124 |
+
return {
|
| 125 |
+
"session_id": session_id,
|
| 126 |
+
"status": session['status'],
|
| 127 |
+
"output": base64_audio
|
| 128 |
+
}
|
| 129 |
+
else:
|
| 130 |
+
return {
|
| 131 |
+
"session_id": session_id,
|
| 132 |
+
"status": session['status'],
|
| 133 |
+
"output": None
|
| 134 |
+
}
|
| 135 |
|
| 136 |
def cleanup(self):
|
| 137 |
# Stop the pipeline
|
test.py
CHANGED
|
@@ -1,7 +1,60 @@
|
|
| 1 |
from handler import EndpointHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from handler import EndpointHandler
|
| 2 |
+
import requests
|
| 3 |
+
import base64
|
| 4 |
+
import numpy as np
|
| 5 |
+
import sounddevice as sd
|
| 6 |
+
import time
|
| 7 |
|
| 8 |
+
my_handler = EndpointHandler('')
|
| 9 |
|
| 10 |
+
|
| 11 |
+
def play_audio(audio_data, sample_rate=16000):
|
| 12 |
+
sd.play(audio_data, sample_rate)
|
| 13 |
+
sd.wait()
|
| 14 |
+
|
| 15 |
+
def stream_audio(session_id):
|
| 16 |
+
audio_chunks = []
|
| 17 |
+
while True:
|
| 18 |
+
continue_payload = {
|
| 19 |
+
"request_type": "continue",
|
| 20 |
+
"session_id": session_id
|
| 21 |
+
}
|
| 22 |
+
response = my_handler(continue_payload)
|
| 23 |
+
|
| 24 |
+
if response["status"] == "completed" and response["output"] is None:
|
| 25 |
+
break
|
| 26 |
+
|
| 27 |
+
if response["output"]:
|
| 28 |
+
audio_bytes = base64.b64decode(response["output"])
|
| 29 |
+
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
| 30 |
+
audio_chunks.append(audio_np)
|
| 31 |
+
|
| 32 |
+
# Play the chunk immediately (optional)
|
| 33 |
+
play_audio(audio_np)
|
| 34 |
+
|
| 35 |
+
time.sleep(0.01) # Small delay to prevent overwhelming the server
|
| 36 |
+
|
| 37 |
+
return np.concatenate(audio_chunks) if audio_chunks else None
|
| 38 |
+
|
| 39 |
+
# Test with text input
|
| 40 |
+
text_payload = {
|
| 41 |
+
"request_type": "start",
|
| 42 |
+
"inputs": "Tell me a cool fact about Messi.",
|
| 43 |
+
"input_type": "text",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
start_response = my_handler(text_payload)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if "session_id" in start_response:
|
| 50 |
+
print(f"Session started. Session ID: {start_response['session_id']}")
|
| 51 |
+
print("Streaming audio response...")
|
| 52 |
+
|
| 53 |
+
full_audio = stream_audio(start_response['session_id'])
|
| 54 |
+
|
| 55 |
+
if full_audio is not None:
|
| 56 |
+
print("Received complete audio response. Playing...")
|
| 57 |
+
else:
|
| 58 |
+
print("No audio received.")
|
| 59 |
+
else:
|
| 60 |
+
print("Error:", start_response)
|