diff --git a/convert-to-coreml b/convert-to-coreml index 27bfa55..4305237 100755 --- a/convert-to-coreml +++ b/convert-to-coreml @@ -32,12 +32,8 @@ def main(): # Create sample 'audio' for tracing wav = torch.zeros(2, int(args.length * samplerate)) - # Reproduce the STFT step (which we cannot convert to Core ML, unfortunately) - _, stft_mag = estimator.compute_stft(wav) - print('==> Tracing model') - traced_model = torch.jit.trace(estimator.separator, stft_mag) - out = traced_model(stft_mag) + traced_model = torch.jit.trace(estimator, wav) print('==> Converting to Core ML') mlmodel = ct.convert( @@ -45,7 +41,7 @@ def main(): convert_to='mlprogram', # TODO: Investigate whether we'd want to make the input shape flexible # See https://coremltools.readme.io/docs/flexible-inputs - inputs=[ct.TensorType(shape=stft_mag.shape)] + inputs=[ct.TensorType(shape=wav.shape)] ) output_dir: Path = args.output diff --git a/pyproject.toml b/pyproject.toml index d8383a5..34bacec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "Spleeter implementation in PyTorch" # and fail during model conversions e.g. noting that BlobWriter is not available. requires-python = "<3.11" dependencies = [ - "coremltools >= 6.3, < 7", + "coremltools == 7.0b1", "numpy >= 1.24, < 2", "tensorflow >= 2.13.0rc0", "torch >= 2.0, < 3", diff --git a/spleeter_pytorch/estimator.py b/spleeter_pytorch/estimator.py index a6c042b..fcb302a 100644 --- a/spleeter_pytorch/estimator.py +++ b/spleeter_pytorch/estimator.py @@ -32,19 +32,27 @@ def compute_stft(self, wav): stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win, center=True, return_complex=True, pad_mode='constant') + + # implement torch.view_as_real(stft) manually since coremltools doesn't support it + stft = torch.stack((torch.real(stft), torch.imag(stft)), axis=-1) # only keep freqs smaller than self.F - stft = stft[:, :self.F, :] - mag = stft.abs() + stft = stft[:, :self.F] - return torch.view_as_real(stft), mag + # implement torch.hypot manually since coremltools doesn't support it + mag = torch.sqrt(stft[..., 0] ** 2 + stft[..., 1] ** 2) + + return stft, mag def inverse_stft(self, stft): """Inverses stft to wave form""" pad = self.win_length // 2 + 1 - stft.size(1) stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) - stft = torch.view_as_complex(stft) + + # implement torch.view_as_complex(stft) manually since coremltools doesn't support it + stft = torch.complex(stft[..., 0], stft[..., 1]) + wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True, window=self.win) return wav.detach()