Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions compiam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@
},
},
},
"rhythm:tcn-carnatic": {
"module_name": "compiam.rhythm.meter.tcn_carnatic",
"class_name": "TCNTracker",
"default_version": "v1",
"kwargs": {
"v1": {
"model_path": os.path.join(
"models",
"rhythm",
"tcn-carnatic"
),
"download_link": "https://zenodo.org/records/18449067/files/compIAM-TCNCarnatic.zip?download=1",
"download_checksum": "995369933f2a344af0ffa57ea5c15e62",
},
},
},
"structure:dhrupad-bandish-segmentation": {
"module_name": "compiam.structure.segmentation.dhrupad_bandish_segmentation",
"class_name": "DhrupadBandishSegmentation",
Expand Down
3 changes: 2 additions & 1 deletion compiam/rhythm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
|--------------------------------|------------------------------------------------|-----------|
| Akshara Pulse Tracker Detector | Detect onsets of aksharas in tabla recordings | [1] |
| Mnemonic Stroke Transcription | Bol/Solkattu trasncription using HMM | [2] |
| TCN Carnatic | Carnatic meter tracking using TCN | [3] |


[1] Originally implemented by Ajay Srinivasamurthy as part of PyCompMusic - https://github.com/MTG/pycompmusic

[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain.
[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain.
2 changes: 1 addition & 1 deletion compiam/rhythm/meter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from compiam.data import models_dict

from compiam.rhythm.meter.akshara_pulse_tracker import AksharaPulseTracker

from compiam.rhythm.meter.tcn_carnatic import TCNTracker

# Show user the available tools
def list_tools():
Expand Down
193 changes: 193 additions & 0 deletions compiam/rhythm/meter/tcn_carnatic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os
import sys
import torch
import numpy as np
import madmom
from typing import Dict
from tqdm import tqdm
from compiam.exceptions import ModelNotTrainedError

from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker
from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor
from compiam.rhythm.meter.tcn_carnatic.post import beat_tracker, joint_tracker, sequential_tracker
from compiam.utils.download import download_remote_model
from compiam.utils import get_logger, WORKDIR
from compiam.io import write_csv

logger = get_logger(__name__)

class TCNTracker(object):
"""TCN beat tracker tuned to Carnatic Music."""
def __init__(self, post_processor="joint", model_version=42, model_path=None, download_link=None, download_checksum=None, gpu=-1):
"""TCN beat tracker init method.

:param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'.
:param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62.
:param model_path: path to file to the model weights.
:param download_link: link to the remote pre-trained model.
:param download_checksum: checksum of the model file.
"""
### IMPORTING OPTIONAL DEPENDENCIES
try:
global torch
import torch
global madmom
import madmom

global MultiTracker
from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker

except:
raise ImportError(
"In order to use this tool you need to have torch and madmom installed. "
"Install compIAM with torch and madmom support: pip install 'compiam[torch,madmom]'"
)
###
if post_processor not in ["beat", "joint", "sequential"]:
raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.")
if model_version not in [42, 52, 62]:
raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.")

self.gpu = gpu
self.device = None
self.select_gpu(gpu)

self.model_path = model_path
self.model_version = f'multitracker_{model_version}.pth'
self.download_link = download_link
self.download_checksum = download_checksum

self.trained = False
self.model = self._build_model()
if self.model_path is not None:
self.load_model(self.model_path)
self.pre_processor = PreProcessor(fps=100)
self.pad_frames = 2

self.post_processor = joint_tracker if post_processor == "joint" else \
sequential_tracker


def _build_model(self):
"""Build the TCN model."""
model = MultiTracker().to(self.device)
model.eval()
return model


def load_model(self, model_path):
"""Load pre-trained model weights."""
if not os.path.exists(os.path.join(model_path, self.model_version)):
self.download_model(model_path) # Downloading model weights

self.model.load_weights(os.path.join(model_path, self.model_version), self.device)

self.model_path = model_path
self.trained = True


def download_model(self, model_path=None, force_overwrite=True):
"""Download pre-trained model."""
download_path = (
#os.sep + os.path.join(*model_path.split(os.sep)[:-2])
model_path
if model_path is not None
else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic")
)
# Creating model folder to store the weights
if not os.path.exists(download_path):
os.makedirs(download_path)
download_remote_model(
self.download_link,
self.download_checksum,
download_path,
force_overwrite=force_overwrite,
)


@torch.no_grad()
def predict(self, input_data: str) -> Dict:
"""Run inference on input audio file.

:param input_data: path to audio file or numpy array like audio signal.

:returns: a 2-D list with beats and beat positions.
"""

if self.trained is False:
raise ModelNotTrainedError(
"""Model is not trained. Please load model before running inference!
You can load the pre-trained instance with the load_model wrapper."""
)

features = self.preprocess_audio(input_data)
x = torch.from_numpy(features).to(self.device)
output = self.model(x)
beats_act = output["beats"].squeeze().detach().cpu().numpy()
downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy()

pred = self.post_processor(beats_act, downbeats_act)

return pred

def preprocess_audio(self, input_data: str, input_sr: int = 44100) -> np.ndarray:
"""Preprocess input audio file to extract features for inference.
:param audio_path: Path to the input audio file.
:param input_sr: Sampling rate of the input audio file.

:returns: Preprocessed features as a numpy array.
"""
if isinstance(input_data, str):
if not os.path.exists(input_data):
raise FileNotFoundError("Target audio not found.")
audio, sr = madmom.io.audio.load_audio_file(input_data)
if audio.shape[0] == 2:
audio = audio.mean(axis=0)
s = madmom.audio.Signal(audio, sr, num_channels=1)
elif isinstance(input_data, np.ndarray):
audio = input_data
if audio.shape[0] == 2:
audio = audio.mean(axis=0)
s = madmom.audio.Signal(audio, input_sr, num_channels=1)
else:
raise ValueError("Input must be path to audio signal or an audio array")

x = self.pre_processor(s)

pad_start = np.repeat(x[:1], self.pad_frames, axis=0)
pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0)
x_padded = np.concatenate((pad_start, x, pad_stop))

x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0)

return x_final

@staticmethod
def save_pitch(data, output_path):
"""Calling the write_csv function in compiam.io to write the output beat track in a file

:param data: the data to write
:param output_path: the path where the data is going to be stored

:returns: None
"""
return write_csv(data, output_path)


def select_gpu(self, gpu="-1"):
"""Select the GPU to use for inference.

:param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc.
:returns: None
"""
if int(gpu) == -1:
self.device = torch.device("cpu")
else:
if torch.cuda.is_available():
self.device = torch.device("cuda:" + str(gpu))
elif torch.backends.mps.is_available():
self.device = torch.device("mps:" + str(gpu))
else:
self.device = torch.device("cpu")
logger.warning("No GPU available. Running on CPU.")
self.gpu = gpu
135 changes: 135 additions & 0 deletions compiam/rhythm/meter/tcn_carnatic/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import torch
import torch.nn as nn

class ResBlock(nn.Module):
def __init__(self, dilation_rate, n_filters, kernel_size, dropout_rate=0.15):
super().__init__()
in_channels = 20
self.res = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=1, padding='same')
self.conv_1 = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, dilation=dilation_rate, padding='same')
self.conv_2 = nn.Conv1d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, dilation=dilation_rate*2, padding='same')
self.elu = nn.ELU()
self.dropout = nn.Dropout(0.15)
self.conv_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=n_filters, kernel_size=1, padding='same')

return

def forward(self, x):
res_x = self.res(x)
conv_1 = self.conv_1(x)
conv_2 = self.conv_2(x)
concat = torch.cat([conv_1, conv_2], dim=1)
out = self.elu(concat)
out = self.dropout(out)
out = self.conv_3(out)

return res_x + out, out

class TCN(nn.Module):
def __init__(self, n_filters, kernel_size, n_dilations, dropout_rate):
super().__init__()
dilations = [2**i for i in range(n_dilations)]

self.tcn_layers = nn.ModuleDict({})

for idx, d in enumerate(dilations):
self.tcn_layers[f"tcn_{idx}"] = ResBlock(d, n_filters, kernel_size, dropout_rate)

self.activation = nn.ELU()

return

def forward(self, x):
skip_connections = []
for tcn_i in self.tcn_layers:
x, skip_out = self.tcn_layers[tcn_i](x)
skip_connections.append(skip_out)

x = self.activation(x)

skip = torch.stack(skip_connections, dim=-1).sum(dim=-1)

return x, skip


class MultiTracker(nn.Module):
def __init__(self, n_filters=20, n_dilations=11, kernel_size=5, dropout_rate=0.15):
super().__init__()
self.dropout_rate = dropout_rate

self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(3, 3), stride=1, padding="valid")
self.elu_1 = nn.ELU()
self.mp_1 = nn.MaxPool2d((1, 3))
self.dropout_1 = nn.Dropout(dropout_rate)

self.conv_2 = nn.Conv2d(in_channels=n_filters, out_channels=n_filters, kernel_size=(1, 10), padding="valid")
self.elu_2 = nn.ELU()
self.mp_2 = nn.MaxPool2d((1, 3))
self.dropout_2 = nn.Dropout(dropout_rate)

self.conv_3 = nn.Conv2d(in_channels=n_filters, out_channels=n_filters, kernel_size=(3, 3), padding="valid")
self.elu_3 = nn.ELU()
self.mp_3 = nn.MaxPool2d((1, 3))
self.dropout_3 = nn.Dropout(dropout_rate)

self.tcn = TCN(n_filters, kernel_size, n_dilations, dropout_rate)

self.beats_dropout = nn.Dropout(dropout_rate)
self.beats_dense = nn.Linear(n_filters, 1)
self.beats_act = nn.Sigmoid()

self.downbeats_dropout = nn.Dropout(dropout_rate)
self.downbeats_dense = nn.Linear(n_filters, 1)
self.downbeats_act = nn.Sigmoid()

self.tempo_dropout = nn.Dropout(dropout_rate)
self.tempo_dense = nn.Linear(n_filters, 300)
self.tempo_act = nn.Softmax(dim=1)

def forward(self, x):
x = self.conv_1(x)
x = self.elu_1(x)
x = self.mp_1(x)
x = self.dropout_1(x)

x = self.conv_2(x)
x = self.elu_2(x)
x = self.mp_2(x)
x = self.dropout_2(x)

x = self.conv_3(x)
x = self.elu_3(x)
x = self.mp_3(x)
x = self.dropout_3(x)

x = torch.squeeze(x, -1)

x, skip = self.tcn(x)

x = x.transpose(-2,-1)

beats = self.beats_dropout(x)
beats = self.beats_dense(beats)
beats = self.beats_act(beats)

downbeats = self.downbeats_dropout(x)
downbeats = self.downbeats_dense(downbeats)
downbeats = self.downbeats_act(downbeats)

activations = {}
activations["beats"] = beats
activations["downbeats"] = downbeats

return activations

def load_weights(self, model_path, device):
ckpt = torch.load(model_path, map_location=device)
state_dict = ckpt.get("state_dict", ckpt)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("model."):
k = k[len("model."):]
new_state_dict[k] = v

self.load_state_dict(new_state_dict)
return
Loading
Loading