Source code for aiaccel.torch.datasets.scatter_dataset

import numpy.typing as npt
from typing import TypeVar

from collections.abc import Callable

import numpy as np

import torch.distributed as dist
from torch.utils.data import Dataset, Subset

T = TypeVar("T")


[docs] def scatter_dataset( dataset: Dataset[T], permute_fn: Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None = None, ) -> Subset[T]: """ Splits a dataset into subsets and returns the subset corresponding to the current process rank. Args: dataset (Dataset[T]): The input dataset to be split. permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional): A function that takes an array of indices and returns a permuted version of the array. If None, a default permutation function using np.random.Generator is used. Defaults to None. Returns: Subset[T]: The subset of the input dataset corresponding to the current process rank. """ if permute_fn is None: permute_fn = np.random.Generator(np.random.PCG64(0)).permutation world_size = dist.get_world_size() rank = dist.get_rank() dataset_size = len(dataset) # type: ignore[arg-type] total_size = int(np.ceil(dataset_size / world_size)) * world_size indices = permute_fn(np.arange(dataset_size)) repeated_indices = np.concatenate([indices, indices[: total_size - dataset_size]]) split_indices = np.split(repeated_indices, world_size) return Subset(dataset, list(split_indices[rank]))