diff --git a/CHANGES.md b/CHANGES.md index 93bd95a..8bf3c48 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,10 @@ This changelog is intended for _humans_ and follows many of the principles from Changes for this project _do not_ currently follow the [Semantic Versioning rules](https://semver.org/spec/v2.0.0.html). Instead, changes appear below grouped by the date they were added to the workflow. +# 23 December 2025 + + - Add frequency and fitness estimates for amino acid haplotypes for each subtype and geographic resolution. See [#26](https://github.com/nextstrain/forecasts-flu/pull/26) for details. + # 4 August 2025 - Enable forecasts for MLR models using a forecast horizon of 6 steps with a frequency estimation interval of 14-days to produce 84-day forecasts. See [#18](https://github.com/nextstrain/forecasts-flu/pull/18) for details. diff --git a/Snakefile b/Snakefile index 1427850..3f577c2 100644 --- a/Snakefile +++ b/Snakefile @@ -1,7 +1,10 @@ configfile: "config/defaults.yaml" wildcard_constraints: - date = r"\d{4}-\d{2}-\d{2}" + data_provenance=r"(gisaid)", + variant_classification=r"(emerging_haplotype|aa_haplotype)", + lineage=r"(h1n1pdm|h3n2|vic)", + date=r"\d{4}-\d{2}-\d{2}" def get_todays_date(): from datetime import datetime @@ -13,24 +16,24 @@ run_date = config.get("run_date", get_todays_date()) if config.get("s3_dst"): rule upload_all_models: input: - expand("results/{lineage}/{geo_resolution}/mlr/{date}_results_s3_upload.done", lineage=config["lineages"], geo_resolution=config["geo_resolutions"], date=run_date), - expand("results/{lineage}/{geo_resolution}/mlr/results_s3_upload.done", lineage=config["lineages"], geo_resolution=config["geo_resolutions"]) + expand("results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/{date}_results_s3_upload.done", data_provenance=config["data_provenances"], variant_classification=config["variant_classifications"], lineage=config["lineages"], geo_resolution=config["geo_resolutions"], date=run_date), + expand("results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/results_s3_upload.done", data_provenance=config["data_provenances"], variant_classification=config["variant_classifications"], lineage=config["lineages"], geo_resolution=config["geo_resolutions"]) else: rule all: input: - expand("plots/{lineage}/{geo_resolution}/ga/ga_by_variant.png", lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), - expand("plots/{lineage}/{geo_resolution}/ga/ga_by_location.png", lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), - expand("plots/{lineage}/{geo_resolution}/freq/freq_by_location.png", lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), + expand("plots/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/ga/ga_by_variant.png", data_provenance=config["data_provenances"], variant_classification=config["variant_classifications"], lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), + expand("plots/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/ga/ga_by_location.png", data_provenance=config["data_provenances"], variant_classification=config["variant_classifications"], lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), + expand("plots/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/freq/freq_by_location.png", data_provenance=config["data_provenances"], variant_classification=config["variant_classifications"], lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), rule all_models: input: - expand("results/{lineage}/{geo_resolution}/mlr/MLR_results.json", lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), + expand("results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/MLR_results.json", data_provenance=config["data_provenances"], variant_classification=config["variant_classifications"], lineage=config["lineages"], geo_resolution=config["geo_resolutions"]), rule download_metadata: output: - "data/{lineage}/metadata.tsv", + "data/{data_provenance}/{lineage}/metadata.tsv", params: - s3_path=lambda wildcards: config["data"][wildcards.lineage]["s3_metadata"] + s3_path=lambda wildcards: config["data"][wildcards.data_provenance][wildcards.lineage]["s3_metadata"], shell: """ aws s3 cp {params.s3_path} - | xz -c -d > {output} @@ -38,9 +41,9 @@ rule download_metadata: rule download_nextclade: output: - "data/{lineage}/nextclade.tsv", + "data/{data_provenance}/{lineage}/nextclade.tsv", params: - s3_path=lambda wildcards: config["data"][wildcards.lineage]["s3_nextclade"] + s3_path=lambda wildcards: config["data"][wildcards.data_provenance][wildcards.lineage]["s3_nextclade"], shell: """ aws s3 cp {params.s3_path} - | xz -c -d > {output} @@ -48,7 +51,7 @@ rule download_nextclade: rule download_haplotype_definitions: output: - haplotypes="results/{lineage}/haplotype_definitions.tsv", + haplotypes="data/nextstrain/{lineage}/haplotype_definitions.tsv", shell: """ curl \ @@ -59,10 +62,10 @@ rule download_haplotype_definitions: rule metadata_with_nextclade: input: - metadata="data/{lineage}/metadata.tsv", - nextclade="data/{lineage}/nextclade.tsv", + metadata="data/{data_provenance}/{lineage}/metadata.tsv", + nextclade="data/{data_provenance}/{lineage}/nextclade.tsv", output: - metadata="data/{lineage}/metadata_with_nextclade.tsv", + metadata="data/{data_provenance}/{lineage}/metadata_with_nextclade.tsv", shell: """ augur merge \ @@ -73,12 +76,12 @@ rule metadata_with_nextclade: rule filter_data: input: - metadata="data/{lineage}/metadata_with_nextclade.tsv", + metadata="data/{data_provenance}/{lineage}/metadata_with_nextclade.tsv", output: - metadata="results/{lineage}/{geo_resolution}/metadata_with_nextclade.tsv", + metadata="data/{data_provenance}/{lineage}/filtered_metadata_with_nextclade.tsv", params: - min_date=lambda wildcards: config["prepare_data"][wildcards.geo_resolution]["min_date"], - max_date=lambda wildcards: config["prepare_data"][wildcards.geo_resolution]["max_date"], + min_date=lambda wildcards: config["min_date"], + max_date=lambda wildcards: config["max_date"], shell: """ augur filter \ @@ -89,14 +92,15 @@ rule filter_data: --output-metadata {output.metadata} """ -rule assign_haplotypes: +rule assign_emerging_haplotypes: input: - metadata="results/{lineage}/{geo_resolution}/metadata_with_nextclade.tsv", - haplotypes="results/{lineage}/haplotype_definitions.tsv", + metadata="data/{data_provenance}/{lineage}/filtered_metadata_with_nextclade.tsv", + haplotypes="data/nextstrain/{lineage}/haplotype_definitions.tsv", output: - metadata="results/{lineage}/{geo_resolution}/metadata_with_nextclade_updated.tsv" + metadata="data/{data_provenance}/{lineage}/metadata_with_nextclade_with_emerging_haplotypes.tsv", params: variant_column=config["haplotype_variant_column"], + haplotype_column_name="emerging_haplotype", default_haplotype="other", shell: """ @@ -104,19 +108,41 @@ rule assign_haplotypes: --substitutions {input.metadata} \ --haplotypes {input.haplotypes} \ --clade-column {params.variant_column:q} \ + --haplotype-column-name {params.haplotype_column_name:q} \ --default-haplotype {params.default_haplotype:q} \ --output-table {output.metadata} """ +rule assign_aa_haplotypes: + input: + metadata="data/{data_provenance}/{lineage}/metadata_with_nextclade_with_emerging_haplotypes.tsv", + output: + metadata="data/{data_provenance}/{lineage}/metadata_with_nextclade_with_aa_haplotypes.tsv", + params: + genes=["HA1"], + clade_column=config["haplotype_variant_column"], + mutations_column=config["mutations_column"], + haplotype_column_name="aa_haplotype", + shell: + r""" + python3 scripts/assign_aa_haplotypes.py \ + --nextclade {input.metadata:q} \ + --genes {params.genes:q} \ + --strip-genes \ + --clade-column {params.clade_column:q} \ + --mutations-column {params.mutations_column:q} \ + --attribute-name {params.haplotype_column_name:q} \ + --output {output.metadata:q} + """ + rule clade_seq_counts: input: - metadata="results/{lineage}/{geo_resolution}/metadata_with_nextclade_updated.tsv", + metadata="data/{data_provenance}/{lineage}/metadata_with_nextclade_with_aa_haplotypes.tsv", output: - sequence_counts="results/{lineage}/{geo_resolution}/seq_counts.tsv", + sequence_counts="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/seq_counts.tsv", params: id_column="strain", date_column="date", - variant_column=config["variant"], shell: """ ./scripts/summarize-clade-sequence-counts \ @@ -124,20 +150,20 @@ rule clade_seq_counts: --id-column {params.id_column:q} \ --date-column {params.date_column:q} \ --location-column {wildcards.geo_resolution:q} \ - --clade-column {params.variant_column:q} \ + --clade-column {wildcards.variant_classification:q} \ --output {output.sequence_counts} """ rule prepare_clade_data: """Preparing clade counts for analysis""" input: - sequence_counts = "results/{lineage}/{geo_resolution}/seq_counts.tsv" + sequence_counts = "results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/seq_counts.tsv" output: - sequence_counts = "results/{lineage}/{geo_resolution}/prepared_seq_counts.tsv" + sequence_counts = "results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/prepared_seq_counts.tsv" params: - min_date=lambda wildcards: config["prepare_data"][wildcards.geo_resolution]["min_date"], - location_min_seq=lambda wildcards: config["prepare_data"][wildcards.geo_resolution]["location_min_seq"], - clade_min_seq=lambda wildcards: config["prepare_data"][wildcards.geo_resolution]["clade_min_seq"], + min_date=lambda wildcards: config["min_date"], + location_min_seq=lambda wildcards: config["prepare_data"][wildcards.data_provenance][wildcards.variant_classification][wildcards.geo_resolution]["location_min_seq"], + clade_min_seq=lambda wildcards: config["prepare_data"][wildcards.data_provenance][wildcards.variant_classification][wildcards.geo_resolution]["clade_min_seq"], shell: """ python ./scripts/prepare-data.py \ @@ -150,16 +176,16 @@ rule prepare_clade_data: rule mlr_model: input: - counts="results/{lineage}/{geo_resolution}/prepared_seq_counts.tsv", + counts="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/prepared_seq_counts.tsv", config="config/mlr/{lineage}.yaml", output: - model="results/{lineage}/{geo_resolution}/mlr/initial_MLR_results.json", + model="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/initial_MLR_results.json", params: data_name="initial_MLR", - path="results/{lineage}/{geo_resolution}/mlr/", - max_date=lambda wildcards: config["prepare_data"][wildcards.geo_resolution].get("max_date", "0D"), + path=subpath(output.model, parent=True), + max_date=config["max_date"], benchmark: - "results/{lineage}/{geo_resolution}/mlr/mlr-model_benchmark.tsv" + "results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/mlr-model_benchmark.tsv" resources: mem_mb=3000, shell: @@ -174,29 +200,29 @@ rule mlr_model: rule add_colors_to_mlr_model: input: - model="results/{lineage}/{geo_resolution}/mlr/initial_MLR_results.json", - auspice_config="results/{lineage}/auspice_config.json", + model="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/initial_MLR_results.json", + auspice_config="data/nextstrain/{lineage}/auspice_config.json", + color_schemes="config/color_schemes.tsv", output: - model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json", - params: - coloring_field=config["coloring_field"], + model="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/MLR_results.json", shell: r""" python scripts/add_colors_to_model.py \ --model {input.model:q} \ --auspice-config {input.auspice_config:q} \ - --coloring-field {params.coloring_field:q} \ + --color-schemes {input.color_schemes:q} \ + --coloring-field {wildcards.variant_classification:q} \ --output {output.model:q} """ rule parse_mlr_json: input: - model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json", + model="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/MLR_results.json", output: - ga="results/{lineage}/{geo_resolution}/mlr/ga.tsv", - freq="results/{lineage}/{geo_resolution}/mlr/freq.tsv", - emp="results/{lineage}/{geo_resolution}/mlr/raw_freq.tsv", - freq_forecast="results/{lineage}/{geo_resolution}/mlr/freq_forecast.tsv", + ga="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/ga.tsv", + freq="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/freq.tsv", + emp="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/raw_freq.tsv", + freq_forecast="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/freq_forecast.tsv", params: version="MLR", shell: @@ -214,7 +240,7 @@ rule get_pivot: input: config="config/mlr/{lineage}.yaml" output: - "results/{lineage}/{geo_resolution}/pivot.txt" + "results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/pivot.txt" shell: """ python3 scripts/get_pivot.py \ @@ -224,7 +250,7 @@ rule get_pivot: rule download_auspice_config_json: output: - config="results/{lineage}/auspice_config.json", + config="data/nextstrain/{lineage}/auspice_config.json", shell: """ curl \ @@ -235,14 +261,12 @@ rule download_auspice_config_json: rule plot_freq: input: - freq_data="results/{lineage}/{geo_resolution}/mlr/freq.tsv", - raw_data="results/{lineage}/{geo_resolution}/mlr/raw_freq.tsv", + freq_data="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/freq.tsv", + raw_data="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/raw_freq.tsv", color_scheme="config/color_schemes.tsv", - auspice_config="results/{lineage}/auspice_config.json" + auspice_config="data/nextstrain/{lineage}/auspice_config.json", output: - variant="plots/{lineage}/{geo_resolution}/freq/freq_by_location.png" - params: - coloring_field=config["coloring_field"], + variant="plots/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/freq/freq_by_location.png", shell: """ python3 ./scripts/plot-freq.py \ @@ -250,21 +274,19 @@ rule plot_freq: --input_raw {input.raw_data} \ --colors {input.color_scheme} \ --auspice-config {input.auspice_config} \ - --coloring-field {params.coloring_field} \ + --coloring-field {wildcards.variant_classification} \ --output {output.variant} """ rule plot_ga: input: - ga="results/{lineage}/{geo_resolution}/mlr/ga.tsv", + ga="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/ga.tsv", color_scheme="config/color_schemes.tsv", - pivot="results/{lineage}/{geo_resolution}/pivot.txt", - auspice_config="results/{lineage}/auspice_config.json" + pivot="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/pivot.txt", + auspice_config="data/nextstrain/{lineage}/auspice_config.json" output: - variant="plots/{lineage}/{geo_resolution}/ga/ga_by_variant.png", - location="plots/{lineage}/{geo_resolution}/ga/ga_by_location.png", - params: - coloring_field=config["coloring_field"], + variant="plots/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/ga/ga_by_variant.png", + location="plots/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/ga/ga_by_location.png", shell: """ python3 ./scripts/plot-ga.py \ @@ -275,7 +297,7 @@ rule plot_ga: --out_location {output.location} \ --pivot {input.pivot} \ --auspice-config {input.auspice_config} \ - --coloring-field {params.coloring_field} + --coloring-field {wildcards.variant_classification} """ if config.get("s3_dst"): diff --git a/config/defaults.yaml b/config/defaults.yaml index 9ee49dc..88a192d 100644 --- a/config/defaults.yaml +++ b/config/defaults.yaml @@ -1,13 +1,21 @@ data: - h3n2: - s3_metadata: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h3n2/metadata.tsv.xz" - s3_nextclade: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h3n2/ha/nextclade.tsv.xz" - h1n1pdm: - s3_metadata: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h1n1pdm/metadata.tsv.xz" - s3_nextclade: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h1n1pdm/ha/nextclade.tsv.xz" - vic: - s3_metadata: "s3://nextstrain-data-private/files/workflows/seasonal-flu/vic/metadata.tsv.xz" - s3_nextclade: "s3://nextstrain-data-private/files/workflows/seasonal-flu/vic/ha/nextclade.tsv.xz" + gisaid: + h3n2: + s3_metadata: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h3n2/metadata.tsv.xz" + s3_nextclade: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h3n2/ha/nextclade.tsv.xz" + h1n1pdm: + s3_metadata: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h1n1pdm/metadata.tsv.xz" + s3_nextclade: "s3://nextstrain-data-private/files/workflows/seasonal-flu/h1n1pdm/ha/nextclade.tsv.xz" + vic: + s3_metadata: "s3://nextstrain-data-private/files/workflows/seasonal-flu/vic/metadata.tsv.xz" + s3_nextclade: "s3://nextstrain-data-private/files/workflows/seasonal-flu/vic/ha/nextclade.tsv.xz" + +data_provenances: + - gisaid + +variant_classifications: + - emerging_haplotype + - aa_haplotype lineages: - h1n1pdm @@ -18,18 +26,25 @@ geo_resolutions: - country - region +min_date: "6M" +max_date: "0D" + prepare_data: - country: - min_date: "6M" - max_date: "0D" - location_min_seq: 150 - clade_min_seq: 30 - region: - min_date: "6M" - max_date: "0D" - location_min_seq: 150 - clade_min_seq: 30 + gisaid: + emerging_haplotype: + country: + location_min_seq: 100 + clade_min_seq: 30 + region: + location_min_seq: 100 + clade_min_seq: 30 + aa_haplotype: + country: + location_min_seq: 100 + clade_min_seq: 30 + region: + location_min_seq: 100 + clade_min_seq: 30 haplotype_variant_column: "clade" -variant: "haplotype" -coloring_field: "emerging_haplotype" +mutations_column: "founderMuts['clade'].aaSubstitutions" diff --git a/scripts/add_colors_to_model.py b/scripts/add_colors_to_model.py index 50d0fb9..967a5af 100644 --- a/scripts/add_colors_to_model.py +++ b/scripts/add_colors_to_model.py @@ -8,11 +8,16 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model", required=True, help="JSON file of model to add colors to") parser.add_argument("--auspice-config", required=True, help="Auspice config JSON with a color scale to use to color variants in the given model") + parser.add_argument("--color-schemes", required=True, help="file with color schemes with N tab-delimited colors on row N") parser.add_argument("--coloring-field", required=True, help="name of the coloring field in the given Auspice config JSON where the color scale is stored") parser.add_argument("--output", required=True, help="JSON file of model with colors added") args = parser.parse_args() + # Load the model JSON. + with open(args.model, "r", encoding="utf-8") as fh: + model = json.load(fh) + # Get color scale from Auspice config. with open(args.auspice_config, "r", encoding="utf-8") as fh: auspice_config = json.load(fh) @@ -23,20 +28,33 @@ color_scale = coloring["scale"] break + # Assign colors from color schemes, if no scale is defined in the Auspice + # config for the requested coloring field. if color_scale is None: + # We need one color per named variant excluding the "other" label. + variants = [variant for variant in model["metadata"]["variants"] if variant != "other"] + n_colors = len(variants) + + # Load the required number of colors. + with open(args.color_schemes, "r", encoding="utf-8") as fh: + for line in fh: + if len(colors := line.strip().split("\t")) == n_colors: + color_scale = [ + [variant, color] + for variant, color in zip(variants, colors) + ] + break + + if color_scale: + # Add the color scale to the model JSON. + model["metadata"]["variantColors"] = color_scale + else: print( - f"ERROR: Could not find a color scale for the field '{args.coloring_field}' in the given Auspice config JSON.", + f"ERROR: Could not find a color scale for the field '{args.coloring_field}' in the given Auspice config JSON or assign enough colors from the color schemes file.", file=sys.stderr, ) sys.exit(1) - # Load the model JSON. - with open(args.model, "r", encoding="utf-8") as fh: - model = json.load(fh) - - # Add the color scale to the model JSON. - model["metadata"]["variantColors"] = color_scale - # Save the modified model JSON. with open(args.output, "w", encoding="utf-8") as oh: json.dump(model, oh) diff --git a/scripts/assign_aa_haplotypes.py b/scripts/assign_aa_haplotypes.py new file mode 100644 index 0000000..ac9041e --- /dev/null +++ b/scripts/assign_aa_haplotypes.py @@ -0,0 +1,104 @@ +"""Annotate a Nextclade-style TSV with an amino acid haplotype per records using each record's clade and amino acid +substitutions relative to the clade. +""" +import argparse +import json +import pandas as pd + + +def create_haplotype_for_record(record, clade_column, mutations_column, genes=None, strip_genes=False, sites_by_gene=None): + """Create a haplotype string for the given record based on the values in its + clade and mutations column. If a list of genes is given, filter mutations to + only those in the requested genes. + + """ + clade = record[clade_column] + + if record[mutations_column] == "": + return clade + + mutations = record[mutations_column].split(",") + + # Filter mutations to requested genes. + if sites_by_gene is not None: + filtered_mutations = [] + for mutation in mutations: + # mutation looks like "HA1:N145K" + gene, allele = mutation.split(":") + position = allele[1:-1] + if gene in sites_by_gene and position in sites_by_gene[gene]: + filtered_mutations.append(mutation) + + mutations = filtered_mutations + elif genes is not None: + mutations = [ + mutation + for mutation in mutations + if mutation.split(":")[0] in genes + ] + + mutations = "-".join(mutations).replace(":", "-") + + if mutations: + if strip_genes and genes is not None: + for gene in genes: + mutations = mutations.replace(f"{gene}-", "") + + return f"{clade}:{mutations}" + else: + return clade + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--nextclade", required=True, help="TSV file of Nextclade annotations with columns for clade and AA mutations derived from clade") + parser.add_argument("--clade-column", help="name of the branch attribute for clade labels in the given Nextclade annotations", default="subclade") + parser.add_argument("--mutations-column", help="name of the attribute for mutations relative to clades in the given Nextclade annotations", default="founderMuts['subclade'].aaSubstitutions") + parser.add_argument("--genes", nargs="+", help="list of genes to filter mutations to. If not provided, all mutations will be used.") + parser.add_argument("--distance-map", help="distance map JSON of genes and positions to include in haplotypes") + parser.add_argument("--strip-genes", action="store_true", help="strip gene names from coordinates in output haplotypes") + parser.add_argument("--attribute-name", default="haplotype", help="name of attribute to store the derived haplotype in the output file") + parser.add_argument("--output", help="TSV file of Nextclade annotations with derived haplotype column added", required=True) + args = parser.parse_args() + + # Load Nextclade annotations. + df = pd.read_csv( + args.nextclade, + sep="\t", + dtype={ + args.clade_column: "str", + args.mutations_column: "str", + }, + na_filter=False, + ) + + # Load distance map. + sites_by_gene = None + if args.distance_map: + with open(args.distance_map, "r", encoding="utf-8") as fh: + distance_map = json.load(fh) + sites_by_gene = distance_map["map"] + + # Annotate derived haplotypes. + df[args.attribute_name] = df.apply( + lambda record: create_haplotype_for_record( + record, + args.clade_column, + args.mutations_column, + args.genes, + args.strip_genes, + sites_by_gene, + ), + axis=1 + ) + + # Save updated Nextclade annotations + df.to_csv( + args.output, + sep="\t", + index=False, + ) diff --git a/viz/src/main.jsx b/viz/src/main.jsx index 6f34a90..a141fa1 100644 --- a/viz/src/main.jsx +++ b/viz/src/main.jsx @@ -1,4 +1,4 @@ -import React from 'react' +import React, {useMemo} from 'react' import ReactDOM from 'react-dom/client' import '@nextstrain/evofr-viz/dist/index.css'; import { PanelDisplay, useModelData } from '@nextstrain/evofr-viz'; @@ -29,10 +29,7 @@ const TABS = { } for (const [key,info] of Object.entries(TABS)) { info.modelName = key; - info.modelUrl = modelUrl(key); info.frequency = { - title: `Frequencies for ${info.displayName}`, - description: "", params: { preset: "frequency", rawDataToggleName: "Raw Data", @@ -40,8 +37,6 @@ for (const [key,info] of Object.entries(TABS)) { }, }; info.growthAdvantage = { - title: `Growth advantage for ${info.displayName}`, - description: "", params: {preset: "growthAdvantage"}, }; info.sites = { @@ -73,8 +68,14 @@ function getModelDate() { return (new URLSearchParams(window.location.search)).get('date'); } -function modelUrl(subtypeResolution) { - return `https://data.nextstrain.org/files/workflows/forecasts-flu/${subtypeResolution}/mlr/MLR_results.json`; +function modelUrl(variantClassification, subtypeResolution, modelDate) { + let url = `https://data.nextstrain.org/files/workflows/forecasts-flu/gisaid/${variantClassification}/${subtypeResolution}/mlr/MLR_results.json`; + + if (modelDate) { + url = url.replace(/([^/]+)$/, `${modelDate}_MLR_results.json`); + } + + return url; } /** @@ -98,15 +99,25 @@ function filterLocations(model, hierarchical=true) { function App() { - const [tabSelected, setTabSelected] = React.useState(getStartingTab) - const config = TABS[tabSelected]; let modelDate = getModelDate(); - if (modelDate) { - config.modelUrl = config.modelUrl.replace(/([^/]+)$/, `${modelDate}_MLR_results.json`) - } + const [tabSelected, setTabSelected] = React.useState(getStartingTab); + + /* configuration for the viz depends on the tab selected. Wrap this in + `useMemo` so that we only reconstruct it when the tab changes. Because the + data fetching (via `useModelData`) is triggered each time the config + changes we need to avoid recreating this config object else we get into an + infinite loop of data fetches! + */ + const config = useMemo(() => { + return { + emergingHaplotype: Object.assign({}, TABS[tabSelected], {modelUrl: modelUrl("emerging_haplotype", tabSelected, modelDate)}), + aaHaplotype: Object.assign({}, TABS[tabSelected], {modelUrl: modelUrl("aa_haplotype", tabSelected, modelDate)}), + } + }, [tabSelected]) // The `useModelData` hook downloads & parses the config-defined JSON - const model = useModelData(config) + const modelEmergingHaplotype = useModelData(config.emergingHaplotype); + const modelAAHaplotype = useModelData(config.aaHaplotype); function changeTab(key) { setTabSelected(key); @@ -129,26 +140,50 @@ function App() {
- {config.frequency && ( + {config.emergingHaplotype.frequency && ( + <> +

Emerging haplotype frequencies for {TABS[tabSelected].displayName}

+

+ Updated {modelEmergingHaplotype?.modelData?.get('updated') || 'loading'}. +

+
+ +
+ + )} + + {TABS[tabSelected].growthAdvantage && ( + <> +

Emerging haplotype growth advantages for {TABS[tabSelected].displayName}

+

+ Updated {modelEmergingHaplotype?.modelData?.get('updated') || 'loading'}. +

+
+ +
+ + )} + + {config.aaHaplotype.frequency && ( <> -

{config.frequency.title}

+

Amino acid haplotype frequencies for {TABS[tabSelected].displayName}

- {config.frequency.description}. Updated {model?.modelData?.get('updated') || 'loading'}. + Updated {modelAAHaplotype?.modelData?.get('updated') || 'loading'}.

- +
)} {TABS[tabSelected].growthAdvantage && ( <> -

{config.growthAdvantage.title}

+

Amino acid haplotype growth advantages for {TABS[tabSelected].displayName}

- {config.growthAdvantage.description}. Updated {model?.modelData?.get('updated') || 'loading'}. + Updated {modelAAHaplotype?.modelData?.get('updated') || 'loading'}.

- +
)} diff --git a/workflow/upload.smk b/workflow/upload.smk index c086c80..0ea6616 100644 --- a/workflow/upload.smk +++ b/workflow/upload.smk @@ -10,9 +10,9 @@ def _get_s3_url(w, input_file): rule copy_latest_model_results_to_dated: input: - latest_model_results = "results/{lineage}/{geo_resolution}/mlr/MLR_results.json", + latest_model_results="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/MLR_results.json", output: - dated_model_results = "results/{lineage}/{geo_resolution}/mlr/{date}_MLR_results.json", + dated_model_results="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/{date}_MLR_results.json", shell: """ cp {input.latest_model_results} {output.dated_model_results} @@ -20,9 +20,9 @@ rule copy_latest_model_results_to_dated: rule upload_dated_model_results_to_s3: input: - model_results = "results/{lineage}/{geo_resolution}/mlr/{date}_MLR_results.json" + model_results="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/{date}_MLR_results.json", output: - touch("results/{lineage}/{geo_resolution}/mlr/{date}_results_s3_upload.done") + touch("results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/{date}_results_s3_upload.done") params: s3_url=lambda w, input: _get_s3_url(w, input.model_results), shell: @@ -35,9 +35,9 @@ rule upload_dated_model_results_to_s3: rule upload_latest_model_results_to_s3: input: - model_results = "results/{lineage}/{geo_resolution}/mlr/MLR_results.json" + model_results="results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/MLR_results.json" output: - touch("results/{lineage}/{geo_resolution}/mlr/results_s3_upload.done") + touch("results/{data_provenance}/{variant_classification}/{lineage}/{geo_resolution}/mlr/results_s3_upload.done") params: s3_url=lambda w, input: _get_s3_url(w, input.model_results), shell: