diff --git a/compiam/data.py b/compiam/data.py index fc1c2a4b..ae4d166d 100644 --- a/compiam/data.py +++ b/compiam/data.py @@ -125,6 +125,22 @@ }, }, }, + "rhythm:tcn-carnatic": { + "module_name": "compiam.rhythm.meter.tcn_carnatic", + "class_name": "TCNTracker", + "default_version": "v1", + "kwargs": { + "v1": { + "model_path": os.path.join( + "models", + "rhythm", + "tcn-carnatic" + ), + "download_link": "https://zenodo.org/records/18449067/files/compIAM-TCNCarnatic.zip?download=1", + "download_checksum": "995369933f2a344af0ffa57ea5c15e62", + }, + }, + }, "structure:dhrupad-bandish-segmentation": { "module_name": "compiam.structure.segmentation.dhrupad_bandish_segmentation", "class_name": "DhrupadBandishSegmentation", diff --git a/compiam/rhythm/README.md b/compiam/rhythm/README.md index 6fddc88d..e4a17534 100644 --- a/compiam/rhythm/README.md +++ b/compiam/rhythm/README.md @@ -4,8 +4,9 @@ |--------------------------------|------------------------------------------------|-----------| | Akshara Pulse Tracker Detector | Detect onsets of aksharas in tabla recordings | [1] | | Mnemonic Stroke Transcription | Bol/Solkattu trasncription using HMM | [2] | +| TCN Carnatic | Carnatic meter tracking using TCN | [3] | [1] Originally implemented by Ajay Srinivasamurthy as part of PyCompMusic - https://github.com/MTG/pycompmusic -[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain. \ No newline at end of file +[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain. diff --git a/compiam/rhythm/meter/__init__.py b/compiam/rhythm/meter/__init__.py index fbc5feb3..834b748d 100644 --- a/compiam/rhythm/meter/__init__.py +++ b/compiam/rhythm/meter/__init__.py @@ -4,7 +4,7 @@ from compiam.data import models_dict from compiam.rhythm.meter.akshara_pulse_tracker import AksharaPulseTracker - +from compiam.rhythm.meter.tcn_carnatic import TCNTracker # Show user the available tools def list_tools(): diff --git a/compiam/rhythm/meter/tcn_carnatic/__init__.py b/compiam/rhythm/meter/tcn_carnatic/__init__.py new file mode 100644 index 00000000..7718f3e7 --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/__init__.py @@ -0,0 +1,193 @@ +import os +import sys +import torch +import numpy as np +import madmom +from typing import Dict +from tqdm import tqdm +from compiam.exceptions import ModelNotTrainedError + +from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker +from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor +from compiam.rhythm.meter.tcn_carnatic.post import beat_tracker, joint_tracker, sequential_tracker +from compiam.utils.download import download_remote_model +from compiam.utils import get_logger, WORKDIR +from compiam.io import write_csv + +logger = get_logger(__name__) + +class TCNTracker(object): + """TCN beat tracker tuned to Carnatic Music.""" + def __init__(self, post_processor="joint", model_version=42, model_path=None, download_link=None, download_checksum=None, gpu=-1): + """TCN beat tracker init method. + + :param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'. + :param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62. + :param model_path: path to file to the model weights. + :param download_link: link to the remote pre-trained model. + :param download_checksum: checksum of the model file. + """ + ### IMPORTING OPTIONAL DEPENDENCIES + try: + global torch + import torch + global madmom + import madmom + + global MultiTracker + from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker + + except: + raise ImportError( + "In order to use this tool you need to have torch and madmom installed. " + "Install compIAM with torch and madmom support: pip install 'compiam[torch,madmom]'" + ) + ### + if post_processor not in ["beat", "joint", "sequential"]: + raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.") + if model_version not in [42, 52, 62]: + raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.") + + self.gpu = gpu + self.device = None + self.select_gpu(gpu) + + self.model_path = model_path + self.model_version = f'multitracker_{model_version}.pth' + self.download_link = download_link + self.download_checksum = download_checksum + + self.trained = False + self.model = self._build_model() + if self.model_path is not None: + self.load_model(self.model_path) + self.pre_processor = PreProcessor(fps=100) + self.pad_frames = 2 + + self.post_processor = joint_tracker if post_processor == "joint" else \ + sequential_tracker + + + def _build_model(self): + """Build the TCN model.""" + model = MultiTracker().to(self.device) + model.eval() + return model + + + def load_model(self, model_path): + """Load pre-trained model weights.""" + if not os.path.exists(os.path.join(model_path, self.model_version)): + self.download_model(model_path) # Downloading model weights + + self.model.load_weights(os.path.join(model_path, self.model_version), self.device) + + self.model_path = model_path + self.trained = True + + + def download_model(self, model_path=None, force_overwrite=True): + """Download pre-trained model.""" + download_path = ( + #os.sep + os.path.join(*model_path.split(os.sep)[:-2]) + model_path + if model_path is not None + else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic") + ) + # Creating model folder to store the weights + if not os.path.exists(download_path): + os.makedirs(download_path) + download_remote_model( + self.download_link, + self.download_checksum, + download_path, + force_overwrite=force_overwrite, + ) + + + @torch.no_grad() + def predict(self, input_data: str) -> Dict: + """Run inference on input audio file. + + :param input_data: path to audio file or numpy array like audio signal. + + :returns: a 2-D list with beats and beat positions. + """ + + if self.trained is False: + raise ModelNotTrainedError( + """Model is not trained. Please load model before running inference! + You can load the pre-trained instance with the load_model wrapper.""" + ) + + features = self.preprocess_audio(input_data) + x = torch.from_numpy(features).to(self.device) + output = self.model(x) + beats_act = output["beats"].squeeze().detach().cpu().numpy() + downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy() + + pred = self.post_processor(beats_act, downbeats_act) + + return pred + + def preprocess_audio(self, input_data: str, input_sr: int = 44100) -> np.ndarray: + """Preprocess input audio file to extract features for inference. + :param audio_path: Path to the input audio file. + :param input_sr: Sampling rate of the input audio file. + + :returns: Preprocessed features as a numpy array. + """ + if isinstance(input_data, str): + if not os.path.exists(input_data): + raise FileNotFoundError("Target audio not found.") + audio, sr = madmom.io.audio.load_audio_file(input_data) + if audio.shape[0] == 2: + audio = audio.mean(axis=0) + s = madmom.audio.Signal(audio, sr, num_channels=1) + elif isinstance(input_data, np.ndarray): + audio = input_data + if audio.shape[0] == 2: + audio = audio.mean(axis=0) + s = madmom.audio.Signal(audio, input_sr, num_channels=1) + else: + raise ValueError("Input must be path to audio signal or an audio array") + + x = self.pre_processor(s) + + pad_start = np.repeat(x[:1], self.pad_frames, axis=0) + pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0) + x_padded = np.concatenate((pad_start, x, pad_stop)) + + x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0) + + return x_final + + @staticmethod + def save_pitch(data, output_path): + """Calling the write_csv function in compiam.io to write the output beat track in a file + + :param data: the data to write + :param output_path: the path where the data is going to be stored + + :returns: None + """ + return write_csv(data, output_path) + + + def select_gpu(self, gpu="-1"): + """Select the GPU to use for inference. + + :param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc. + :returns: None + """ + if int(gpu) == -1: + self.device = torch.device("cpu") + else: + if torch.cuda.is_available(): + self.device = torch.device("cuda:" + str(gpu)) + elif torch.backends.mps.is_available(): + self.device = torch.device("mps:" + str(gpu)) + else: + self.device = torch.device("cpu") + logger.warning("No GPU available. Running on CPU.") + self.gpu = gpu diff --git a/compiam/rhythm/meter/tcn_carnatic/model.py b/compiam/rhythm/meter/tcn_carnatic/model.py new file mode 100644 index 00000000..3b4ff51e --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/model.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn + +class ResBlock(nn.Module): + def __init__(self, dilation_rate, n_filters, kernel_size, dropout_rate=0.15): + super().__init__() + in_channels = 20 + self.res = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=1, padding='same') + self.conv_1 = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, dilation=dilation_rate, padding='same') + self.conv_2 = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, dilation=dilation_rate*2, padding='same') + self.elu = nn.ELU() + self.dropout = nn.Dropout(0.15) + self.conv_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=n_filters, kernel_size=1, padding='same') + + return + + def forward(self, x): + res_x = self.res(x) + conv_1 = self.conv_1(x) + conv_2 = self.conv_2(x) + concat = torch.cat([conv_1, conv_2], dim=1) + out = self.elu(concat) + out = self.dropout(out) + out = self.conv_3(out) + + return res_x + out, out + +class TCN(nn.Module): + def __init__(self, n_filters, kernel_size, n_dilations, dropout_rate): + super().__init__() + dilations = [2**i for i in range(n_dilations)] + + self.tcn_layers = nn.ModuleDict({}) + + for idx, d in enumerate(dilations): + self.tcn_layers[f"tcn_{idx}"] = ResBlock(d, n_filters, kernel_size, dropout_rate) + + self.activation = nn.ELU() + + return + + def forward(self, x): + skip_connections = [] + for tcn_i in self.tcn_layers: + x, skip_out = self.tcn_layers[tcn_i](x) + skip_connections.append(skip_out) + + x = self.activation(x) + + skip = torch.stack(skip_connections, dim=-1).sum(dim=-1) + + return x, skip + + +class MultiTracker(nn.Module): + def __init__(self, n_filters=20, n_dilations=11, kernel_size=5, dropout_rate=0.15): + super().__init__() + self.dropout_rate = dropout_rate + + self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(3, 3), stride=1, padding="valid") + self.elu_1 = nn.ELU() + self.mp_1 = nn.MaxPool2d((1, 3)) + self.dropout_1 = nn.Dropout(dropout_rate) + + self.conv_2 = nn.Conv2d(in_channels=n_filters, out_channels=n_filters, kernel_size=(1, 10), padding="valid") + self.elu_2 = nn.ELU() + self.mp_2 = nn.MaxPool2d((1, 3)) + self.dropout_2 = nn.Dropout(dropout_rate) + + self.conv_3 = nn.Conv2d(in_channels=n_filters, out_channels=n_filters, kernel_size=(3, 3), padding="valid") + self.elu_3 = nn.ELU() + self.mp_3 = nn.MaxPool2d((1, 3)) + self.dropout_3 = nn.Dropout(dropout_rate) + + self.tcn = TCN(n_filters, kernel_size, n_dilations, dropout_rate) + + self.beats_dropout = nn.Dropout(dropout_rate) + self.beats_dense = nn.Linear(n_filters, 1) + self.beats_act = nn.Sigmoid() + + self.downbeats_dropout = nn.Dropout(dropout_rate) + self.downbeats_dense = nn.Linear(n_filters, 1) + self.downbeats_act = nn.Sigmoid() + + self.tempo_dropout = nn.Dropout(dropout_rate) + self.tempo_dense = nn.Linear(n_filters, 300) + self.tempo_act = nn.Softmax(dim=1) + + def forward(self, x): + x = self.conv_1(x) + x = self.elu_1(x) + x = self.mp_1(x) + x = self.dropout_1(x) + + x = self.conv_2(x) + x = self.elu_2(x) + x = self.mp_2(x) + x = self.dropout_2(x) + + x = self.conv_3(x) + x = self.elu_3(x) + x = self.mp_3(x) + x = self.dropout_3(x) + + x = torch.squeeze(x, -1) + + x, skip = self.tcn(x) + + x = x.transpose(-2,-1) + + beats = self.beats_dropout(x) + beats = self.beats_dense(beats) + beats = self.beats_act(beats) + + downbeats = self.downbeats_dropout(x) + downbeats = self.downbeats_dense(downbeats) + downbeats = self.downbeats_act(downbeats) + + activations = {} + activations["beats"] = beats + activations["downbeats"] = downbeats + + return activations + + def load_weights(self, model_path, device): + ckpt = torch.load(model_path, map_location=device) + state_dict = ckpt.get("state_dict", ckpt) + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("model."): + k = k[len("model."):] + new_state_dict[k] = v + + self.load_state_dict(new_state_dict) + return diff --git a/compiam/rhythm/meter/tcn_carnatic/post.py b/compiam/rhythm/meter/tcn_carnatic/post.py new file mode 100644 index 00000000..e31c7c35 --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/post.py @@ -0,0 +1,62 @@ +#import madmom +from madmom.features.beats import DBNBeatTrackingProcessor +from madmom.features.downbeats import DBNDownBeatTrackingProcessor, DBNBarTrackingProcessor +from scipy.ndimage import maximum_filter1d +import numpy as np + +fps= 100 # frames per second for the DBN processors +min_bpm= 55.0 +max_bpm= 230.0 + +epsilon = 1e-5 + +def clip_probabilities(probs): + """Clip probabilities to avoid exact 0 and 1 values that cause DBN issues.""" + probs = np.maximum(probs, 0) + probs = np.minimum(probs, 1) + return probs * (1 - epsilon) + epsilon / 2 + +def beat_tracker(beats_act): + beats_act = clip_probabilities(beats_act) + beat_dbn = DBNBeatTrackingProcessor( + min_bpm=min_bpm, max_bpm=max_bpm, fps=fps, transition_lambda=100, online=False) + + if beats_act.size > 1: + beats_pred = beat_dbn(beats_act) + return beats_pred + else: + # If no beats are detected, return an empty array + return np.array([]) + +def joint_tracker(beats_act, downbeats_act): + beats_act = clip_probabilities(beats_act) + downbeats_act = clip_probabilities(downbeats_act) + + downbeat_tracker = DBNDownBeatTrackingProcessor( + beats_per_bar=[3, 5, 7, 8], min_bpm=min_bpm, max_bpm=max_bpm, fps=fps) + + combined_act = np.vstack((np.maximum(beats_act - downbeats_act, 0), downbeats_act)).T + pred = downbeat_tracker(combined_act) + + return pred + +def sequential_tracker(beats_act, downbeats_act): + beats_act = clip_probabilities(beats_act) + downbeats_act = clip_probabilities(downbeats_act) + + beats = beat_tracker(beats_act) + + # bars (i.e. track beats and then downbeats) + beat_idx = (beats * fps).astype(int) + bar_act = maximum_filter1d(downbeats_act, size=3) + bar_act = bar_act[beat_idx] + bar_act = np.vstack((beats, bar_act)).T + + bar_tracker = DBNBarTrackingProcessor(beats_per_bar=(3, 5, 7, 8), meter_change_prob=1e-3, observation_weight=4) + + try: + pred = bar_tracker(bar_act) + except IndexError: + pred = np.empty((0, 2)) + + return pred diff --git a/compiam/rhythm/meter/tcn_carnatic/pre.py b/compiam/rhythm/meter/tcn_carnatic/pre.py new file mode 100644 index 00000000..8ba2826d --- /dev/null +++ b/compiam/rhythm/meter/tcn_carnatic/pre.py @@ -0,0 +1,17 @@ +import numpy as np +from madmom.processors import SequentialProcessor +from madmom.audio.signal import SignalProcessor, FramedSignalProcessor +from madmom.audio.stft import ShortTimeFourierTransformProcessor +from madmom.audio.spectrogram import FilteredSpectrogramProcessor, LogarithmicSpectrogramProcessor +from torch.utils.data import Dataset + + +class PreProcessor(SequentialProcessor): + def __init__(self, frame_size=2048, num_bands=12, log=np.log, add=1e-6, fps=100): + sig = SignalProcessor(num_channels=1, sample_rate=44100) + frames = FramedSignalProcessor(frame_size=frame_size, fps=fps) + stft = ShortTimeFourierTransformProcessor() + filt = FilteredSpectrogramProcessor(num_bands=num_bands) + spec = LogarithmicSpectrogramProcessor(log=log, add=add) + super(PreProcessor, self).__init__((sig, frames, stft, filt, spec, np.array)) + self.fps = fps diff --git a/docs/source/rhythm.rst b/docs/source/rhythm.rst index 2ca5ef18..1e94a0b7 100644 --- a/docs/source/rhythm.rst +++ b/docs/source/rhythm.rst @@ -24,4 +24,13 @@ Akshara Pulse Tracker --------------------- .. autoclass:: compiam.rhythm.meter.akshara_pulse_tracker.AksharaPulseTracker - :members: + :members: + +TCN Carnatic +--------------------- + +.. note:: + REQUIRES: torch + +.. autoclass:: compiam.rhythm.meter.tcn_carnatic.TCNTracker + :members: diff --git a/main.py b/main.py new file mode 100644 index 00000000..f7771bde --- /dev/null +++ b/main.py @@ -0,0 +1,217 @@ +import compiam +import os + + + +import numpy as np +import matplotlib.pyplot as plt +import librosa +import IPython.display as ipd + +#@title Plotting function +def plot_sonify(y, sr=22050, beats=None, labels=None, start_time=0, duration=None): + """ + Plots the waveform with optional beats and downbeats. If beats are None, only the waveform is plotted. + + Parameters: + - y: Audio signal waveform data. + - sr: Sampling rate of the audio signal. + - beats: Array of beat times (in seconds). If None, only the waveform is plotted. + - labels: Optional array of labels corresponding to beats. Downbeats should be labeled as 1. + - start_time: Start time in seconds for the plot window (default: 0). + - duration: Duration in seconds for the plot window. If None, plots entire signal. + """ + time = (np.arange(len(y))/sr) + + plt.figure(figsize=(100, 5)) + + # Apply time windowing if specified + plot_y = y + plot_time = time + + if duration is not None: + end_time = start_time + duration + # Find indices for the time window + start_idx = np.searchsorted(time, start_time) + end_idx = np.searchsorted(time, end_time) + plot_time = time[start_idx:end_idx] + plot_y = y[start_idx:end_idx] + + # Plot waveform + plt.plot(plot_time, plot_y) + plt.ylabel('Amplitude') + plt.xlabel('Time (sec)') + plt.title("Waveform") + + # Set x-axis limits if duration is specified + if duration is not None: + plt.xlim(start_time, start_time + duration) + + # Check if beats are provided + if beats is not None: + # Filter beats to the time window if duration is specified + if duration is not None: + end_time = start_time + duration + mask = (beats >= start_time) & (beats <= end_time) + windowed_beats = beats[mask] + windowed_labels = labels[mask] if labels is not None else None + else: + windowed_beats = beats + windowed_labels = labels + + # Separate beats into downbeats and other beats + if windowed_labels is not None: + windowed_labels = np.array(windowed_labels) # Ensure labels is a NumPy array + downbeat_times = windowed_beats[windowed_labels == 1] # Downbeats (label == 1) + beat_times = windowed_beats[windowed_labels != 1] # Other beats (label != 1) + plt.title("Waveform with Beats and Downbeats") + else: + downbeat_times = np.array([]) # Empty array for downbeats + beat_times = windowed_beats # All beats + plt.title("Waveform with Beats") + + # Plot vertical lines for downbeats and beats + ylim = np.max(np.abs(plot_y)) + plt.vlines(downbeat_times, ymin=-ylim-0.1, ymax=ylim+0.1, label='DownBeats', color='black', linewidths=1, linestyle='--') + plt.vlines(beat_times, ymin=-ylim-0.1, ymax=ylim+0.1, label='Beats', color='red', linewidths=1, linestyle=':') + + plt.legend(frameon=True, framealpha=1.0, edgecolor='black', loc='lower left', bbox_to_anchor=(0, 0.05), fontsize='small') + + # Generate click sounds for downbeats and beats + # Use the windowed audio for click generation + audio_length = len(plot_y) + + # Adjust beat times relative to the start of the windowed audio + if duration is not None: + adjusted_downbeat_times = downbeat_times - start_time + adjusted_beat_times = beat_times - start_time + + # Filter out negative times and times beyond the audio length + time_limit = len(plot_y) / sr + adjusted_downbeat_times = adjusted_downbeat_times[(adjusted_downbeat_times >= 0) & (adjusted_downbeat_times <= time_limit)] + adjusted_beat_times = adjusted_beat_times[(adjusted_beat_times >= 0) & (adjusted_beat_times <= time_limit)] + else: + adjusted_downbeat_times = downbeat_times + adjusted_beat_times = beat_times + + downbeat_click = librosa.clicks(times=adjusted_downbeat_times, sr=sr, click_freq=1000, length=audio_length, click_duration=0.1) + beat_click = librosa.clicks(times=adjusted_beat_times, sr=sr, click_freq=500, length=audio_length, click_duration=0.1) + + # Combine original audio with clicks + combined_audio = plot_y + beat_click + downbeat_click + + # Normalize combined audio to prevent clipping + combined_audio = combined_audio / np.max(np.abs(combined_audio)) + + else: + combined_audio = plot_y + + # Play the audio with clicks + audio_widget = ipd.Audio(combined_audio, rate=sr) + + plt.savefig('plot.png') + plt.close() + + return audio_widget + + +def plot_spec(y, sr, gt_beats=None, gt_labels=None, pred_beats=None, pred_labels=None, start_time=0, duration=30): + """ + Plots spectrogram with optional ground truth and predicted beats/downbeats. + + Parameters: + - y: Audio signal + - sr: Sampling rate + - gt_beats: Ground truth beat times (optional) + - gt_labels: Ground truth beat labels (optional, 1 for downbeats, other values for beats) + - pred_beats: Predicted beat times (optional) + - pred_labels: Predicted beat labels (optional, 1 for downbeats, other values for beats) + - start_time: Start time in seconds for the plot window (default: 0) + - duration: Duration in seconds for the plot window (default: 20) + """ + hop_length = 512 + + # Calculate end time + end_time = start_time + duration + + # Extract audio segment + start_sample = int(start_time * sr) + end_sample = int(end_time * sr) + audio_segment = y[start_sample:end_sample] + + # Generate spectrogram + spec = librosa.amplitude_to_db(np.abs(librosa.stft(audio_segment, hop_length=hop_length)), ref=np.max) + librosa.display.specshow(spec, y_axis='log', sr=sr, hop_length=hop_length, x_axis='time') + plt.title(f"Log-frequency power spectrogram of track ({start_time}s - {end_time}s)") + plt.colorbar(format="%+2.f dB") + + # Process ground truth beats if provided + if gt_beats is not None: + # Filter beat times and labels to the specified window + gt_mask = (gt_beats >= start_time) & (gt_beats <= end_time) + gt_beats_windowed = gt_beats[gt_mask] + gt_labels_windowed = gt_labels[gt_mask] if gt_labels is not None else None + + # Adjust times relative to the start of the window + gt_beats_adjusted = gt_beats_windowed - start_time + + # Separate beats and downbeats for ground truth + if gt_labels_windowed is not None: + gt_labels_windowed = np.array(gt_labels_windowed) + gt_downbeats = gt_beats_adjusted[gt_labels_windowed == 1] + gt_beats_only = gt_beats_adjusted[gt_labels_windowed != 1] + else: + gt_downbeats = np.array([]) + gt_beats_only = gt_beats_adjusted + + # Plot ground truth annotations in the upper part + if len(gt_beats_only) > 0: + plt.vlines(gt_beats_only, hop_length * 2, sr / 2, linestyles='dotted', color='w', alpha=0.8) + if len(gt_downbeats) > 0: + plt.vlines(gt_downbeats, hop_length * 2, sr / 2, color='w', alpha=0.8) + plt.text(duration * 0.35, hop_length * 1.65, 'GT Annotations (above)', color='w', fontsize=10, fontweight='bold') + + # Process predicted beats if provided + if pred_beats is not None: + # Filter beat times and labels to the specified window + pred_mask = (pred_beats >= start_time) & (pred_beats <= end_time) + pred_beats_windowed = pred_beats[pred_mask] + pred_labels_windowed = pred_labels[pred_mask] if pred_labels is not None else None + + # Adjust times relative to the start of the window + pred_beats_adjusted = pred_beats_windowed - start_time + + # Separate beats and downbeats for predictions + if pred_labels_windowed is not None: + pred_labels_windowed = np.array(pred_labels_windowed) + pred_downbeats = pred_beats_adjusted[pred_labels_windowed == 1] + pred_beats_only = pred_beats_adjusted[pred_labels_windowed != 1] + else: + pred_downbeats = np.array([]) + pred_beats_only = pred_beats_adjusted + + # Plot predictions in the lower part + if len(pred_beats_only) > 0: + plt.vlines(pred_beats_only, 0, hop_length, linestyles='dotted', color='w', alpha=0.8) + if len(pred_downbeats) > 0: + plt.vlines(pred_downbeats, 0, hop_length, color='w', alpha=0.8) + plt.text(duration * 0.35, hop_length * 1.1, 'Predictions (below)', color='w', fontsize=10, fontweight='bold') + + plt.show() + +audio_path = "../Devi Pavane.wav" + +print(compiam.list_models()) +from compiam.rhythm.meter.tcn_carnatic import TCNTracker +tracker = TCNTracker(post_processor="sequential") +tracker.trained = True +pred = tracker.predict(audio_path) + +beats = pred[:, 0] + +print(pred) +print(pred.shape) + +y, sr = librosa.load(audio_path) + +plot_sonify(y=y, sr=sr, beats=beats) diff --git a/pyproject.toml b/pyproject.toml index f5961a9c..08b3d37e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,9 +74,12 @@ torch = [ essentia = [ "essentia", ] +madmom = [ + "madmom @ git+https://github.com/vivekvjyn/madmom.git", +] [project.urls] Homepage = "https://github.com/MTG/compIAM" Documentation = "https://mtg.github.io/compIAM/" Issues = "https://github.com/MTG/compIAM/issues/" -Releases = "https://github.com/MTG/compIAM/releases/" \ No newline at end of file +Releases = "https://github.com/MTG/compIAM/releases/" diff --git a/tests/resources/rhythm/beat_test.wav b/tests/resources/rhythm/beat_test.wav new file mode 100644 index 00000000..bdfddc8c Binary files /dev/null and b/tests/resources/rhythm/beat_test.wav differ diff --git a/tests/rhythm/test_tcn_carnatic.py b/tests/rhythm/test_tcn_carnatic.py new file mode 100644 index 00000000..0bfc0a55 --- /dev/null +++ b/tests/rhythm/test_tcn_carnatic.py @@ -0,0 +1,55 @@ +import os +import pytest +import librosa + +import numpy as np + +from compiam.data import TESTDIR +from compiam.exceptions import ModelNotTrainedError + + +def test_predict_joint(): + from compiam.rhythm.meter.tcn_carnatic import TCNTracker + + tracker = TCNTracker() + with pytest.raises(ModelNotTrainedError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + tracker.trained = True + with pytest.raises(FileNotFoundError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + beats = tracker.predict( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + + print(beats.shape) + + audio_in, sr = librosa.load( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + beats_2 = tracker.predict(audio_in) + + assert isinstance(beats, np.ndarray) + assert isinstance(beats_2, np.ndarray) + +def test_predict_sequential(): + from compiam.rhythm.meter.tcn_carnatic import TCNTracker + + tracker = TCNTracker(post_processor="sequential") + with pytest.raises(ModelNotTrainedError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + tracker.trained = True + with pytest.raises(FileNotFoundError): + tracker.predict(os.path.join(TESTDIR, "resources", "rhythm", "hola.wav")) + beats = tracker.predict( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + + print(beats.shape) + + audio_in, sr = librosa.load( + os.path.join(TESTDIR, "resources", "rhythm", "beat_test.wav") + ) + beats_2 = tracker.predict(audio_in) + + assert isinstance(beats, np.ndarray) + assert isinstance(beats_2, np.ndarray)