aiaccel.torch.datasets.scatter_dataset#

aiaccel.torch.datasets.scatter_dataset(dataset: Dataset[T], permute_fn: Callable[[ndarray[tuple[int, ...], dtype[int64]]], ndarray[tuple[int, ...], dtype[int64]]] | None = None) Subset[T][source]#

Splits a dataset into subsets and returns the subset corresponding to the current process rank.

Parameters:
  • dataset (Dataset[T]) – The input dataset to be split.

  • permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional) – A function that takes an array of indices and returns a permuted version of the array. If None, a default permutation function using np.random.Generator is used. Defaults to None.

Returns:

The subset of the input dataset corresponding to the current process rank.

Return type:

Subset[T]