"""
The module defines the :class:`ChainList` - a list of `Chain*`-type objects that
behaves like a regular list but has additional bells and whistles tailored
towards `Chain*` data structures.
"""
from __future__ import annotations
import operator as op
import typing as t
from collections import abc
from functools import partial
from itertools import chain, zip_longest, tee, groupby
import pandas as pd
from more_itertools import nth, peekable, unique_everseen
import lXtractor.core.segment as lxs
from lXtractor.chain.base import is_chain_type_iterable, is_chain_type
from lXtractor.core.base import Ord, ApplyT
from lXtractor.core.config import DefaultConfig
from lXtractor.core.exceptions import MissingData
from lXtractor.util import apply
if t.TYPE_CHECKING:
from lXtractor.chain import ChainSequence, ChainStructure, Chain
CT = t.TypeVar("CT", ChainStructure, ChainSequence, Chain)
CS = t.TypeVar("CS", ChainStructure, ChainSequence)
CTU: t.TypeAlias = ChainSequence | ChainStructure | Chain
else:
CT = t.TypeVar("CT")
T = t.TypeVar("T")
__all__ = ("ChainList", "add_category")
[docs]
def add_category(c: t.Any, cat: str):
"""
:param c: A Chain*-type object.
:param cat: Category name.
:return:
"""
if hasattr(c, "meta"):
meta = c.meta
else:
raise TypeError(f"Failed to find .meta attr in {c}")
field = DefaultConfig["metadata"]["category"]
if field not in meta:
meta[field] = cat
else:
existing = meta[field].split(",")
if cat not in existing:
meta[field] += f",{cat}"
def _check_chain_types(objs: abc.Sequence[T]):
if not is_chain_type_iterable(objs):
raise TypeError("A sequence of objects is not a Chain*-type sequence")
[docs]
class ChainList(abc.MutableSequence[CT]):
# TODO: consider implementing pattern-based search over whole sequence
# or a sequence region.
# For the above, consider filtering to hits.
# It may be beneficial to implement this functionality for ChainSequence.
"""
A mutable single-type collection holding either :class:`Chain`'s,
or :class:`ChainSequence`'s, or :class:`ChainStructure`'s.
Object's funtionality relies on this type purity.
Adding of / contatenating with objects of a different type shall
raise an error.
It behaves like a regular list with additional functionality.
>>> from lXtractor.chain import ChainSequence
>>> s = ChainSequence.from_string('SEQUENCE', name='S')
>>> x = ChainSequence.from_string('XXX', name='X')
>>> x.meta['category'] = 'x'
>>> cl = ChainList([s, s, x])
>>> cl
[S|1-8, S|1-8, X|1-3]
>>> cl[0]
S|1-8
>>> cl['S']
[S|1-8, S|1-8]
>>> cl[:2]
[S|1-8, S|1-8]
>>> cl['1-3']
[X|1-3]
Adding/appending/removing objects of a similar type is easy and works
similar to a regular list.
>>> cl += [s]
>>> assert len(cl) == 4
>>> cl.remove(s)
>>> assert len(cl) == 3
Categories can be accessed as attributes or using ``[]`` syntax
(similar to the `Pandas.DataFrame` columns).
>>> cl.x
[X|1-3]
>>> cl['x']
[X|1-3]
While creating a chain list, using a `groups` parameter will assign
categories to sequences.
Note that such operations return a new :class:`ChainList` object.
>>> cl = ChainList([s, x], categories=['S', ['X1', 'X2']])
>>> cl.S
[S|1-8]
>>> cl.X2
[X|1-3]
>>> cl['X1']
[X|1-3]
"""
__slots__ = ("_chains",)
[docs]
def __init__(
self,
chains: abc.Iterable[CT],
categories: abc.Iterable[str | abc.Iterable[str]] | None = None,
):
"""
:param chains: An iterable over ``Chain*``-type objects.
:param categories: An optional list of categories.
If provided, they will be assigned to inputs' `meta` attributes.
"""
if not isinstance(chains, list):
chains = list(chains)
_check_chain_types(chains)
if categories is not None:
for c, cat in zip(chains, categories, strict=True):
if isinstance(cat, str):
add_category(c, cat)
else:
for _cat in cat:
add_category(c, _cat)
#: Protected container. One should NOT change it directly.
self._chains: list[CT] = chains
@property
def categories(self) -> abc.Set[str]:
"""
:return: A set of categories inferred from `meta` of encompassed
objects.
"""
return set(chain.from_iterable(map(lambda c: c.categories, self)))
def __len__(self) -> int:
return len(self._chains)
def __eq__(self, other: t.Any) -> bool:
if not isinstance(other, ChainList):
return False
if isinstance(other, abc.Sized):
if len(self) != len(other):
return False
return all(o1 == o2 for o1, o2 in zip(self, other))
return False
@t.overload
def __getitem__(self, index: int) -> CT:
...
@t.overload
def __getitem__(self, index: slice) -> ChainList[CT]:
...
@t.overload
def __getitem__(self, index: str) -> ChainList[CT]:
...
def __getitem__(self, index: t.SupportsIndex | slice | str) -> CT | ChainList[CT]:
match index:
case int():
return self._chains.__getitem__(index)
case slice():
return ChainList(self._chains[index])
case str():
if index in self.categories:
return self.filter_category(index)
return self.filter(lambda x: index in x.id) # type: ignore
case _:
raise TypeError(f"Incorrect index type {type(index)}")
def __getattr__(self, name: str):
"""
See the example in pandas:
https://github.com/pandas-dev/pandas/blob/
61e0db25f5982063ba7bab062074d55d5e549586/pandas/core/generic.py#L5811
"""
if name == "__setstate__":
raise AttributeError(name)
if name.startswith("__"):
object.__getattribute__(self, name)
if name in self.categories:
return self.filter(lambda c: any(cat == name for cat in c.categories))
raise AttributeError
@t.overload
def __setitem__(self, index: int, value: CT):
...
@t.overload
def __setitem__(self, index: slice, value: abc.Iterable[CT]):
...
def __setitem__(self, index: t.SupportsIndex | slice, value: CT | abc.Iterable[CT]):
if len(self) == 0:
raise MissingData("Not possible to use __setitem__ when ChainList is empty")
self_type = type(self._chains[0])
if is_chain_type(value):
other_type = type(value)
else:
if is_chain_type_iterable(value):
# Doesn't accept unions for some reason
value = peekable(value) # type: ignore
other_type = type(value.peek())
else:
raise TypeError("Incompatible value type")
if self_type is other_type or id(self_type) == id(other_type):
self._chains[index] = value # type: ignore # overloading failure
else:
raise TypeError(
f"Value type {other_type} conflicts with existing "
f"items type {self_type}"
)
def __delitem__(self, index: t.SupportsIndex | int | slice):
self._chains.__delitem__(index)
def __contains__(self, item: object) -> bool:
if isinstance(item, str):
for c in self:
if c.id == item:
return True
return False
return item in self._chains
def __add__(self, other: ChainList | abc.Iterable):
match other:
case ChainList():
if len(self._chains) > 0:
_check_chain_types([self._chains[0], *other])
return ChainList(self._chains + other._chains)
case abc.Iterable():
if len(self._chains) > 0:
other = list(other)
_check_chain_types([self._chains[0], *other])
return ChainList(self._chains + list(other))
case _:
raise TypeError(f"Unsupported type {type(other)}")
def __repr__(self) -> str:
return self._chains.__repr__()
def __iter__(self) -> abc.Iterator[CT]:
return iter(self._chains)
[docs]
def index(self, value: CT, start: int = 0, stop: int | None = None) -> int:
stop = stop or len(self)
return self._chains.index(value, start, stop)
[docs]
def insert(self, index: int, value: CT):
if len(self) > 0:
_check_chain_types([self[0], value])
self._chains.insert(index, value)
[docs]
def iter_children(self) -> abc.Generator[ChainList[CT], None, None]:
"""
Simultaneously iterate over topological levels of children.
>>> from lXtractor.chain import ChainSequence
>>> s = ChainSequence.from_string('ABCDE', name='A')
>>> child1 = s.spawn_child(1, 4)
>>> child2 = child1.spawn_child(2, 3)
>>> x = ChainSequence.from_string('XXXX', name='X')
>>> child3 = x.spawn_child(1, 3)
>>> cl = ChainList([s, x])
>>> list(cl.iter_children())
[[A|1-4<-(A|1-5), X|1-3<-(X|1-4)], [A|2-3<-(A|1-4<-(A|1-5))]]
:return: An iterator over chain lists of children levels.
"""
# Mypy thinks zip_longest produces tuples of `object` types
# probably due to "*"
yield from map(
lambda xs: ChainList(chain.from_iterable(xs)), # type: ignore
zip_longest(*map(lambda c: c.iter_children(), self._chains), fillvalue=[]),
)
[docs]
def iter_ids(self) -> abc.Iterator[str]:
"""
Iterate over ids of this chain list.
:return: An iterator over chain ids.
"""
for c in self._chains:
yield c.id
[docs]
def get_level(self, n: int) -> ChainList[CT]:
"""
Get a specific level of a hierarchical tree starting from this list::
l0: this list
l1: children of each child of each object in l0
l2: children of each child of each object in l1
...
:param n: The level index (0 indicates this list).
Other levels are obtained via :meth:`iter_children`.
:return: A chain list of object corresponding to a specific topological
level of a child tree.
"""
if n == 0:
return self
return nth(self.iter_children(), n - 1, default=ChainList([]))
[docs]
def collapse_children(self) -> ChainList[CT]:
"""
Collapse all children of each object in this list into a single
chain list.
>>> from lXtractor.chain import ChainSequence
>>> s = ChainSequence.from_string('ABCDE', name='A')
>>> child1 = s.spawn_child(1, 4)
>>> child2 = child1.spawn_child(2, 3)
>>> cl = ChainList([s]).collapse_children()
>>> assert isinstance(cl, ChainList)
>>> cl
[A|1-4<-(A|1-5), A|2-3<-(A|1-4<-(A|1-5))]
:return: A chain list of all children.
"""
return ChainList(chain.from_iterable(self.iter_children()))
[docs]
def collapse(self) -> ChainList[CT]:
"""
Collapse all objects and their children within this list into a new
chain list. This is a shortcut for
``chain_list + chain_list.collapse_children()``.
:return: Collapsed list.
"""
return self + self.collapse_children()
[docs]
def iter_sequences(self) -> abc.Generator[ChainSequence, None, None]:
"""
:return: An iterator over :class:`ChainSequence`'s.
"""
# mypy doesn't know the type is known at runtime
from lXtractor import chain as lxc
if len(self) > 0:
x = self[0]
if isinstance(x, (lxc.Chain, lxc.ChainStructure)):
yield from (c.seq for c in self._chains)
else:
yield from iter(self._chains)
else:
yield from iter([])
[docs]
def iter_structures(self) -> abc.Generator[ChainStructure, None, None]:
"""
:return: An generator over :class:`ChainStructure`'s.
"""
# mypy doesn't know the type is known at runtime
from lXtractor import chain as lxc
if len(self) > 0:
x = self[0]
if isinstance(x, lxc.Chain):
yield from chain.from_iterable(c.structures for c in self._chains)
elif isinstance(x, lxc.ChainStructure):
yield from iter(self._chains)
else:
yield from iter([])
else:
yield from iter([])
[docs]
def iter_structure_sequences(self) -> abc.Generator[ChainSequence, None, None]:
"""
:return: Iterate over :attr:`ChainStructure._seq` attributes.
"""
yield from (s.seq for s in self.iter_structures())
@property
def sequences(self) -> ChainList[ChainSequence]:
"""
:return: Get all :attr:`lXtractor.core.chain.Chain._seq` or
`lXtractor.core.chain.sequence.ChainSequence` objects within this
chain list.
"""
return ChainList(self.iter_sequences())
@property
def structures(self) -> ChainList[ChainStructure]:
return ChainList(self.iter_structures())
@property
def structure_sequences(self) -> ChainList[ChainSequence]:
return ChainList(self.iter_structure_sequences())
@property
def ids(self) -> list[str]:
"""
:return: A list of ids for all chains in this list.
"""
return list(self.iter_ids())
@staticmethod
def _get_seg_matcher(
s: str,
) -> abc.Callable[[ChainSequence, lxs.Segment, t.Optional[str]], bool]:
def matcher(
seq: ChainSequence, seg: lxs.Segment, map_name: t.Optional[str] = None
) -> bool:
if map_name is not None:
# Get elements in the _seq whose mapped sequence matches
# seg boundaries
start_item = seq.get_closest(map_name, seg.start)
end_item = seq.get_closest(map_name, seg.end, reverse=True)
if start_item is None or end_item is None:
return False
start = start_item._asdict()[map_name]
end = end_item._asdict()[map_name]
# If not such elements -> no match
# Create a new temporary segment using the mapped boundaries
_seq: lxs.Segment | ChainSequence = lxs.Segment(start, end)
else:
_seq = seq
match s:
case "overlap":
return _seq.overlaps(seg)
case "bounded":
return _seq.bounded_by(seg)
case "bounding":
return _seq.bounds(seg)
case _:
raise ValueError(f"Invalid matching mode {s}")
return matcher
@staticmethod
def _get_pos_matcher(
ps: abc.Iterable[Ord],
) -> abc.Callable[[ChainSequence, t.Optional[str]], bool]:
def matcher(seq: ChainSequence, map_name: t.Optional[str] = None) -> bool:
obj: abc.Sequence | ChainSequence = seq
if map_name:
obj = seq[map_name]
return all(p in obj for p in ps)
return matcher
def _filter_seqs(
self,
seqs: abc.Iterable[ChainSequence],
match_type: str,
s: lxs.Segment | abc.Iterable[Ord],
map_name: t.Optional[str],
) -> abc.Iterator[bool]:
if isinstance(s, lxs.Segment):
match_fn = partial(
self._get_seg_matcher(match_type), seg=s, map_name=map_name
)
else:
match_fn = partial(self._get_pos_matcher(s), map_name=map_name)
return map(match_fn, seqs)
def _filter_str(
self,
structures: abc.Iterable[ChainStructure],
match_type: str,
s: lxs.Segment | abc.Collection[Ord],
map_name: t.Optional[str],
) -> abc.Iterator[bool]:
return self._filter_seqs(
map(lambda x: x._seq, structures), match_type, s, map_name
)
[docs]
def filter_pos(
self,
s: lxs.Segment | abc.Collection[Ord],
*,
match_type: str = "overlap",
map_name: str | None = None,
) -> ChainList[CS]:
"""
Filter to objects encompassing certain consecutive position regions
or arbitrary positions' collections.
For :class:`Chain` and :class:`ChainStructure`, the filtering is over
`_seq` attributes.
:param s: What to search for:
#. ``s=Segment(start, end)`` to find all objects encompassing
certain region.
#. ``[pos1, posX, posN]`` to find all objects encompassing the
specified positions.
:param match_type: If `s` is `Segment`, this value determines the
acceptable relationships between `s` and each
:class:`ChainSequence`:
#. "overlap" -- it's enough to overlap with `s`.
#. "bounding" -- object is accepted if it bounds `s`.
#. "bounded" -- object is accepted if it's bounded by `s`.
:param map_name: Use this map within to map positions of `s`.
For instance, to each for all elements encompassing region 1-5 of
a canonical sequence, one would use
.. code-block:: python
chain_list.filter_pos(
s=Segment(1, 5), match_type="bounding",
map_name="map_canonical"
)
:return: A list of hits of the same type.
"""
from lXtractor import chain as lxc
if len(self) > 0:
x = self[0]
if isinstance(x, lxc.Chain):
objs, fn = self.iter_sequences(), self._filter_seqs
elif isinstance(x, lxc.ChainSequence):
objs, fn = iter(self), self._filter_seqs
else:
objs, fn = iter(self), self._filter_str
else:
return ChainList([])
objs1, objs2 = tee(objs)
mask = fn(objs1, match_type, s, map_name)
return ChainList(
map(op.itemgetter(1), filter(lambda x: x[0], zip(mask, objs2)))
)
[docs]
def filter(self, pred: abc.Callable[[CT], bool]) -> ChainList[CT]:
"""
>>> from lXtractor.chain import ChainSequence
>>> cl = ChainList(
... [ChainSequence.from_string('AAAX', name='A'),
... ChainSequence.from_string('XXX', name='X')]
... )
>>> cl.filter(lambda c: c.seq1[0] == 'A')
[A|1-4]
:param pred: Predicate callable for filtering.
:return: A filtered chain list (new object).
"""
return ChainList(filter(pred, self))
[docs]
def filter_category(self, name: str) -> ChainList:
"""
:param name: Category name.
:return: Filtered objects having this category within their
``meta["category"]``.
"""
return self.filter(lambda c: any(cat == name for cat in c.categories))
[docs]
def apply(
self,
fn: ApplyT,
verbose: bool = False,
desc: str = "Applying to objects",
num_proc: int = 1,
) -> ChainList[CT]:
"""
Apply a function to each object and return a new chain list of results.
:param fn: A callable to apply.
:param verbose: Display progress bar.
:param desc: Progress bar description.
:param num_proc: The number of CPUs to use. ``num_proc <= 1`` indicates
sequential processing.
:return: A new chain list with application results.
"""
return ChainList(apply(fn, self._chains, verbose, desc, num_proc))
[docs]
def drop_duplicates(
self, key: abc.Callable[[CT], t.Hashable] | None = lambda x: x.id
) -> t.Self:
"""
:param key: A callable accepting the single element and returning some
hashable object associated with that element.
:return: A new list with unique elements as judged by the `key`.
"""
return self.__class__(unique_everseen(self._chains, key=key))
[docs]
def summary(self, **kwargs) -> pd.DataFrame:
return pd.concat([c.summary(**kwargs) for c in self])
[docs]
def groupby(self, key: abc.Callable[[CT], T]) -> abc.Iterator[tuple[T, t.Self]]:
"""
Group sequences in this list by a given key.
:param key: Some callable accepting a single chain and returning
a grouper value.
:return: An iterator over pairs ``(group, chains)``, where ``chains``
is a chain list of chains that belong to ``group``.
"""
for g, gg in groupby(self._chains, key):
yield g, self.__class__(gg)
[docs]
def sort(self, key: abc.Callable[[CT], T] = lambda x: x.id) -> ChainList[CT]:
return self.__class__(sorted(self._chains, key=key))
def _wrap_children(children: abc.Iterable[CT] | None) -> ChainList[CT]:
if children:
if not isinstance(children, ChainList):
assert is_chain_type_iterable(children)
return ChainList(children)
return children
return ChainList([])
if __name__ == "__main__":
raise RuntimeError