Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 129 additions & 15 deletions openfold3/core/data/tools/colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Literal, NamedTuple
from typing import Any, Literal, NamedTuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,6 +60,59 @@ def __str__(self) -> str:
return self.name.lower()


def _parse_a3m_file_by_m(a3m_file_path: Path, m_filter: dict[int, Any] | None = None) -> dict[int, list[str]]:
"""Parse an a3m file and extract MSA sections by M value.

This function uses the same parsing logic as the main query_colabfold_msa_server
to extract individual MSA sections from batch a3m files.

Real a3m files from ColabFold have null bytes (\x00) before new M value headers,
which is how the parser detects new M sections.

Args:
a3m_file_path: Path to the a3m file to parse
m_filter: Optional dictionary mapping M values to filter by. If provided,
only M values in this dictionary will be included in the result.
If None, all M values will be included.

Returns:
Dictionary mapping M values to lists of lines (the MSA section for that M)
"""
a3m_lines = {}
try:
with open(a3m_file_path) as f:
update_M, M = True, None
for line in f:
if len(line) > 0:
if "\x00" in line:
line = line.replace("\x00", "")
update_M = True
if line.startswith(">") and update_M:
try:
M = int(line[1:].rstrip())
update_M = False
# If filter is provided, only include M values in the filter
if m_filter is None or M in m_filter:
if M not in a3m_lines:
a3m_lines[M] = []
except ValueError:
# Not a pure integer (e.g., UniRef header), keep current M
pass
# Add line to current M section (only if M is set and in filter if provided)
if M is not None:
if m_filter is None or M in m_filter:
a3m_lines[M].append(line)
except FileNotFoundError:
# File doesn't exist - return empty dict without logging (expected in some cases)
return {}
except Exception as e:
# Only log warnings for unexpected errors
logger.warning(f"Error parsing a3m file {a3m_file_path}: {e}")
return {}

return a3m_lines


def query_colabfold_msa_server(
x: list[str],
prefix: Path,
Expand Down Expand Up @@ -395,21 +448,15 @@ def download(ID, path):
# Gather a3m lines
a3m_lines = {}
for a3m_file in a3m_files:
update_M, M = True, None
with open(a3m_file) as f:
for line in f:
if len(line) > 0:
if "\x00" in line:
line = line.replace("\x00", "")
update_M = True
if line.startswith(">") and update_M:
M = int(line[1:].rstrip())
update_M = False
if M not in a3m_lines:
a3m_lines[M] = []
a3m_lines[M].append(line)
file_a3m_lines = _parse_a3m_file_by_m(Path(a3m_file))
# Merge into main dict
for M, lines in file_a3m_lines.items():
if M not in a3m_lines:
a3m_lines[M] = []
a3m_lines[M].extend(lines)

a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
# Only include M values that exist in a3m_lines to avoid KeyError
a3m_lines = ["".join(a3m_lines[n]) for n in Ms if n in a3m_lines]

return (a3m_lines, template_paths) if use_templates else a3m_lines

Expand Down Expand Up @@ -750,6 +797,73 @@ def query_format_main(self):
index=False,
)

# Copy raw a3m files to raw_colabfold_output directory after batch processing
self._organize_raw_main_outputs_by_query()

def _organize_raw_main_outputs_by_query(self):
"""Copy raw main MSA a3m files to raw_colabfold_output directory.

This method extracts individual MSA sections from the batch a3m files
and saves them as {rep_id}/{filename}.a3m in raw_colabfold_output directory.
Only a3m files are copied, not tar.gz or m8 files.
"""
raw_main_dir = self.output_directory / "raw" / "main"
if not raw_main_dir.exists():
return

# Build M -> rep_id mapping
# M is the ColabFold/MMseqs2 internal sequence identifier (starts at 101)
# It's the number used in a3m file headers like >101, >102, etc.
m_to_rep_id = {m: rep_id for rep_id, m in self.colabfold_mapper.rep_id_to_m.items()}

if not m_to_rep_id:
logger.warning(f"No M to rep_id mapping found. rep_id_to_m: {self.colabfold_mapper.rep_id_to_m}")
return

logger.info(f"Found {len(m_to_rep_id)} M values in mapping: {list(m_to_rep_id.keys())}")

# Create raw_colabfold_output directory
raw_colabfold_output_dir = self.output_directory / "raw_colabfold_output"
raw_colabfold_output_dir.mkdir(parents=True, exist_ok=True)

# Find all .a3m files in raw/main directory dynamically
a3m_files = [f for f in raw_main_dir.iterdir() if f.is_file() and f.suffix == ".a3m"]

if not a3m_files:
logger.warning(f"No .a3m files found in {raw_main_dir}")
return

logger.info(f"Found {len(a3m_files)} a3m files to process: {[f.name for f in a3m_files]}")

# Process each a3m file found in raw/main
for a3m_path in a3m_files:
a3m_file = a3m_path.name

# Extract MSA sections by M value using the shared parsing function
# Filter by m_to_rep_id to only extract sections for sequences we care about
msa_sections = _parse_a3m_file_by_m(a3m_path, m_filter=m_to_rep_id)

# Save each MSA section to rep_id-specific directory
logger.info(f"Found {len(msa_sections)} MSA sections in {a3m_file}: {list(msa_sections.keys())}")
for M, lines in msa_sections.items():
if M not in m_to_rep_id:
logger.warning(f"M value {M} not found in m_to_rep_id mapping")
continue
rep_id = m_to_rep_id[M]

# Create rep_id-specific directory
rep_dir = raw_colabfold_output_dir / str(rep_id)
rep_dir.mkdir(parents=True, exist_ok=True)

# Save MSA section to rep_id directory
output_file = rep_dir / a3m_file
try:
with open(output_file, "w") as f:
f.writelines(lines)
logger.info(f"Saved MSA section for M={M}, rep_id={rep_id} to {output_file}")
except Exception as e:
logger.warning(f"Error writing {output_file}: {e}")

def query_format_paired(self):
"""Submits queries and formats the outputs for paired MSAs."""
paired_alignments_directory = self.output_directory / "paired"
Expand Down
Loading