Source code for terrace.graph

from functools import partial
from typing import Dict, List, Sequence, Set, Tuple, Optional, Type, TypeVar, Generic, Any, Union
from dataclassy import dataclass
import dgl
import torch

from .batch import BatchView, BatchViewBase, Batchable, Batch, BatchBase, collate

[docs]@dataclass class TypeTree: type: Type subtypes: Dict[str, "TypeTree"]
[docs]def flatten_batch(b: Batch) -> Tuple[Dict[str, torch.Tensor], TypeTree]: """ Flattens batch to a single dict from keys to tensors. Also returns all the batch types. This is neccessary because dgl stores ndata and edata as such dicts. """ type_tree = TypeTree(b.item_type(), {}) ret = {} for key, val in b.asdict().items(): if isinstance(val, torch.Tensor): ret[key] = val type_tree.subtypes[key] = TypeTree(torch.Tensor, {}) elif isinstance(val, Batch): sub_dict, subtype_tree = flatten_batch(val) type_tree.subtypes[key] = subtype_tree for key2, val2 in sub_dict.items(): ret[f"{key}/{key2}"] = val2 else: raise NotImplementedError(f"flatten_batch is only currently implemented for torch.Tensor and Batches, but found attribute {key} with type {type(val)}") return ret, type_tree
[docs]def unflatten_dict(flat): """ Helper function for unflatten_batch. Returns an unflattened dict for a dict with keys like 'key/subkey' """ ret = {} for key, val in flat.items(): container = ret subkeys = key.split("/") for subkey in subkeys[:-1]: if subkey not in container: container[subkey] = {} container = container[subkey] if val == 'None': val = None container[subkeys[-1]] = val return ret
[docs]def make_batch_from_unflat_dict(unflat_dict, type_tree): """ Helper function for unflatten_batch. Makes a batch from a dict that has been unflattened """ batch_type = type_tree.type kwargs = {} for key, val in unflat_dict.items(): if isinstance(val, dict): arg = make_batch_from_unflat_dict(val, type_tree.subtypes[key]) elif isinstance(val, torch.Tensor): arg = val else: raise AssertionError(f"Expected only dicts and tensors, but got {key}={val}") kwargs[key] = arg return Batch(batch_type, **kwargs)
[docs]def unflatten_batch(flat_dict: Dict[str, torch.Tensor], type_tree: TypeTree) -> Batch: """ Reverses flatten_batch; takes a flattened dict and a type tree and returns a proper Batch """ unflat_dict = unflatten_dict(flat_dict) return make_batch_from_unflat_dict(unflat_dict, type_tree)
N = TypeVar('N', bound=Batchable) E = TypeVar('E', bound=Optional[Batchable])
[docs]@dataclass class GraphBase(Generic[N, E]): _dgl_graph: Union[dgl.graph, dgl.batch] _node_type_tree: TypeTree _edge_type_tree: Optional[TypeTree] @property def ndata(self) -> Batch[N]: return unflatten_batch(self._dgl_graph.ndata, self._node_type_tree) @property def edata(self) -> Batch[E]: if self._edge_type_tree is None: return None return unflatten_batch(self._dgl_graph.edata, self._edge_type_tree) @property def edges(self) -> List[Tuple[int, int]]: ret = [] for src, dst in zip(*self._dgl_graph.edges()): ret.append((int(src), int(dst))) return ret
[docs] def dgl(self) -> Union[dgl.graph, dgl.batch]: return self._dgl_graph
[docs] def to(self, device): cls_ = type(self) ret = cls_.__new__(cls_) for key, val in self.__dict__.items(): if key == "_dgl_graph": val = val.to(device) ret.__dict__[key] = val return ret
def __repr__(self) -> str: indent = " " ret = f"{self.__class__.__name__}(\n" ret += indent + "ndata=" + repr(self.ndata).replace("\n", "\n" + indent) + "\n" ret += indent + "edata=" + repr(self.edata).replace("\n", "\n" + indent) + "\n" ret += ")" return ret
[docs]class Graph(GraphBase[N, E], Batchable): """ Wrapper around dgl graph allowing easier access to data """
[docs] @staticmethod def get_batch_type(): return GraphBatch
def __init__(self, nodes: Sequence[N], edges: Sequence[Tuple[int, int]], edata: Optional[List[E]] = None, directed: bool = False, device = 'cpu'): """ If directed is false, both permutations of the edges will be added. """ src_list = [] dst_list = [] new_edata = None if edata is None else [] for i, (n1, n2) in enumerate(edges): src_list.append(n1) dst_list.append(n2) # if new_edata is not None: # new_edata.append(edata[i]) if not directed: src_list.append(n2) dst_list.append(n1) if new_edata is not None: new_edata.append(edata[i]) new_edata.append(edata[i]) if directed and edata is not None: assert new_edata == [] new_edata = edata dgl_graph = dgl.graph((torch.tensor(src_list), torch.tensor(dst_list)), num_nodes=len(nodes), idtype=torch.int32, device=device) node_batch = collate(nodes) node_flat_dict, node_type_tree = flatten_batch(node_batch) for key, val in node_flat_dict.items(): dgl_graph.ndata[key] = val if new_edata is not None: edge_batch = collate(new_edata) edge_flat_dict, edge_type_tree = flatten_batch(edge_batch) for key, val in edge_flat_dict.items(): dgl_graph.edata[key] = val else: edge_type_tree = None super().__init__(dgl_graph, node_type_tree, edge_type_tree)
G = TypeVar('G', bound=Graph)
[docs]class GraphBatch(BatchBase[G], GraphBase): _internal_attribs = [ "_graph_type", "node_slices", "edge_slices", "__dict__" ] _graph_type: Type[Graph] node_slices: List[slice] edge_slices: List[slice] def __init__(self, items: List[Graph[N, E]]): dgl_batch = dgl.batch([ item.dgl() for item in items ]) first = items[0] super().__init__(dgl_batch, first._node_type_tree, first._edge_type_tree) self._graph_type = type(first) # node and edge slices determine which nodes/edges belong to which # graph in the batch self.node_slices = [] self.edge_slices = [] tot_n = 0 tot_e = 0 for n, e in zip(self.dgl().batch_num_nodes(), self.dgl().batch_num_edges()): n = int(n) e = int(e) self.node_slices.append(slice(tot_n,tot_n+n)) self.edge_slices.append(slice(tot_e,tot_e+e)) tot_n += n tot_e += e def __len__(self) -> int: return self.dgl().batch_size def __getitem__(self, index: int) -> Graph[N, E]: if isinstance(index, int): if index >= len(self): raise IndexError() else: raise NotImplementedError() return GraphBatchView(self, index)
[docs] def item_type(self): return self._graph_type
def __getattribute__(self, name: str) -> Any: if name in GraphBatch._internal_attribs or name in self.__dict__: return object.__getattribute__(self, name) batch_method_name = "batch_" + name if hasattr(self, "_graph_type") and hasattr(self._graph_type, batch_method_name): batch_method = getattr(self._graph_type, batch_method_name) if callable(batch_method): return partial(batch_method, self) return object.__getattribute__(self, name)
[docs]class GraphBatchView(BatchViewBase[G]): _internal_attribs = [ "_batch", "_index" ] _batch: Batch[G] _index: Union[int, slice] # either int or slice _node_type_tree: TypeTree _edge_type_tree: TypeTree def __init__(self, batch: Batch, index: Union[int, slice]): self._batch = batch self._index = index self._node_type_tree = batch._node_type_tree self._edge_type_tree = batch._edge_type_tree @property def ndata(self) -> Batch[N]: idxs = self._batch.node_slices[self._index] return self._batch.ndata[idxs] @property def edata(self) -> Batch[E]: if self._batch.edata is None: return None idxs = self._batch.edge_slices[self._index] return self._batch.edata[idxs] @property def edges(self) -> List[Tuple[int, int]]: node_idxs = self._batch.node_slices[self._index] n0 = node_idxs.start node_idxs_expl = range(len(self._batch.ndata))[node_idxs] ret = [] for n1, n2 in self._batch.edges: if n1 in node_idxs_expl: assert n2 in node_idxs_expl ret.append((n1 - n0, n2 - n0)) return ret
[docs] def dgl(self): return dgl.unbatch(self._batch.dgl())[self._index]
[docs] def get_type(self): return self._batch._graph_type
def __repr__(self): return GraphBase.__repr__(self)
[docs] @staticmethod def get_batch_type(): return GraphBatch