"""Utility datastructures."""
# Core Library
import collections.abc
import logging
from copy import deepcopy
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
Union,
cast,
overload,
)
# First party
from mpu.datastructures.trie import Trie # noqa
logger = logging.getLogger(__name__)
T = TypeVar("T")
[docs]class EList(list, Generic[T]):
"""
Enhanced List.
This class supports every operation a normal list supports. Additionally,
you can call it with a list as an argument.
Examples
--------
>>> l = EList([2, 1, 0])
>>> l[2]
0
>>> l[[2, 0]]
[0, 2]
>>> l[l]
[0, 1, 2]
"""
def __init__(self, *args: Iterable[T]):
list.__init__(self, *args)
def __getitem__(self, key):
"""
Retrieve one or multiple elements.
Parameters
----------
key : int or List[int] or List[List[int]] or ...
Returns
-------
value : EList or element
"""
if isinstance(key, list):
return EList([self[index] for index in key])
else:
return list.__getitem__(self, key)
[docs] def remove_indices(self, indices: List[int]) -> "EList":
"""
Remove rows by which have the given indices.
Parameters
----------
indices : List[int]
Returns
-------
filtered_list : EList
"""
new_list = []
for index, element in enumerate(self):
if index not in indices:
new_list.append(element)
return EList(new_list)
[docs]def flatten(iterable: Iterable, string_flattening: bool = False) -> List:
"""
Flatten an given iterable of iterables into one list.
Parameters
----------
iterable : Iterable
string_flattening : bool
If this is False, then strings are NOT flattened
Returns
-------
flat_list : List
Examples
--------
>>> flatten([1, [2, [3]]])
[1, 2, 3]
>>> flatten(((1, 2), (3, 4), (5, 6)))
[1, 2, 3, 4, 5, 6]
>>> flatten(EList([EList([1, 2]), (3, [4, [[5]]])]))
[1, 2, 3, 4, 5]
"""
flat_list = []
for item in iterable:
is_iterable = isinstance(item, collections.abc.Iterable) and (
string_flattening or (not isinstance(item, str))
)
if is_iterable:
flat_list.extend(flatten(item))
else:
flat_list.append(item)
return flat_list
[docs]def dict_merge(
dict_left: Dict, dict_right: Dict, merge_method: str = "take_left_shallow"
) -> Dict:
r"""
Merge two dictionaries.
This method does NOT modify dict_left or dict_right!
Apply this method multiple times if the dictionary is nested.
Parameters
----------
dict_left : Dict
dict_right: Dict
merge_method : {'take_left_shallow', 'take_left_deep', \
'take_right_shallow', 'take_right_deep', \
'sum'}
* take_left_shallow: Use both dictinaries. If both have the same key,
take the value of dict_left
* take_left_deep : If both dictionaries have the same key and the value
is a dict for both again, then merge those sub-dictionaries
* take_right_shallow : See take_left_shallow
* take_right_deep : See take_left_deep
* sum : sum up both dictionaries. If one does not have a value for a
key of the other, assume the missing value to be zero.
Returns
-------
merged_dict : Dict
Examples
--------
>>> dict_merge({'a': 1, 'b': 2}, {'c': 3}) == {'a': 1, 'b': 2, 'c': 3}
True
>>> out = dict_merge({'a': {'A': 1}},
... {'a': {'A': 2, 'B': 3}}, 'take_left_deep')
>>> expected = {'a': {'A': 1, 'B': 3}}
>>> out == expected
True
>>> out = dict_merge({'a': {'A': 1}},
... {'a': {'A': 2, 'B': 3}}, 'take_left_shallow')
>>> expected = {'a': {'A': 1}}
>>> out == expected
True
>>> out = dict_merge({'a': 1, 'b': {'c': 2}},
... {'b': {'c': 3, 'd': 4}},
... 'sum')
>>> expected = {'a': 1, 'b': {'c': 5, 'd': 4}}
>>> out == expected
True
"""
if merge_method in ["take_right_shallow", "take_right_deep"]:
return _dict_merge_right(dict_left, dict_right, merge_method)
elif merge_method == "take_left_shallow":
return dict_merge(dict_right, dict_left, "take_right_shallow")
elif merge_method == "take_left_deep":
return dict_merge(dict_right, dict_left, "take_right_deep")
elif merge_method == "sum":
new_dict = deepcopy(dict_left)
for key, value in dict_right.items():
if key not in new_dict:
new_dict[key] = value
else:
recurse = isinstance(value, dict)
if recurse:
new_dict[key] = dict_merge(
dict_left[key], dict_right[key], merge_method="sum"
)
else:
new_dict[key] = dict_left[key] + dict_right[key]
return new_dict
else:
raise NotImplementedError(f"merge_method='{merge_method}' is not known.")
def _dict_merge_right(dict_left: Dict, dict_right: Dict, merge_method: str) -> Dict:
"""See documentation of mpu.datastructures.dict_merge."""
new_dict = deepcopy(dict_left)
for key, value in dict_right.items():
if key not in new_dict:
new_dict[key] = deepcopy(value)
else:
recurse = (
merge_method == "take_right_deep"
and isinstance(dict_left[key], dict)
and isinstance(dict_right[key], dict)
)
if recurse:
new_dict[key] = dict_merge(
dict_left[key],
dict_right[key],
merge_method="take_right_deep",
)
else:
new_dict[key] = value
return new_dict
[docs]def set_dict_value(dictionary: Dict, keys: List[Any], value: Any) -> Dict:
"""
Set a value in a (nested) dictionary by defining a list of keys.
.. note:: Side-effects
This function does not make a copy of dictionary, but directly
edits it.
Parameters
----------
dictionary : Dict
keys : List[Any]
value : Any
Returns
-------
dictionary : dict
Examples
--------
>>> d = {'a': {'b': {'c': 'x', 'f': 'g'}, 'd': 'e'}}
>>> expected = {'a': {'b': {'c': 'foobar', 'f': 'g'}, 'd': 'e'}}
>>> set_dict_value(d, ['a', 'b', 'c'], 'foobar') == expected
True
"""
orig = dictionary
for key in keys[:-1]:
dictionary = dictionary.setdefault(key, {})
dictionary[keys[-1]] = value
return orig
[docs]def does_keychain_exist(dict_: Dict, list_: List) -> bool:
"""
Check if a sequence of keys exist in a nested dictionary.
Parameters
----------
dict_ : Dict[str/int/tuple, Any]
list_ : List[str/int/tuple]
Returns
-------
keychain_exists : bool
Examples
--------
>>> d = {'a': {'b': {'c': 'd'}}}
>>> l_exists = ['a', 'b']
>>> does_keychain_exist(d, l_exists)
True
>>> l_no_existent = ['a', 'c']
>>> does_keychain_exist(d, l_no_existent)
False
"""
for key in list_:
if key not in dict_:
return False
dict_ = dict_[key]
return True
[docs]class IntervalLike:
"""
Anything like an interval or a union of an interval.
As mpu supports Python 2.7 until 2020 and does not want to include extra
dependencies, ABC cannot be used.
"""
[docs] def is_empty(self) -> bool:
"""Return if the IntervalLike is empty."""
raise NotImplementedError
[docs] def issubset(self, other: "IntervalLike") -> bool:
"""
Check if the interval "self" is completely inside of other.
Parameters
----------
other : IntervalLike
Returns
-------
is_inside : bool
"""
[docs] def union(self, other: "IntervalLike") -> "IntervalLike":
"""
Combine two Intervals.
Parameters
----------
other : IntervalLike
Returns
-------
interval_union : IntervalLike
"""
raise NotImplementedError
[docs] def intersection(self, other: "IntervalLike") -> "IntervalLike":
"""
Intersect two IntervalLike objects.
Parameters
----------
other : IntervalLike
Returns
-------
intersected : IntervalLike
"""
[docs]class Interval(IntervalLike):
"""
Representation of an interval.
The empty interval is represented as left=None, right=None.
Left and right have to be comparable.
Typically, it would be numbers or dates.
Parameters
----------
left : Optional[Any]
right : Optional[Any]
"""
def __init__(self, left: Optional[Any] = None, right: Optional[Any] = None):
if int(left is None) + int(right is None) not in [0, 2]:
raise RuntimeError("Either left and right are None, or neither.")
elif (left is not None) and (left > right):
raise RuntimeError("left may not be bigger than right")
self.left = left
self.right = right
[docs] def is_empty(self) -> bool:
"""Return if the interval is empty."""
return self.left is None
[docs] def union(self, other: IntervalLike) -> IntervalLike:
"""
Combine two Intervals.
Parameters
----------
other : IntervalLike
Returns
-------
interval_union : IntervalLike
"""
# Capture special cases
if self.is_empty():
return other
elif other.is_empty():
return self
if isinstance(other, Interval):
other.left = cast(Any, other.left) # Tell mypy it's not None
other.right = cast(Any, other.right) # Tell mypy it's not None
# Standardize - after this step, the other.left is left of self.left
if other.left > self.left:
other, self = self, other
self.left = cast(Any, self.left) # Tell mypy it's not None
self.right = cast(Any, self.right) # Tell mypy it's not None
# Go through all cases
if other.right < self.left:
# Completely disjoint
return IntervalUnion([self, other])
elif other.right == self.left:
# next to each other
return Interval(other.left, self.right)
elif other.right <= self.right:
return Interval(other.left, self.right)
elif other.right > self.right:
# other is a superset of self
return other
else:
# This should never happen
raise NotImplementedError(f"Can't merge {self} and {other}")
elif isinstance(other, IntervalUnion):
union = cast(
Union[Interval, IntervalUnion],
IntervalUnion([[self.left, self.right]] + other.intervals)._simplify(),
)
return union
else:
raise NotImplementedError(f"Can't merge {self} and {other}")
@overload # type: ignore[override]
def intersection(self, other: "Interval") -> "Interval":
...
@overload
def intersection(self, other: "IntervalUnion") -> IntervalLike:
...
[docs] def intersection(self, other: IntervalLike) -> IntervalLike:
"""
Intersect two IntervalLike objects.
Parameters
----------
other : IntervalLike
Returns
-------
intersected : IntervalLike
"""
# Any intersection with an empty interval is empty
if self.is_empty() or other.is_empty():
return Interval(None, None)
if isinstance(other, IntervalUnion):
return other.intersection(self)
other = cast(Interval, other)
other.left = cast(Any, other.left) # Tell mypy it's not None
other.right = cast(Any, other.right) # Tell mypy it's not None
# Standardize - after this step, the other.left is left of self.left
if other.left > self.left:
other, self = self, other
self.left = cast(Any, self.left) # Tell mypy it's not None
self.right = cast(Any, self.right) # Tell mypy it's not None
# Go through all cases
if other.right < self.left:
# Completely disjoint
return Interval(None, None)
elif other.right == self.left:
# next to each other
return Interval(other.right, other.right)
elif other.right <= self.right:
return Interval(self.left, other.right)
elif other.right > self.right:
# other is a superset of self
return self
else:
# This should never happen
raise NotImplementedError(f"Can't intersect {self} and {other}")
def __repr__(self):
"""Get an unambiguous representation."""
if self.is_empty():
return "Interval()"
else:
return f"Interval({self.left}, {self.right})"
def __str__(self):
"""Get an human-readable representation."""
if self.is_empty():
return "[]"
else:
return f"[{self.left}, {self.right}]"
__and__ = intersection
__or__ = union
def __eq__(self, other) -> bool:
"""Check if other is equal to this object."""
if isinstance(other, (Interval, IntervalUnion)):
return self.issubset(other) and other.issubset(self)
else:
return False
[docs] def issubset(self, other: IntervalLike) -> bool:
"""
Check if the interval "self" is completely inside of other.
Parameters
----------
other : IntervalLike
Returns
-------
is_inside : bool
"""
if self.is_empty():
return True
elif other.is_empty():
# This could only be true, if self was empty as well
# The order of those if / elif blocks matters here!
return False
elif isinstance(other, Interval):
self.left = cast(Any, self.left) # Tell mypy it's not None
self.right = cast(Any, self.right) # Tell mypy it's not None
return other.left <= self.left <= self.right <= other.right
elif isinstance(other, IntervalUnion):
return any(self.issubset(interval) for interval in other.intervals)
else:
raise RuntimeError(
"issubset is only defined on Interval and "
"IntervalUnion, "
f"but {type(other)} was given"
)
[docs]class IntervalUnion(IntervalLike):
"""A union of Intervals."""
def __init__(self, intervals):
if not isinstance(intervals, list):
raise TypeError(f"'{type(intervals)}' is not a list")
self.intervals = []
for interval in intervals:
if isinstance(interval, Interval):
self.intervals.append(interval)
else:
if len(interval) == 0:
self.intervals.append(Interval())
else:
self.intervals.append(Interval(interval[0], interval[1]))
[docs] def is_empty(self) -> bool:
"""Return if the IntervalUnion is empty."""
return all(interval.is_empty() for interval in self.intervals)
[docs] def issubset(self, other: IntervalLike) -> bool:
"""
Check if this IntervalUnion is completely inside of `other`.
Parameters
----------
other : Interval or IntervalUnion
Returns
-------
is_inside : bool
"""
self._simplify()
if isinstance(other, (Interval, IntervalUnion)):
# If every interval of this is inside the interval `other`,
# then this IntervalUnion is completely in `other`.
return all(interval.issubset(other) for interval in self.intervals)
else:
raise RuntimeError(
"issubset is only defined on Interval and IntervalUnion, "
f"but {type(other)} was given"
)
def _get_keypoints(self) -> List[Any]:
"""
Get all points which are relevant for this IntervalUnion.
Returns
-------
keypoints : List[Any]
"""
keypoints = []
for interval in self.intervals:
keypoints.append(interval.left)
keypoints.append(interval.right)
return keypoints
def _simplify(self) -> Optional[IntervalLike]:
"""
Simplify the representation of the components.
This means:
1. Making sure that the minimum number of components is used
2. The intervals are in order (by left element)
Returns
-------
simplified_interval_union : IntervalUnion
Please note that this is guaranteed to stay an IntervalUnion, even
if it collapses to a single interval.
"""
if len(self.intervals) == 0:
return None
self.intervals = sorted(self.intervals, key=lambda n: n.left)
simpler_intervals = [self.intervals[0]]
for interval in self.intervals[1:]:
combined = simpler_intervals[-1].union(interval)
if isinstance(combined, Interval):
simpler_intervals[-1] = combined
else:
simpler_intervals.append(interval)
self.intervals = simpler_intervals
return self
[docs] def union(self, other: IntervalLike) -> IntervalLike:
"""
Return the union between this IntervalUnion and another object.
Parameters
----------
other : Interval or IntervalUnion
Returns
-------
union : Interval or IntervalUnion
"""
if isinstance(other, Interval):
self.intervals.append(other)
elif isinstance(other, IntervalUnion):
self.intervals += other.intervals
else:
raise RuntimeError(f"Union with type={type(other)} not supported")
self._simplify()
return self
[docs] def intersection(self, other: IntervalLike) -> IntervalLike:
"""
Return the intersection between this IntervalUnion and another object.
This changes the object itself!
Parameters
----------
other : Interval or IntervalUnion
Returns
-------
intersection : Interval or IntervalUnion
"""
if isinstance(other, Interval):
self.intervals = [
interval.intersection(other) for interval in self.intervals
]
self._simplify()
return self
elif isinstance(other, IntervalUnion):
keypoints_self = sorted(self._get_keypoints())
keypoints_other = sorted(other._get_keypoints())
keypoints = sorted(keypoints_self + keypoints_other)
new_intervals = []
for i in range(len(keypoints) - 1):
left, right = keypoints[i], keypoints[i + 1]
interval = Interval(left, right)
if interval.issubset(self) and interval.issubset(other):
new_intervals.append(interval)
return IntervalUnion(new_intervals)
else:
raise RuntimeError(f"Intersection with type={type(other)} not supported")
def __repr__(self) -> str:
"""Get an unambiguous representation."""
return "IntervalUnion(" + str(self.intervals) + ")"
def __str__(self) -> str:
return str(self.intervals)
def __eq__(self, other: Any) -> bool:
"""Check if other is equal to this object."""
if isinstance(other, (IntervalUnion, Interval)):
return self.issubset(other) and other.issubset(self)
else:
return False
__and__ = intersection
__or__ = union