"""
Swiss Army knife
"""
import asyncio
import collections
import enum
import functools
import hashlib
import importlib
import inspect
import itertools
import os
import re
import types as _types
import logging # isort:skip
_log = logging.getLogger(__name__)
[docs]
@functools.cache
def is_running_in_development_environment():
"""
Whether we are running in a development environment or in production
This is determined by looking for a ``UPSIES_DEV`` variable and interpreting its value as a
:class:`~.types.Bool`. The default is ``False``.
"""
from . import types # noqa: F811 [*] Redefinition of unused `types`
upsies_dev = os.environ.get('UPSIES_DEV', None)
try:
return bool(types.Bool(upsies_dev))
except ValueError:
return False
[docs]
def os_family():
"""
Return "windows" or "unix"
"""
return 'windows' if os.name == 'nt' else 'unix'
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py
[docs]
class LazyModule(_types.ModuleType):
"""
Lazily import module to decrease execution time
:param str module: Name of the module
:param mapping namespace: Usually the return value of `globals()`
:param str name: Name of the module in `namespace`; defaults to `module`
"""
def __init__(self, module, namespace, name=None):
self._module = module
self._namespace = namespace
self._name = name or module
super().__init__(module)
def _load(self):
# Import the target module and insert it into the parent's namespace
module = importlib.import_module(self.__name__)
self._namespace[self._name] = module
# Update this object's dict so that if someone keeps a reference to the
# LazyLoader, lookups are efficient (__getattr__ is only called on
# lookups that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item):
module = self._load()
return getattr(module, item)
def __dir__(self):
module = self._load()
return dir(module)
[docs]
def submodules(package):
"""
Return list of submodules and subpackages in `package`
:param str package: Fully qualified name of parent package,
e.g. "upsies.imagehosts"
"""
# Get absolute path to parent directory of top-level package
own_path = os.path.dirname(__file__)
rel_path = __package__.replace('.', os.sep)
assert own_path.endswith(rel_path), f'{own_path!r}.endswith({rel_path!r})'
project_path = own_path[:-len(rel_path)]
# Add relative path within project to given package
package_path = os.path.join(project_path, package.replace('.', os.sep))
# Find and import public submodules
submods = []
for name in os.listdir(package_path):
if not name.startswith('_'):
name = name.removesuffix('.py')
if '.' not in name:
submods.append(
importlib.import_module(name=f'.{name}', package=package)
)
return submods
[docs]
def subclasses(basecls, modules):
"""
Find subclasses in modules
:param type basecls: Class that all returned classes are a subclass of
:param modules: Modules to search
:type modules: list of module objects
"""
subclses = []
for mod in modules:
for _name, member in inspect.getmembers(mod):
if (
member is not basecls and
isinstance(member, type) and
issubclass(member, basecls)
):
subclses.append(member)
return tuple(subclses)
[docs]
def closest_number(n, ns, max=None, default=0):
"""
Return the number from `ns` that is closest to `n`
:param n: Given number
:param ns: Sequence of allowed numbers
:param max: Remove any item from `ns` that is larger than `max`
:param default: Return value in case `ns` is empty
"""
if max is not None:
ns_ = tuple(n_ for n_ in ns if n_ <= max)
if not ns_:
raise ValueError(f'No number equal to or below {max}: {ns}')
else:
ns_ = ns
return min(ns_, key=lambda x: abs(x - n), default=default)
[docs]
class MonitoredList(collections.abc.MutableSequence):
"""
:class:`list` that calls `callback` after every change
:param callback: Callable that gets the instance as a positional argument
"""
def __init__(self, *args, callback, **kwargs):
self._list = list(*args, **kwargs)
self._callback = callback
def __getitem__(self, index):
return self._list[index]
def __setitem__(self, index, value):
self._list[index] = value
self._callback(self)
def __delitem__(self, index):
del self._list[index]
self._callback(self)
[docs]
def insert(self, index, value):
self._list.insert(index, value)
self._callback(self)
def __len__(self):
return len(self._list)
def __eq__(self, other):
return self._list == other
def __repr__(self):
return f'{type(self).__name__}({self._list!r}, callback={self._callback!r})'
[docs]
def is_sequence(obj):
"""Return whether `obj` is a sequence and not a string"""
return (
isinstance(obj, collections.abc.Sequence)
and not isinstance(obj, str)
)
[docs]
def merge_dicts(a, b, path=()):
"""
Merge nested dictionaries `a` and `b` into new dictionary with same
structure
"""
keys = itertools.chain(a, b)
merged = {}
for key in keys:
if (isinstance(a.get(key), collections.abc.Mapping) and
isinstance(b.get(key), collections.abc.Mapping)):
merged[key] = merge_dicts(a[key], b[key], (*path, key))
elif key in b:
# Value from b takes precedence
merged[key] = b[key]
elif key in a:
# Value from a is default
merged[key] = a[key]
return merged
[docs]
def deduplicate(seq, key=None):
"""
Return sequence `seq` with all duplicate items removed while maintaining
the original order
:param key: Callable that gets each item and returns a hashable identifier
for that item
"""
if key is None:
def key(k):
return k
seen_keys = set()
deduped = []
for item in seq:
k = key(item)
if k not in seen_keys:
seen_keys.add(k)
deduped.append(item)
return deduped
[docs]
def as_groups(sequence, group_sizes, default=None):
"""
Iterate over items from `sequence` in equally sized groups
:params sequence: List of items to group
:params group_sizes: Sequence of acceptable number of items in a group
Find the group size with the lowest number of `default` items in
the last group. That group size is then used for all groups.
:param default: Value to pad last group with if
``len(sequence) % group_size != 0``
Example:
>>> sequence = range(1, 10)
>>> for group in as_groups(sequence, [4, 5], default="_"):
... print(group)
(1, 2, 3, 4, 5)
(6, 7, 8, 9, '_')
>>> for group in as_groups(sequence, [3, 4], default="_"):
... print(group)
(1, 2, 3)
(4, 5, 6)
(7, 8, 9)
"""
# Calculate group size that results in the least number of `default` values
# in the final group
gs_map = collections.defaultdict(list)
for gs in group_sizes:
# How many items from `sequence` are in the last group
overhang = len(sequence) % gs
# How many `default` values are in the last group
default_count = 0 if overhang == 0 else gs - overhang
gs_map[default_count].append(gs)
lowest_default_count = sorted(gs_map)[0]
group_size = max(gs_map[lowest_default_count])
args = [iter(sequence)] * group_size
yield from itertools.zip_longest(*args, fillvalue=default)
_unsupported_semantic_hash_types = (
collections.abc.Iterator,
collections.abc.Iterable,
collections.abc.Generator,
)
[docs]
def semantic_hash(obj):
"""
Return SHA256 hash for `obj` that stays the same between Python interpreter
sessions
https://github.com/schollii/sandals/blob/master/json_sem_hash.py
"""
def as_str(obj):
if isinstance(obj, str):
return obj
elif isinstance(obj, collections.abc.Mapping):
stringified = ((as_str(k), as_str(v)) for k, v in obj.items())
return as_str(sorted(stringified))
elif isinstance(obj, (collections.abc.Sequence, collections.abc.Set)):
stringified = (as_str(item) for item in obj)
return ''.join(sorted(stringified))
elif isinstance(obj, _unsupported_semantic_hash_types):
raise RuntimeError(f'Unsupported type: {type(obj)}: {obj!r}')
else:
return str(obj)
return hashlib.sha256(bytes(as_str(obj), 'utf-8')).hexdigest()
[docs]
def run_task(coro, callback):
"""
Run awaitable in background task and return immediately
This method should be used to call coroutine functions and other awaitables
in a synchronous context.
The returned task must be collected (e.g. in a :class:`list`) and awaited or
cancelled eventually.
:param coro: Any awaitable object
:param callback: Callable that is called with the returned task when
`coro` returns, is cancelled or raises any other exception
:return: :class:`asyncio.Task` instance
"""
if not callable(callback):
raise ValueError(f'Not callable: {callable!r}')
def handle_task_done(task):
# Call callback no matter what
try:
task.result()
except BaseException:
callback(task)
else:
callback(task)
task = asyncio.create_task(coro)
task.add_done_callback(handle_task_done)
return task
[docs]
async def run_async(function, *args, **kwargs):
"""
Run synchronous `function` asynchronously in a thread
See :meth:`asyncio.BaseEventLoop.run_in_executor`.
"""
loop = asyncio.get_running_loop()
wrapped = functools.partial(function, *args, **kwargs)
return await loop.run_in_executor(None, wrapped)
_NOTHING = object()
[docs]
def blocking_memoize(coro_func):
"""
Asynchronous memoization decorator that blocks concurrent calls with the
same arguments
The first call calls the decorated function while subsequent calls wait
until the first call returns and the return value is cached. Subsequent
calls then get the return value from the cache.
Exceptions raised by `coro_func` are also cached and re-raised on subsequent
calls.
The decorated function provides a `clear_cache` method that removes any
cached return values.
"""
cache = {}
lock = collections.defaultdict(asyncio.Lock)
@functools.wraps(coro_func)
async def wrapper(*args, **kwargs):
cache_key = semantic_hash((str(coro_func), args, kwargs))
async with lock[cache_key]:
result = cache.get(cache_key, _NOTHING)
if result is _NOTHING:
try:
result = await coro_func(*args, **kwargs)
except BaseException as e:
result = e
cache[cache_key] = result
if isinstance(result, BaseException):
raise result
else:
return result
def clear_cache():
cache.clear()
wrapper.clear_cache = clear_cache
return wrapper
[docs]
def flatten_nested_lists(thing):
"""
Return flattened :class:`list`
:param thing: Arbitrarily nested iterables
If `thing` is not an iterable (and not a string), it is returned inside a
list.
"""
def flatten(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, (str, bytes)):
for item in x:
yield from flatten(item)
else:
yield x
return list(flatten(thing))
[docs]
class PrettyEnum(enum.Enum):
[docs]
@classmethod
def from_string(cls, string): # noqa: F811 (Redefinition of unused `string`)
"""
Convert human-readable string back to enum
>>> TrumpableReason.from_string(
... str(TrumpableReason.HARDCODED_SUBTITLES)
... )
<TrumpableReason.HARDCODED_SUBTITLES: 4>
:raise AttributeError: if `string` is not known
"""
name = string.replace(' ', '_').upper()
try:
return getattr(cls, name)
except AttributeError:
raise ValueError(f'Unknown {cls.__name__}: {string!r}') from None
def __str__(self):
return ' '.join(
word.capitalize()
for word in self.name.split('_')
)
# We must import these here to prevent circular imports
from . import ( # noqa: E402 isort:skip
argtypes,
bbcode,
browser,
config,
country,
daemon,
disc,
fs,
html,
http,
image,
mediainfo,
predbs,
release,
signal,
string,
subproc,
torrent,
types,
update,
webdbs,
)