Skip to content

Commit 548fdcc

Browse files
author
dohyun-s
committed
implement unpaired alignment && fix colabfold_search error
1 parent 7ad01b8 commit 548fdcc

File tree

4 files changed

+104
-53
lines changed

4 files changed

+104
-53
lines changed

colabfold/batch.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121

2222
from colabfold.inputs import (
2323
get_queries_pairwise, unpack_a3ms,
24-
parse_fasta, get_queries,
24+
parse_fasta, get_queries, msa_to_str
2525
)
26+
from colabfold.run_alphafold import set_model_type
2627

2728
from colabfold.download import default_data_dir, download_alphafold_params
28-
29+
import sys
2930
import logging
3031
logger = logging.getLogger(__name__)
3132

@@ -304,17 +305,31 @@ def main():
304305
for x in batch:
305306
if x not in query_seqs_unique:
306307
query_seqs_unique.append(x)
308+
query_seqs_cardinality = [0] * len(query_seqs_unique)
309+
for seq in batch:
310+
seq_idx = query_seqs_unique.index(seq)
311+
query_seqs_cardinality[seq_idx] += 1
307312
use_env = "env" in args.msa_mode or "Environmental" in args.msa_mode
308313
paired_a3m_lines = run_mmseqs2(
309314
query_seqs_unique,
310-
str(Path(args.results).joinpath(str(jobname))),
315+
str(Path(args.results).joinpath(str(jobname)+"_paired")),
311316
use_env=use_env,
312317
use_pairwise=True,
313318
use_pairing=True,
314319
host_url=args.host_url,
315320
)
316321

317-
path_o = Path(args.results).joinpath(f"{jobname}_pairwise")
322+
if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired":
323+
unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env")
324+
unpaired_a3m_lines = run_mmseqs2(
325+
query_seqs_unique,
326+
str(Path(args.results).joinpath(str(jobname)+"_unpaired")),
327+
use_env=use_env,
328+
use_pairwise=False,
329+
use_pairing=False,
330+
host_url=args.host_url,
331+
)
332+
path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise")
318333
for filenum in path_o.iterdir():
319334
queries_new = []
320335
if Path(filenum).suffix.lower() == ".a3m":
@@ -326,6 +341,16 @@ def main():
326341
query_sequence = seqs[0]
327342
a3m_lines = [Path(file).read_text()]
328343
val = int(header[0].split('\t')[1][1:]) - 102
344+
# match paired seq id and unpaired seq id
345+
if args.pair_mode == "none" or "unpaired" or "unpaired_paired":
346+
tmp = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1]
347+
a3m_lines = [msa_to_str(
348+
[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]
349+
)]
350+
## Another way: do not use msa_to_str and unserialize function rather
351+
## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments..
352+
##
353+
# a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]]
329354
queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines))
330355

331356
if args.sort_queries_by == "length":

colabfold/inputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def pad_input_multimer(
8585
pad_len: int,
8686
use_templates: bool,
8787
) -> model.features.FeatureDict:
88-
model_config = model_runner.config
8988
shape_schema = {
9089
"aatype": ["num residues placeholder"],
9190
"residue_index": ["num residues placeholder"],
@@ -696,7 +695,7 @@ def unserialize_msa(
696695
)
697696
prev_query_start += query_len
698697
paired_msa = [""] * len(query_seq_len)
699-
unpaired_msa = None
698+
unpaired_msa = [""] * len(query_seq_len)
700699
already_in = dict()
701700
for i in range(1, len(a3m_lines), 2):
702701
header = a3m_lines[i]
@@ -734,7 +733,6 @@ def unserialize_msa(
734733
paired_msa[j] += ">" + header_no_faster_split[j] + "\n"
735734
paired_msa[j] += seqs_line[j] + "\n"
736735
else:
737-
unpaired_msa = [""] * len(query_seq_len)
738736
for j, seq in enumerate(seqs_line):
739737
if has_amino_acid[j]:
740738
unpaired_msa[j] += header + "\n"
@@ -752,6 +750,8 @@ def unserialize_msa(
752750
template_feature = mk_mock_template(query_seq)
753751
template_features.append(template_feature)
754752

753+
if unpaired_msa == [""] * len(query_seq_len):
754+
unpaired_msa = None
755755
return (
756756
unpaired_msa,
757757
paired_msa,

colabfold/mmseqs/search.py

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,9 @@ def get_queries_pairwise(
3232
df = pandas.read_csv(input_path, sep=sep)
3333
assert "id" in df.columns and "sequence" in df.columns
3434
queries = [
35-
(str(df["id"][0])+'&'+str(seq_id), [df["sequence"][0].upper(),sequence.upper()], None)
36-
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) if i!=0
35+
(str(df["id"][0])+'&'+str(seq_id), sequence.upper(), None)
36+
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False))
3737
]
38-
for i in range(len(queries)):
39-
if len(queries[i][1]) == 1:
40-
queries[i] = (queries[i][0], queries[i][1][0], None)
4138
elif input_path.suffix == ".a3m":
4239
raise NotImplementedError()
4340
elif input_path.suffix in [".fasta", ".faa", ".fa"]:
@@ -47,9 +44,7 @@ def get_queries_pairwise(
4744
sequence = sequence.upper()
4845
if sequence.count(":") == 0:
4946
# Single sequence
50-
if i==0:
51-
continue
52-
queries.append((headers[0]+'&'+header, [sequences[0],sequence], None))
47+
queries.append((header, sequence, None))
5348
else:
5449
# Complex mode
5550
queries.append((header, sequence.upper().split(":"), None))
@@ -449,9 +444,9 @@ def main():
449444
args = parser.parse_args()
450445

451446
if args.interaction_scan:
452-
queries, is_complex = get_queries_pairwise(args.query, None)
447+
queries, is_complex = get_queries_pairwise(args.query)
453448
else:
454-
queries, is_complex = get_queries(args.query, None)
449+
queries, is_complex = get_queries(args.query)
455450

456451
queries_unique = []
457452
for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries):
@@ -481,10 +476,9 @@ def main():
481476
query_seqs_cardinality,
482477
) in enumerate(queries_unique):
483478
if job_number==0:
484-
f.write(f">{raw_jobname}_0\n{query_sequences[0]}\n")
485-
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
479+
f.write(f">{raw_jobname}_0\n{query_sequences}\n")
486480
else:
487-
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
481+
f.write(f">{queries_unique[0][0]+'&'+raw_jobname}\n{query_sequences}\n")
488482
else:
489483
with query_file.open("w") as f:
490484
for job_number, (
@@ -498,18 +492,6 @@ def main():
498492
args.mmseqs,
499493
["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"],
500494
)
501-
with args.base.joinpath("qdb.lookup").open("w") as f:
502-
id = 0
503-
file_number = 0
504-
for job_number, (
505-
raw_jobname,
506-
query_sequences,
507-
query_seqs_cardinality,
508-
) in enumerate(queries_unique):
509-
for seq in query_sequences:
510-
f.write(f"{id}\t{raw_jobname}\t{file_number}\n")
511-
id += 1
512-
file_number += 1
513495

514496
mmseqs_search_monomer(
515497
mmseqs=args.mmseqs,
@@ -542,30 +524,66 @@ def main():
542524
interaction_scan=args.interaction_scan,
543525
)
544526

527+
if args.interaction_scan:
528+
if len(queries_unique) > 1:
529+
for i in range(len(queries_unique)-2):
530+
idx = 2 + i*2
531+
## delete duplicated query files 2.paired, 4.paired...
532+
os.remove(args.base.joinpath(f"{idx}.paired.a3m"))
533+
for j in range(len(queries_unique)-2):
534+
# replace targets' right file name
535+
id1 = j*2 + 3
536+
id2 = j + 2
537+
os.replace(args.base.joinpath(f"{id1}.paired.a3m"), args.base.joinpath(f"{id2}.paired.a3m"))
538+
545539
id = 0
546-
for job_number, (
547-
raw_jobname,
548-
query_sequences,
549-
query_seqs_cardinality,
550-
) in enumerate(queries_unique):
551-
unpaired_msa = []
552-
paired_msa = None
553-
if len(query_seqs_cardinality) > 1:
540+
if not args.interaction_scan:
541+
for job_number, (
542+
raw_jobname,
543+
query_sequences,
544+
query_seqs_cardinality,
545+
) in enumerate(queries_unique):
546+
unpaired_msa = []
547+
paired_msa = None
548+
if len(query_seqs_cardinality) > 1:
549+
paired_msa = []
550+
else:
551+
for seq in query_sequences:
552+
with args.base.joinpath(f"{id}.a3m").open("r") as f:
553+
unpaired_msa.append(f.read())
554+
args.base.joinpath(f"{id}.a3m").unlink()
555+
if len(query_seqs_cardinality) > 1:
556+
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
557+
paired_msa.append(f.read())
558+
args.base.joinpath(f"{id}.paired.a3m").unlink()
559+
id += 1
560+
msa = msa_to_str(
561+
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
562+
)
563+
args.base.joinpath(f"{job_number}.a3m").write_text(msa)
564+
else:
565+
for job_number, _ in enumerate(queries_unique[:-1]):
566+
query_sequences = [queries_unique[0][1], queries_unique[job_number+1][1]]
567+
unpaired_msa = []
554568
paired_msa = []
555-
for seq in query_sequences:
556-
with args.base.joinpath(f"{id}.a3m").open("r") as f:
569+
with args.base.joinpath(f"0.a3m").open("r") as f:
570+
unpaired_msa.append(f.read())
571+
with args.base.joinpath(f"{job_number+1}.a3m").open("r") as f:
557572
unpaired_msa.append(f.read())
558-
args.base.joinpath(f"{id}.a3m").unlink()
559-
if len(query_seqs_cardinality) > 1:
560-
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
561-
paired_msa.append(f.read())
562-
args.base.joinpath(f"{id}.paired.a3m").unlink()
563-
id += 1
564-
msa = msa_to_str(
565-
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
566-
)
567-
args.base.joinpath(f"{job_number}.a3m").write_text(msa)
568573

574+
with args.base.joinpath(f"0.paired.a3m").open("r") as f:
575+
paired_msa.append(f.read())
576+
with args.base.joinpath(f"{job_number+1}.paired.a3m").open("r") as f:
577+
paired_msa.append(f.read())
578+
msa = msa_to_str(
579+
unpaired_msa, paired_msa, query_sequences, [1,1]
580+
)
581+
args.base.joinpath(f"{job_number}_final.a3m").write_text(msa)
582+
for job_number, _ in enumerate(queries_unique):
583+
args.base.joinpath(f"{job_number}.a3m").unlink()
584+
args.base.joinpath(f"{job_number}.paired.a3m").unlink()
585+
for job_number, _ in enumerate(queries_unique[:-1]):
586+
os.replace(args.base.joinpath(f"{job_number}_final.a3m"), args.base.joinpath(f"{job_number}.a3m"))
569587
query_file.unlink()
570588
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
571589
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])

colabfold/run_alphafold.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,15 @@ def run(
259259
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \
260260
= unserialize_msa(a3m_lines, query_sequence)
261261
if not use_templates: template_features = template_features_
262-
262+
## Another way passing argument
263+
##
264+
# (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines
265+
# template_features_ = []
266+
# from colabfold.inputs import mk_mock_template
267+
# for query_seq in query_seqs_unique:
268+
# template_feature = mk_mock_template(query_seq)
269+
# template_features_.append(template_feature)
270+
# if not use_templates: template_features = template_features_
263271
# save a3m
264272
msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality)
265273
result_dir.joinpath(f"{jobname}.a3m").write_text(msa)

0 commit comments

Comments
 (0)