Source code for sfini._util

# --- 80 characters -----------------------------------------------------------
# Created by: Laurie 2018/07/11

"""Common utilities for ``sfini``."""

import inspect
import sys
import typing as T
import logging as lg
import functools as ft
from collections import abc

import boto3
from botocore import credentials
from botocore import client as botocore_client

_logger = lg.getLogger(__name__)
lg.getLogger("botocore").setLevel(lg.WARNING)
MAX_NAME_LENGTH = 79
INVALID_NAME_CHARACTERS = " \n\t<>{}[]?*\"#%\\^|~`$&,;:/"
DEBUG = "pytest" in sys.modules
JSONable = T.Union[
    None,
    bool,
    str,
    int,
    float,
    T.List["JSONable"],
    T.Dict[str, "JSONable"]]


class DefaultParameter:
    """Default parameter for step-functions definition."""
    def __bool__(self):
        return False

    def __eq__(self, other):
        return isinstance(self, type(other))

    def __str__(self):
        return "<unspecified>"

    def __repr__(self):
        return type(self).__name__ + "()"


def setup_logging(level: int = None):
    """Setup logging for ``sfini``, if logs would otherwise be ignored.

    Args:
        level: logging level (see ``logging``), default: leave unchanged
    """

    lg.basicConfig(
        format="%(asctime)s [%(levelname)8s] %(name)s: %(message)s",
        level=level)
    if level is not None:
        lg.getLogger().setLevel(level)
        [h.setLevel(level) for h in lg.getLogger().handlers]


def cached_property(fn: T.Callable) -> property:
    """Decorate a method as a cached property.

    The wrapped method's result is stored in the instance's ``__cache__``
    dictionary, with the method's name as key.

    Args:
        fn: method to decorate

    Returns:
        cached property
    """

    name = fn.__name__

    def _ensure_cache(self):
        if not hasattr(self, "__cache__"):
            self.__cache__ = {}

    @ft.wraps(fn)
    def wrapped(self):
        _ensure_cache(self)
        if name not in self.__cache__:
            self.__cache__[name] = fn(self)
        return self.__cache__[name]

    if DEBUG:  # for testing
        def fset(self, value):
            _ensure_cache(self)
            self.__cache__[name] = value

        def fdel(self):
            _ensure_cache(self)
            del self.__cache__[name]

        return property(wrapped, fset=fset, fdel=fdel)

    return property(wrapped)


def assert_valid_name(name: str):
    """Ensure a valid name of activity, state-machine or state.

    Args:
        name: name to analyse

    Raises:
        ValueError: name is invalid
    """

    if len(name) > MAX_NAME_LENGTH:
        raise ValueError("Name is too long: '%s'" % name)
    if any(c in name for c in INVALID_NAME_CHARACTERS):
        raise ValueError("Name contains invalid characters: '%s'" % name)


def collect_paginated(
        fn: T.Callable[..., T.Dict[str, JSONable]],
        **kwargs: JSONable
) -> T.Dict[str, JSONable]:
    """Call SFN API paginated endpoint.

    Calls ``fn`` until "nextToken" isn't in the return value, collating
    results. Uses recursion: if recursion limit is reached, increase
    ``maxResults`` if available, otherwise increase the maximum recursion
    limit using the ``sys`` package.

    Args:
        fn: SFN API function
        **kwargs: arguments to ``fn``

    Returns:
        combined results of paginated API calls
    """

    result = fn(**kwargs)
    if "nextToken" in result:
        kwargs["nextToken"] = result.pop("nextToken")
        r2 = collect_paginated(fn, **kwargs)
        [result[k].extend(v) for k, v in r2.items() if isinstance(v, list)]
    return result


def easy_repr(instance) -> str:
    """Use attributes to generate a string representation.

    Set class ``__repr__ = easy_repr``.

    Args:
        instance: object to get representation of

    Returns:
        object representation
    """

    sig = inspect.signature(type(instance))
    params = sig.parameters.values()

    # Can't yet process var-args
    has_var_pos = any(p.kind == p.VAR_POSITIONAL for p in params)
    has_var_kw = any(p.kind == p.VAR_KEYWORD for p in params)
    if has_var_pos or has_var_kw:
        raise RuntimeError("Can't use `easy_repr` with var-args yet")

    # Separate difference kinds of parameters
    params_pos = [p for p in params if p.kind == p.POSITIONAL_ONLY]
    params_any = [p for p in params if p.kind == p.POSITIONAL_OR_KEYWORD]
    params_kw = [p for p in params if p.kind == p.KEYWORD_ONLY]

    params_any_required = [p for p in params_any if p.default == p.empty]
    params_any_optional = [p for p in params_any if p.default != p.empty]
    params_unnamed = params_pos + params_any_required
    params_named = params_any_optional + params_kw

    arg_strs = []
    for param in params_unnamed:
        attr_val = getattr(instance, param.name)
        arg_str = repr(attr_val)
        if len(arg_str) > 80 and isinstance(attr_val, abc.Sized):
            arg_str = "len %d" % len(attr_val)
        arg_strs.append(arg_str)
    for param in params_named:
        attr_val = getattr(instance, param.name)
        if param.default != param.empty and attr_val == param.default:
            continue
        arg_str = repr(attr_val)
        if len(arg_str) > 80 and isinstance(attr_val, abc.Sized):
            arg_str = "len(%s)=%d" % (param.name, len(attr_val))
        else:
            arg_str = "%s=%s" % (param.name, arg_str)
        arg_strs.append(arg_str)

    args_str = ", ".join(arg_strs)
    type_name = type(instance).__name__
    return "%s(%s)" % (type_name, args_str)


[docs]class AWSSession: """AWS session, for preconfigure communication with AWS. Args: session: session to use """ def __init__(self, session: boto3.Session = None): self.session = session or boto3.Session() def __str__(self): fmt = "<access key: %s, region: %s>" return fmt % (self.credentials.access_key, self.region) __repr__ = easy_repr @cached_property def credentials(self) -> credentials.Credentials: """AWS session credentials.""" return self.session.get_credentials() @cached_property def sfn(self) -> botocore_client.BaseClient: """Step Functions client.""" return self.session.client("stepfunctions") @cached_property def region(self) -> str: """Session AWS region.""" return self.session.region_name @cached_property def account_id(self) -> str: """Session's account's account ID.""" return self.session.client("sts").get_caller_identity()["Account"]