aiaccel.torch.datasets package#

Submodules#

aiaccel.torch.datasets.cached_dataset module#

class aiaccel.torch.datasets.cached_dataset.CachedDataset(dataset: Dataset[T_co])[source]#

Bases: Dataset[T_co]

A dataset wrapper that caches the samples to improve performance.

Parameters:

dataset (Dataset) – The original dataset to be wrapped.

dataset#

The original dataset.

Type:

Dataset

manager#

The multiprocessing manager.

Type:

Manager

cache#

The cache dictionary to store the cached samples.

Type:

dict

aiaccel.torch.datasets.file_cached_dataset module#

class aiaccel.torch.datasets.file_cached_dataset.FileCachedDataset(dataset: Dataset[T_co], cache_path: str | Path)[source]#

Bases: Dataset[T_co]

A dataset wrapper that caches samples to disk to reduce memory usage.

This class wraps an existing torch.utils.data.Dataset and caches samples as pickle files in a specified directory.

Parameters:
  • dataset (Dataset[T]) – The dataset to wrap.

  • cache_path (str | Path) – Directory where cached samples will be stored.

__len__()[source]#

Returns the number of samples in the dataset.

__getitem__(index

int) -> Any: Retrieves a sample from cache or the original dataset.

aiaccel.torch.datasets.hdf5_dataset module#

class aiaccel.torch.datasets.hdf5_dataset.HDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#

Bases: RawHDF5Dataset

A dataset class for loading data from an HDF5 file.

This class extends the RawHDF5Dataset class and provides a convenient way to load data from an HDF5 file and convert it into a dictionary of torch tensors.

Parameters:
  • path (str) – The path to the HDF5 file.

  • transform (callable, optional) – A function/transform that takes in a dictionary of data and returns a modified version. Default is None.

Returns:

A dictionary containing the data loaded from the HDF5 file, where the keys are

the names of the data fields and the values are torch tensors.

Return type:

dict[str, torch.Tensor]

class aiaccel.torch.datasets.hdf5_dataset.RawHDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#

Bases: Dataset[dict[str, Any]]

A dataset class for reading data from HDF5 files.

Parameters:
  • dataset_path (Union[Path, str]) – The path to the HDF5 dataset file.

  • grp_list (Union[Path, str, List[str], None], optional) – The list of groups to load from the dataset. If None, all groups in the dataset will be loaded. If a string or Path, it should be the path to a file containing the list of groups. If a list, it should directly specify the groups to load. Defaults to None.

Raises:

NotImplementedError – If grp_list is of an unsupported type.

dataset_path#

The path to the HDF5 dataset file.

Type:

Union[Path, str]

grp_list#

The list of groups to load from the dataset.

Type:

List[str]

f#

The HDF5 file object used for reading the dataset.

Type:

Optional[h5.File]

aiaccel.torch.datasets.scatter_dataset module#

aiaccel.torch.datasets.scatter_dataset.scatter_dataset(dataset: Dataset[T], permute_fn: Callable[[ndarray[tuple[int, ...], dtype[int64]]], ndarray[tuple[int, ...], dtype[int64]]] | None = None) Subset[T][source]#

Splits a dataset into subsets and returns the subset corresponding to the current process rank.

Parameters:
  • dataset (Dataset[T]) – The input dataset to be split.

  • permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional) – A function that takes an array of indices and returns a permuted version of the array. If None, a default permutation function using np.random.Generator is used. Defaults to None.

Returns:

The subset of the input dataset corresponding to the current process rank.

Return type:

Subset[T]

Module contents#

class aiaccel.torch.datasets.CachedDataset(dataset: Dataset[T_co])[source]#

Bases: Dataset[T_co]

A dataset wrapper that caches the samples to improve performance.

Parameters:

dataset (Dataset) – The original dataset to be wrapped.

dataset#

The original dataset.

Type:

Dataset

manager#

The multiprocessing manager.

Type:

Manager

cache#

The cache dictionary to store the cached samples.

Type:

dict

class aiaccel.torch.datasets.FileCachedDataset(dataset: Dataset[T_co], cache_path: str | Path)[source]#

Bases: Dataset[T_co]

A dataset wrapper that caches samples to disk to reduce memory usage.

This class wraps an existing torch.utils.data.Dataset and caches samples as pickle files in a specified directory.

Parameters:
  • dataset (Dataset[T]) – The dataset to wrap.

  • cache_path (str | Path) – Directory where cached samples will be stored.

__len__()[source]#

Returns the number of samples in the dataset.

__getitem__(index

int) -> Any: Retrieves a sample from cache or the original dataset.

class aiaccel.torch.datasets.HDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#

Bases: RawHDF5Dataset

A dataset class for loading data from an HDF5 file.

This class extends the RawHDF5Dataset class and provides a convenient way to load data from an HDF5 file and convert it into a dictionary of torch tensors.

Parameters:
  • path (str) – The path to the HDF5 file.

  • transform (callable, optional) – A function/transform that takes in a dictionary of data and returns a modified version. Default is None.

Returns:

A dictionary containing the data loaded from the HDF5 file, where the keys are

the names of the data fields and the values are torch tensors.

Return type:

dict[str, torch.Tensor]

class aiaccel.torch.datasets.RawHDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#

Bases: Dataset[dict[str, Any]]

A dataset class for reading data from HDF5 files.

Parameters:
  • dataset_path (Union[Path, str]) – The path to the HDF5 dataset file.

  • grp_list (Union[Path, str, List[str], None], optional) – The list of groups to load from the dataset. If None, all groups in the dataset will be loaded. If a string or Path, it should be the path to a file containing the list of groups. If a list, it should directly specify the groups to load. Defaults to None.

Raises:

NotImplementedError – If grp_list is of an unsupported type.

dataset_path#

The path to the HDF5 dataset file.

Type:

Union[Path, str]

grp_list#

The list of groups to load from the dataset.

Type:

List[str]

f#

The HDF5 file object used for reading the dataset.

Type:

Optional[h5.File]

aiaccel.torch.datasets.scatter_dataset(dataset: Dataset[T], permute_fn: Callable[[ndarray[tuple[int, ...], dtype[int64]]], ndarray[tuple[int, ...], dtype[int64]]] | None = None) Subset[T][source]#

Splits a dataset into subsets and returns the subset corresponding to the current process rank.

Parameters:
  • dataset (Dataset[T]) – The input dataset to be split.

  • permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional) – A function that takes an array of indices and returns a permuted version of the array. If None, a default permutation function using np.random.Generator is used. Defaults to None.

Returns:

The subset of the input dataset corresponding to the current process rank.

Return type:

Subset[T]