aiaccel.torch.datasets package#
Submodules#
aiaccel.torch.datasets.cached_dataset module#
- class aiaccel.torch.datasets.cached_dataset.CachedDataset(dataset: Dataset[T_co])[source]#
Bases:
Dataset
[T_co
]A dataset wrapper that caches the samples to improve performance.
- Parameters:
dataset (Dataset) – The original dataset to be wrapped.
- dataset#
The original dataset.
- Type:
Dataset
- manager#
The multiprocessing manager.
- Type:
Manager
- cache#
The cache dictionary to store the cached samples.
- Type:
dict
aiaccel.torch.datasets.file_cached_dataset module#
- class aiaccel.torch.datasets.file_cached_dataset.FileCachedDataset(dataset: Dataset[T_co], cache_path: str | Path)[source]#
Bases:
Dataset
[T_co
]A dataset wrapper that caches samples to disk to reduce memory usage.
This class wraps an existing torch.utils.data.Dataset and caches samples as pickle files in a specified directory.
- Parameters:
dataset (Dataset[T]) – The dataset to wrap.
cache_path (str | Path) – Directory where cached samples will be stored.
- __getitem__(index
int) -> Any: Retrieves a sample from cache or the original dataset.
aiaccel.torch.datasets.hdf5_dataset module#
- class aiaccel.torch.datasets.hdf5_dataset.HDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#
Bases:
RawHDF5Dataset
A dataset class for loading data from an HDF5 file.
This class extends the RawHDF5Dataset class and provides a convenient way to load data from an HDF5 file and convert it into a dictionary of torch tensors.
- Parameters:
path (str) – The path to the HDF5 file.
transform (callable, optional) – A function/transform that takes in a dictionary of data and returns a modified version. Default is None.
- Returns:
- A dictionary containing the data loaded from the HDF5 file, where the keys are
the names of the data fields and the values are torch tensors.
- Return type:
dict[str, torch.Tensor]
- class aiaccel.torch.datasets.hdf5_dataset.RawHDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#
Bases:
Dataset
[dict
[str
,Any
]]A dataset class for reading data from HDF5 files.
- Parameters:
dataset_path (Union[Path, str]) – The path to the HDF5 dataset file.
grp_list (Union[Path, str, List[str], None], optional) – The list of groups to load from the dataset. If None, all groups in the dataset will be loaded. If a string or Path, it should be the path to a file containing the list of groups. If a list, it should directly specify the groups to load. Defaults to None.
- Raises:
NotImplementedError – If grp_list is of an unsupported type.
- dataset_path#
The path to the HDF5 dataset file.
- Type:
Union[Path, str]
- grp_list#
The list of groups to load from the dataset.
- Type:
List[str]
- f#
The HDF5 file object used for reading the dataset.
- Type:
Optional[h5.File]
aiaccel.torch.datasets.scatter_dataset module#
- aiaccel.torch.datasets.scatter_dataset.scatter_dataset(dataset: Dataset[T], permute_fn: Callable[[ndarray[tuple[int, ...], dtype[int64]]], ndarray[tuple[int, ...], dtype[int64]]] | None = None) Subset[T] [source]#
Splits a dataset into subsets and returns the subset corresponding to the current process rank.
- Parameters:
dataset (Dataset[T]) – The input dataset to be split.
permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional) – A function that takes an array of indices and returns a permuted version of the array. If None, a default permutation function using np.random.Generator is used. Defaults to None.
- Returns:
The subset of the input dataset corresponding to the current process rank.
- Return type:
Subset[T]
Module contents#
- class aiaccel.torch.datasets.CachedDataset(dataset: Dataset[T_co])[source]#
Bases:
Dataset
[T_co
]A dataset wrapper that caches the samples to improve performance.
- Parameters:
dataset (Dataset) – The original dataset to be wrapped.
- dataset#
The original dataset.
- Type:
Dataset
- manager#
The multiprocessing manager.
- Type:
Manager
- cache#
The cache dictionary to store the cached samples.
- Type:
dict
- class aiaccel.torch.datasets.FileCachedDataset(dataset: Dataset[T_co], cache_path: str | Path)[source]#
Bases:
Dataset
[T_co
]A dataset wrapper that caches samples to disk to reduce memory usage.
This class wraps an existing torch.utils.data.Dataset and caches samples as pickle files in a specified directory.
- Parameters:
dataset (Dataset[T]) – The dataset to wrap.
cache_path (str | Path) – Directory where cached samples will be stored.
- __getitem__(index
int) -> Any: Retrieves a sample from cache or the original dataset.
- class aiaccel.torch.datasets.HDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#
Bases:
RawHDF5Dataset
A dataset class for loading data from an HDF5 file.
This class extends the RawHDF5Dataset class and provides a convenient way to load data from an HDF5 file and convert it into a dictionary of torch tensors.
- Parameters:
path (str) – The path to the HDF5 file.
transform (callable, optional) – A function/transform that takes in a dictionary of data and returns a modified version. Default is None.
- Returns:
- A dictionary containing the data loaded from the HDF5 file, where the keys are
the names of the data fields and the values are torch tensors.
- Return type:
dict[str, torch.Tensor]
- class aiaccel.torch.datasets.RawHDF5Dataset(dataset_path: Path | str, grp_list: Path | str | list[str] | None = None)[source]#
Bases:
Dataset
[dict
[str
,Any
]]A dataset class for reading data from HDF5 files.
- Parameters:
dataset_path (Union[Path, str]) – The path to the HDF5 dataset file.
grp_list (Union[Path, str, List[str], None], optional) – The list of groups to load from the dataset. If None, all groups in the dataset will be loaded. If a string or Path, it should be the path to a file containing the list of groups. If a list, it should directly specify the groups to load. Defaults to None.
- Raises:
NotImplementedError – If grp_list is of an unsupported type.
- dataset_path#
The path to the HDF5 dataset file.
- Type:
Union[Path, str]
- grp_list#
The list of groups to load from the dataset.
- Type:
List[str]
- f#
The HDF5 file object used for reading the dataset.
- Type:
Optional[h5.File]
- aiaccel.torch.datasets.scatter_dataset(dataset: Dataset[T], permute_fn: Callable[[ndarray[tuple[int, ...], dtype[int64]]], ndarray[tuple[int, ...], dtype[int64]]] | None = None) Subset[T] [source]#
Splits a dataset into subsets and returns the subset corresponding to the current process rank.
- Parameters:
dataset (Dataset[T]) – The input dataset to be split.
permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional) – A function that takes an array of indices and returns a permuted version of the array. If None, a default permutation function using np.random.Generator is used. Defaults to None.
- Returns:
The subset of the input dataset corresponding to the current process rank.
- Return type:
Subset[T]