Skip to content

Commit

Permalink
Run make style
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Nov 8, 2023
1 parent 1fcb25a commit e213757
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 15 deletions.
9 changes: 6 additions & 3 deletions TTS/cs_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ class CS_API:
},
}


SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"]

def __init__(self, api_token=None, model="XTTS"):
Expand Down Expand Up @@ -308,7 +307,11 @@ def tts_to_file(
print(api.list_speakers_as_tts_models())

ts = time.time()
wav, sr = api.tts("It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name)
wav, sr = api.tts(
"It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name
)
print(f" [i] XTTS took {time.time() - ts:.2f}s")

filepath = api.tts_to_file(text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav")
filepath = api.tts_to_file(
text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav"
)
23 changes: 17 additions & 6 deletions TTS/tts/layers/tortoise/dpm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,21 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3,] * (
orders = [
3,
] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [3,] * (
orders = [
3,
] * (
K - 1
) + [1]
else:
orders = [3,] * (
orders = [
3,
] * (
K - 1
) + [2]
elif order == 2:
Expand All @@ -581,7 +587,9 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
] * K
else:
K = steps // 2 + 1
orders = [2,] * (
orders = [
2,
] * (
K - 1
) + [1]
elif order == 1:
Expand Down Expand Up @@ -1440,7 +1448,10 @@ def sample(
model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep":
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
(
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps,
order=order,
skip_type=skip_type,
Expand Down Expand Up @@ -1548,4 +1559,4 @@ def expand_dims(v, dims):
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
return v[(...,) + (None,) * (dims - 1)]
7 changes: 5 additions & 2 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,15 @@ def __init__(self, vocab_file=None):
@cached_property
def katsu(self):
import cutlet

return cutlet.Cutlet()

def check_input_length(self, txt, lang):
limit = self.char_limits.get(lang, 250)
if len(txt) > limit:
print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.")
print(
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
)

def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
if overlap_len > len(wav_chunk):
# wav_chunk is smaller than overlap_len, pass on last wav_gen
if wav_gen_prev is not None:
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):]
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
else:
# not expecting will hit here as problem happens on last chunk
wav_chunk = wav_gen[-overlap_len:]
Expand All @@ -616,7 +616,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
wav_chunk[:overlap_len] += crossfade_wav

wav_overlap = wav_gen[-overlap_len:]
wav_gen_prev = wav_gen
return wav_chunk, wav_gen_prev, wav_overlap
Expand Down
4 changes: 3 additions & 1 deletion tests/xtts_tests/test_xtts_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@


# Training sentences generations
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
SPEAKER_REFERENCE = [
"tests/data/ljspeech/wavs/LJ001-0002.wav"
] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language


Expand Down
4 changes: 3 additions & 1 deletion tests/xtts_tests/test_xtts_v2-0_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@


# Training sentences generations
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
SPEAKER_REFERENCE = [
"tests/data/ljspeech/wavs/LJ001-0002.wav"
] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language


Expand Down

0 comments on commit e213757

Please sign in to comment.