diff --git a/workflow/Snakefile b/workflow/Snakefile index 47c90dec..10a95324 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -554,6 +554,44 @@ def get_output4cell_type_annotation() -> list[str]: for k in NORMALISATION_ID for m in filenames ] + +def get_combined_label_outputs() -> list[str]: + prefix: str = f'{config["output_path"]}/cell_type_annotation' + + samples_by_conditions: dict[str, list[str]] = get_dict_value( + config, + "experiments", + cc.EXPERIMENTS_COLLECTIONS_NAME, + cc.EXPERIMENTS_COLLECTIONS_CONDITIONS_NAME, + ) + + annotations_by_conditions: dict[str, list[str]] = get_dict_value( + config, + cc.WILDCARDS_NAME, + cc.WILDCARDS_CELL_TYPE_ANNOTATION_NAME, + ) + + samples_annotations: list[tuple[str, str]] = cross_values_by_key( + samples_by_conditions, + annotations_by_conditions, + ) + + return [ + os.path.join( + prefix, + seg, + sample, + norm, + annotation.split("/")[0], # approach (e.g. reference_based) + annotation.split("/")[1], # reference_name + "combined_labels.parquet", + ) + for seg in SEGMENTATION_ID + for (sample, annotation) in samples_annotations + for norm in NORMALISATION_ID + ] + + def get_output4joint_analysis() -> list[str] | None: @@ -1319,13 +1357,14 @@ def get_input2all( if cell_type_annotation: ret.extend(get_output4cell_type_annotation()) + ret.extend(get_combined_label_outputs()) if neighborhood_analysis: ret.extend(get_output4neighborhood_analysis()) if doublet_finding: ret.extend(get_output4doublet_finding()) - + if count_correction: ret.extend(get_output4count_correction( cell_type_annotation, diff --git a/workflow/rules/_cell_type_annotation/combine_annotation_results.smk b/workflow/rules/_cell_type_annotation/combine_annotation_results.smk new file mode 100644 index 00000000..39f03bf0 --- /dev/null +++ b/workflow/rules/_cell_type_annotation/combine_annotation_results.smk @@ -0,0 +1,30 @@ +####################################### +# Rules # +####################################### + +rule combineAnnotationLabels: + wildcard_constraints: + reference_name="[^/]+" + + input: + root_dir=lambda wc: os.path.join( + config["output_path"], + "cell_type_annotation", + wc.segmentation_id, + wc.sample_id, + wc.normalisation_id, + "reference_based", + wc.reference_name + ) + + output: + combined=f'{config["output_path"]}/cell_type_annotation/{{segmentation_id}}/{{sample_id}}/{{normalisation_id}}/reference_based/{{reference_name}}/combined_labels.parquet' + + log: + f'{config["output_path"]}/cell_type_annotation/{{segmentation_id}}/{{sample_id}}/{{normalisation_id}}/reference_based/{{reference_name}}/logs/combineAnnotationLabels.log' + container: + config["containers"]["r"] + threads: 1 + script: + "../../scripts/_cell_type_annotation/combine_annotation_results.R" + diff --git a/workflow/rules/cell_type_annotation.smk b/workflow/rules/cell_type_annotation.smk index 2615a989..567ff09f 100644 --- a/workflow/rules/cell_type_annotation.smk +++ b/workflow/rules/cell_type_annotation.smk @@ -135,3 +135,5 @@ include: "_cell_type_annotation/_reference_based/rctd.smk" include: "_cell_type_annotation/_reference_based/singler.smk" include: "_cell_type_annotation/_reference_based/seurat.smk" include: "_cell_type_annotation/_reference_based/xgboost.smk" + +include: "_cell_type_annotation/combine_annotation_results.smk" diff --git a/workflow/scripts/_cell_type_annotation/combine_annotation_results.R b/workflow/scripts/_cell_type_annotation/combine_annotation_results.R new file mode 100644 index 00000000..f8d06efc --- /dev/null +++ b/workflow/scripts/_cell_type_annotation/combine_annotation_results.R @@ -0,0 +1,51 @@ +log <- file(snakemake@log[[1]], open = "wt") +sink(log, type = "output") +sink(log, type = "message") + +library(dplyr) +library(arrow) +library(stringr) +library(purrr) + +input_dir <- snakemake@input[["root_dir"]] + +# Find all labels.parquet files under input_dir +label_files <- list.files( + path = input_dir, + pattern = "^labels\\.parquet$", + full.names = TRUE, + recursive = TRUE +) + +if (length(label_files) == 0) { + stop("No labels.parquet files found under: ", input_dir) +} + + +read_and_annotate <- function(path) { + df <- read_parquet(path) + + # Extract tool/level/mode from path + parts <- str_split(path, .Platform$file.sep)[[1]] + n <- length(parts) + + if (n < 4) { + stop("Unexpected path structure (too few components): ", path) + } + + mode <- parts[n - 1] + level <- parts[n - 2] + tool <- parts[n - 3] + + prefix <- paste(level, tool, mode, sep = "/") + + # Rename columns except cell_id + colnames(df)[colnames(df) != "cell_id"] <- paste0(prefix, "/", colnames(df)[colnames(df) != "cell_id"]) + return(df) +} + +merged <- label_files %>% + map(read_and_annotate) %>% + reduce(full_join, by = "cell_id") + +write_parquet(merged, snakemake@output[["combined"]])