aiaccel.torch.datasets.scatter_dataset#
- aiaccel.torch.datasets.scatter_dataset(dataset: Dataset[T], permute_fn: Callable[[ndarray[tuple[int, ...], dtype[int64]]], ndarray[tuple[int, ...], dtype[int64]]] | None = None) Subset[T] [source]#
Splits a dataset into subsets and returns the subset corresponding to the current process rank.
- Parameters:
dataset (Dataset[T]) – The input dataset to be split.
permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional) – A function that takes an array of indices and returns a permuted version of the array. If None, a default permutation function using np.random.Generator is used. Defaults to None.
- Returns:
The subset of the input dataset corresponding to the current process rank.
- Return type:
Subset[T]