Source code for feste.optimization

import operator
from abc import ABC, abstractmethod
from typing import Any, Callable

from dask.core import istask
from dask.delayed import tokenize
from tlz import groupby

from feste.graph import FesteGraph


[docs]def make_getitem_task(object: Any, index: int) -> Any: """This function will create a new __getitem__ task which is used to get single values from return of fused calls. :param object: the object to get the item from :param index: which index to get :return: a task tuple (function, object, index) """ return (operator.getitem, object, index)
[docs]class Optimization(ABC): """Optimization abstract class. This class represents an optimization that can be applied on the Feste graph."""
[docs] @abstractmethod def apply(self, graph: FesteGraph) -> FesteGraph: """Apply the optimization into the graph and return a modified graph. :param graph: Feste graph to optimize :return: optimized graph """ raise NotImplementedError
[docs]class Optimizer: """This is Feste optimizer, it received a list of optimizations and apply these optimizations on a Feste graph. :param optimizations: list Feste optimizations. """ def __init__(self, optimizations: list[Optimization]): self.optimizations = optimizations
[docs] def apply(self, graph: FesteGraph) -> FesteGraph: """Apply all optimizations in the Feste graph. :param graph: graph to optimize :return: graph optimized (with all optimizations) """ for optim in self.optimizations: graph = optim.apply(graph) return graph
[docs] @classmethod def from_backends(cls) -> 'Optimizer': """Create the optimizer using all optimizations from classes that are inheriting from the backend FesteBase class.""" # TODO: Avoid circular imports from task from feste.task import FesteBase optimizations = [] subclasses = FesteBase.__subclasses__() for subclass in subclasses: optimizations.extend(subclass.optimizations()) optimizer = Optimizer(optimizations) return optimizer
[docs]class BatchOptimization(Optimization): """This is a static optimization to do batching of calls statically. Another optimization is done during scheduling as tasks might get ready before/after. :param rewrite_rules: rule that describes how to change a single call to a batched call for APIs that support it. """ def __init__(self, rewrite_rules: dict[Callable, Callable]) -> None: self.rewrite_rules = rewrite_rules
[docs] def apply(self, graph: FesteGraph) -> FesteGraph: # Get all tasks from the graph tasks = [] for key, task in dict(graph).items(): if not istask(task): continue # Add the key as suffix of the task arguments # as we need to keep track of which keys # the task belong to. suffix_task = task + (key,) tasks.append(suffix_task) # Group by <function / object>, so we don't call # the batch for different objects (different parameters) # Note: we use here identity instead of equality # across objects to group them. task_groups = groupby(lambda x: (x[0], x[1]), tasks) new_tasks = {} # Group key = (function, object) # Group task = [(function, object, parameter, key), ...] for group_key, group_tasks in task_groups.items(): # Check if batching is possible if len(group_tasks) <= 1: continue # Check if the key is in the rewrite rules rewrite_rule_keys = self.rewrite_rules.keys() if group_key[0] not in rewrite_rule_keys: continue # Build argument list for the task # TODO: need to support more than one arg arg_key_list = [(task[2], task[3]) for task in group_tasks] unzipped_arg_key_list = zip(*arg_key_list) arg_list, key_order = unzipped_arg_key_list # New task using rewriting rule new_function = self.rewrite_rules[group_key[0]] new_task = (new_function, group_key[1]) + (list(arg_list),) key_name = "fuse-batch-" + tokenize(new_task) new_tasks[key_name] = new_task # Replace each call to get from the batched # responde call. for index, task in enumerate(group_tasks): new_task = make_getitem_task(key_name, index) graph.update({key_order[index]: new_task}) graph.update(new_tasks) return graph