Source code for feste.graph

from graphlib import TopologicalSorter
from typing import Any, Iterator, Mapping, TextIO

import dagviz
import networkx as nx
from dask.base import collections_to_dsk
from dask.base import unpack_collections as base_unpack_collections
from dask.core import get_dependencies
from dask.dot import dot_graph
from dask.order import order
from rich.pretty import pprint


[docs]class FesteGraph(Mapping): """A computational graph representing the flow described by the call of Feste tasks. :param graph: initialize from a dictionary. """ def __init__(self, graph: dict[str, Any]): self.graph = graph
[docs] @classmethod def collect(cls, *args) -> tuple['FesteGraph', list, callable]: # type:ignore """Create a Feste graph from a collection of objects. :param args: collection of objects. :return: Tuple (Graph, collections, repack function) """ collections, repack = base_unpack_collections(*args) dsk = collections_to_dsk(collections, optimize_graph=False) return cls(dict(dsk)), collections, repack
[docs] def get_all_dependencies(self) -> dict[str, str]: """Returns a dict with all dependencies.""" dependencies = {k: get_dependencies(self.graph, k) for k in self.graph} return dependencies
[docs] def topological_sorter(self) -> TopologicalSorter: """Returns a topological sorter from the graph dependencies.""" dependencies = self.get_all_dependencies() return TopologicalSorter(dependencies)
[docs] def to_dict(self) -> dict: """Convert the graph to a dictionary.""" return dict(self.graph)
[docs] def print(self) -> None: """Print the internal graph representation.""" pprint(self.graph)
[docs] def order(self) -> dict[str, int]: """Return the execution order hint.""" return order(self.graph) # type: ignore
def __getitem__(self, key: str) -> Any: """Gets an item from the graph. :param key: the graph key """ return self.graph[key] def __setitem__(self, key: str, value: Any) -> None: """Sets an item from the graph. :param key: param to set :param value: the value to set """ self.graph[key] = value def __iter__(self) -> Iterator: return iter(self.graph) def __len__(self) -> int: return len(self.graph)
[docs] def update(self, graph_dict: dict[str, Any]) -> None: """Update the internal graph from a dictionary. :param graph_dict: dictionary to update from. """ self.graph.update(graph_dict)
[docs] def visualize(self, filename: str) -> None: """Export the graph into a file. :param filename: filename to export the graph (e.g. .pdf, .png) """ dot_graph(dict(self), filename=filename, verbose=True, collapse_outputs=False)
[docs] def dagviz_metro(self, svg_handle: TextIO) -> None: """Writes a metro style dag visualization using dagviz. :param svg_handle: the svg file type object :return: svg content """ deps = self.get_all_dependencies() G = nx.DiGraph(deps) r = dagviz.render_svg(G) svg_handle.write(r)