Source code for terrace.batch

from functools import partial
from torch.utils._pytree import tree_map
from typing import Any, Generic, Optional, Type, TypeVar, Union, List, Tuple
from dataclassy import dataclass
import torch
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from .categorical_tensor import CategoricalTensor, NoStackCatTensor
from .meta_utils import recursive_map


[docs]@dataclass class Batchable: """ Base class for all objects we want to batchify This method can also define static methods collate_{attribute} and index_{attribute} """
[docs] @staticmethod def get_batch_type(): """ override this method if you want a custom batch type for your batchable class """ return Batch
[docs] def asdict(self): """ Override this method to define what attributes you want your Batch to define """ return self.__dict__
T = TypeVar('T')
[docs]class BatchBase(Generic[T]):
[docs] def item_type(self) -> Type[Batchable]: """ Returns the type of each item in the batch """ raise NotImplementedError()
[docs] def cuda(self): return self.to("cuda")
[docs] def cpu(self): return self.to("cpu")
def _batch_repr(val): if isinstance(val, torch.Tensor): val_str = f"Tensor(shape={val.shape}, dtype={val.dtype})" else: val_str = repr(val) return val_str
[docs]class Batch(BatchBase[T]): _internal_attribs = [ "_batch_len", "_batch_type", "__dict__" ] _batch_len: int _batch_type: Type[Batchable] def __init__(self, items: Optional[Union[List[T], Type[T]]], **kwargs): if isinstance(items, type): self._init_from_type(items, **kwargs) else: self._init_from_list(items) def _init_from_list(self, items: List[T]): assert len(items) > 0 template = items[0] template_type = template.__class__ # type(template) # todo: very hacky if isinstance(template, BatchViewBase): template_type = template.get_type() self._batch_len = len(items) self._batch_type = template_type if isinstance(template, BatchBase): raise ValueError("Batchifying batches is not supported at the moment") attribs = { key: [] for key in template.asdict().keys() } for item in items: for key, attrib in item.asdict().items(): attribs[key].append(attrib) for key, attrib_list in attribs.items(): if key in Batch._internal_attribs: raise ValueError(f"{key} is used internally by Batch, so it shouldn't be a member of {template_type}") collate_method = "collate_" + key if hasattr(template_type, collate_method): collated = getattr(template_type, collate_method)(attrib_list) else: collated = collate(attrib_list) self.__dict__[key] = collated def _init_from_type(self, batch_type: Type, **kwargs): self._batch_type = batch_type if len(kwargs) == 0: raise ValueError("Can't determine batch size of empty batch") self._batch_len = len(next(iter(kwargs.values()))) for key, val in kwargs.items(): self.__dict__[key] = val def __len__(self) -> int: return self._batch_len def __getitem__(self, index) -> Union["BatchView[T]", Any]: if isinstance(index, str): return self.__dict__[index] if isinstance(index, int) and index >= len(self): raise IndexError() if isinstance(index, int): return BatchView(self, index) return collate([self[i] for i in range(len(self))[index]])
[docs] def item_type(self) -> Type[T]: """ Returns the type of each item in the batch """ return self._batch_type
[docs] def attribute_names(self) -> List[str]: """ Returns the names of all the batched attributes """ return [ key for key in self.__dict__.keys() if key not in Batch._internal_attribs ]
[docs] def asdict(self): """ Convert to dict """ return { key: val for key, val in self.__dict__.items() if key not in Batch._internal_attribs }
[docs] def to(self, device): to_dict = { key: recursive_map(lambda x: x.to(device) if hasattr(x, 'to') else x, val) for key, val in self.asdict().items() } return Batch(self._batch_type, **to_dict)
def __repr__(self): indent = " " ret = f"Batch[{self._batch_type.__name__}](\n" for key, val in self.asdict().items(): val_str = _batch_repr(val).replace("\n", "\n" + indent) ret += indent + f"{key}={val_str}\n" ret += ")" return ret def __getattribute__(self, name: str) -> Any: if name in Batch._internal_attribs or name in self.__dict__: return object.__getattribute__(self, name) batch_method_name = "batch_" + name if hasattr(self, "_batch_type") and hasattr(self._batch_type, batch_method_name): batch_method = getattr(self._batch_type, batch_method_name) if callable(batch_method): return partial(batch_method, self) return object.__getattribute__(self, name)
[docs]class BatchViewBase(Generic[T], Batchable): pass
def _get_methods_for_type(type_): return {func: getattr(type_, func) for func in dir(type_) if callable(getattr(type_, func)) and not func.startswith("__")}
[docs]class BatchView(BatchViewBase[T]): """ View of an item in a batch. Should act like said item in most circumstances. We use views instead of creating actual items because, for many use cases, lazily indexing batches is much faster """ _internal_attribs = [ "_batch", "_index", "_get_methods" ] _batch: Batch[T] _index: Union[int, slice] # either int or slice def __init__(self, batch: Batch, index: Union[int, slice]): self._batch = batch self._index = index # self.__class__ = self._batch._batch_type def _get_methods(self): type_ = self._batch._batch_type return { key: val for key, val in _get_methods_for_type(type_).items() if key not in _get_methods_for_type(BatchView) } def __getattribute__(self, name: str) -> Any: if name == "__dict__" or name in BatchView._internal_attribs or name in _get_methods_for_type(BatchView): return object.__getattribute__(self, name) if name in self.__dict__: return object.__getattribute__(self, name) if name in self._batch.__dict__ and not name in Batch._internal_attribs: attrib = getattr(self._batch, name) item_type = self._batch.item_type() index_method = "index_" + name if hasattr(item_type, index_method): ret = getattr(item_type, index_method)(attrib, self._index) self.__dict__[name] = ret return ret if isinstance(attrib, tuple): ret = tuple([ item[self._index] for item in attrib ]) self.__dict__[name] = ret return ret elif isinstance(attrib, dict): ret = { key: val[self._index] for key, val in attrib.items() } self.__dict__[name] = ret return ret ret = attrib[self._index] self.__dict__[name] = ret return ret if isinstance(self._index, int): # assume we are acting like a T methods = self._get_methods() if name in methods: return partial(methods[name], self) else: # assume we are acting like a Batch[T] # todo: make this whole "what am I " process more robust batch_method_name = "batch_" + name if hasattr(self, "_batch") and hasattr(self.get_type(), batch_method_name): batch_method = getattr(self.get_type(), batch_method_name) if callable(batch_method): return partial(batch_method, self) pass return object.__getattribute__(self, name) def __getitem__(self, index): if isinstance(self._index, int): raise ValueError("Can only index into a BatchView if the batchview is view of a slice of a batch") new_idx = range(len(self._batch))[index] if isinstance(new_idx, range): new_idx = slice(new_idx.start, new_idx.stop, new_idx.step) raise NotImplementedError # need to test this else: assert isinstance(new_idx, int) return BatchView(self._batch, new_idx)
[docs] def asdict(self): return { key: getattr(self, key) for key in self._batch.attribute_names() }
def __repr__(self): # Todo: this is basically copied from batch __repr__. Make both # call a generalized function # args = [] # for key, val in self.asdict().items(): # val_str = _batch_repr(val) # args.append(f"{key}={val_str}") # return f"BatchView[{self.get_type().__name__}]({', '.join(args)})" indent = " " ret = f"BatchView[{self.get_type().__name__}](\n" for key, val in self.asdict().items(): val_str = _batch_repr(val).replace("\n", "\n" + indent) ret += indent + f"{key}={val_str}\n" ret += ")" return ret
[docs] def get_type(self): return self._batch._batch_type
[docs]class LazyBatch(BatchBase[T]): """ This is mainly used if you're recollating after indexing a batch ( eg collate([batch[0]])). If you use collate([batch[0]], lazy=True) (recommended), it will return a LazyBatch. This is nice because it will only re-collate the attributes of the batch on the fly """ _items: List[T] _batch_type: Type[Batchable] _internal_attribs = [ "_items", "_batch_type" ] def __init__(self, items): self._items = items self._batch_type = type(items[0]) def __len__(self): return len(self._items) def __getitem__(self, index): return self._items[index] def __getattribute__(self, name): if name.startswith("_"): return object.__getattribute__(self, name) batch_method_name = "batch_" + name if hasattr(self, "_batch_type") and hasattr(self._batch_type, batch_method_name): batch_method = getattr(self._batch_type, batch_method_name) if callable(batch_method): return partial(batch_method, self) if hasattr(self[0], name): return collate([getattr(item, name) for item in self._items], lazy=True) return object.__getattribute__(self, name) def __repr__(self): return "LazyBatch"
[docs] def to(self, device): raise NotImplementedError
[docs]class NoStackTensor(torch.Tensor): """ This is used when you want to collate tensors into a list instead of stack them. E.g. when you have different shapes""" @staticmethod def __new__(cls, tensor: torch.Tensor): return torch.Tensor._make_subclass(cls, tensor.to('meta')) def __init__(self, tensor): self.tensor = tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs={}): # return torch.Tensor.__torch_dispatch__(func, types, args, kwargs) def unwrap(x): return x.tensor if isinstance(x, NoStackTensor) else x def wrap(x): return NoStackTensor(x) if isinstance(x, torch.Tensor) else x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) out = func(*args, **kwargs) return tree_map(wrap, out)
[docs]def collate(batch: Any, lazy=False) -> Any: """ turn a list of items into a batch of items. Replacement for pytorch's default collate. This is what we use in the custom DataLoader class """ # performance optimization -- if we've already batched something, no # need to do it again if isinstance(batch, Batch): return batch example = batch[0] if isinstance(example, Batchable): batch_type = type(example).get_batch_type() if lazy and batch_type == Batch: return LazyBatch(batch) return batch_type(batch) elif isinstance(example, tuple) or isinstance(example, list): ret = [] for i, item in enumerate(example): all_items = [ b[i] for b in batch] ret.append(collate(all_items, lazy)) return type(example)(ret) elif isinstance(example, dict): ret = {} for key in example.keys(): to_collate = [] for item in batch: to_collate.append(item[key]) ret[key] = collate(to_collate, lazy) return ret elif isinstance(example, NoStackTensor) or isinstance(example, NoStackCatTensor): return batch elif isinstance(example, CategoricalTensor): return torch.stack(batch) else: try: return default_collate(batch) except TypeError: return batch
[docs]class DataLoader(torch.utils.data.DataLoader): """ Dataloader that correctly batchifies Batchable data """ def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): super(DataLoader, self).__init__( dataset, batch_size, shuffle, collate_fn=collate, **kwargs )