diff --git a/.github/environment-ci.yml b/.github/environment-ci.yml index d3c24541..4e84827d 100644 --- a/.github/environment-ci.yml +++ b/.github/environment-ci.yml @@ -1,41 +1,42 @@ name: compiam-dev + channels: - conda-forge - defaults + dependencies: + - python=3.9 - pip - - "attrs>=23.1.0" - - "matplotlib>=3.0.0" - - "numpy>=1.20.3,<=1.26.4" - - "joblib>=1.2.0" - - "pathlib~=1.0.1" - - "tqdm>=4.66.1" - - "IPython>=7.34.0" - - "ipywidgets>=7.0.0,<8" - - "Jinja2~=3.1.2" - - "configobj~=5.0.6" - - "seaborn" - - "librosa>=0.10.1" - - "scikit-learn==1.5.2" - - "scikit-image~=0.24.0" - - "hmmlearn==0.3.3" - - "fastdtw~=0.3.4" - ####### - libvorbis - - pytest>=7.4.3 - ####### + - pip: - - "keras<3.0.0" - - "tensorflow>=2.12.0,<2.16" - - "torch==2.0.0" - - "torchaudio==2.0.1" - - "essentia" - - "soundfile>=0.12.1" - - "opencv-python~=4.6.0" - - "mirdata==0.3.9" - - "compmusic==0.4" - - "attrs>=23.1.0" - - "black>=23.3.0" - - "decorator>=5.1.1" - - "future>=0.18.3" - - "testcontainers>=3.7.1" \ No newline at end of file + - keras<3.0.0 + - tensorflow>=2.12.0,<2.16 + - torch==2.0.0 + - torchaudio==2.0.1 + - essentia + - soundfile>=0.12.1 + - opencv-python~=4.6.0 + - mirdata==0.3.9 + - compmusic==0.4 + - attrs>=23.1.0 + - black>=23.3.0 + - decorator>=5.1.1 + - future>=0.18.3 + - testcontainers>=3.7.1 + - madmom @ git+https://github.com/vivekvjyn/madmom.git + - matplotlib>=3.0.0 + - numpy>=1.20.3,<=1.23.5 + - joblib>=1.2.0 + - tqdm>=4.66.1 + - IPython>=7.34.0 + - ipywidgets>=7.0.0,<8 + - Jinja2~=3.1.2 + - configobj~=5.0.6 + - seaborn + - librosa>=0.10.1 + - scikit-learn==1.5.2 + - scikit-image~=0.24.0 + - hmmlearn==0.3.3 + - fastdtw~=0.3.4 + - pytest>=7.4.3 diff --git a/compiam/data.py b/compiam/data.py index fc1c2a4b..ae4d166d 100644 --- a/compiam/data.py +++ b/compiam/data.py @@ -125,6 +125,22 @@ }, }, }, + "rhythm:tcn-carnatic": { + "module_name": "compiam.rhythm.meter.tcn_carnatic", + "class_name": "TCNTracker", + "default_version": "v1", + "kwargs": { + "v1": { + "model_path": os.path.join( + "models", + "rhythm", + "tcn-carnatic" + ), + "download_link": "https://zenodo.org/records/18449067/files/compIAM-TCNCarnatic.zip?download=1", + "download_checksum": "995369933f2a344af0ffa57ea5c15e62", + }, + }, + }, "structure:dhrupad-bandish-segmentation": { "module_name": "compiam.structure.segmentation.dhrupad_bandish_segmentation", "class_name": "DhrupadBandishSegmentation", diff --git a/compiam/rhythm/README.md b/compiam/rhythm/README.md index 6fddc88d..e4a17534 100644 --- a/compiam/rhythm/README.md +++ b/compiam/rhythm/README.md @@ -4,8 +4,9 @@ |--------------------------------|------------------------------------------------|-----------| | Akshara Pulse Tracker Detector | Detect onsets of aksharas in tabla recordings | [1] | | Mnemonic Stroke Transcription | Bol/Solkattu trasncription using HMM | [2] | +| TCN Carnatic | Carnatic meter tracking using TCN | [3] | [1] Originally implemented by Ajay Srinivasamurthy as part of PyCompMusic - https://github.com/MTG/pycompmusic -[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain. \ No newline at end of file +[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain. diff --git a/compiam/rhythm/meter/__init__.py b/compiam/rhythm/meter/__init__.py index fbc5feb3..834b748d 100644 --- a/compiam/rhythm/meter/__init__.py +++ b/compiam/rhythm/meter/__init__.py @@ -4,7 +4,7 @@ from compiam.data import models_dict from compiam.rhythm.meter.akshara_pulse_tracker import AksharaPulseTracker - +from compiam.rhythm.meter.tcn_carnatic import TCNTracker # Show user the available tools def list_tools(): diff --git a/compiam/rhythm/meter/tcn_carnatic/__init__.py b/compiam/rhythm/meter/tcn_carnatic/__init__.py new file mode 100644 index 00000000..d8252224 --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/__init__.py @@ -0,0 +1,203 @@ +import os +import sys +import numpy as np +from typing import Dict +from tqdm import tqdm +from compiam.exceptions import ModelNotTrainedError + +from compiam.utils.download import download_remote_model +from compiam.utils import get_logger, WORKDIR +from compiam.io import write_csv + +logger = get_logger(__name__) + +class TCNTracker(object): + """TCN beat tracker tuned to Carnatic Music.""" + def __init__(self, + post_processor="joint", + model_version=42, + model_path=None, + download_link=None, + download_checksum=None, + gpu=-1): + """TCN beat tracker init method. + + :param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'. + :param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62. + :param model_path: path to file to the model weights. + :param download_link: link to the remote pre-trained model. + :param download_checksum: checksum of the model file. + """ + ### IMPORTING OPTIONAL DEPENDENCIES + try: + global torch + import torch + except ImportError: + raise ImportError( + "Torch is required to use TCNTracker. " + "Install compIAM with torch support: pip install 'compiam[torch]'" + ) + + try: + global madmom + import madmom + except ImportError: + raise ImportError( + "Madmom is required to use TCNTracker. " + "Install compIAM with madmom support: pip install 'compiam[madmom]'" + ) +### + global MultiTracker, PreProcessor, joint_tracker, sequential_tracker + from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker + from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor + from compiam.rhythm.meter.tcn_carnatic.post import joint_tracker, sequential_tracker + + if post_processor not in ["beat", "joint", "sequential"]: + raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.") + if model_version not in [42, 52, 62]: + raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.") + + self.gpu = gpu + self.device = None + self.select_gpu(gpu) + + self.model_path = model_path + self.model_version = f'multitracker_{model_version}.pth' + self.download_link = download_link + self.download_checksum = download_checksum + + self.trained = False + self.model = self._build_model() + if self.model_path is not None: + self.load_model(self.model_path) + self.pad_frames = 2 + + self.post_processor = joint_tracker if post_processor == "joint" else \ + sequential_tracker + + + def _build_model(self): + """Build the TCN model.""" + model = MultiTracker().to(self.device) + model.eval() + return model + + + def load_model(self, model_path): + """Load pre-trained model weights.""" + if not os.path.exists(os.path.join(model_path, self.model_version)): + self.download_model(model_path) # Downloading model weights + + self.model.load_weights(os.path.join(model_path, self.model_version), self.device) + + self.model_path = model_path + self.trained = True + + + def download_model(self, model_path=None, force_overwrite=True): + """Download pre-trained model.""" + download_path = ( + #os.sep + os.path.join(*model_path.split(os.sep)[:-2]) + model_path + if model_path is not None + else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic") + ) + # Creating model folder to store the weights + if not os.path.exists(download_path): + os.makedirs(download_path) + download_remote_model( + self.download_link, + self.download_checksum, + download_path, + force_overwrite=force_overwrite, + ) + + def predict(self, input_data: str, sr: int = 44100, min_bpm=55, max_bpm=230, beats_per_bar=[3, 5, 7, 8]) -> Dict: + """Run inference on input audio file. + + :param input_data: path to audio file or numpy array like audio signal. + :param sr: sampling rate of the input audio signal (default: 44100). + :param min_bpm: minimum BPM for beat tracking (default: 55). + :param max_bpm: maximum BPM for beat tracking (default: 230). + :param beats_per_bar: list of possible beats per bar for downbeat tracking (default: [3, 5, 7, 8]). + + :returns: a 2-D list with beats and beat positions. + """ + if self.trained is False: + raise ModelNotTrainedError( + """Model is not trained. Please load model before running inference! + You can load the pre-trained instance with the load_model wrapper.""" + ) + + features = self.preprocess_audio(input_data, sr) + x = torch.from_numpy(features).to(self.device) + output = self.model(x) + beats_act = output["beats"].squeeze().detach().cpu().numpy() + downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy() + + pred = self.post_processor(beats_act, downbeats_act, min_bpm=min_bpm, max_bpm=max_bpm, beats_per_bar=beats_per_bar) + + return pred + + def preprocess_audio(self, input_data: str, input_sr: int) -> np.ndarray: + """Preprocess input audio file to extract features for inference. + :param audio_path: Path to the input audio file. + :param input_sr: Sampling rate of the input audio file. + + :returns: Preprocessed features as a numpy array. + """ + if isinstance(input_data, str): + if not os.path.exists(input_data): + raise FileNotFoundError("Target audio not found.") + audio, sr = madmom.io.audio.load_audio_file(input_data) + if audio.shape[0] == 2: + audio = audio.mean(axis=0) + signal = madmom.audio.Signal(audio, sr, num_channels=1) + elif isinstance(input_data, np.ndarray): + audio = input_data + if audio.shape[0] == 2: + audio = audio.mean(axis=0) + signal = madmom.audio.Signal(audio, input_sr, num_channels=1) + sr = input_sr + else: + raise ValueError("Input must be path to audio signal or an audio array") + + x = PreProcessor(sample_rate=sr)(signal) + + pad_start = np.repeat(x[:1], self.pad_frames, axis=0) + pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0) + x_padded = np.concatenate((pad_start, x, pad_stop)) + + x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0) + + return x_final + + @staticmethod + def save_pitch(data, output_path): + """Calling the write_csv function in compiam.io to write the output beat track in a file + + :param data: the data to write + :param output_path: the path where the data is going to be stored + + :returns: None + """ + return write_csv(data, output_path) + + + def select_gpu(self, gpu="-1"): + """Select the GPU to use for inference. + + :param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc. + :returns: None + """ + if int(gpu) == -1: + self.device = torch.device("cpu") + else: + if torch.cuda.is_available(): + self.device = torch.device("cuda:" + str(gpu)) + elif torch.backends.mps.is_available(): + self.device = torch.device("mps:" + str(gpu)) + else: + self.device = torch.device("cpu") + logger.warning("No GPU available. Running on CPU.") + self.gpu = gpu diff --git a/compiam/rhythm/meter/tcn_carnatic/model.py b/compiam/rhythm/meter/tcn_carnatic/model.py new file mode 100644 index 00000000..3b4ff51e --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/model.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn + +class ResBlock(nn.Module): + def __init__(self, dilation_rate, n_filters, kernel_size, dropout_rate=0.15): + super().__init__() + in_channels = 20 + self.res = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=1, padding='same') + self.conv_1 = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, dilation=dilation_rate, padding='same') + self.conv_2 = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, dilation=dilation_rate*2, padding='same') + self.elu = nn.ELU() + self.dropout = nn.Dropout(0.15) + self.conv_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=n_filters, kernel_size=1, padding='same') + + return + + def forward(self, x): + res_x = self.res(x) + conv_1 = self.conv_1(x) + conv_2 = self.conv_2(x) + concat = torch.cat([conv_1, conv_2], dim=1) + out = self.elu(concat) + out = self.dropout(out) + out = self.conv_3(out) + + return res_x + out, out + +class TCN(nn.Module): + def __init__(self, n_filters, kernel_size, n_dilations, dropout_rate): + super().__init__() + dilations = [2**i for i in range(n_dilations)] + + self.tcn_layers = nn.ModuleDict({}) + + for idx, d in enumerate(dilations): + self.tcn_layers[f"tcn_{idx}"] = ResBlock(d, n_filters, kernel_size, dropout_rate) + + self.activation = nn.ELU() + + return + + def forward(self, x): + skip_connections = [] + for tcn_i in self.tcn_layers: + x, skip_out = self.tcn_layers[tcn_i](x) + skip_connections.append(skip_out) + + x = self.activation(x) + + skip = torch.stack(skip_connections, dim=-1).sum(dim=-1) + + return x, skip + + +class MultiTracker(nn.Module): + def __init__(self, n_filters=20, n_dilations=11, kernel_size=5, dropout_rate=0.15): + super().__init__() + self.dropout_rate = dropout_rate + + self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(3, 3), stride=1, padding="valid") + self.elu_1 = nn.ELU() + self.mp_1 = nn.MaxPool2d((1, 3)) + self.dropout_1 = nn.Dropout(dropout_rate) + + self.conv_2 = nn.Conv2d(in_channels=n_filters, out_channels=n_filters, kernel_size=(1, 10), padding="valid") + self.elu_2 = nn.ELU() + self.mp_2 = nn.MaxPool2d((1, 3)) + self.dropout_2 = nn.Dropout(dropout_rate) + + self.conv_3 = nn.Conv2d(in_channels=n_filters, out_channels=n_filters, kernel_size=(3, 3), padding="valid") + self.elu_3 = nn.ELU() + self.mp_3 = nn.MaxPool2d((1, 3)) + self.dropout_3 = nn.Dropout(dropout_rate) + + self.tcn = TCN(n_filters, kernel_size, n_dilations, dropout_rate) + + self.beats_dropout = nn.Dropout(dropout_rate) + self.beats_dense = nn.Linear(n_filters, 1) + self.beats_act = nn.Sigmoid() + + self.downbeats_dropout = nn.Dropout(dropout_rate) + self.downbeats_dense = nn.Linear(n_filters, 1) + self.downbeats_act = nn.Sigmoid() + + self.tempo_dropout = nn.Dropout(dropout_rate) + self.tempo_dense = nn.Linear(n_filters, 300) + self.tempo_act = nn.Softmax(dim=1) + + def forward(self, x): + x = self.conv_1(x) + x = self.elu_1(x) + x = self.mp_1(x) + x = self.dropout_1(x) + + x = self.conv_2(x) + x = self.elu_2(x) + x = self.mp_2(x) + x = self.dropout_2(x) + + x = self.conv_3(x) + x = self.elu_3(x) + x = self.mp_3(x) + x = self.dropout_3(x) + + x = torch.squeeze(x, -1) + + x, skip = self.tcn(x) + + x = x.transpose(-2,-1) + + beats = self.beats_dropout(x) + beats = self.beats_dense(beats) + beats = self.beats_act(beats) + + downbeats = self.downbeats_dropout(x) + downbeats = self.downbeats_dense(downbeats) + downbeats = self.downbeats_act(downbeats) + + activations = {} + activations["beats"] = beats + activations["downbeats"] = downbeats + + return activations + + def load_weights(self, model_path, device): + ckpt = torch.load(model_path, map_location=device) + state_dict = ckpt.get("state_dict", ckpt) + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("model."): + k = k[len("model."):] + new_state_dict[k] = v + + self.load_state_dict(new_state_dict) + return diff --git a/compiam/rhythm/meter/tcn_carnatic/post.py b/compiam/rhythm/meter/tcn_carnatic/post.py new file mode 100644 index 00000000..5b7fb090 --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/post.py @@ -0,0 +1,56 @@ +#import madmom +from madmom.features.beats import DBNBeatTrackingProcessor +from madmom.features.downbeats import DBNDownBeatTrackingProcessor, DBNBarTrackingProcessor +from scipy.ndimage import maximum_filter1d +import numpy as np + +def clip_probabilities(probs, epsilon=1e-5): + """Clip probabilities to avoid exact 0 and 1 values that cause DBN issues.""" + probs = np.maximum(probs, 0) + probs = np.minimum(probs, 1) + return probs * (1 - epsilon) + epsilon / 2 + +def beat_tracker(beats_act, min_bpm=55, max_bpm=230, fps=100, transition_lambda=100): + beats_act = clip_probabilities(beats_act) + beat_dbn = DBNBeatTrackingProcessor( + min_bpm=min_bpm, max_bpm=max_bpm, fps=fps, transition_lambda=transition_lambda, online=False) + + if beats_act.size > 1: + beats_pred = beat_dbn(beats_act) + return beats_pred + else: + # If no beats are detected, return an empty array + return np.array([]) + +def joint_tracker(beats_act, downbeats_act, min_bpm=55, max_bpm=230, fps=100, beats_per_bar=[3, 5, 7, 8]): + beats_act = clip_probabilities(beats_act) + downbeats_act = clip_probabilities(downbeats_act) + + downbeat_tracker = DBNDownBeatTrackingProcessor( + beats_per_bar=beats_per_bar, min_bpm=min_bpm, max_bpm=max_bpm, fps=fps) + + combined_act = np.vstack((np.maximum(beats_act - downbeats_act, 0), downbeats_act)).T + pred = downbeat_tracker(combined_act) + + return pred + +def sequential_tracker(beats_act, downbeats_act, min_bpm=55, max_bpm=230, fps=100, beats_per_bar=[3, 5, 7, 8]): + beats_act = clip_probabilities(beats_act) + downbeats_act = clip_probabilities(downbeats_act) + + beats = beat_tracker(beats_act) + + # bars (i.e. track beats and then downbeats) + beat_idx = (beats * fps).astype(int) + bar_act = maximum_filter1d(downbeats_act, size=3) + bar_act = bar_act[beat_idx] + bar_act = np.vstack((beats, bar_act)).T + + bar_tracker = DBNBarTrackingProcessor(beats_per_bar=beats_per_bar, meter_change_prob=1e-3, observation_weight=4) + + try: + pred = bar_tracker(bar_act, min_bpm=min_bpm, max_bpm=max_bpm) + except IndexError: + pred = np.empty((0, 2)) + + return pred diff --git a/compiam/rhythm/meter/tcn_carnatic/pre.py b/compiam/rhythm/meter/tcn_carnatic/pre.py new file mode 100644 index 00000000..c8b0f9b6 --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/pre.py @@ -0,0 +1,17 @@ +import numpy as np +from madmom.processors import SequentialProcessor +from madmom.audio.signal import SignalProcessor, FramedSignalProcessor +from madmom.audio.stft import ShortTimeFourierTransformProcessor +from madmom.audio.spectrogram import FilteredSpectrogramProcessor, LogarithmicSpectrogramProcessor +from torch.utils.data import Dataset + + +class PreProcessor(SequentialProcessor): + def __init__(self, sample_rate, frame_size=2048, num_bands=12, log=np.log, add=1e-6, fps=100, num_channels=1): + sig = SignalProcessor(num_channels=num_channels, sample_rate=sample_rate) + frames = FramedSignalProcessor(frame_size=frame_size, fps=fps) + stft = ShortTimeFourierTransformProcessor() + filt = FilteredSpectrogramProcessor(num_bands=num_bands) + spec = LogarithmicSpectrogramProcessor(log=log, add=add) + super(PreProcessor, self).__init__((sig, frames, stft, filt, spec, np.array)) + self.fps = fps diff --git a/docs/source/rhythm.rst b/docs/source/rhythm.rst index 2ca5ef18..1e94a0b7 100644 --- a/docs/source/rhythm.rst +++ b/docs/source/rhythm.rst @@ -24,4 +24,13 @@ Akshara Pulse Tracker --------------------- .. autoclass:: compiam.rhythm.meter.akshara_pulse_tracker.AksharaPulseTracker - :members: + :members: + +TCN Carnatic +--------------------- + +.. note:: + REQUIRES: torch + +.. autoclass:: compiam.rhythm.meter.tcn_carnatic.TCNTracker + :members: diff --git a/pyproject.toml b/pyproject.toml index f5961a9c..62c03ac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ dependencies = [ "attrs>=23.1.0", "matplotlib>=3.0.0", - "numpy>=1.20.3,<=1.26.4", + "numpy>=1.20.3,<=1.23.5", "joblib>=1.2.0", "pathlib~=1.0.1", "tqdm>=4.66.1", @@ -74,9 +74,12 @@ torch = [ essentia = [ "essentia", ] +madmom = [ + "madmom @ git+https://github.com/vivekvjyn/madmom.git", +] [project.urls] Homepage = "https://github.com/MTG/compIAM" Documentation = "https://mtg.github.io/compIAM/" Issues = "https://github.com/MTG/compIAM/issues/" -Releases = "https://github.com/MTG/compIAM/releases/" \ No newline at end of file +Releases = "https://github.com/MTG/compIAM/releases/" diff --git a/tests/conftest.py b/tests/conftest.py index 9fb61c7f..829bf7d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ "essentia", "full_ml", "all", + "madmom @ git+https://github.com/vivekvjyn/madmom.git" ] diff --git a/tests/resources/rhythm/48k.wav b/tests/resources/rhythm/48k.wav new file mode 100644 index 00000000..576942ee Binary files /dev/null and b/tests/resources/rhythm/48k.wav differ diff --git a/tests/resources/rhythm/beat_test.wav b/tests/resources/rhythm/beat_test.wav new file mode 100644 index 00000000..bdfddc8c Binary files /dev/null and b/tests/resources/rhythm/beat_test.wav differ diff --git a/tests/rhythm/test_tcn_carnatic.py b/tests/rhythm/test_tcn_carnatic.py new file mode 100644 index 00000000..6a2e059b --- /dev/null +++ b/tests/rhythm/test_tcn_carnatic.py @@ -0,0 +1,69 @@ +import os +import pytest +import librosa +import torch +import numpy as np + +from compiam.data import TESTDIR +from compiam.exceptions import ModelNotTrainedError + +def test_predict_joint(): + from compiam.rhythm.meter.tcn_carnatic import TCNTracker + + tracker = TCNTracker() + with pytest.raises(ModelNotTrainedError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + tracker.trained = True + with pytest.raises(FileNotFoundError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + + print(tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav"))) + beats = tracker.predict( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + + audio_in, sr = librosa.load( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + beats_2 = tracker.predict(audio_in, sr) + + assert isinstance(beats, np.ndarray) + assert isinstance(beats_2, np.ndarray) + assert beats.shape[1] == 2 + assert beats_2.shape[1] == 2 + +def test_predict_sequential(): + from compiam.rhythm.meter.tcn_carnatic import TCNTracker + + tracker = TCNTracker(post_processor="sequential") + with pytest.raises(ModelNotTrainedError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + tracker.trained = True + with pytest.raises(FileNotFoundError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + beats = tracker.predict( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + + audio_in, sr = librosa.load( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + beats_2 = tracker.predict(audio_in) + + assert isinstance(beats, np.ndarray) + assert isinstance(beats_2, np.ndarray) + assert beats.shape[1] == 2 + assert beats_2.shape[1] == 2 + +def test_48k(): + from compiam.rhythm.meter.tcn_carnatic import TCNTracker + + tracker = TCNTracker(post_processor="sequential") + tracker.trained = True + audio_in, sr = librosa.load( + os.path.join(TESTDIR, "resources", "rhythm", "48k.wav"), sr=48000 + ) + beats = tracker.predict(audio_in) + + assert isinstance(beats, np.ndarray) + assert beats.shape[1] == 2