Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/f5_tts/model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
self._resamplers = {}

def get_frame_len(self, index):
row = self.data[index]
Expand All @@ -51,8 +52,6 @@ def __getitem__(self, index):
row = self.data[index]
audio = row["audio"]["array"]

# logger.info(f"Audio shape: {audio.shape}")

sample_rate = row["audio"]["sampling_rate"]
duration = audio.shape[-1] / sample_rate

Expand All @@ -62,8 +61,9 @@ def __getitem__(self, index):
audio_tensor = torch.from_numpy(audio).float()

if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = resampler(audio_tensor)
if sample_rate not in self._resamplers:
self._resamplers[sample_rate] = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = self._resamplers[sample_rate](audio_tensor)

audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')

Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(
mel_spec_type=mel_spec_type,
),
)
self._resamplers = {}

def get_frame_len(self, index):
if (
Expand Down Expand Up @@ -149,8 +150,11 @@ def __getitem__(self, index):

# resample if necessary
if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
audio = resampler(audio)
if source_sample_rate not in self._resamplers:
self._resamplers[source_sample_rate] = torchaudio.transforms.Resample(
source_sample_rate, self.target_sample_rate
)
audio = self._resamplers[source_sample_rate](audio)

# to mel spectrogram
mel_spec = self.mel_spectrogram(audio)
Expand Down
28 changes: 16 additions & 12 deletions src/f5_tts/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

mel_basis_cache = {}
hann_window_cache = {}
vocos_mel_stft_cache = {}


def get_bigvgan_mel_spectrogram(
Expand Down Expand Up @@ -84,23 +85,26 @@ def get_vocos_mel_spectrogram(
hop_length=256,
win_length=1024,
):
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(waveform.device)
device = waveform.device
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{device}"
if key not in vocos_mel_stft_cache:
vocos_mel_stft_cache[key] = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'

assert len(waveform.shape) == 2

mel = mel_stft(waveform)
mel = vocos_mel_stft_cache[key](waveform)
mel = mel.clamp(min=1e-5).log()
return mel

Expand Down
Loading