Source code for feste.task

# This module contains modified code from Dask which is
# licensed under BSD 3-Clause License for the following holder:
# Copyright (c) 2014, Anaconda, Inc. and contributors.

import inspect
import operator
import types
from typing import Any, Optional

from dask.base import is_dask_collection, replace_name_in_key
from dask.core import quote
from dask.delayed import Delayed, right, tokenize, unpack_collections, unzip
from dask.highlevelgraph import HighLevelGraph
from dask.utils import apply, funcname
from tlz import concat, curry

from feste import context
from feste.compute import compute
from feste.optimization import Optimization


[docs]class FesteDelayed(Delayed): """Feste delayed is a lazy-evaluation node in Feste's graph.""" def __call__(self, *args, pure=None, dask_key_name=None, **kwargs): # type:ignore if context.get("eager"): return self._obj(*args, **kwargs) else: func = feste_task(apply, pure=pure) if dask_key_name is not None: return func(self, args, kwargs, dask_key_name=dask_key_name) return func(self, args, kwargs)
[docs] def compute(self, **kwargs) -> Any: # type: ignore (result,) = compute(self, traverse=False, **kwargs) return result
def __repr__(self) -> str: return f"FesteDelayed({repr(self.key)})" def _rebuild(self, dsk, *, rename=None) -> Any: # type: ignore key = replace_name_in_key(self.key, rename) if rename else self.key if isinstance(dsk, HighLevelGraph) and len(dsk.layers) == 1: layer = next(iter(dsk.layers)) else: layer = None return FesteDelayed(key, dsk, self._length, layer=layer) @classmethod def _get_binary_operator(cls, op, inv=False) -> Any: # type: ignore method = feste_task(right(op) if inv else op, pure=True) return lambda *args, **kwargs: method(*args, **kwargs) _get_unary_operator = _get_binary_operator
[docs]class FesteDelayedLeaf(FesteDelayed): """This is very similar to the DelayedLeaf in Dask, with the differences that we adjust the call to include eager mode execution.""" __slots__ = ("_obj", "_pure", "_nout") def __init__(self, obj: Any, key: Any, pure: Optional[bool] = None, nout: Optional[int] = None) -> None: super().__init__(key, None) self._obj = obj self._pure = pure self._nout = nout @property def dask(self) -> Any: return HighLevelGraph.from_collections( self._key, {self._key: self._obj}, dependencies=() ) def __call__(self, *args, **kwargs) -> Any: # type: ignore if context.get("eager"): return self._obj(*args, **kwargs) else: return call_function( self._obj, self._key, args, kwargs, pure=self._pure, nout=self._nout ) @property def __name__(self) -> Any: return self._obj.__name__ @property def __doc__(self): # type: ignore return self._obj.__doc__
[docs]@curry def feste_task(obj: Any, name: Optional[Any] = None, pure: Optional[bool] = None, nout: Optional[int] = None, traverse: bool = True) -> FesteDelayed: """Function and decorator that can be used to introduce the lazy-evaluation nodes of computation using Feste's graph.""" if isinstance(obj, FesteDelayed): return obj if is_dask_collection(obj) or traverse: task, collections = unpack_collections(obj) else: task = quote(obj) collections = set() if not (nout is None or (type(nout) is int and nout >= 0)): raise ValueError("nout must be None or a " "non-negative integer, got %s" % nout) if task is obj: if not name: try: prefix = obj.__name__ except AttributeError: prefix = type(obj).__name__ token = tokenize(obj, nout, pure=pure) name = f"{prefix}-{token}" return FesteDelayedLeaf(obj, name, pure=pure, nout=nout) else: if not name: name = f"{type(obj).__name__}-{tokenize(task, pure=pure)}" layer = {name: task} graph = HighLevelGraph.from_collections(name, layer, dependencies=collections) return FesteDelayed(name, graph, nout)
[docs]def call_function(func, func_token, args, # type:ignore kwargs, pure=None, nout=None) -> Any: dask_key_name = kwargs.pop("dask_key_name", None) pure = kwargs.pop("pure", pure) if dask_key_name is None: name = "{}-{}".format( funcname(func), tokenize(func_token, *args, pure=pure, **kwargs), ) else: name = dask_key_name args2, collections = unzip(map(unpack_collections, args), 2) collections = list(concat(collections)) if kwargs: dask_kwargs, collections2 = unpack_collections(kwargs) collections.extend(collections2) task = (apply, func, list(args2), dask_kwargs) else: task = (func,) + args2 graph = HighLevelGraph.from_collections( name, {name: task}, dependencies=collections ) nout = nout if nout is not None else None return FesteDelayed(name, graph, length=nout)
[docs]class FesteBase: """Feste Base class that is used for backends. Every backend that is added in Feste needs to inherit from this class as it will add support for eager execution and optimizations.""" def __init__(self) -> None: eager_mode = context.get("eager") if eager_mode: self._replace_delayed()
[docs] @classmethod def optimizations(cls) -> list[Optimization]: return []
def _replace_delayed(self) -> None: members = inspect.getmembers(self) for name, obj in members: if isinstance(obj, FesteDelayed): setattr(self, name, obj._obj) if inspect.ismethod(obj): if isinstance(obj.__func__, FesteDelayed): original_function = obj.__func__._obj bind_method = types.MethodType(original_function, self) setattr(self, name, bind_method)
for op in [ operator.abs, operator.neg, operator.pos, operator.invert, operator.add, operator.sub, operator.mul, operator.floordiv, operator.truediv, operator.mod, operator.pow, operator.and_, operator.or_, operator.xor, operator.lshift, operator.rshift, operator.eq, operator.ge, operator.gt, operator.ne, operator.le, operator.lt, operator.getitem, ]: FesteDelayed._bind_operator(op) try: FesteDelayed._bind_operator(operator.matmul) except AttributeError: pass