diff --git a/README.md b/README.md index abe9d48..b039efb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # align_trim -Stand alone version of ARTIC's fieldbioinfomatics align_trim.py +Stand alone version of ARTIC's fieldbioinformatics align_trim.py ## Installation @@ -25,10 +25,10 @@ uv run align_trim --help ### Basic Usage ```bash -aligntrim [OPTIONS] BEDFILE +align_trim [OPTIONS] BEDFILE ``` -The tool reads alignment data from either a SAM file or stdin and outputs trimmed alignments to stdout in SAM format by default. +The tool reads alignment data from either a SAM/BAM file or stdin and outputs trimmed alignments to stdout in SAM format by default. ### Required Arguments @@ -43,9 +43,9 @@ The tool reads alignment data from either a SAM file or stdin and outputs trimme #### Processing Options -- `--normalise`, `-n` : Subsample to N coverage per amplicon. Use 0 for no normalisation (default: 0) +- `--normalise`, `-n` : Normalise to target depth N per amplicon using a greedy per-read algorithm. Each read is kept only if it brings the amplicon depth closer to the target. Use 0 for no normalisation (default: 0) - `--min-mapq`, `-m` : Minimum mapping quality to keep an aligned read (default: 20) -- `--primer-match-threshold`, `-p` : Fuzzy match primer positions within this threshold (default: 35) +- `--primer-match-threshold`, `-p` : Add this many bases of padding to the 5' end of primer coordinates to allow fuzzy matching for reads with barcodes/adapters (default: 35) #### Primer and Read Handling @@ -69,30 +69,30 @@ The tool reads alignment data from either a SAM file or stdin and outputs trimme #### Basic trimming with primer removal ```bash -aligntrim primers.bed --bamfile input.bam --output trimmed.bam +align_trim primers.bed --samfile input.bam --output trimmed.bam ``` #### Normalize coverage and generate reports ```bash -aligntrim primers.bed --bamfile input.bam --normalise 100 \ +align_trim primers.bed --samfile input.bam --normalise 100 \ --report alignment_report.tsv --amp-depth-report depth_report.tsv \ --output normalized.bam ``` #### Process from stdin with verbose output ```bash -samtools view -h input.bam | aligntrim primers.bed --verbose > trimmed.sam 2> verbose.out.txt +samtools view -h input.bam | align_trim primers.bed --verbose > trimmed.sam 2> verbose.out.txt ``` #### Strict full-length read filtering ```bash -aligntrim primers.bed --bamfile input.bam --require-full-length \ +align_trim primers.bed --samfile input.bam --require-full-length \ --min-mapq 30 --output filtered.bam ``` #### Allow mismatched primer pairs with custom threshold ```bash -aligntrim primers.bed --bamfile input.bam --allow-incorrect-pairs \ +align_trim primers.bed --samfile input.bam --allow-incorrect-pairs \ --primer-match-threshold 50 --output relaxed.bam ``` diff --git a/align_trim/main.py b/align_trim/main.py index 66be3bf..7d87552 100644 --- a/align_trim/main.py +++ b/align_trim/main.py @@ -1,24 +1,18 @@ -from copy import copy +import argparse import csv -import pysam +import itertools import sys -import numpy as np -import random -import argparse from collections import defaultdict -from typing import Optional -from pathlib import Path -import itertools -from typing import Union - +from copy import copy from importlib.metadata import version +from pathlib import Path +from typing import Optional, Union -from primalbedtools.scheme import Scheme -from primalbedtools.bedfiles import BedLine, merge_primers +import numpy as np +import pysam from primalbedtools.amplicons import Amplicon, create_amplicons - -RANDOM_SEED = 42 - +from primalbedtools.bedfiles import BedLine, merge_primers +from primalbedtools.scheme import Scheme # consumesReference lookup for if a CIGAR operation consumes the reference sequence consumesReference = [True, False, True, True, False, False, False, True] @@ -458,7 +452,7 @@ def handle_segments( return False # softmask the alignment if right primer start/end inside alignment - if segment.reference_end > p2_position: + if segment.reference_end > p2_position: # type: ignore try: trim(segment, p2_position, True, args.verbose) if args.verbose: @@ -484,7 +478,7 @@ def handle_segments( # Check require-full-length if args.require_full_length: - if segment.reference_start > p1.end or segment.reference_end < p2.start: + if segment.reference_start > p1.end or segment.reference_end < p2.start: # type: ignore if args.verbose: print( f"{segment.query_name}: ref_start {segment.reference_start} > p1.end {p1.end} or ref_end {segment.reference_end} < p2.start {p2.start}, does not span a full amplicon, skipping", @@ -496,7 +490,7 @@ def handle_segments( if not args.normalise: outfile_writer.write(segment) segment_amp_relative_start = segment.reference_start - p1.start - segment_amp_relative_end = segment.reference_end - p1.start + segment_amp_relative_end = segment.reference_end - p1.start # type: ignore if segment_amp_relative_start < 0: segment_amp_relative_start = 0 @@ -562,8 +556,8 @@ def handle_segments( if args.require_full_length: if segment1.reference_start < segment2.reference_start: if ( - segment1.reference_start > p1.end - or segment2.reference_end < p2.start + segment1.reference_start > p1.end # type: ignore + or segment2.reference_end < p2.start # type: ignore ): if args.verbose: print( @@ -574,7 +568,7 @@ def handle_segments( else: if ( segment2.reference_start > p1.end - or segment1.reference_end < p2.start + or segment1.reference_end < p2.start # type: ignore ): if args.verbose: print( @@ -589,7 +583,7 @@ def handle_segments( outfile_writer.write(segment2) for segment_in_pair in (segment1, segment2): segment_amp_relative_start = segment_in_pair.reference_start - p1.start - segment_amp_relative_end = segment_in_pair.reference_end - p1.start + segment_amp_relative_end = segment_in_pair.reference_end - p1.start # type: ignore if segment_amp_relative_start < 0: segment_amp_relative_start = 0 amp_depths[segment1.reference_name][amplicon][ @@ -601,129 +595,6 @@ def handle_segments( return (amplicon, segment) -def normalise( - trimmed_segments: dict, - normalise: int, - primers: list[BedLine], - outfile: pysam.AlignmentFile, - verbose: bool = False, -): - """Normalise the depth of the trimmed segments to a given value. Perform per-amplicon normalisation using numpy vector maths to determine whether the segment in question would take the depth closer to the desired depth across the amplicon. - - Args: - trimmed_segments (dict): Dict containing amplicon number as key and list of pysam.AlignedSegment as value, if paired segments are used, the value will be a list of tuples containing the two segments. - normalise (int): Desired normalised depth - bed (list): Primer scheme as a list of BedLine objects - outfile (pysam.AlignmentFile): Output file handle to write the normalised segments to - verbose (bool): If True, will print normalisation info during processing - - Raises: - ValueError: Amplicon assigned to segment not found in primer scheme file - - Returns: - dict: A dictionary containing the mean depth for each amplicon post normalisation - """ - amplicons = {} - - for amplicon in create_amplicons(primers): - amplicons.setdefault(amplicon.chrom, {}) - amplicons[amplicon.chrom].setdefault( - amplicon.amplicon_number, - { - "length": amplicon.amplicon_end - amplicon.amplicon_start, - "p_start": amplicon.amplicon_start, - }, - ) - - # mean_depths = {x: {} for x in amplicons} - mean_depths = {} - for chrom in amplicons: - for amplicon in amplicons[chrom]: - mean_depths[(chrom, amplicon)] = 0 - - for chrom, amplicon_dict in trimmed_segments.items(): - for amplicon, segments in amplicon_dict.items(): - if amplicon not in amplicons[chrom]: - raise ValueError(f"Amplicon {amplicon} not found in primer scheme file") - - desired_depth = np.full_like( - (amplicons[chrom][amplicon]["length"],), normalise, dtype=int - ) - - amplicon_depth = np.zeros( - (amplicons[chrom][amplicon]["length"],), dtype=int - ) - - if not segments: - if verbose: - print( - f"No segments assigned to amplicon {amplicon}, skipping", - file=sys.stderr, - ) - continue - - random.Random(RANDOM_SEED).shuffle(segments) - - distance = np.mean(np.abs(amplicon_depth - desired_depth)) - - for segment in segments: - paired = isinstance(segment, tuple) - - if paired: - test_depths = np.copy(amplicon_depth) - segment1, segment2 = segment - for segment in (segment1, segment2): - relative_start = ( - segment.reference_start - - amplicons[chrom][amplicon]["p_start"] - ) - - if relative_start < 0: - relative_start = 0 - - relative_end = ( - segment.reference_end - - amplicons[chrom][amplicon]["p_start"] - ) - - test_depths[relative_start:relative_end] += 1 - - test_distance = np.mean(np.abs(test_depths - desired_depth)) - - if test_distance < distance: - amplicon_depth = test_depths - distance = test_distance - # write the segments to the output file - outfile.write(segment1) - outfile.write(segment2) - else: - test_depths = np.copy(amplicon_depth) - - relative_start = ( - segment.reference_start - amplicons[chrom][amplicon]["p_start"] - ) - - if relative_start < 0: - relative_start = 0 - - relative_end = ( - segment.reference_end - amplicons[chrom][amplicon]["p_start"] - ) - - test_depths[relative_start:relative_end] += 1 - - test_distance = np.mean(np.abs(test_depths - desired_depth)) - - if test_distance < distance: - amplicon_depth = test_depths - distance = test_distance - outfile.write(segment) - - mean_depths[(chrom, amplicon)] = np.mean(amplicon_depth) - - return mean_depths - - def read_pair_generator(bam, region_string=None): """ Generate read pairs in a BAM file or within a region string. @@ -823,6 +694,11 @@ def go(args): Based on the most likely primer position, based on the alignment coordinates. """ + # guard for negative normalise + if args.normalise is not None and args.normalise < 0: + print("normalise must be >= 0, exiting.", file=sys.stderr) + sys.exit(1) + # prepare the report outfile if args.report: reportfh = open(args.report, "w") @@ -856,11 +732,10 @@ def go(args): amplicon_list = create_amplicons(scheme.bedlines) amplicons = {} for amplicon in amplicon_list: - amplicon.length = amplicon.amplicon_end - amplicon.amplicon_start + amplicon.length = amplicon.amplicon_end - amplicon.amplicon_start # type: ignore amplicons.setdefault(amplicon.chrom, {})[amplicon.amplicon_number] = amplicon pools = set([bl.pool for bl in scheme.bedlines]) - chroms = set([bl.chrom for bl in scheme.bedlines]) pools_str = {str(x) for x in pools} pools_str.add("unmatched") @@ -919,7 +794,8 @@ def go(args): for amp in amplicon_list: amp_depths.setdefault(amp.chrom, {}) amp_depths[amp.chrom].setdefault( - amp.amplicon_number, np.zeros(amp.length, dtype=int) + amp.amplicon_number, + np.zeros(amp.length, dtype=int), # type: ignore ) # Initialise the mean depths dictionary, this will get stomped over if normalisation is requested @@ -936,7 +812,14 @@ def go(args): padding=args.primer_match_threshold, ) - trimmed_segments = {x: {} for x in chroms} + # Per-amplicon normalisation state: running depth array and current MAD from target + if args.normalise: + norm_state = {} + for amp in amplicon_list: + norm_state[(amp.chrom, amp.amplicon_number)] = { + "depth": np.zeros(amp.length, dtype=int), # type: ignore + "distance": float(args.normalise), + } if paired: read_pairs = read_pair_generator(chained_iterator) @@ -972,22 +855,24 @@ def go(args): if not args.normalise and not trimmed_pair: continue - trimmed_segments[trimmed_pair[0].reference_name].setdefault(amplicon, []) # type: ignore - - if trimmed_segments: - trimmed_segments[trimmed_pair[0].reference_name][amplicon].append( # type: ignore - trimmed_pair - ) + if args.normalise and trimmed_pair: + chrom = trimmed_pair[0].reference_name # type: ignore + state = norm_state[(chrom, amplicon)] + p_start = amplicons[chrom][amplicon].amplicon_start + test_depths = np.copy(state["depth"]) + for seg in trimmed_pair: # type: ignore + relative_start = max(0, seg.reference_start - p_start) + relative_end = seg.reference_end - p_start + test_depths[relative_start:relative_end] += 1 + test_distance = np.mean(np.abs(test_depths - args.normalise)) + if test_distance < state["distance"]: + state["depth"] = test_depths + state["distance"] = test_distance + outfile.write(trimmed_pair[0]) # type: ignore + outfile.write(trimmed_pair[1]) # type: ignore - # normalise if requested and write normalised segments to outfile if args.normalise: - mean_amp_depths = normalise( - trimmed_segments=trimmed_segments, - normalise=args.normalise, - primers=scheme.bedlines, - outfile=outfile, - verbose=args.verbose, - ) + mean_amp_depths = {k: np.mean(v["depth"]) for k, v in norm_state.items()} else: mean_amp_depths = {} for chrom, chrom_amps in amp_depths.items(): @@ -1042,22 +927,23 @@ def go(args): if not args.normalise and not trimmed_segment: continue - trimmed_segments[trimmed_segment.reference_name].setdefault(amplicon, []) # type: ignore - - if trimmed_segment and args.normalise: - trimmed_segments[trimmed_segment.reference_name][amplicon].append( # type: ignore - trimmed_segment - ) + if args.normalise and trimmed_segment: + chrom = trimmed_segment.reference_name # type: ignore + state = norm_state[(chrom, amplicon)] + p_start = amplicons[chrom][amplicon].amplicon_start + test_depths = np.copy(state["depth"]) + relative_start = max(0, trimmed_segment.reference_start - p_start) # type: ignore + relative_end = trimmed_segment.reference_end - p_start # type: ignore + test_depths[relative_start:relative_end] += 1 + test_distance = np.mean(np.abs(test_depths - args.normalise)) + if test_distance < state["distance"]: + state["depth"] = test_depths + state["distance"] = test_distance + outfile.write(trimmed_segment) # type: ignore # normalise if requested if args.normalise: - mean_amp_depths = normalise( - trimmed_segments=trimmed_segments, - normalise=args.normalise, - primers=scheme.bedlines, - outfile=outfile, - verbose=args.verbose, - ) + mean_amp_depths = {k: np.mean(v["depth"]) for k, v in norm_state.items()} else: mean_amp_depths = {} diff --git a/pyproject.toml b/pyproject.toml index 1a98aa1..804ceb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "align_trim" -version = "1.0.2" +version = "1.1.0" license = "MIT" license-files = ["LICEN[CS]E*"] diff --git a/tests/test_integration.py b/tests/test_integration.py index 42fe70b..24e83d1 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,15 +1,16 @@ -import pathlib -import sys -import unittest import argparse -from primalbedtools.scheme import Scheme -from primalbedtools.bedfiles import merge_primers -from primalbedtools.amplicons import create_amplicons +import pathlib import tempfile -import pysam -from align_trim.main import go, create_primer_lookup, find_primer_with_lookup +import unittest from collections import defaultdict +import pysam +from primalbedtools.amplicons import create_amplicons +from primalbedtools.bedfiles import merge_primers +from primalbedtools.scheme import Scheme + +from align_trim.main import create_primer_lookup, find_primer_with_lookup, go + BED_PATH_V5_3_2 = pathlib.Path(__file__).parent / "test_data/v5.3.2.primer.bed" BAM_PATH_V5_3_2 = pathlib.Path(__file__).parent / "test_data/sars-cov-2_v5.3.2.bam" BED_PATH_V3_0_0 = pathlib.Path(__file__).parent / "test_data/v3.0.0.primer.bed" diff --git a/tests/test_legacy.py b/tests/test_legacy.py index 6b8f961..0df2bce 100644 --- a/tests/test_legacy.py +++ b/tests/test_legacy.py @@ -1,8 +1,9 @@ import pathlib import unittest -from primalbedtools.scheme import Scheme -from primalbedtools.bedfiles import BedLine + import pysam +from primalbedtools.bedfiles import BedLine +from primalbedtools.scheme import Scheme from align_trim.main import find_primer, trim diff --git a/tests/test_main.py b/tests/test_main.py index 7aee35a..9e55cfd 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,9 @@ import pathlib import unittest -from primalbedtools.scheme import Scheme -from primalbedtools.bedfiles import merge_primers + from primalbedtools.amplicons import create_amplicons +from primalbedtools.bedfiles import merge_primers +from primalbedtools.scheme import Scheme from align_trim.main import ( create_primer_lookup, diff --git a/tests/test_normalise.py b/tests/test_normalise.py new file mode 100644 index 0000000..8bbb42f --- /dev/null +++ b/tests/test_normalise.py @@ -0,0 +1,220 @@ +import argparse +import pathlib +import tempfile +import unittest + +import pysam + +from align_trim.main import go + +BED_PATH_V5_3_2 = pathlib.Path(__file__).parent / "test_data/v5.3.2.primer.bed" +BAM_PATH_V5_3_2 = pathlib.Path(__file__).parent / "test_data/sars-cov-2_v5.3.2.bam" +BED_PATH_V3_0_0 = pathlib.Path(__file__).parent / "test_data/v3.0.0.primer.bed" +BAM_PATH_PAIRED_V3_0_0 = ( + pathlib.Path(__file__).parent / "test_data/sars-cov-2_v3.0.0_paired.bam" +) + + +def create_args(**kwargs): + """Create a fake args object with default values matching main.py argument parser""" + defaults = { + "bedfile": str(BED_PATH_V5_3_2.absolute()), + "samfile": str(BAM_PATH_V5_3_2.absolute()), + "normalise": 0, + "min_mapq": 20, + "primer_match_threshold": 35, + "report": None, + "amp_depth_report": None, + "no_trim_primers": False, + "paired": False, + "no_read_groups": False, + "verbose": False, + "allow_incorrect_pairs": False, + "require_full_length": False, + "output": None, + } + defaults.update(kwargs) + return argparse.Namespace(**defaults) + + +def read_amp_depths(path): + """Parse an amp_depth_report TSV into {amplicon: mean_depth}.""" + depths = {} + with open(path) as f: + next(f) # skip header + for line in f: + _, amplicon, mean_depth = line.strip().split("\t") + depths[amplicon] = float(mean_depth) + return depths + + +def count_reads(sam_path): + """Count records in a SAM/BAM file.""" + return sum(1 for _ in pysam.AlignmentFile(str(sam_path), "r")) + + +class TestNormaliseSingleEnd(unittest.TestCase): + def test_depth_bounded_by_target(self): + """Mean amplicon depth after SE normalisation should not grossly exceed the target.""" + with tempfile.TemporaryDirectory( + dir="tests", suffix="-se_norm_depth_bounded" + ) as tempdir: + tempdir_path = pathlib.Path(tempdir) + normalise_target = 200 + + args = create_args( + output=(tempdir_path / "output.sam").absolute(), + normalise=normalise_target, + amp_depth_report=(tempdir_path / "amp_depths.tsv").absolute(), + ) + go(args) + + for amplicon, mean_depth in read_amp_depths( + tempdir_path / "amp_depths.tsv" + ).items(): + self.assertLessEqual( + mean_depth, + normalise_target * 1.5, + f"Amplicon {amplicon} mean depth {mean_depth:.1f} exceeds 1.5× target {normalise_target}", + ) + + def test_reduces_depth_vs_unnormalized(self): + """SE normalisation must reduce depth for at least one amplicon vs. unnormalized.""" + with tempfile.TemporaryDirectory( + dir="tests", suffix="-se_norm_reduces_depth" + ) as tempdir: + tempdir_path = pathlib.Path(tempdir) + + args_no_norm = create_args( + output=(tempdir_path / "no_norm.sam").absolute(), + normalise=0, + amp_depth_report=(tempdir_path / "depths_no_norm.tsv").absolute(), + ) + go(args_no_norm) + + # Use a low target (10) to ensure reduction fires even on small test BAMs + args_norm = create_args( + output=(tempdir_path / "norm.sam").absolute(), + normalise=10, + amp_depth_report=(tempdir_path / "depths_norm.tsv").absolute(), + ) + go(args_norm) + + no_norm = read_amp_depths(tempdir_path / "depths_no_norm.tsv") + norm = read_amp_depths(tempdir_path / "depths_norm.tsv") + + reduction_found = any( + norm[amp] < no_norm[amp] + for amp in norm + if no_norm.get(amp, 0) > 0 + ) + self.assertTrue( + reduction_found, + "SE normalisation did not reduce depth for any amplicon", + ) + + def test_high_target_keeps_all_reads(self): + """With a target far above actual coverage, all reads must be accepted.""" + with tempfile.TemporaryDirectory( + dir="tests", suffix="-se_norm_high_target" + ) as tempdir: + tempdir_path = pathlib.Path(tempdir) + + args_no_norm = create_args( + output=(tempdir_path / "no_norm.sam").absolute(), + normalise=0, + ) + go(args_no_norm) + + args_high_norm = create_args( + output=(tempdir_path / "high_norm.sam").absolute(), + normalise=10000, + ) + go(args_high_norm) + + count_no_norm = count_reads(tempdir_path / "no_norm.sam") + count_high_norm = count_reads(tempdir_path / "high_norm.sam") + + self.assertEqual( + count_no_norm, + count_high_norm, + f"High-target SE normalisation dropped reads: " + f"expected {count_no_norm}, got {count_high_norm}", + ) + + +class TestNormalisePaired(unittest.TestCase): + def test_reduces_depth_vs_unnormalized(self): + """Paired normalisation must reduce depth for at least one amplicon vs. unnormalized.""" + with tempfile.TemporaryDirectory( + dir="tests", suffix="-paired_norm_reduces_depth" + ) as tempdir: + tempdir_path = pathlib.Path(tempdir) + + args_no_norm = create_args( + output=(tempdir_path / "no_norm.bam").absolute(), + bedfile=BED_PATH_V3_0_0.absolute(), + samfile=BAM_PATH_PAIRED_V3_0_0.absolute(), + normalise=0, + amp_depth_report=(tempdir_path / "depths_no_norm.tsv").absolute(), + ) + go(args_no_norm) + + args_norm = create_args( + output=(tempdir_path / "norm.bam").absolute(), + bedfile=BED_PATH_V3_0_0.absolute(), + samfile=BAM_PATH_PAIRED_V3_0_0.absolute(), + normalise=200, + amp_depth_report=(tempdir_path / "depths_norm.tsv").absolute(), + ) + go(args_norm) + + no_norm = read_amp_depths(tempdir_path / "depths_no_norm.tsv") + norm = read_amp_depths(tempdir_path / "depths_norm.tsv") + + reduction_found = any( + norm[amp] < no_norm[amp] + for amp in norm + if no_norm.get(amp, 0) > 0 + ) + self.assertTrue( + reduction_found, + "Paired normalisation did not reduce depth for any amplicon", + ) + + def test_high_target_keeps_all_reads(self): + """With a target far above actual coverage, all paired reads must be accepted.""" + with tempfile.TemporaryDirectory( + dir="tests", suffix="-paired_norm_high_target" + ) as tempdir: + tempdir_path = pathlib.Path(tempdir) + + args_no_norm = create_args( + output=(tempdir_path / "no_norm.bam").absolute(), + bedfile=BED_PATH_V3_0_0.absolute(), + samfile=BAM_PATH_PAIRED_V3_0_0.absolute(), + normalise=0, + ) + go(args_no_norm) + + args_high_norm = create_args( + output=(tempdir_path / "high_norm.bam").absolute(), + bedfile=BED_PATH_V3_0_0.absolute(), + samfile=BAM_PATH_PAIRED_V3_0_0.absolute(), + normalise=10000, + ) + go(args_high_norm) + + count_no_norm = count_reads(tempdir_path / "no_norm.bam") + count_high_norm = count_reads(tempdir_path / "high_norm.bam") + + self.assertEqual( + count_no_norm, + count_high_norm, + f"High-target paired normalisation dropped reads: " + f"expected {count_no_norm}, got {count_high_norm}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/uv.lock b/uv.lock index 3218254..7636af5 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ [[package]] name = "align-trim" -version = "1.0.2" +version = "1.1.0" source = { editable = "." } dependencies = [ { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },