Source code for lXtractor.variables.manager

"""
``Manager`` handles variable calculations, such as:

#. Variable manipulations (assignment, deletions, and resetting).
#. Calculation of variables. Simply manages the calculation process, whereas
    calculators (:class:`lXtractor.variables.calculator.GenericCalculator`
    for instance) do the heavy lifting.
#. Aggregation of the calculation results, either
    :meth:`from_chains <Manager.aggregate_from_chains>` or
    :meth:`from_iterable <Manager.aggregate_from_it>`.

"""
import logging
import typing as t
from collections import abc
from itertools import chain, repeat, tee

import numpy as np
import pandas as pd
from more_itertools import unzip, peekable, split_when, chunked
from toolz import curry
from tqdm.auto import tqdm

import lXtractor.chain as lxc
from lXtractor.core import Ligand
from lXtractor.core.config import DefaultConfig
from lXtractor.core.exceptions import MissingData
from lXtractor.core.structure import GenericStructure
from lXtractor.variables.base import (
    SequenceVariable,
    StructureVariable,
    Variables,
    AbstractCalculator,
    LigandVariable,
)

# TODO: get proper target by using `seq_name` from a seq var definition

T = t.TypeVar("T")
LigInp: t.TypeAlias = tuple[lxc.ChainStructure, Ligand]
Inp: t.TypeAlias = lxc.ChainSequence | lxc.ChainStructure | LigInp
InpT = t.TypeVar("InpT", lxc.ChainStructure, lxc.ChainSequence, LigInp)
InpV: t.TypeAlias = SequenceVariable | StructureVariable | LigandVariable
CalcRes: t.TypeAlias = tuple[Inp, InpV, bool, t.Any]
StagedSeq: t.TypeAlias = tuple[
    lxc.ChainSequence,
    abc.Sequence[t.Any],
    abc.Sequence[SequenceVariable],
    abc.Mapping[int, int] | None,
]
StagedStr: t.TypeAlias = tuple[
    lxc.ChainStructure,
    GenericStructure,
    abc.Sequence[StructureVariable],
    abc.Mapping[int, int] | None,
]
StagedLig: t.TypeAlias = tuple[
    tuple[lxc.ChainStructure, Ligand],
    Ligand,
    abc.Sequence[LigandVariable],
    abc.Mapping[int, int] | None,
]

LOGGER = logging.getLogger(__name__)


def _update_variables(vs: Variables, upd: abc.Iterable[InpV]) -> Variables:
    for v in upd:
        vs[v] = None
    return vs


[docs] def get_mapping(obj: t.Any, map_name: str | None, map_to: str | None) -> dict | None: """ Obtain mapping from a Chain*-type object. >>> s = lxc.ChainSequence.from_string('ABCD', name='_seq') >>> s.add_seq('some_map', [5, 6, 7, 8]) >>> s.add_seq('another_map', ['D', 'B', 'C', 'A']) >>> get_mapping(s, 'some_map', None) {5: 1, 6: 2, 7: 3, 8: 4} >>> get_mapping(s, 'another_map', 'some_map') {'D': 5, 'B': 6, 'C': 7, 'A': 8} :param obj: Chain*-type object. If not a Chain*-type object, raises `AttributeError`. :param map_name: The name of a map to create the mapping from. If ``None``, the resulting mapping is ``None``. :param map_to: The name of a map to create a mapping to. If ``None``, will default to the real sequence indices (1-based) for a :class:`ChainSequence <lXtractor.core.chain.ChainSequence>` object and to the structure actual numbering for the :class:`ChainStructure <lXtractor.core.chain.ChainStructure>`. :return: A dictionary mapping from the `map_name` sequence to `map_to` sequence. """ if map_name is None: return None if not isinstance(obj, lxc.ChainSequence): try: seq = obj.seq except AttributeError as e: raise MissingData(f"Object {obj} is missing `seq` attribute") from e else: seq = obj fr = seq[map_name] if map_to is None: if isinstance(obj, lxc.ChainStructure): to = seq[DefaultConfig["mapnames"]["enum"]] else: to = range(1, len(fr) + 1) else: to = seq[map_to] return dict(filter(lambda x: x[0] is not None, zip(fr, to, strict=True)))
@t.overload def _get_vs(obj: lxc.ChainStructure, missing) -> list[StructureVariable]: ... @t.overload def _get_vs(obj: lxc.ChainSequence, missing) -> list[SequenceVariable]: ... def _get_vs( obj: Inp, missing: bool ) -> list[SequenceVariable] | list[StructureVariable]: if missing: return [v for v, r in obj.variables.items() if r is None] return list(obj.variables) def _filter_type(xs: abc.Iterable[t.Any], _t: t.Type[T]) -> abc.Iterator[T]: return filter(lambda x: isinstance(x, _t), xs) def _split_objects( objs: abc.Iterable[Inp], ) -> tuple[list[lxc.ChainSequence], list[lxc.ChainStructure], list[Ligand]]: types = [lxc.ChainSequence, lxc.ChainStructure, LigandVariable] seqs, strs, ligs = ( list(_filter_type(xs, _t)) for xs, _t in zip(tee(objs, len(types)), types) ) return seqs, strs, ligs def _split_variables( vs: abc.Sequence[InpV], ) -> tuple[list[SequenceVariable], list[StructureVariable], list[LigandVariable]]: seq_vs, str_vs, lig_vs = ( list(_filter_type(vs, _t)) for _t in [SequenceVariable, StructureVariable, LigandVariable] ) return seq_vs, str_vs, lig_vs @t.overload def stage( obj: lxc.ChainStructure, vs, *, missing, seq_name, map_name, map_to ) -> StagedStr: ... @t.overload def stage( obj: lxc.ChainSequence, vs, *, missing, seq_name, map_name, map_to ) -> StagedSeq: ... @t.overload def stage( obj: lxc.ChainSequence, vs, *, missing, seq_name, map_name, map_to ) -> StagedSeq: ... @t.overload def stage(obj: LigInp, vs, *, missing, seq_name, map_name, map_to) -> StagedLig: ...
[docs] def stage( obj: InpT, vs: abc.Sequence[InpV] | None, *, missing: bool = True, seq_name: str = DefaultConfig["mapnames"]["seq1"], map_name: str | None = None, map_to: str | None = None, ) -> StagedStr | StagedSeq | StagedLig: """ Stage object for calculation. If it's a chain sequence, will stage some sequence/mapping within it. If it's a chain structure, will stage the atom array. :param obj: A chain sequence or structure or structure-ligand pair to calculate the variables on. :param vs: A sequence of variables to calculate. :param missing: If ``True``, calculate only those assigned variables that are missing. :param seq_name: If `obj` is the chain sequence, the sequence name is used to obtain an actual sequence (``obj[seq_name]``). :param map_name: The mapping name to obtain the mapping keys. If ``None``, the resulting mapping will be ``None``. :param map_to: The mapping name to obtain the mapping values. See :func:`get_mapping` for details. :return: A tuple with four elements: 1. Original object. 2. Staged target passed to a variable for calculation. 3. A sequence of sequence or structural variables. 4. An optional mapping. """ target: lxc.ChainStructure | abc.Sequence | None def stage_vs_and_mapping(cs: lxc.ChainStructure | lxc.ChainSequence): return ( *_split_variables(vs or _get_vs(cs, missing)), get_mapping(cs, map_name, map_to), ) match obj: case lxc.ChainStructure(): target = find_structure(obj) _, _vs, _, m = stage_vs_and_mapping(obj) case lxc.ChainSequence(): target = obj[seq_name] _vs, _, _, m = stage_vs_and_mapping(obj) case (lxc.ChainStructure(), Ligand()): target = obj[1] _, _, _vs, m = stage_vs_and_mapping(obj[0]) case _: raise TypeError(f"Invalid object type {type(obj)}") return obj, target, _vs, m
[docs] def find_structure(s: lxc.ChainStructure) -> GenericStructure | None: """ Recursively search for structure up the ancestral tree. :param s: An arbitrary chain structure. :return: The first non-empty atom array up the parent chain. """ structure = s.structure parent = s.parent while structure is None and parent is not None: structure = parent.structure parent = parent.parent return None or structure
[docs] class Manager: """ Manager of variable calculations, handling assignment, aggregation, and, of course, the calculations themselves. """ __slots__ = ("verbose",)
[docs] def __init__(self, verbose: bool = False): """ :param verbose: Display progress bar. """ self.verbose = verbose
[docs] def assign(self, vs: abc.Sequence[InpV], chains: abc.Iterable[Inp]): """ Assign variables to chains sequences/structures. :param vs: A sequence of variables. :param chains: An iterable over chain sequences/structures. :return: No return. Will store assigned variables within the `variables` attribute. """ seq_vs, str_vs, lig_vs = _split_variables(vs) seqs, strs, ligs = _split_objects(chains) staged_objs: abc.Iterable[tuple[Inp, abc.Sequence[InpV]]] = chain( zip(seqs, repeat(seq_vs)), zip(strs, repeat(str_vs)), zip(ligs, repeat(lig_vs)), ) if self.verbose: staged_objs = tqdm(staged_objs, desc="Assigning variables") for o, _vs in staged_objs: o.variables.update({v: None for v in _vs})
[docs] def remove(self, chains: abc.Iterable[Inp], vs: abc.Sequence[InpV] | None = None): """ Remove variables from the `variables` container. :param chains: An iterable over chain sequences/structures. :param vs: A sequence of variables to remove. If not provided, will remove all variables. :return: No return. """ def _take_key(v): return vs is not None and v in vs or vs is None if self.verbose: chains = tqdm(chains, desc="Removing variables") for c in chains: keys = list(filter(_take_key, c.variables)) for k in keys: c.variables.pop(k)
[docs] def reset(self, chains: abc.Iterable[Inp], vs: abc.Sequence[InpV] | None = None): """ Similar to :meth:`remove`, but instead of deleting, resets variable calculation results. :param chains: An iterable over chain sequences/structures. :param vs: A sequence of variables to reset. If not provided, will reset all variables. :return: No return. """ def _take_key(v): return vs is not None and v in vs or vs is None if self.verbose: chains = tqdm(chains, desc="Resetting variable results") for c in chains: keys = list(filter(_take_key, c.variables)) for k in keys: c.variables[k] = None
[docs] def aggregate_from_chains(self, chains: abc.Iterable[Inp]) -> pd.DataFrame: """ Aggregate calculation results from the `variables` container of the provided chains. >>> from lXtractor.variables.sequential import SeqEl >>> s = lxc.ChainSequence.from_string('abcd', name='_seq') >>> manager = Manager() >>> manager.assign([SeqEl(1)], [s]) >>> df = manager.aggregate_from_chains([s]) >>> len(df) == 1 True >>> list(df.columns) ['VariableID', 'VariableResult', 'ObjectID', 'ObjectType'] :param chains: An iterable over chain sequences/structures. :return: A dataframe with `ObjectID`, `ObjectType`, and calculation results. """ def get_vs(obj: Inp) -> pd.DataFrame: vs_df = obj.variables.as_df() vs_df["ObjectID"] = obj.id vs_df["ObjectType"] = obj.__class__.__name__ return vs_df vs: abc.Iterable[pd.DataFrame] = filter( lambda x: len(x) > 0, map(get_vs, chains) ) vs = peekable(vs) if vs.peek(None) is None: return pd.DataFrame( columns=["VariableID", "VariableResult", "ObjectID", "ObjectType"] ) if self.verbose: vs = tqdm(vs, desc="Aggregating variables") return pd.concat(vs, ignore_index=True)
[docs] def aggregate_from_it( self, results: abc.Iterable[CalcRes], vs_to_cols: bool = True, replace_errors: bool = True, replace_errors_with: t.Any = np.NaN, num_vs: int | None = None, ) -> pd.DataFrame | dict[str, list]: """ Aggregate calculation results directly from :meth:`calculate` output. :param results: An iterable over calculation results. :param vs_to_cols: If ``True``, will attempt to use the wide format for the final results with variables as columns. Otherwise, will use the long format with fixed columns: "ObjectID", "VariableID", "VariableCalculated", and "VariableResult". Note that for the wide format to work, all objects and their variables must have unique IDs. :param replace_errors: When calculation failed, replace the calculation results with certain value. :param replace_errors_with: Use this value to replace erroneous calculation results. :param num_vs: The number of variables per object. Providing this will significantly increase the aggregation speed. :return: A table with results in long or short format. """ def substitute_error(res): if res[2]: return res return res[0], res[1], res[2], replace_errors_with def substitute_ids(res): if isinstance(res[0], tuple): obj_id = res[0][1].id else: obj_id = res[0].id return obj_id, res[1].id, res[2], res[3] def wrap_into_series(res_chunk): idx = chain(["ObjectID"], (res[1] for res in res_chunk)) vs = chain([res_chunk[0][0]], (res[-1] for res in res_chunk)) return pd.Series(vs, idx) if self.verbose: results = tqdm(results, "Accumulating calculations") if replace_errors: results = map(substitute_error, results) if vs_to_cols: results = map(substitute_ids, results) if num_vs is not None: chunks = chunked(results, num_vs) else: chunks = split_when(results, lambda x, y: x[0] != y[0]) wrapped = map(wrap_into_series, chunks) df = pd.DataFrame(wrapped) else: colnames = ["Object", "Variable", "VariableCalculated", "VariableResult"] df = pd.DataFrame(dict(zip(colnames, map(list, unzip(results))))) return df
[docs] def stage( self, chains: abc.Iterable[Inp], vs: abc.Sequence[InpV] | None, **kwargs ) -> abc.Generator[StagedSeq | StagedStr, None, None]: """ Stage objects for calculations (e.g., using :meth:`calculate`). It's a useful method if using a different calculation method and/or parallelization strategy within a `Calculator` class. .. seealso:: :func:`stage` :meth:`calculate` >>> from lXtractor.variables.sequential import SeqEl >>> s = lxc.ChainSequence.from_string('ABCD', name='_seq') >>> m = Manager() >>> staged = list(m.stage([s], [SeqEl(1)])) >>> len(staged) == 1 True >>> staged[0] (_seq|1-4, 'ABCD', [SeqEl(p=1,_rtype='str',seq_name='seq1')], None) :param chains: An iterable over chain sequences/structures. :param vs: A sequence of variables. If not provided, will use assigned variables (see :meth:`assign`). :param kwargs: Passed to :func:`stage`. :return: An iterable over tuples holding data for variables' calculation. """ yield from map(curry(stage)(vs=vs, **kwargs), chains)
[docs] def calculate( self, objs: abc.Iterable[Inp], vs: abc.Sequence[InpV] | None, calculator: AbstractCalculator, *, save: bool = False, **kwargs, ) -> abc.Generator[CalcRes, None, None]: """ Handles variable calculations: 1. Stage calculations (see :meth:`stage`). 2. Calculate variables using the provided calculator. 3. (Optional) save the calculation results to variables container. 4. Output (stream) calculation results. Note that 3 and 4 are done lazily as calculation results from the calculator become available. >>> from lXtractor.variables.calculator import GenericCalculator >>> from lXtractor.variables.sequential import SeqEl >>> s = lxc.ChainSequence.from_string('ABCD', name='_seq') >>> m = Manager() >>> c = GenericCalculator() >>> list(m.calculate([s],[SeqEl(1)],c)) [(_seq|1-4, SeqEl(p=1,_rtype='str',seq_name='seq1'), True, 'A')] >>> list(m.calculate([s],[SeqEl(5)],c))[0][-2:] (False, 'Missing index 4 in sequence') :param objs: An iterable over chain sequences/structures. :param vs: A sequence of variables. If not provided, will use assigned variables (see :meth:`assign`). :param calculator: A calculator object -- some callable with the right signature handling the calculations. :param save: Save calculation results to variables. Will overwrite any existing matching variables. :param kwargs: Passed to :meth:`stage`. :return: A generator over tuples: 1. Original object. 2. Variable. 3. Flag indicated whether the calculation was successful. 4. The calculation result (or the error message). """ objs, targets, variables, mappings = unzip(self.stage(objs, vs, **kwargs)) variables1, variables2 = tee(variables) calculated = calculator(targets, variables1, mappings) for obj, _vs, results in zip(objs, variables2, calculated, strict=True): for v, (is_calculated, res) in zip(_vs, results, strict=True): if save: if is_calculated: obj.variables[v] = res else: obj.variables[v] = None yield obj, v, is_calculated, res
if __name__ == "__main__": raise RuntimeError