Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ test: tool
test-example: tool
[ -n "$(MODULE)" ] && module=modules/$(MODULE).py || module=; \
. .venv/bin/activate && export NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS=0 && \
python -m pytest --doctest-modules --no-cov --ignore=diffsptk/third_party diffsptk/$$module
python -m pytest --doctest-modules --no-cov --ignore=$(PROJECT)/third_party $(PROJECT)/$$module

test-clean:
rm -rf tests/__pycache__
Expand Down
7 changes: 3 additions & 4 deletions diffsptk/modules/acorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ def _precompute(
elif out_format in (2, "biased"):
formatter = lambda x: x / frame_length
elif out_format in (3, "unbiased"):
formatter = lambda x: x / (
torch.arange(
frame_length, frame_length - acr_order - 1, -1, device=x.device
)
n = frame_length - acr_order - 1
formatter = lambda x: (
x / (torch.arange(frame_length, n, -1, device=x.device))
)
else:
raise ValueError(f"out_format {out_format} is not supported.")
Expand Down
6 changes: 3 additions & 3 deletions diffsptk/modules/mcep.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def forward(self, x: torch.Tensor):
... )
>>> x = diffsptk.ramp(19)
>>> mc = mcep(stft(x))
>>> mc
tensor([[-0.8851, 0.7917, -0.1737, 0.0175],
[-0.3522, 4.4222, -1.0882, -0.0510]])
>>> mc.round(decimals=3)
tensor([[-0.8850, 0.7920, -0.1740, 0.0170],
[-0.3520, 4.4220, -1.0880, -0.0510]])

"""
check_size(x.size(-1), self.in_dim, "dimension of spectrum")
Expand Down
245 changes: 236 additions & 9 deletions diffsptk/modules/mglsadf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@

from typing import Any

import mpmath as mp
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from ..utils.private import Lambda, check_size, get_gamma, remove_gain
from ..utils.private import Lambda, check_size, get_gamma, remove_gain, to
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
from .base import BaseNonFunctionalModule
from .c2mpir import CepstrumToMinimumPhaseImpulseResponse
from .frame import Frame
from .gnorm import GeneralizedCepstrumGainNormalization
from .istft import InverseShortTimeFourierTransform
from .linear_intpl import LinearInterpolation
from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum
from .mgc2sp import MelGeneralizedCepstrumToSpectrum
from .root_pol import PolynomialToRoots
from .stft import ShortTimeFourierTransform


Expand Down Expand Up @@ -74,12 +77,15 @@ class PseudoMGLSADigitalFilter(BaseNonFunctionalModule):
phase : ['minimum', 'maximum', 'zero', 'mixed']
The filter type.

mode : ['multi-stage', 'single-stage', 'freq-domain']
mode : ['multi-stage', 'single-stage', 'freq-domain', 'pade-approx']
'multi-stage' approximates the MLSA filter by cascading FIR filters based on the
Taylor series expansion. 'single-stage' uses an FIR filter with the coefficients
derived from the impulse response converted from the input mel-cepstral
coefficients using FFT. 'freq-domain' performs filtering in the frequency domain
rather than the time domain.
rather than the time domain. 'pade-approx' implements the MLSA filter by
cascading all-zero and all-pole filters derived from the factorization. While
this approach is not computationally efficient, it allows for the optimization
of the Pade approximation coefficients.

n_fft : int >= 1
The number of FFT bins used for conversion. Higher values result in increased
Expand All @@ -89,12 +95,20 @@ class PseudoMGLSADigitalFilter(BaseNonFunctionalModule):
The order of the Taylor series expansion (valid only if **mode** is
'multi-stage').

pade_order : int >= 3
The order of Pade approximation (valid only if **mode** is 'pade-approx').

cep_order : int >= 0 or tuple[int, int]
The order of the linear cepstrum (valid only if **mode** is 'multi-stage').
The order of the linear cepstrum (valid only if **mode** is 'multi-stage' or
'pade-approx').

ir_length : int >= 1 or tuple[int, int]
The length of the impulse response (valid only if **mode** is 'single-stage').

learnable : bool
If True, the polynomial coefficients used in the approximation are learnable
(valid only if **mode** is 'multi-stage' or 'pade-approx').

device : torch.device or None
The device of this module.

Expand Down Expand Up @@ -181,6 +195,16 @@ def flip(x):
phase=phase,
**modified_kwargs,
)
elif mode == "pade-approx":
self.mglsadf = MultiStageIIRFilter(
flipped_filter_order,
frame_period,
alpha=alpha,
gamma=gamma,
ignore_gain=ignore_gain,
phase=phase,
**modified_kwargs,
)
else:
raise ValueError(f"mode {mode} is not supported.")

Expand Down Expand Up @@ -238,6 +262,7 @@ def __init__(
taylor_order: int = 20,
cep_order: tuple[int, int] | int = 199,
n_fft: int = 512,
learnable: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
Expand All @@ -248,7 +273,6 @@ def __init__(

self.ignore_gain = ignore_gain
self.phase = phase
self.taylor_order = taylor_order

if alpha == 0 and gamma == 0:
cep_order = filter_order
Expand Down Expand Up @@ -294,6 +318,19 @@ def __init__(

self.linear_intpl = LinearInterpolation(frame_period)

cp = mp.taylor(mp.exp, 0, taylor_order)
cp = np.array([float(x) for x in cp])
weights = cp[1:] / cp[:-1]
weights = np.insert(weights, 0, 1)
self.register_buffer("weights", to(weights, device=device, dtype=dtype))

a = np.ones(taylor_order + 1)
a = to(a, device=device, dtype=dtype)
if learnable:
self.a = nn.Parameter(a)
else:
self.register_buffer("a", a)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -322,12 +359,12 @@ def forward(

c = self.linear_intpl(c)

y = x.clone()
for a in range(1, self.taylor_order + 1):
y = x * self.a[0]
for i in range(1, len(self.a)):
x = self.pad(x)
x = x.unfold(-1, c.size(-1), 1)
x = (x * c).sum(-1) / a
y += x
x = (x * c).sum(-1) * self.weights[i]
y += x * self.a[i]

if not self.ignore_gain:
K = torch.exp(self.linear_intpl(c0))
Expand Down Expand Up @@ -586,3 +623,193 @@ def forward(
Y = H * X
y = self.istft(Y, out_length=x.size(-1))
return y


class MultiStageIIRFilter(nn.Module):
def __init__(
self,
filter_order: tuple[int, int] | int,
frame_period: int,
*,
alpha: float = 0,
gamma: float = 0,
ignore_gain: bool = False,
phase: str = "minimum",
pade_order: int = 5,
cep_order: tuple[int, int] | int = 199,
n_fft: int = 512,
chunk_length: int | None = None,
warmup_length: int | None = None,
learnable: bool = False,
per_stage_pade_coefficients: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()

if phase != "minimum" or is_array_like(filter_order):
raise ValueError("Only minimum-phase filter is supported.")

self.ignore_gain = ignore_gain

self.mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum(
filter_order,
cep_order,
in_alpha=alpha,
in_gamma=gamma,
n_fft=n_fft,
device=device,
dtype=dtype,
)
self.linear_intpl = LinearInterpolation(frame_period)
self.root_pol = PolynomialToRoots(pade_order, device=device, dtype=dtype)

from torchlpc import sample_wise_lpc

self.sample_wise_lpc = sample_wise_lpc

if chunk_length is None:
self.chuking = False
else:
self.chuking = True
self.warmup_length = (
warmup_length if warmup_length is not None else cep_order
)
if chunk_length <= 0:
raise ValueError("chunk_length must be positive.")
if self.warmup_length < 0:
raise ValueError("warmup_length must be non-negative.")
frame_period = chunk_length - self.warmup_length
self.frame_x = Frame(chunk_length, frame_period, center=False)
self.frame_c = Frame(
cep_order * chunk_length, cep_order * frame_period, center=False
)

cr = mp.taylor(mp.exp, 0, pade_order * 2)
cp, cq = mp.pade(cr, pade_order, pade_order)
cp = np.array([float(x) for x in cp])
weights = cp[1:] / cp[:-1]
weights = np.insert(weights, 0, 1)
self.register_buffer("weights", to(weights, device=device, dtype=dtype))

if pade_order == 3:
a1 = np.linspace(1.0, 0.4, pade_order + 1)
elif pade_order == 4:
a1 = np.linspace(1.0, 0.6, pade_order + 1)
elif 5 <= pade_order <= 14:
a1 = np.ones(pade_order + 1)
else:
raise ValueError("pade_order must be in [3, 14].")

if learnable and per_stage_pade_coefficients:
a2 = a1
a1 = np.ones(pade_order + 1)
a1 = to(a1, device=device, dtype=dtype)
a2 = to(a2, device=device, dtype=dtype)
self.a1 = nn.Parameter(a1)
self.a2 = nn.Parameter(a2)
else:
a1 = to(a1, device=device, dtype=dtype)
if learnable:
self.a1 = nn.Parameter(a1)
else:
self.register_buffer("a1", a1)
self.a2 = self.a1

def forward(
self, x: torch.Tensor, mc: torch.Tensor, return_roots: bool = False
) -> torch.Tensor:
if x.dim() == 1:
x = x.unsqueeze(0)
mc = mc.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False

if x.dim() != 2 or mc.dim() != 3:
raise ValueError("x and mc must be 2-D and 3-D tensors, respectively.")

c = self.mgc2c(mc)
c0, c1 = torch.split(c, [1, c.size(-1) - 1], dim=-1)
c_b = self.linear_intpl(c1.flip(-1))
c_a = self.linear_intpl(c1)

T = x.size(-1)
B, _, M = c_a.size()

a1 = torch.clip(self.a1, min=1e-1, max=1e1)
a1[0] = 1.0
a2 = torch.clip(self.a2, min=1e-1, max=1e1)
a2[0] = 1.0

c_b2, c_b1 = torch.split(c_b, [c_b.size(-1) - 1, 1], dim=-1)
c_b1 = c_b1.squeeze(-1)

# Numerator, 1st stage:
y = x * a1[0]
for i in range(1, len(a1)):
x = F.pad(x[..., :-1], (1, 0))
x = x * c_b1 * self.weights[i]
y += x * a1[i]

# Numerator, 2nd stage:
x = y
y = x * a2[0]
for i in range(1, len(a2)):
x = F.pad(x, (M, 0))
x = x.unfold(-1, M + 1, 1)
x = (x[..., :-2] * c_b2).sum(-1) * self.weights[i]
y += x * a2[i]

if self.chuking:
y = F.pad(y, (self.warmup_length, 0))
y = self.frame_x(y)
y = y.reshape(-1, y.size(-1))

c_a = c_a.reshape(B, -1)
c_a = F.pad(c_a, (M * self.warmup_length, 0))
c_a = self.frame_c(c_a)
c_a = c_a.reshape(y.size(0), y.size(1), M)

c_a1, c_a2 = torch.split(c_a, [1, c_a.size(-1) - 1], dim=-1)
c_a2 = F.pad(c_a2, (1, 0))

def compute_roots(a: torch.Tensor) -> torch.Tensor:
pade_coefficients = torch.cumprod(self.weights, 0) * a
roots = self.root_pol(pade_coefficients.flip(0).double())
roots = roots.to(
torch.complex64 if a.dtype == torch.float32 else torch.complex128
)
return roots

roots1 = compute_roots(a1)
roots2 = compute_roots(a2)
roots = torch.stack([roots1, roots2], dim=0)

# Denominator, 1st stage:
y = y.to(roots.dtype)
p1 = torch.reciprocal(roots1)
for i in range(len(p1)):
y = self.sample_wise_lpc(y, (p1[i] * c_a1))

# Denominator, 2nd stage:
p2 = torch.reciprocal(roots2)
for i in range(len(p2)):
y = self.sample_wise_lpc(y, (p2[i] * c_a2))
y = y.real

if self.chuking:
y = y[..., self.warmup_length :]
y = y.reshape(B, -1)
y = y[..., :T]

if not self.ignore_gain:
K = torch.exp(self.linear_intpl(c0))
y = y * K.squeeze(-1)

if unsqueezed:
y = y.squeeze(0)

if return_roots:
return y, roots
return y
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [
dependencies = [
"numpy >= 1.23.0",
"scipy >= 1.12.0",
"mpmath >= 0.17.0",
"tqdm >= 4.63.0",
"soundfile >= 0.10.2",
"torch >= 2.3.1",
Expand Down
Loading