Source code for terrace.module

from typing import Tuple, Union
import torch
import torch.nn as nn

from .categorical_tensor import CategoricalTensor

[docs]class Module(nn.Module): def __init__(self): super().__init__() self._initialized = False self._started_forward = False self._submodule_index = 0 self._submodules = nn.ModuleList() self._submodule_list = [] self._param_index = 0 self._params = nn.ParameterList() self._checkpoints = [] self._checkpoint_index = 0
[docs] def start_forward(self): # todo: these boolean values getting unweildly -- cut down! if self.__dict__["_submodule_index"] != 0: self.__dict__["_initialized"] = True self.__dict__["_submodule_index"] = 0 self.__dict__["_param_index"] = 0 self.__dict__["_started_forward"] = True self.__dict__["_checkpoint_index"] = 0
[docs] def make(self, cls, *args, **kwargs): if not self.__dict__["_started_forward"]: raise RuntimeError("You must call start_forward before you call make") if not self.__dict__["_initialized"]: submod = cls(*args, **kwargs) self._submodules.append(submod) self._submodule_list.append(submod) submod = self.__dict__["_submodule_list"][self.__dict__["_submodule_index"]] self.__dict__["_submodule_index"] += 1 return submod
[docs] def make_param(self, cls, *args, **kwargs): if not self._started_forward: raise RuntimeError("You must call start_forward before you call make_param") if not self._initialized: self._params.append(cls(*args, **kwargs)) param = self._params[self._param_index] self._param_index+= 1 return param
[docs] def checkpoint(self): if self._initialized: self._submodule_index, self._param_index = self._checkpoints[self._checkpoint_index] self._checkpoint_index += 1 else: self._checkpoints.append((self._submodule_index, self._param_index))
[docs] def loop_start(self): raise NotImplementedError
[docs] def loop_body(self): raise NotImplementedError
[docs] def is_initialized(self): return self._initialized or self._submodule_index > 0 or self._param_index > 0
[docs] def parameters(self, recurse: bool = True): assert self.is_initialized(), "Terrace Module needs to be run on data before parameters method can be called" return super().parameters(recurse)
[docs]class WrapperModule(Module): def __init__(self, *args, **kwargs): super().__init__() self.args = args self.kwargs = kwargs
[docs]class LazyLinear(WrapperModule): """ torch >=1.13 has this, but wanted to try my hand at implementing it myself (and ensuring that older torch versions work) """
[docs] def forward(self, x): self.start_forward() in_feats = x.shape[-1] return self.make(nn.Linear, in_feats, *self.args, **self.kwargs)(x)
[docs]class LazyEmbedding(Module): """ LazyEmbedding uses the num_classes from CategoricalTensors to determine embedding weight size. Note that it assumes tensors have shape (..., N) where N is the number of categorical features. So, in most cases where you have a batch of single categorical features, you must give the embedding a tensor of shape (B, 1). Admittedly this is a bit weird, but it does nicely extend to cases where you have multiple categorical features. In this case, the embedding creates an ``nn.Embedding`` layer for each feature and concatenates the result together. Thus the output of this layer has shape (B, E*N), where E is the ``embedding_dim``. """ def __init__(self, embedding_dims: Union[Tuple[int], int]): """ If ``embedding_dims`` is a tuple, it specifies the per-feature embedding dimension (must have the same length as the ``num_classes`` of the input tensor). If it's an int, we use the same dimension for the embedding of all the features. """ super().__init__() self.embedding_dims = embedding_dims
[docs] def forward(self, x: CategoricalTensor): self.start_forward() embedding_dims = [self.embedding_dims]*len(x.num_classes) if isinstance(self.embedding_dims, int) else self.embedding_dims ret = [] for idx in range(x.shape[-1]): max_val = x.num_classes[idx] ret.append(self.make(nn.Embedding, max_val, embedding_dims[idx])(x.tensor[..., idx])) return torch.cat(ret, -1)
[docs]class LazyMultiheadAttention(WrapperModule):
[docs] def forward(self, q, k, v): self.start_forward() embed_dim = q.shape[-1] self.kwargs["kdim"] = k.shape[-1] self.kwargs["vdim"] = v.shape[-1] return self.make(nn.MultiheadAttention, embed_dim, *self.args, **self.kwargs)(q, k, v)
[docs]class LazyLayerNorm(WrapperModule):
[docs] def forward(self, x): self.start_forward() return self.make(nn.LayerNorm, x.shape[1:], *self.args, **self.kwargs)(x)