xieli commited on
Commit
29b1042
·
1 Parent(s): f21ec03

feat: fix paralinguistic, clone prompt, style prompt, input limit

Browse files

feat: add log

feat: fix prompt

fat: fix

feat: add style tag

feat: remove bgm

feat: fix vq0206 token

feat: fix

Files changed (5) hide show
  1. app.py +3 -3
  2. config/__init__.py +2 -2
  3. config/edit_config.py +2 -2
  4. config/prompts.py +10 -7
  5. tts.py +35 -88
app.py CHANGED
@@ -290,7 +290,7 @@ class EditxTab:
290
  self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
291
 
292
  # For para-linguistic, use generated_text; otherwise use source text
293
- if edit_type not in {"para-linguistic"}:
294
  generated_text = text_to_use
295
 
296
  # Use GPU inference with models loaded inside GPU context
@@ -355,14 +355,14 @@ class EditxTab:
355
  with gr.Row():
356
  with gr.Column():
357
  self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
358
- self.prompt_text_input = gr.Textbox(label="Audio Text Content", value="", scale=1)
359
  self.prompt_audio_input = gr.Audio(
360
  sources=["upload", "microphone"],
361
  format="wav",
362
  type="filepath",
363
  label="Input Audio",
364
  )
365
- self.generated_text = gr.Textbox(label="Clone Text", lines=1, max_lines=200)
366
  with gr.Row():
367
  self.button_tts = gr.Button("CLONE")
368
  self.button_edit = gr.Button("EDIT")
 
290
  self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
291
 
292
  # For para-linguistic, use generated_text; otherwise use source text
293
+ if edit_type not in {"paralinguistic"}:
294
  generated_text = text_to_use
295
 
296
  # Use GPU inference with models loaded inside GPU context
 
355
  with gr.Row():
356
  with gr.Column():
357
  self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
358
+ self.prompt_text_input = gr.Textbox(label="Prompt Text", value="", scale=1)
359
  self.prompt_audio_input = gr.Audio(
360
  sources=["upload", "microphone"],
361
  format="wav",
362
  type="filepath",
363
  label="Input Audio",
364
  )
365
+ self.generated_text = gr.Textbox(label="Target Text", lines=1, max_lines=200, max_length=100)
366
  with gr.Row():
367
  self.button_tts = gr.Button("CLONE")
368
  self.button_edit = gr.Button("EDIT")
config/__init__.py CHANGED
@@ -2,11 +2,11 @@
2
  Configuration module for Step-Audio
3
  """
4
 
5
- from .prompts import TTS_SYSTEM_PROMPTS, AUDIO_EDIT_SYSTEM_PROMPT
6
  from .edit_config import get_supported_edit_types
7
 
8
  __all__ = [
9
- 'TTS_SYSTEM_PROMPTS',
10
  'AUDIO_EDIT_SYSTEM_PROMPT',
11
  'get_supported_edit_types'
12
  ]
 
2
  Configuration module for Step-Audio
3
  """
4
 
5
+ from .prompts import AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL, AUDIO_EDIT_SYSTEM_PROMPT
6
  from .edit_config import get_supported_edit_types
7
 
8
  __all__ = [
9
+ 'AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL',
10
  'AUDIO_EDIT_SYSTEM_PROMPT',
11
  'get_supported_edit_types'
12
  ]
config/edit_config.py CHANGED
@@ -23,10 +23,10 @@ def get_supported_edit_types():
23
  'generous', 'act_coy', 'warm', 'shy', 'comfort', 'authority',
24
  'chat', 'radio', 'soulful', 'story', 'vivid', 'program',
25
  'news', 'advertising', 'roar', 'murmur', 'shout', 'deeply', 'loudly',
26
- 'remove'
27
  ],
28
  "vad": [],
29
  "denoise": [],
30
- "para-linguistic": [],
31
  "speed": ["faster", "slower", "more faster", "more slower"],
32
  }
 
23
  'generous', 'act_coy', 'warm', 'shy', 'comfort', 'authority',
24
  'chat', 'radio', 'soulful', 'story', 'vivid', 'program',
25
  'news', 'advertising', 'roar', 'murmur', 'shout', 'deeply', 'loudly',
26
+ 'remove', 'exaggerated'
27
  ],
28
  "vad": [],
29
  "denoise": [],
30
+ "paralinguistic": [],
31
  "speed": ["faster", "slower", "more faster", "more slower"],
32
  }
config/prompts.py CHANGED
@@ -3,13 +3,16 @@
3
  包含所有TTS和编辑相关的系统提示
4
  """
5
 
6
- # TTS相关系统提示
7
- TTS_SYSTEM_PROMPTS = {
8
- "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。",
9
- "sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。",
10
- "sys_prompt_wo_spk": '以自然的语速读出下面的文字。',
11
- "sys_prompt_with_spk": '请用{}的声音尽可能自然地说出下面这些话。',
12
- }
 
 
 
13
 
14
  AUDIO_EDIT_SYSTEM_PROMPT = """As a highly skilled audio editing and tuning specialist, you excel in interpreting user instructions and applying precise adjustments to meet their needs. Your expertise spans a wide range of enhancement capabilities, including but not limited to:
15
  # Emotional Enhancement
 
3
  包含所有TTS和编辑相关的系统提示
4
  """
5
 
6
+ AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL = """Generate audio with the following timbre, prosody and speaking style
7
+
8
+ [speaker_start]
9
+ speaker name: {speaker}
10
+ speaker prompt text:
11
+ {prompt_text}
12
+ speaker audio tokens:
13
+ {prompt_wav_tokens}
14
+ [speaker_end]
15
+ """
16
 
17
  AUDIO_EDIT_SYSTEM_PROMPT = """As a highly skilled audio editing and tuning specialist, you excel in interpreting user instructions and applying precise adjustments to meet their needs. Your expertise spans a wide range of enhancement capabilities, including but not limited to:
18
  # Emotional Enhancement
tts.py CHANGED
@@ -13,7 +13,7 @@ from http import HTTPStatus
13
  import torchaudio
14
 
15
  from model_loader import model_loader, ModelSource
16
- from config.prompts import TTS_SYSTEM_PROMPTS, AUDIO_EDIT_SYSTEM_PROMPT
17
  from stepvocoder.cosyvoice2.cli.cosyvoice import CosyVoice
18
  from transformers.generation.logits_process import LogitsProcessor
19
  from transformers.generation.utils import LogitsProcessorList
@@ -101,40 +101,9 @@ class StepAudioTTS:
101
  )
102
 
103
  # Use system prompts from config module
104
- self.tts_sys_prompt_dict = TTS_SYSTEM_PROMPTS
105
  self.edit_sys_prompt = AUDIO_EDIT_SYSTEM_PROMPT
106
 
107
- def get_audio_tokens(self, input_audio_data_numpy, input_audio_sample_rate):
108
- """
109
- Extract audio tokens using audio_tokenizer
110
-
111
- Args:
112
- input_audio_data_numpy: Audio data as numpy array
113
- input_audio_sample_rate: Sample rate of the audio
114
-
115
- Returns:
116
- str: Audio tokens as string
117
- """
118
- # Convert numpy array to tensor if needed
119
- if isinstance(input_audio_data_numpy, torch.Tensor):
120
- audio_tensor = input_audio_data_numpy
121
- else:
122
- audio_tensor = torch.from_numpy(input_audio_data_numpy).float()
123
-
124
- # Ensure proper shape (add batch dimension if needed)
125
- if len(audio_tensor.shape) == 1:
126
- audio_tensor = audio_tensor.unsqueeze(0)
127
-
128
- # Use the correct API: wav2token returns _, vq02_codes, vq06_codes
129
- _, vq02_codes, vq06_codes = self.audio_tokenizer.wav2token(audio_tensor, input_audio_sample_rate)
130
-
131
- # Merge VQ codes to token string
132
- audio_tokens = self.audio_tokenizer.merge_vq0206_to_token_str(
133
- vq02_codes, vq06_codes
134
- )
135
-
136
- return audio_tokens
137
-
138
  def clone(
139
  self,
140
  prompt_wav_path: str,
@@ -155,16 +124,19 @@ class StepAudioTTS:
155
  try:
156
  logger.debug(f"Starting voice cloning: {prompt_wav_path}")
157
  prompt_wav, sample_rate = torchaudio.load(prompt_wav_path)
158
- prompt_code, prompt_token, prompt_token_len, speech_feat, speech_feat_len, speech_embedding = (
159
  self.preprocess_prompt_wav(prompt_wav_path)
160
  )
161
  prompt_speaker = self.generate_clone_voice_id(prompt_text, prompt_wav)
162
-
163
- token_ids = self._encode_audio_tts_prompt(
 
 
164
  target_text,
165
  prompt_text,
166
  prompt_speaker,
167
- prompt_code,
 
168
  )
169
 
170
  output_ids = self.llm.generate(
@@ -176,10 +148,11 @@ class StepAudioTTS:
176
  )
177
  output_ids = output_ids[:, len(token_ids) : -1] # skip eos token
178
  logger.debug("Voice cloning generation completed")
 
179
  return (
180
  self.cosy_model.token2wav_nonstream(
181
  output_ids - 65536,
182
- prompt_token,
183
  speech_feat.to(torch.bfloat16),
184
  speech_embedding.to(torch.bfloat16),
185
  ),
@@ -211,16 +184,16 @@ class StepAudioTTS:
211
  Tuple[torch.Tensor, int]: Edited audio tensor and sample rate
212
  """
213
  try:
214
- logger.debug(f"Starting audio editing: {edit_type} - {edit_info}")
215
-
216
- # Load input audio
217
- input_audio, sample_rate = torchaudio.load(input_audio_path)
218
-
219
- # Get audio tokens
220
- audio_tokens = self.get_audio_tokens(input_audio, sample_rate)
221
-
222
  # Build instruction prefix based on edit type
223
  instruct_prefix = self._build_audio_edit_instruction(audio_text, edit_type, edit_info, text)
 
224
 
225
  # Encode the complete prompt to token sequence
226
  prompt_tokens = self._encode_audio_edit_prompt(
@@ -238,15 +211,12 @@ class StepAudioTTS:
238
  logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]),
239
  )
240
  output_ids = output_ids[:, len(prompt_tokens) : -1] # skip eos token
241
-
242
- _, prompt_token, _, speech_feat, _, speech_embedding = (
243
- self.preprocess_prompt_wav(input_audio_path)
244
- )
245
  logger.debug("Audio editing generation completed")
246
  return (
247
  self.cosy_model.token2wav_nonstream(
248
  output_ids - 65536,
249
- prompt_token,
250
  speech_feat.to(torch.bfloat16),
251
  speech_embedding.to(torch.bfloat16),
252
  ),
@@ -285,16 +255,12 @@ class StepAudioTTS:
285
  elif edit_type == "style":
286
  if edit_info == "remove":
287
  instruct_prefix = f"Remove any speaking styles in the following audio and the reference text is: {audio_text}\n"
288
- elif edit_info in {"exaggerated","ethereal","whisper","act_coy","older"}:
289
- instruct_prefix = f"Make the following audio more {edit_info} style. The text corresponding to the audio is: {audio_text}\n"
290
  else:
291
- instruct_prefix=f"Make the following audio more {edit_info}. The text corresponding to the audio is: {audio_text}\n"
292
  elif edit_type == "denoise":
293
  instruct_prefix = f"Remove any noise from the given audio while preserving the voice content clearly. Ensure that the speech quality remains intact with minimal distortion, and eliminate all noise from the audio."
294
  elif edit_type == "vad":
295
  instruct_prefix = f"Remove any silent portions from the given audio while preserving the voice content clearly. Ensure that the speech quality remains intact with minimal distortion, and eliminate all silence from the audio."
296
- elif edit_type == "bgm":
297
- instruct_prefix = f"Remove any background music (BGM) from the given audio while preserving the voice content clearly. Ensure that the speech quality remains intact with minimal distortion, and eliminate all BGM from the audio."
298
  elif edit_type == "paralinguistic":
299
  instruct_prefix = f"Add some non-verbal sounds to make the audio more natural, the new text is : {text}\n The text corresponding to the audio is: {audio_text}\n"
300
  else:
@@ -331,30 +297,22 @@ class StepAudioTTS:
331
  history.extend([4] + qrole_toks + human_turn_toks + [3] + [4] + arole_toks)
332
  return history
333
 
334
- def _encode_audio_tts_prompt(
335
- self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list
336
  ):
337
- rap_or_vocal = self.detect_instruction_name(text) in ("RAP", "哼唱")
338
-
339
- if rap_or_vocal:
340
- if "哼唱" in text:
341
- prompt = self.tts_sys_prompt_dict["sys_prompt_for_vocal"]
342
- else:
343
- prompt = self.tts_sys_prompt_dict["sys_prompt_for_rap"]
344
- elif prompt_speaker:
345
- prompt = self.tts_sys_prompt_dict["sys_prompt_with_spk"].format(prompt_speaker)
346
- else:
347
- prompt = self.tts_sys_prompt_dict["sys_prompt_wo_spk"]
348
-
349
  sys_tokens = self.tokenizer.encode(f"system\n{prompt}")
350
 
351
  history = [1]
352
  history.extend([4] + sys_tokens + [3])
353
 
354
  _prefix_tokens = self.tokenizer.encode("\n")
355
- prompt_token_encode = self.tokenizer.encode("\n" + prompt_text)
356
- prompt_tokens = prompt_token_encode[len(_prefix_tokens) :]
357
-
358
  target_token_encode = self.tokenizer.encode("\n" + text)
359
  target_tokens = target_token_encode[len(_prefix_tokens) :]
360
 
@@ -364,14 +322,6 @@ class StepAudioTTS:
364
  history.extend(
365
  [4]
366
  + qrole_toks
367
- + prompt_tokens
368
- + [3]
369
- + [4]
370
- + arole_toks
371
- + prompt_code
372
- + [3]
373
- + [4]
374
- + qrole_toks
375
  + target_tokens
376
  + [3]
377
  + [4]
@@ -410,20 +360,17 @@ class StepAudioTTS:
410
  prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path)
411
  if prompt_wav.shape[0] > 1:
412
  prompt_wav = prompt_wav.mean(dim=0, keepdim=True) # 将多通道音频转换为单通道
413
- prompt_token, prompt_token_len = self.cosy_model.frontend.extract_speech_token(
414
- prompt_wav, prompt_wav_sr
415
- )
416
  speech_feat, speech_feat_len = self.cosy_model.frontend.extract_speech_feat(
417
  prompt_wav, prompt_wav_sr
418
  )
419
  speech_embedding = self.cosy_model.frontend.extract_spk_embedding(
420
  prompt_wav, prompt_wav_sr
421
  )
422
- prompt_code, _, _ = self.audio_tokenizer.wav2token(prompt_wav, prompt_wav_sr)
423
  return (
424
- prompt_code,
425
- prompt_token,
426
- prompt_token_len,
427
  speech_feat,
428
  speech_feat_len,
429
  speech_embedding,
 
13
  import torchaudio
14
 
15
  from model_loader import model_loader, ModelSource
16
+ from config.prompts import AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL, AUDIO_EDIT_SYSTEM_PROMPT
17
  from stepvocoder.cosyvoice2.cli.cosyvoice import CosyVoice
18
  from transformers.generation.logits_process import LogitsProcessor
19
  from transformers.generation.utils import LogitsProcessorList
 
101
  )
102
 
103
  # Use system prompts from config module
104
+ self.edit_clone_sys_prompt_tpl = AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL
105
  self.edit_sys_prompt = AUDIO_EDIT_SYSTEM_PROMPT
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def clone(
108
  self,
109
  prompt_wav_path: str,
 
124
  try:
125
  logger.debug(f"Starting voice cloning: {prompt_wav_path}")
126
  prompt_wav, sample_rate = torchaudio.load(prompt_wav_path)
127
+ vq0206_codes, vq02_codes_ori, vq06_codes_ori, speech_feat, speech_feat_len, speech_embedding = (
128
  self.preprocess_prompt_wav(prompt_wav_path)
129
  )
130
  prompt_speaker = self.generate_clone_voice_id(prompt_text, prompt_wav)
131
+ prompt_wav_tokens = self.audio_tokenizer.merge_vq0206_to_token_str(
132
+ vq02_codes_ori, vq06_codes_ori
133
+ )
134
+ token_ids = self._encode_audio_edit_clone_prompt(
135
  target_text,
136
  prompt_text,
137
  prompt_speaker,
138
+ vq0206_codes,
139
+ prompt_wav_tokens,
140
  )
141
 
142
  output_ids = self.llm.generate(
 
148
  )
149
  output_ids = output_ids[:, len(token_ids) : -1] # skip eos token
150
  logger.debug("Voice cloning generation completed")
151
+ vq0206_codes_vocoder = torch.tensor([vq0206_codes], dtype=torch.long) - 65536
152
  return (
153
  self.cosy_model.token2wav_nonstream(
154
  output_ids - 65536,
155
+ vq0206_codes_vocoder,
156
  speech_feat.to(torch.bfloat16),
157
  speech_embedding.to(torch.bfloat16),
158
  ),
 
184
  Tuple[torch.Tensor, int]: Edited audio tensor and sample rate
185
  """
186
  try:
187
+ logger.debug(f"Starting audio editing: {edit_type} - {edit_info}")
188
+ vq0206_codes, vq02_codes_ori, vq06_codes_ori, speech_feat, _, speech_embedding = (
189
+ self.preprocess_prompt_wav(input_audio_path)
190
+ )
191
+ audio_tokens = self.audio_tokenizer.merge_vq0206_to_token_str(
192
+ vq02_codes_ori, vq06_codes_ori
193
+ )
 
194
  # Build instruction prefix based on edit type
195
  instruct_prefix = self._build_audio_edit_instruction(audio_text, edit_type, edit_info, text)
196
+ print(f"instruct_prefix: {instruct_prefix}")
197
 
198
  # Encode the complete prompt to token sequence
199
  prompt_tokens = self._encode_audio_edit_prompt(
 
211
  logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]),
212
  )
213
  output_ids = output_ids[:, len(prompt_tokens) : -1] # skip eos token
214
+ vq0206_codes_vocoder = torch.tensor([vq0206_codes], dtype=torch.long) - 65536
 
 
 
215
  logger.debug("Audio editing generation completed")
216
  return (
217
  self.cosy_model.token2wav_nonstream(
218
  output_ids - 65536,
219
+ vq0206_codes_vocoder,
220
  speech_feat.to(torch.bfloat16),
221
  speech_embedding.to(torch.bfloat16),
222
  ),
 
255
  elif edit_type == "style":
256
  if edit_info == "remove":
257
  instruct_prefix = f"Remove any speaking styles in the following audio and the reference text is: {audio_text}\n"
 
 
258
  else:
259
+ instruct_prefix = f"Make the following audio more {edit_info} style. The text corresponding to the audio is: {audio_text}\n"
260
  elif edit_type == "denoise":
261
  instruct_prefix = f"Remove any noise from the given audio while preserving the voice content clearly. Ensure that the speech quality remains intact with minimal distortion, and eliminate all noise from the audio."
262
  elif edit_type == "vad":
263
  instruct_prefix = f"Remove any silent portions from the given audio while preserving the voice content clearly. Ensure that the speech quality remains intact with minimal distortion, and eliminate all silence from the audio."
 
 
264
  elif edit_type == "paralinguistic":
265
  instruct_prefix = f"Add some non-verbal sounds to make the audio more natural, the new text is : {text}\n The text corresponding to the audio is: {audio_text}\n"
266
  else:
 
297
  history.extend([4] + qrole_toks + human_turn_toks + [3] + [4] + arole_toks)
298
  return history
299
 
300
+ def _encode_audio_edit_clone_prompt(
301
+ self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list, prompt_wav_tokens: str
302
  ):
303
+ prompt = self.edit_clone_sys_prompt_tpl.format(
304
+ speaker=prompt_speaker,
305
+ prompt_text=prompt_text,
306
+ prompt_wav_tokens=prompt_wav_tokens
307
+ )
308
+ print(f"edit clone prompt: {prompt}")
 
 
 
 
 
 
309
  sys_tokens = self.tokenizer.encode(f"system\n{prompt}")
310
 
311
  history = [1]
312
  history.extend([4] + sys_tokens + [3])
313
 
314
  _prefix_tokens = self.tokenizer.encode("\n")
315
+
 
 
316
  target_token_encode = self.tokenizer.encode("\n" + text)
317
  target_tokens = target_token_encode[len(_prefix_tokens) :]
318
 
 
322
  history.extend(
323
  [4]
324
  + qrole_toks
 
 
 
 
 
 
 
 
325
  + target_tokens
326
  + [3]
327
  + [4]
 
360
  prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path)
361
  if prompt_wav.shape[0] > 1:
362
  prompt_wav = prompt_wav.mean(dim=0, keepdim=True) # 将多通道音频转换为单通道
 
 
 
363
  speech_feat, speech_feat_len = self.cosy_model.frontend.extract_speech_feat(
364
  prompt_wav, prompt_wav_sr
365
  )
366
  speech_embedding = self.cosy_model.frontend.extract_spk_embedding(
367
  prompt_wav, prompt_wav_sr
368
  )
369
+ vq0206_codes, vq02_codes_ori, vq06_codes_ori = self.audio_tokenizer.wav2token(prompt_wav, prompt_wav_sr)
370
  return (
371
+ vq0206_codes,
372
+ vq02_codes_ori,
373
+ vq06_codes_ori,
374
  speech_feat,
375
  speech_feat_len,
376
  speech_embedding,