# --- 80 characters -----------------------------------------------------------
# Created by: Laurie 2018/07/12
"""Activity task polling and execution.
You can provide you're own workers: the interface to the activities is
public. This module's worker implementation uses threading, and is
designed to be resource-managed outside of Python.
"""
import json
import uuid
import time
import socket
import threading
import traceback
import typing as T
import logging as lg
from botocore import exceptions as bc_exc
from . import _util
_logger = lg.getLogger(__name__)
_host_name = socket.getfqdn(socket.gethostname())
[docs]class WorkerCancel(KeyboardInterrupt):
"""Workflow execution interrupted by user."""
def __init__(self, *args, **kwargs):
msg = (
"Activity execution cancelled by user. "
"This could be due to a `KeyboardInterrupt` during execution, "
"or the worker was killed during task polling.")
super().__init__(msg, *args, **kwargs)
[docs]class TaskExecution:
"""Execute a task, providing heartbeats and catching failures.
Args:
activity (sfini.activity.CallableActivity): activity to execute
task of
task_token: task token for execution identification
task_input: task input
session: session to use for AWS communication
"""
def __init__(
self,
activity,
task_token: str,
task_input: _util.JSONable,
*,
session: _util.AWSSession = None):
self.activity = activity
self.task_token = task_token
self.task_input = task_input
self.session = session or _util.AWSSession()
self._heartbeat_thread = threading.Thread(target=self._heartbeat)
self._request_stop = False
def __str__(self):
return "%s - %s" % (self.activity.name, self.task_token)
__repr__ = _util.easy_repr
def _send(self, send_fn: T.Callable, **kw):
"""Send execution update to SFN."""
if self._request_stop:
_logger.warning("Skipping sending update as we've already quit")
return
return send_fn(taskToken=self.task_token, **kw)
def _report_exception(self, exc: BaseException):
"""Report failure."""
_logger.info("Reporting task failure for '%s'" % self, exc_info=exc)
tb = traceback.format_exception(type(exc), exc, exc.__traceback__)
self._send(
self.session.sfn.send_task_failure,
error=type(exc).__name__,
cause="".join(tb))
self._request_stop = True
[docs] def report_cancelled(self):
"""Cancel a task execution: stop interaction with SFN."""
fmt = "Reporting task failure for '%s' due to cancellation"
_logger.info(fmt % self)
self._send(
self.session.sfn.send_task_failure,
error=WorkerCancel.__name__,
cause=str(WorkerCancel()))
self._request_stop = True
def _report_success(self, res: _util.JSONable):
"""Report success."""
fmt = "Reporting task success for '%s' with output: %s"
_logger.debug(fmt % (self, res))
self._send(self.session.sfn.send_task_success, output=json.dumps(res))
self._request_stop = True
def _send_heartbeat(self):
"""Send a heartbeat."""
_logger.debug("Sending heartbeat for '%s'" % self)
try:
self._send(self.session.sfn.send_task_heartbeat)
except bc_exc.ClientError as e:
self._request_stop = True
if e.response["Error"]["Code"] != "TaskTimedOut":
raise
_logger.error("Task execution '%s' timed-out" % self)
def _heartbeat(self):
"""Run heartbeat sending."""
heartbeat = self.activity.heartbeat
while True:
t = time.time()
if self._request_stop:
break
self._send_heartbeat()
time.sleep(heartbeat - (time.time() - t))
[docs] def run(self):
"""Run task."""
self._heartbeat_thread.start()
t = time.time()
try:
res = self.activity.call_with(self.task_input)
except KeyboardInterrupt:
self.report_cancelled()
return
except Exception as e:
self._report_exception(e)
return
fmt = "Task '%s' completed in %.6f seconds"
_logger.debug(fmt % (self, time.time() - t))
self._report_success(res)
[docs]class Worker:
"""Worker to poll for activity task executions.
Args:
activity (sfini.activity.CallableActivity): activity to poll and
run executions of
name: name of worker, used for identification, default: a
combination of UUID and host's FQDN
session: session to use for AWS communication
"""
_task_execution_class = TaskExecution
def __init__(
self,
activity,
name: str = None,
*,
session: _util.AWSSession = None):
self.activity = activity
self.name = name or "%s-%s" % (_host_name, str(str(uuid.uuid4()))[:8])
self.session = session or _util.AWSSession()
self._poller = threading.Thread(target=self._worker)
self._request_finish = False
self._exc = None
def __str__(self):
return "%s [%s]" % (self.name, self.activity.name)
__repr__ = _util.easy_repr
def _execute_on(self, task_input: _util.JSONable, task_token: str):
"""Execute the provided task.
Args:
task_input: activity task execution input
task_token: task execution identifier
"""
_logger.debug("Got task input: %s" % task_input)
execution = self._task_execution_class(
self.activity,
task_token,
task_input,
session=self.session)
if self._request_finish:
execution.report_cancelled()
else:
execution.run()
def _poll_and_execute(self):
"""Poll for tasks to execute, then execute any found."""
while not self._request_finish:
fmt = "Polling for activity '%s' executions"
_logger.debug(fmt % self.activity)
resp = self.session.sfn.get_activity_task(
activityArn=self.activity.arn,
workerName=self.name)
if resp.get("taskToken", None) is not None:
input_ = json.loads(resp["input"])
self._execute_on(input_, resp["taskToken"])
def _worker(self):
"""Run polling, catching exceptins."""
try:
self._poll_and_execute()
except (Exception, KeyboardInterrupt) as e:
_logger.warning("Polling/execution failed", exc_info=e)
self._exc = e # send exception to main thread
self._request_finish = True
[docs] def start(self):
"""Start polling."""
from . import activity
if not isinstance(self.activity, activity.CallableActivity):
raise TypeError("Activity '%s' cannot be executed" % self.activity)
_util.assert_valid_name(self.name)
_logger.info("Worker '%s': waiting on final poll to finish" % self)
self._poller.start()
[docs] def join(self):
"""Block until polling exit."""
try:
self._poller.join()
except KeyboardInterrupt:
_logger.info("Quitting polling due to KeyboardInterrupt")
self._request_finish = True
return
except Exception:
self._request_finish = True
raise
if self._exc is not None:
raise self._exc
[docs] def end(self):
"""End polling."""
_logger.info("Worker '%s': waiting on final poll to finish" % self)
self._request_finish = True
[docs] def run(self):
"""Run worker to poll for and execute specified tasks."""
self.start()
self.join()