Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 34 additions & 33 deletions .github/environment-ci.yml
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
name: compiam-dev

channels:
- conda-forge
- defaults

dependencies:
- python=3.9
- pip
- "attrs>=23.1.0"
- "matplotlib>=3.0.0"
- "numpy>=1.20.3,<=1.26.4"
- "joblib>=1.2.0"
- "pathlib~=1.0.1"
- "tqdm>=4.66.1"
- "IPython>=7.34.0"
- "ipywidgets>=7.0.0,<8"
- "Jinja2~=3.1.2"
- "configobj~=5.0.6"
- "seaborn"
- "librosa>=0.10.1"
- "scikit-learn==1.5.2"
- "scikit-image~=0.24.0"
- "hmmlearn==0.3.3"
- "fastdtw~=0.3.4"
#######
- libvorbis
- pytest>=7.4.3
#######

- pip:
- "keras<3.0.0"
- "tensorflow>=2.12.0,<2.16"
- "torch==2.0.0"
- "torchaudio==2.0.1"
- "essentia"
- "soundfile>=0.12.1"
- "opencv-python~=4.6.0"
- "mirdata==0.3.9"
- "compmusic==0.4"
- "attrs>=23.1.0"
- "black>=23.3.0"
- "decorator>=5.1.1"
- "future>=0.18.3"
- "testcontainers>=3.7.1"
- keras<3.0.0
- tensorflow>=2.12.0,<2.16
- torch==2.0.0
- torchaudio==2.0.1
- essentia
- soundfile>=0.12.1
- opencv-python~=4.6.0
- mirdata==0.3.9
- compmusic==0.4
- attrs>=23.1.0
- black>=23.3.0
- decorator>=5.1.1
- future>=0.18.3
- testcontainers>=3.7.1
- madmom @ git+https://github.com/vivekvjyn/madmom.git
- matplotlib>=3.0.0
- numpy>=1.20.3,<=1.23.5
- joblib>=1.2.0
- tqdm>=4.66.1
- IPython>=7.34.0
- ipywidgets>=7.0.0,<8
- Jinja2~=3.1.2
- configobj~=5.0.6
- seaborn
- librosa>=0.10.1
- scikit-learn==1.5.2
- scikit-image~=0.24.0
- hmmlearn==0.3.3
- fastdtw~=0.3.4
- pytest>=7.4.3
16 changes: 16 additions & 0 deletions compiam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@
},
},
},
"rhythm:tcn-carnatic": {
"module_name": "compiam.rhythm.meter.tcn_carnatic",
"class_name": "TCNTracker",
"default_version": "v1",
"kwargs": {
"v1": {
"model_path": os.path.join(
"models",
"rhythm",
"tcn-carnatic"
),
"download_link": "https://zenodo.org/records/18449067/files/compIAM-TCNCarnatic.zip?download=1",
"download_checksum": "995369933f2a344af0ffa57ea5c15e62",
},
},
},
"structure:dhrupad-bandish-segmentation": {
"module_name": "compiam.structure.segmentation.dhrupad_bandish_segmentation",
"class_name": "DhrupadBandishSegmentation",
Expand Down
3 changes: 2 additions & 1 deletion compiam/rhythm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
|--------------------------------|------------------------------------------------|-----------|
| Akshara Pulse Tracker Detector | Detect onsets of aksharas in tabla recordings | [1] |
| Mnemonic Stroke Transcription | Bol/Solkattu trasncription using HMM | [2] |
| TCN Carnatic | Carnatic meter tracking using TCN | [3] |


[1] Originally implemented by Ajay Srinivasamurthy as part of PyCompMusic - https://github.com/MTG/pycompmusic

[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain.
[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain.
2 changes: 1 addition & 1 deletion compiam/rhythm/meter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from compiam.data import models_dict

from compiam.rhythm.meter.akshara_pulse_tracker import AksharaPulseTracker

from compiam.rhythm.meter.tcn_carnatic import TCNTracker

# Show user the available tools
def list_tools():
Expand Down
203 changes: 203 additions & 0 deletions compiam/rhythm/meter/tcn_carnatic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import os
import sys
import numpy as np
from typing import Dict
from tqdm import tqdm
from compiam.exceptions import ModelNotTrainedError

from compiam.utils.download import download_remote_model
from compiam.utils import get_logger, WORKDIR
from compiam.io import write_csv

logger = get_logger(__name__)

class TCNTracker(object):
"""TCN beat tracker tuned to Carnatic Music."""
def __init__(self,
post_processor="joint",
model_version=42,
model_path=None,
download_link=None,
download_checksum=None,
gpu=-1):
"""TCN beat tracker init method.

:param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'.
:param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62.
:param model_path: path to file to the model weights.
:param download_link: link to the remote pre-trained model.
:param download_checksum: checksum of the model file.
"""
### IMPORTING OPTIONAL DEPENDENCIES
try:
global torch
import torch
except ImportError:
raise ImportError(
"Torch is required to use TCNTracker. "
"Install compIAM with torch support: pip install 'compiam[torch]'"
)

try:
global madmom
import madmom
except ImportError:
raise ImportError(
"Madmom is required to use TCNTracker. "
"Install compIAM with madmom support: pip install 'compiam[madmom]'"
)
###
global MultiTracker, PreProcessor, joint_tracker, sequential_tracker
from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker
from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor
from compiam.rhythm.meter.tcn_carnatic.post import joint_tracker, sequential_tracker

if post_processor not in ["beat", "joint", "sequential"]:
raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.")
if model_version not in [42, 52, 62]:
raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.")

self.gpu = gpu
self.device = None
self.select_gpu(gpu)

self.model_path = model_path
self.model_version = f'multitracker_{model_version}.pth'
self.download_link = download_link
self.download_checksum = download_checksum

self.trained = False
self.model = self._build_model()
if self.model_path is not None:
self.load_model(self.model_path)
self.pad_frames = 2

self.post_processor = joint_tracker if post_processor == "joint" else \
sequential_tracker


def _build_model(self):
"""Build the TCN model."""
model = MultiTracker().to(self.device)
model.eval()
return model


def load_model(self, model_path):
"""Load pre-trained model weights."""
if not os.path.exists(os.path.join(model_path, self.model_version)):
self.download_model(model_path) # Downloading model weights

self.model.load_weights(os.path.join(model_path, self.model_version), self.device)

self.model_path = model_path
self.trained = True


def download_model(self, model_path=None, force_overwrite=True):
"""Download pre-trained model."""
download_path = (
#os.sep + os.path.join(*model_path.split(os.sep)[:-2])
model_path
if model_path is not None
else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic")
)
# Creating model folder to store the weights
if not os.path.exists(download_path):
os.makedirs(download_path)
download_remote_model(
self.download_link,
self.download_checksum,
download_path,
force_overwrite=force_overwrite,
)

def predict(self, input_data: str, sr: int = 44100, min_bpm=55, max_bpm=230, beats_per_bar=[3, 5, 7, 8]) -> Dict:
"""Run inference on input audio file.

:param input_data: path to audio file or numpy array like audio signal.
:param sr: sampling rate of the input audio signal (default: 44100).
:param min_bpm: minimum BPM for beat tracking (default: 55).
:param max_bpm: maximum BPM for beat tracking (default: 230).
:param beats_per_bar: list of possible beats per bar for downbeat tracking (default: [3, 5, 7, 8]).

:returns: a 2-D list with beats and beat positions.
"""
if self.trained is False:
raise ModelNotTrainedError(
"""Model is not trained. Please load model before running inference!
You can load the pre-trained instance with the load_model wrapper."""
)

features = self.preprocess_audio(input_data, sr)
x = torch.from_numpy(features).to(self.device)
output = self.model(x)
beats_act = output["beats"].squeeze().detach().cpu().numpy()
downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy()

pred = self.post_processor(beats_act, downbeats_act, min_bpm=min_bpm, max_bpm=max_bpm, beats_per_bar=beats_per_bar)

return pred

def preprocess_audio(self, input_data: str, input_sr: int) -> np.ndarray:
"""Preprocess input audio file to extract features for inference.
:param audio_path: Path to the input audio file.
:param input_sr: Sampling rate of the input audio file.

:returns: Preprocessed features as a numpy array.
"""
if isinstance(input_data, str):
if not os.path.exists(input_data):
raise FileNotFoundError("Target audio not found.")
audio, sr = madmom.io.audio.load_audio_file(input_data)
if audio.shape[0] == 2:
audio = audio.mean(axis=0)
signal = madmom.audio.Signal(audio, sr, num_channels=1)
elif isinstance(input_data, np.ndarray):
audio = input_data
if audio.shape[0] == 2:
audio = audio.mean(axis=0)
signal = madmom.audio.Signal(audio, input_sr, num_channels=1)
sr = input_sr
else:
raise ValueError("Input must be path to audio signal or an audio array")

x = PreProcessor(sample_rate=sr)(signal)

pad_start = np.repeat(x[:1], self.pad_frames, axis=0)
pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0)
x_padded = np.concatenate((pad_start, x, pad_stop))

x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0)

return x_final

@staticmethod
def save_pitch(data, output_path):
"""Calling the write_csv function in compiam.io to write the output beat track in a file

:param data: the data to write
:param output_path: the path where the data is going to be stored

:returns: None
"""
return write_csv(data, output_path)


def select_gpu(self, gpu="-1"):
"""Select the GPU to use for inference.

:param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc.
:returns: None
"""
if int(gpu) == -1:
self.device = torch.device("cpu")
else:
if torch.cuda.is_available():
self.device = torch.device("cuda:" + str(gpu))
elif torch.backends.mps.is_available():
self.device = torch.device("mps:" + str(gpu))
else:
self.device = torch.device("cpu")
logger.warning("No GPU available. Running on CPU.")
self.gpu = gpu
Loading
Loading