Source code for upsies.utils.argtypes

"""
CLI argument types

All types return normalized values and raise ValueError for invalid values.

A custom error message can be provided by raising
:class:`argparse.ArgumentTypeError`.
"""

import argparse
import functools
import os

from .. import errors, utils
from . import fs, types

natsort = utils.LazyModule(module='natsort', namespace=globals())


ArgumentTypeError = argparse.ArgumentTypeError
"""
Exception that should be raised by any callable that is passed to
:func:`argparse.ArgumentParser.add_argument` as `type` if it gets an invalid
value
"""


[docs] def comma_separated(argtype): """ Multiple comma-separated values :param argtype: Any callable that returns a validated object for one of the comma-separated values or raises :class:`ValueError`, :class:`TypeError` or :class:`argparse.ArgumentTypeError` :return: Sequence of `argtype` return values """ def comma_separated(value): values = [] for string in str(value).split(','): string = string.strip() if string: try: values.append(argtype(string)) except (ValueError, TypeError) as e: raise argparse.ArgumentTypeError(f'Invalid value: {string}') from e return values return comma_separated
[docs] def content(value): """Existing path to release file(s)""" path = release(value) return existing_path(path)
[docs] def existing_path(value): """Path to existing path""" path = str(value) if not os.path.exists(path): raise argparse.ArgumentTypeError(f'No such file or directory: {value}') else: return path
[docs] def imagehost(value): """Name of a image hosting service from :mod:`~.imagehosts`""" from .. import imagehosts if value in imagehosts.imagehost_names(): return value.lower() else: raise argparse.ArgumentTypeError(f'Unsupported image hosting service: {value}')
[docs] def imagehosts(value): """Comma-separated list of names of image hosting services from :mod:`~.imagehosts`""" names = [] for name in value.split(','): name = name.strip() if name: names.append(imagehost(name)) return names
[docs] def bool_or_none(value): """Convert `value` to :class:`~.types.Bool` or `None` if `value` is `None`""" if value is None: return None else: try: return types.Bool(value) except ValueError as e: raise argparse.ArgumentTypeError(e) from e
[docs] def integer(value, *, min=None, max=None): """ Natural number (:class:`float` is rounded down) :param int min: Minimum value :param int max: Maximum value """ try: return types.Integer(min=min, max=max)(value) except ValueError as e: raise argparse.ArgumentTypeError(e) from e
[docs] @functools.cache def make_integer(*, min, max): """ Return function that takes a number and passes it to :func:`integer` together with `min` and `max` :param int min: Minimum number of screenshots :param int max: Maximum number of screenshots """ return functools.partial(integer, min=min, max=max)
[docs] @functools.cache def files_with_extension(extension, *, allow_no_hits=True): """ Return function that recursively searches a directory for files with `extension` If the returned function gets a file path with the wanted extension, it is simply returned. :param str extension: Wanted file name extension (e.g. "png") :param bool allow_no_hits: Whether :exc:`argparse.ArgumentTypeError` is raised if no matching files are found """ def is_match(filepath): if fs.file_extension(filepath).casefold() == extension.casefold(): try: fs.assert_file_readable(filepath) except errors.ContentError as e: raise argparse.ArgumentTypeError(e) from e else: return True return False def files_with_extension(value): matching_files = [] if os.path.isdir(value): for dirpath, _dirnames, filenames in os.walk(value): for filename in filenames: filepath = os.path.join(dirpath, filename) if is_match(filepath): matching_files.append(filepath) if not matching_files and not allow_no_hits: raise argparse.ArgumentTypeError(f'{value}: No {extension} files found') else: if os.path.exists(value): if is_match(value): matching_files = (value,) else: msg = f'Expected file extension {extension}' if ext := fs.file_extension(value): msg += f', not {ext}' msg += f': {value}' raise argparse.ArgumentTypeError(msg) if not matching_files and not allow_no_hits: raise argparse.ArgumentTypeError(f'{value}: Not a {extension} file') return tuple(natsort.natsorted( matching_files, key=lambda filepath: fs.basename(filepath).casefold(), )) return files_with_extension
[docs] @functools.cache def one_of(values): """ Return function that returns an item of `values` or raises :class:`argparse.ArgumentTypeError` :param values: Allowed values """ values = tuple(values) def one_of_values(value): if value in values: return value else: raise argparse.ArgumentTypeError(f'Invalid value: {value}') return one_of_values
[docs] def regex(value): """:class:`re.Pattern` object""" try: return types.Regex(value) except ValueError as e: raise argparse.ArgumentTypeError(e) from e
[docs] def release(value): """Same as :func:`content`, but doesn't have to exist""" from .. import errors from . import predbs path = str(value) try: predbs.assert_not_abbreviated_filename(path) except errors.SceneAbbreviatedFilenameError as e: raise argparse.ArgumentTypeError(e) from e else: return path
[docs] def predb_name(value): """Name of a scene release database from :mod:`~.utils.predbs`""" from . import predbs if value in predbs.predb_names(): return value.lower() else: raise argparse.ArgumentTypeError(f'Unsupported scene release database: {value}')
[docs] def predb(value): """ :class:`~.PredbApiBase` instance from a corresponding :attr:`~.PredbApiBase.name` """ from . import predbs try: return predbs.predb(value.lower()) except ValueError as e: raise argparse.ArgumentTypeError(e) from e
[docs] def timestamp(value): """Turn `value` into :class:`types.Timestamp`""" try: return types.Timestamp.from_string(value) except (ValueError, TypeError) as e: raise argparse.ArgumentTypeError(e) from e
[docs] def tracker(value): """Name of a tracker from :mod:`~.trackers`""" from .. import trackers if value in trackers.tracker_names(): return value.lower() else: raise argparse.ArgumentTypeError(f'Unsupported tracker: {value}')
[docs] def webdb(value): """Name of a movie/series database from :mod:`~.webdbs`""" from . import webdbs if value in webdbs.webdb_names(): return value.lower() else: raise argparse.ArgumentTypeError(f'Unsupported database: {value}')
[docs] @functools.cache def webdb_id(webdb_name): """ Return function that finds a web DB ID in a string, e.g. an URL :param str webdb_name: Name of a web DB, e.g. "imdb" The returned function takes any object and passes it to :meth:`~.WebDbApiBase.get_id_from_text`. """ from . import webdbs db = webdbs.webdb(webdb_name) def webdb_id(value): id = db.get_id_from_text(str(value)) if id: return id else: raise argparse.ArgumentTypeError(f'Invalid {db.label} ID: {value}') return webdb_id
[docs] def subtitle(value): """:class:`~.Subtitle` instance from language code""" subtitle = utils.mediainfo.text.Subtitle.from_string(value) if subtitle.language == '?': raise argparse.ArgumentTypeError(f'Unknown language code: {value}') else: return subtitle
[docs] def episodes(value): """ :class:`~.Episodes` instance and episode title as :class:`str` If no episode title is specified, the second return value is `None`. Example values: - ``"S03"`` - ``"S03E04"`` - ``"S03E04 The Episode Title"`` """ if ' ' in value: parts = value.split(' ', maxsplit=1) else: parts = (value,) # Season/Episode info. if not utils.release.Episodes.is_episodes_info(parts[0]): raise argparse.ArgumentTypeError(f'Invalid season/episode info: {value}') else: episodes = utils.release.Episodes.from_string(parts[0]) # Episode title. if len(parts) >= 2 and (episode_title := parts[1].strip()): # Unless there is exactly one episode specified, an episode title doesn't make sense. if len(episodes) != 1 or sum(len(eps) for eps in episodes.values()) != 1: raise argparse.ArgumentTypeError(f'Exacly ony episode required if episode title is specified: {value}') else: return episodes, episode_title else: return episodes, None
[docs] def release_year(value): """:class:`~.ReleaseYear` instance""" try: return utils.types.ReleaseYear(value) except ValueError as e: raise argparse.ArgumentTypeError(e) from e