Batches

Terrace Batches are a way to object-orientify PyTorch code. Using vanilla PyTorch DataLoader, your dataset can return tensors (and tuples and dicts of tensors), but not arbitrary classes – PyTorch doesn’t know how to collate them. Not anymore!

Let’s suppose you’re dealing with a dataset of people. For each person, there is an image of their face and a string name. In the dark ages, you would have loaded the face and name into two seperate tensors and passed them individually as arguments to your functions. Now, however, you can make a Person class.

import torch
import terrace as ter

MAX_NAME_LEN = 128
IMG_SIZE = 256
class Person(ter.Batchable):

    face: torch.Tensor
    name: torch.Tensor

    def __init__(self): # fake person data
        self.face = torch.zeros((3, IMG_SIZE, IMG_SIZE))
        self.name = torch.zeros((MAX_NAME_LEN,))

Since we’ve inherited from the Batchable class, terrace knows how to collate multiple people into a single Batch[Person].

dave = Person()
rhonda = Person()
batch = ter.collate([dave, rhonda])
print(batch)

# Batches have a length and can be indexed, just like lists
print("Batch length:", len(batch))

# But you can also access their (batched) members, just like
# objects of the original class
print("Person name shape:", dave.name.shape)
print("Batch[Person] name shape:", batch.name.shape) # notice the extra batch dimension

print(batch[0]) # un-batchification
Batch[Person](
   face=Tensor(shape=torch.Size([2, 3, 256, 256]), dtype=torch.float32)
   name=Tensor(shape=torch.Size([2, 128]), dtype=torch.float32)
)
Batch length: 2
Person name shape: torch.Size([128])
Batch[Person] name shape: torch.Size([2, 128])
BatchView[Person](
   face=Tensor(shape=torch.Size([3, 256, 256]), dtype=torch.float32)
   name=Tensor(shape=torch.Size([128]), dtype=torch.float32)
)

Note

Notice that last line – when we index a batch, it returns a BatchView[Person], not a Person directly. This is for performance reasons – under the hood, a BatchView lazily indexes the members of its parent batch. The vast majority of the time, this BatchView will act exactly like a Person. If it doesn’t, please submit a bug report.

Now we can use Terrace’s DataLoader to automatically batchify a custom dataset.

class PersonDataset(torch.utils.data.Dataset):

    def __len__(self):
        return 16

    def __getitem__(self, index):
        return Person()

batch_size = 8
dataset = PersonDataset()
loader = ter.DataLoader(dataset, batch_size=batch_size)
for batch in loader:
    print(batch)
Batch[Person](
   face=Tensor(shape=torch.Size([8, 3, 256, 256]), dtype=torch.float32)
   name=Tensor(shape=torch.Size([8, 128]), dtype=torch.float32)
)
Batch[Person](
   face=Tensor(shape=torch.Size([8, 3, 256, 256]), dtype=torch.float32)
   name=Tensor(shape=torch.Size([8, 128]), dtype=torch.float32)
)

Additionally, we can create new batches of people directly (e.g. in a generative model).

batch = ter.Batch(Person,
                  face=torch.zeros((batch_size, 3, IMG_SIZE, IMG_SIZE)),
                  name=torch.zeros((batch_size, MAX_NAME_LEN)))
print(batch)
Batch[Person](
   face=Tensor(shape=torch.Size([8, 3, 256, 256]), dtype=torch.float32)
   name=Tensor(shape=torch.Size([8, 128]), dtype=torch.float32)
)

Warning

Creating a batch directly from batched member data is dangerous because Terrace (currently) doesn’t do any checking to make sure you’ve input reasonable arguments.

Graphs

In addition to batchifying everyday data, Terrace has special graph functionality. With GraphBatches, terraces provides an higher-level object-oriented abstraction over DGL graphs. (In the future, PyG might be added as a backend as well.).

You can create Batchable subclasses for both node and edge data. Here’s how.

class Atom(ter.Batchable):

    # let's suppose atoms have a 3D position
    # and an atomic mass

    position: torch.Tensor
    mass: torch.Tensor

    def __init__(self):
        """ Fill with dummy data """
        self.position = torch.zeros((3,))
        self.mass = torch.zeros((1,))

class Bond(ter.Batchable):

    order: torch.Tensor

    def __init__(self):
        self.order = torch.zeros((1,))

# to create a Terrace graph, we need the node data, edge indexes,
# and (optionally) edge data
ndata = [ Atom(), Atom(), Atom() ]
edges = [ (0, 1), (0, 2)]
edata = [ Bond(), Bond() ]

mol = ter.Graph(ndata, edges, edata)
print(mol)

# we can access their node and edge data batches with ndata and edata
print(mol.ndata)
print(mol.edata)

# If we want, we can also get the underlying DGL graph.
# This is necessary when we want to create wrapper modules
# for the DGL model classes
print(mol.dgl())
Graph(
   ndata=Batch[Atom](
      position=Tensor(shape=torch.Size([3, 3]), dtype=torch.float32)
      mass=Tensor(shape=torch.Size([3, 1]), dtype=torch.float32)
   )
   edata=Batch[Bond](
      order=Tensor(shape=torch.Size([4, 1]), dtype=torch.float32)
   )
)
Batch[Atom](
   position=Tensor(shape=torch.Size([3, 3]), dtype=torch.float32)
   mass=Tensor(shape=torch.Size([3, 1]), dtype=torch.float32)
)
Batch[Bond](
   order=Tensor(shape=torch.Size([4, 1]), dtype=torch.float32)
)
Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'position': Scheme(shape=(3,), dtype=torch.float32), 'mass': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={'order': Scheme(shape=(1,), dtype=torch.float32)})

Of course, we can combine graph batches and regular batches into arbitrarily complex nested structures.

class MolAndData(ter.Batchable):

    # all Batchable classes are dataclasses,
    # so we don't actually need a constructor

    mol: ter.Graph[Atom, Bond]
    data: torch.Tensor

data = torch.zeros((8,))
mol_and_data = MolAndData(mol, data)
batch = ter.collate([ mol_and_data, mol_and_data, mol_and_data])
print(batch)
Batch[MolAndData](
   mol=GraphBatch(
      ndata=Batch[Atom](
         position=Tensor(shape=torch.Size([9, 3]), dtype=torch.float32)
         mass=Tensor(shape=torch.Size([9, 1]), dtype=torch.float32)
      )
      edata=Batch[Bond](
         order=Tensor(shape=torch.Size([12, 1]), dtype=torch.float32)
      )
   )
   data=Tensor(shape=torch.Size([3, 8]), dtype=torch.float32)
)

This is a very simple example, but feel free to go wild.

Advanced features

To enable your code to be even more object-oriented, Terrace allows you to define member functions for your batches. If you define a member function in your Batchable class with the name batch_{func_name}, batches of your class will all have the member function {func_name}. Here’s how we can modify the Person class from above to use this feature.

class Person(ter.Batchable):

    face: torch.Tensor
    name: torch.Tensor

    def say_hi(self):
        print("Hello, I'm a person")

    def batch_say_hi(self):
        print(f"Hello, I'm a batch of {len(self)} people")

    def __init__(self): # fake person data
        self.face = torch.zeros((3, IMG_SIZE, IMG_SIZE))
        self.name = torch.zeros((MAX_NAME_LEN,))

person = Person()
batch = ter.collate([Person(), Person(), Person()])

person.say_hi()
batch.say_hi()
Hello, I'm a person
Hello, I'm a batch of 3 people