Source code for lXtractor.protocols.superpose

"""
A sandbox module to encapsulate high-level operations based on core
`lXtractor`'s functionality.
"""
import logging
import typing as t
from collections import abc
from itertools import combinations

import biotite.structure as bst
import numpy as np
from toolz import curry

from lXtractor.chain import ChainStructure
from lXtractor.chain import filter_selection_extended, subset_to_matching
from lXtractor.core.exceptions import MissingData, LengthMismatch, InitError
from lXtractor.util import apply
from lXtractor.util.seq import biotite_align
from lXtractor.util.structure import filter_to_common_atoms

LOGGER = logging.getLogger(__name__)

_InpSuperpose: t.TypeAlias = tuple[str, bst.AtomArray, bst.AtomArray | None]
_InpAlignSuperpose: t.TypeAlias = tuple[str, ChainStructure, bst.AtomArray | None]
_OutSuperpose: t.TypeAlias = tuple[
    str, str, float, t.Any, tuple[np.ndarray, np.ndarray, np.ndarray]
]
_StagedSupInp = t.TypeVar("_StagedSupInp", _InpSuperpose, _InpAlignSuperpose)
_DistFn: t.TypeAlias = abc.Callable[[bst.AtomArray, bst.AtomArray], t.Any]
_Selection: t.TypeAlias = tuple[
    abc.Sequence[int] | None,
    abc.Sequence[abc.Sequence[str]] | abc.Sequence[str] | None,
]
_Selector: t.TypeAlias = abc.Callable[[ChainStructure], bst.AtomArray]

SuperposeOutput = t.NamedTuple(
    "SuperposeOutput",
    [
        ("ID_fix", str),
        ("ID_mob", str),
        ("RmsdSuperpose", float),
        ("Distance", t.Any),
        ("Transformation", tuple[np.ndarray, np.ndarray, np.ndarray]),
    ],
)


@curry
def _stage_inp(
    c: ChainStructure,
    selection_superpose: _Selection | _Selector,
    selection_dist: _Selection | _Selector | None,
    map_name: str | None,
    exclude_hydrogen: bool,
    to_array: bool,
    tolerate_missing: bool,
) -> _InpSuperpose | _InpAlignSuperpose:
    def init_sub_chain(a):
        try:
            return ChainStructure(a, c.chain_id)
        except Exception as e:
            raise InitError(
                f"Failed to create ChainStructure from array {a} for {c}"
            ) from e

    def apply_selection(sel: _Selection) -> bst.AtomArray:
        pos, atoms = sel
        mask = filter_selection_extended(
            c,
            pos=pos,
            atom_names=atoms,
            map_name=map_name,
            exclude_hydrogen=exclude_hydrogen,
            tolerate_missing=tolerate_missing,
        )
        return c.array[mask]

    if isinstance(selection_superpose, tuple):
        a_sup = apply_selection(selection_superpose)
    else:
        a_sup = selection_superpose(c)

    if selection_dist is None:
        a_dist = None
    else:
        if isinstance(selection_dist, tuple):
            a_dist = apply_selection(selection_dist)
        else:
            a_dist = selection_dist(c)

    if len(a_sup) == 0:
        raise MissingData(f"Empty selection for superposition atoms in structure {c}")

    if to_array:
        return c.id, a_sup, a_dist

    return c.id, init_sub_chain(a_sup), a_dist


def _yield_staged_pairs(
    fixed: abc.Iterable[ChainStructure],
    mobile: abc.Iterable[ChainStructure] | None,
    stage: abc.Callable[[ChainStructure], _StagedSupInp],
) -> abc.Generator[tuple[_StagedSupInp, _StagedSupInp], None, None]:
    _fixed = map(stage, fixed)
    if mobile is None:
        yield from combinations(_fixed, 2)
    else:
        _mobile = list(map(stage, mobile))
        for fs in _fixed:
            for ms in _mobile:
                yield fs, ms


[docs] def superpose_pair( pair: tuple[_InpSuperpose, _InpSuperpose], dist_fn: _DistFn | None ) -> _OutSuperpose: """ A function performing superposition and rmsd calculation of already prepared :class:`AtomArray` objects. Each must have the same number of atoms. :param pair: A pair of staged inputs. A staged input is a tuple with an identifier, an atom array to superpose, and an optional atom array for the `dist_fn`. :param dist_fn: An optional distance function accepting two positional args: "fixed" atom array and superposed atom array. :return: a tuple with id_fixed, id_mobile, rmsd of the superposed atoms, calculated distance, and the transformation matrices. """ # f_ for fixed, m_ for mobile (f_id, f_array_sup, f_array_dist), (m_id, m_array_sup, m_array_dist) = pair if len(f_array_sup) != len(m_array_sup): raise LengthMismatch( "For superposition, expected fixed and mobile array to have " f"the same number of atoms, but {len(f_array_sup)} != {len(m_array_sup)}" ) _, transformation = bst.superimpose(f_array_sup, m_array_sup) m_array_sup = bst.superimpose_apply(m_array_sup, transformation) rmsd_sup = bst.rmsd(f_array_sup, m_array_sup) if all(x is not None for x in [dist_fn, f_array_dist, m_array_dist]): m_array_dist = bst.superimpose_apply(m_array_dist, transformation) dist = dist_fn(f_array_dist, m_array_dist) else: dist = None return f_id, m_id, rmsd_sup, dist, transformation
[docs] def align_and_superpose_pair( pair: tuple[_InpAlignSuperpose, _InpAlignSuperpose], dist_fn: _DistFn | None, skip_aln_if_match: str, ) -> _OutSuperpose: """ Use sequence alignment to subset each chain structure in `pair` to common aligned residues and common atoms in each aligned residue pair. Use :func:`superpose_pair` to superpose the atom arrays from subsetted chain structures. :param pair: A pair of staged inputs. :param dist_fn: An optional distance function accepting two positional args: "fixed" atom array and superposed atom array. :param skip_aln_if_match: Passed to :func:`lXtractor.core.chain.subset_to_matching`. :return: a tuple with id_fixed, id_mobile, rmsd of the superposed atoms, calculated distance, and the transformation matrices. """ (f_id, f_str_sup, f_str_dist), (m_id, m_str_sup, m_str_dist) = pair f_str_aln, m_str_aln = subset_to_matching( f_str_sup, m_str_sup, skip_if_match=skip_aln_if_match, align_method=biotite_align, name="Mobile", ) f_mask, m_mask = filter_to_common_atoms( f_str_aln.array, m_str_aln.array, allow_residue_mismatch=True ) return superpose_pair( ( (f_id, f_str_aln.array[f_mask], f_str_dist), (m_id, m_str_aln.array[m_mask], m_str_dist), ), dist_fn, )
[docs] def superpose_pairwise( fixed: abc.Iterable[ChainStructure], mobile: abc.Iterable[ChainStructure] | None = None, selection_superpose: _Selection | _Selector = (None, None), selection_dist: _Selection | _Selector | None = None, dist_fn: _DistFn | None = None, *, strict: bool = True, map_name: str | None = None, exclude_hydrogen: bool = False, skip_aln_if_match: str = "len", verbose: bool = False, num_proc: int = 1, **kwargs, ) -> abc.Generator[SuperposeOutput, None, None]: """ Superpose pairs of structures. Two modes are available: 1. ``strict=True`` -- potentially faster and memory efficient, more parallelization friendly. In this case, after selection using the provided positions and atoms, the number of atoms between each fixed and mobile structure must match exactly. 2. ``strict=False`` -- a "flexible" protocol. In this case, after the selection of atoms, there are two additional steps: 1. Sequence alignment between the selected subsets. It's guaranteed to produce the same number of residues between fixed and mobile, which may be less than the initially selected number (see :func:`subset_to_matching`). 2. Following this, subset each pair of residues between fixed and mobile to a common list of atoms (see :func:`filter_to_common_atoms <lXtractor.util.structure.filter_to_common_atoms>`). As a result, the "flexible" mode may be suitable for distantly related structures, while the "strict" mode may be used whenever it's guaranteed that the selection will produce the same sets of atoms between fixed and mobile. .. seealso:: :func:`lXtractor.util.structure.filter_selection_extended` -- used to apply the selections. :param fixed: An iterable over chain structures that won't be moved. :param mobile: An iterable over chain structures to superpose onto fixed ones. If ``None``, will use the combinations of `fixed`. :param selection_superpose: A tuple with (residue positions, atom names) to select atoms for superposition, which will be applied to each `fixed` and `mobile` structure. If ``(None, None)``, will use all positions and atoms. Alternatively, a selector function accepting a chain structure and returning an atom array. If `strict` is ``False``, it will convert the selected atom array to a chain structure. :param selection_dist: Same as `selection_superpose`. In addition, accepts ``None`` to indicate an empty selection, in which case, `dist_fn` should also be ``None``. :param dist_fn: An optional distance function applied to a pair of superposed atom arrays, possibly different from the arrays selected for superposition, which is controlled via `selection_dist`. :param map_name: Mapping for positions in both selection arguments. If used, must exist within :attr:`Seq <lXtractor.core.chain. ChainStructure._seq>` of each fixed and mobile structure. A good candidate is a mapping to a reference sequence or :class:`Alignment <lXtractor.core.alignment.Alignment>`. :param exclude_hydrogen: Exclude all hydrogen atoms during selection. :param strict: Enable/disable the "strict" protocol. See the explanation above. :param skip_aln_if_match: Skip the sequence alignment if this field matches. :param verbose: Display progress bar. :param num_proc: The number of parallel processes. For large selections, may consume a lot of RAM, so caution advised. :param kwargs: Passed to :meth:`ProcessPoolExecutor.map`. Useful for controlling `chunksize` and `timeout` parameters. :return: A generator of ``namedtuple`` outputs each containing the IDs of the superposed objects, the RMSD between superposed structures, the distance function output, and the transformation matrices. """ stage = _stage_inp( # pylint: disable=no-value-for-parameter selection_superpose=selection_superpose, selection_dist=selection_dist, map_name=map_name, exclude_hydrogen=exclude_hydrogen, to_array=strict, tolerate_missing=not strict, ) pairs = _yield_staged_pairs(fixed, mobile, stage) n = None if verbose: if isinstance(fixed, abc.Sized): if isinstance(mobile, abc.Sized): n = len(fixed) * len(mobile) else: n = int(len(fixed) * (len(fixed) - 1) / 2) fn = ( curry(superpose_pair)(dist_fn=dist_fn) if strict else curry(align_and_superpose_pair)( # pylint: disable=no-value-for-parameter skip_aln_if_match=skip_aln_if_match, dist_fn=dist_fn ) ) results = apply( fn, pairs, verbose, "Superposing pairs", num_proc, n, use_joblib=True, **kwargs ) yield from map(lambda x: SuperposeOutput(*x), results)
if __name__ == "__main__": raise RuntimeError