lightning/examples/fabric/tensor_parallel/data.py

22 lines
678 B
Python

import torch
from torch.utils.data import Dataset
class RandomTokenDataset(Dataset):
def __init__(self, vocab_size: int, seq_length: int):
self.vocab_size = vocab_size
self.seq_length = seq_length
self.tokens = torch.randint(
self.vocab_size,
size=(len(self), self.seq_length + 1),
# Set a seed to make this toy dataset the same on each rank
# Fabric will add a `DistributedSampler` to shard the data correctly
generator=torch.Generator().manual_seed(42),
)
def __len__(self) -> int:
return 128
def __getitem__(self, item: int):
return self.tokens[item]