Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/model/cfm.py +6 -1
- src/f5_tts/model/utils.py +19 -0
src/f5_tts/model/cfm.py
CHANGED
|
@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
|
|
| 22 |
from f5_tts.model.utils import (
|
| 23 |
default,
|
| 24 |
exists,
|
|
|
|
| 25 |
lens_to_mask,
|
| 26 |
list_str_to_idx,
|
| 27 |
list_str_to_tensor,
|
|
@@ -92,6 +93,7 @@ class CFM(nn.Module):
|
|
| 92 |
seed: int | None = None,
|
| 93 |
max_duration=4096,
|
| 94 |
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
|
|
|
| 95 |
no_ref_audio=False,
|
| 96 |
duplicate_test=False,
|
| 97 |
t_inter=0.1,
|
|
@@ -190,7 +192,10 @@ class CFM(nn.Module):
|
|
| 190 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
| 191 |
steps = int(steps * (1 - t_start))
|
| 192 |
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
| 194 |
if sway_sampling_coef is not None:
|
| 195 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
| 196 |
|
|
|
|
| 22 |
from f5_tts.model.utils import (
|
| 23 |
default,
|
| 24 |
exists,
|
| 25 |
+
get_epss_timesteps,
|
| 26 |
lens_to_mask,
|
| 27 |
list_str_to_idx,
|
| 28 |
list_str_to_tensor,
|
|
|
|
| 93 |
seed: int | None = None,
|
| 94 |
max_duration=4096,
|
| 95 |
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
| 96 |
+
use_epss=True,
|
| 97 |
no_ref_audio=False,
|
| 98 |
duplicate_test=False,
|
| 99 |
t_inter=0.1,
|
|
|
|
| 192 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
| 193 |
steps = int(steps * (1 - t_start))
|
| 194 |
|
| 195 |
+
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
|
| 196 |
+
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
| 197 |
+
else:
|
| 198 |
+
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
| 199 |
if sway_sampling_coef is not None:
|
| 200 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
| 201 |
|
src/f5_tts/model/utils.py
CHANGED
|
@@ -189,3 +189,22 @@ def repetition_found(text, length=2, tolerance=10):
|
|
| 189 |
if count > tolerance:
|
| 190 |
return True
|
| 191 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
if count > tolerance:
|
| 190 |
return True
|
| 191 |
return False
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# get the empirically pruned step for sampling
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_epss_timesteps(n, device, dtype):
|
| 198 |
+
dt = 1 / 32
|
| 199 |
+
predefined_timesteps = {
|
| 200 |
+
5: [0, 2, 4, 8, 16, 32],
|
| 201 |
+
6: [0, 2, 4, 6, 8, 16, 32],
|
| 202 |
+
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
| 203 |
+
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
| 204 |
+
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
| 205 |
+
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
| 206 |
+
}
|
| 207 |
+
t = predefined_timesteps.get(n, [])
|
| 208 |
+
if not t:
|
| 209 |
+
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
|
| 210 |
+
return dt * torch.tensor(t, device=device, dtype=dtype)
|