# --- 80 characters -----------------------------------------------------------
# Created by: Laurie 2019/05/13
"""State definition bases and mix-ins."""
import typing as T
import logging as lg
from .. import _util
_logger = lg.getLogger(__name__)
_default = _util.DefaultParameter()
STATES_ERRORS = (
"ALL",
"Timeout",
"TaskFailed",
"Permissions",
"ResultPathMatchFailure",
"ParameterPathFailure",
"BranchFailed",
"NoChoiceMatched")
[docs]class State:
"""Abstract state.
Args:
name: name of state
comment: state description
input_path: state input filter JSONPath, ``None`` for empty input
output_path: state output filter JSONPath, ``None`` for discarded
output
"""
def __init__(
self,
name: str,
comment: str = _default,
input_path: T.Union[str, None] = _default,
output_path: T.Union[str, None] = _default):
self.name = name
self.comment = comment
self.input_path = input_path
self.output_path = output_path
def __str__(self):
name = type(self).__name__
return "%s [%s]" % (self.name, name)
__repr__ = _util.easy_repr
[docs] def add_to(self, states):
"""Add this state to a state-machine definition.
Any child states will also be added to the definition.
Args:
states (dict[str, State]): state-machine states
"""
_logger.debug("Adding state to state-machine definition: '%s'" % self)
if states.get(self.name, self) != self:
raise ValueError("State name '%s' already registered" % self.name)
states[self.name] = self
[docs] def to_dict(self) -> T.Dict[str, _util.JSONable]:
"""Convert this state to a definition dictionary.
Returns:
definition
"""
defn = {"Type": type(self).__name__}
if self.comment != _default:
defn["Comment"] = self.comment
if self.input_path != _default:
defn["InputPath"] = self.input_path
if self.output_path != _default:
defn["OutputPath"] = self.output_path
return defn
[docs]class HasNext(State):
"""State able to advance mix-in.
Args:
name: name of state
comment: state description
input_path: state input filter JSONPath, ``None`` for empty input
output_path: state output filter JSONPath, ``None`` for discarded
output
Attributes:
next: next state to execute, or ``None`` if state is terminal
"""
def __init__(
self,
name,
comment=_default,
input_path=_default,
output_path=_default):
super().__init__(
name,
comment=comment,
input_path=input_path,
output_path=output_path)
self.next: T.Union[State, None] = None
[docs] def add_to(self, states):
super().add_to(states)
if self.next is not None and self.next.name not in states:
self.next.add_to(states)
[docs] def goes_to(self, state: State):
"""Set next state after this state finishes.
Args:
state: state to execute next
"""
if self.next is not None:
_logger.warning("Overriding current next state: %s" % self.next)
self.next = state
[docs] def remove_next(self):
"""Remove next state, making this state terminal."""
self.next = None
[docs] def to_dict(self):
defn = super().to_dict()
if self.next is None:
defn["End"] = True
else:
defn["Next"] = self.next.name
return defn
[docs]class HasResultPath(State):
"""State with result mix-in.
Args:
name: name of state
comment: state description
input_path: state input filter JSONPath, ``None`` for empty input
output_path: state output filter JSONPath, ``None`` for discarded
output
result_path: task output location JSONPath, ``None`` for discarded
output
"""
def __init__(
self,
name,
comment=_default,
input_path=_default,
output_path=_default,
result_path: T.Union[str, None] = _default):
super().__init__(
name,
comment=comment,
input_path=input_path,
output_path=output_path)
self.result_path = result_path
[docs] def to_dict(self):
defn = super().to_dict()
if self.result_path != _default:
defn["ResultPath"] = self.result_path
return defn
[docs]class CanRetry(State):
"""Retryable state mix-in.
Args:
name: name of state
comment: state description
input_path: state input filter JSONPath, ``None`` for empty input
output_path: state output filter JSONPath, ``None`` for discarded
output
Attributes:
retriers: error handler policies
"""
def __init__(
self,
name,
comment=_default,
input_path=_default,
output_path=_default):
super().__init__(
name,
comment=comment,
input_path=input_path,
output_path=output_path)
self.retriers: T.List[T.Tuple[T.Sequence[str], T.Dict[str, ...]]] = []
[docs] def retry_for(
self,
errors: T.Sequence[str],
interval: int = _default,
max_attempts: int = _default,
backoff_rate: float = _default):
"""Add a retry handler.
Args:
errors: codes of errors for retry to be executed. See AWS Step
Functions documentation
interval: (initial) retry interval (seconds)
max_attempts: maximum number of attempts before re-raising error
backoff_rate: retry interval increase factor between attempts
"""
policy = {
"interval": interval,
"max_attempts": max_attempts,
"backoff_rate": backoff_rate}
self.retriers.append((errors, policy))
@staticmethod
def _retrier_defn(
errors: T.Sequence[str],
policy: T.Dict[str, T.Any]
) -> T.Dict[str, _util.JSONable]:
"""Build retry handler definition.
Args:
errors: codes of errors for retry handler to be invoked
policy: retry handler policy
Returns:
definitions
"""
_validate_errors(errors)
defn = {"ErrorEquals": errors}
if policy["interval"] != _default:
defn["IntervalSeconds"] = policy["interval"]
if policy["max_attempts"] != _default:
defn["MaxAttempts"] = policy["max_attempts"]
if policy["backoff_rate"] != _default:
defn["BackoffRate"] = policy["backoff_rate"]
return defn
def _get_retrier_defns(self) -> T.List[T.Dict[str, _util.JSONable]]:
"""Build retry handler definitions.
Returns:
definitions
"""
return [self._retrier_defn(e, p) for e, p in self.retriers]
[docs] def to_dict(self):
defn = super().to_dict()
retry = self._get_retrier_defns()
if retry:
defn["Retry"] = retry
return defn
[docs]class CanCatch(State):
"""Exception catching state mix-in.
Args:
name: name of state
comment: state description
input_path: state input filter JSONPath, ``None`` for empty input
output_path: state output filter JSONPath, ``None`` for discarded
output
Attributes:
catchers: error handler policies
"""
def __init__(
self,
name,
comment=_default,
input_path=_default,
output_path=_default):
super().__init__(
name,
comment=comment,
input_path=input_path,
output_path=output_path)
self.catchers: T.List[T.Tuple[T.Sequence[str], T.Dict[str, ...]]] = []
[docs] def add_to(self, states):
super().add_to(states)
for _, policy in self.catchers:
if policy["next_state"].name not in states:
policy["next_state"].add_to(states)
[docs] def catch(
self,
errors: T.Sequence[str],
next_state: State,
result_path: T.Union[str, None] = _default):
"""Add an error handler.
Args:
errors: code of errors for catch clause to be executed. See AWS
Step Functions documentation
next_state: state to execute for catch clause
result_path: error details location JSONPath
"""
if any(any(e in excs_ for e in errors) for excs_, _ in self.catchers):
fmt = "Handler has already accounted-for errors: %s"
_logger.warning(fmt % errors)
policy = {"next_state": next_state, "result_path": result_path}
self.catchers.append((errors, policy))
@staticmethod
def _catcher_defn(
errors: T.Sequence[str],
policy: T.Dict[str, T.Any]
) -> T.Dict[str, _util.JSONable]:
"""Build error handler definition.
Args:
errors: codes of errors for retry handler to be invoked
policy: retry handler policy
Returns:
definitions
"""
_validate_errors(errors)
defn = {"ErrorEquals": errors, "Next": policy["next_state"].name}
if policy["result_path"] != _default:
defn["ResultPath"] = policy["result_path"]
return defn
def _get_catcher_defns(self) -> T.List[T.Dict[str, _util.JSONable]]:
"""Build error handler definitions.
Returns:
definitions
"""
return [self._catcher_defn(e, p) for e, p in self.catchers]
[docs] def to_dict(self):
defn = super().to_dict()
catch = self._get_catcher_defns()
if catch:
defn["Catch"] = catch
return defn
def _validate_errors(errors: T.Sequence[str]):
"""Validate error conditions.
Args:
errors: condition error codes
Raises:
ValueError: invalid condition
"""
if not errors:
raise ValueError("Cannot have no-error condition")
if "States.ALL" in errors and len(errors) > 1:
msg = "Cannot combine 'States.ALL' condition with other errors"
raise ValueError(msg)
for err in errors:
if err.startswith("States."):
if err[7:] not in STATES_ERRORS:
fmt = "States error name was '%s', must be one of: %s"
raise ValueError(fmt % (err[7:], STATES_ERRORS))