aiaccel.torch.lightning.datamodules.SingleDataModule#
- class aiaccel.torch.lightning.datamodules.SingleDataModule(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#
A PyTorch Lightning DataModule designed to handle training and validation datasets with support for caching and dataset scattering. .. attribute:: train_dataset_fn
A callable function to create the training dataset.
- type:
Callable[…, Dataset[str]]
- val_dataset_fn#
A callable function to create the validation dataset.
- Type:
Callable[…, Dataset[str]]
- batch_size#
The batch size for the DataLoader.
- Type:
int
- use_cache#
Whether to cache the datasets. Defaults to False.
- Type:
bool
- use_scatter#
Whether to scatter the datasets. Defaults to True.
- Type:
bool
- num_workers#
Number of workers for the DataLoader. Defaults to 10.
- Type:
int
- common_args#
Common arguments to pass to the dataset functions. Defaults to None.
- Type:
dict[str, Any] | None
- setup(stage
str | None) -> None: Prepares the datasets for training and validation. Only supports the “fit” stage. Raises a ValueError if the stage is not “fit”.
- _create_dataloader(dataset, **kwargs
Any) -> DataLoader: Internal method to create a DataLoader for a given dataset with specified configurations.
- __init__(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#
- prepare_data_per_node#
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices#
If True, dataloader with zero length within local rank is allowed. Default value is False.
Methods
__init__
(train_dataset_fn, val_dataset_fn, ...)from_datasets
([train_dataset, val_dataset, ...])Create an instance from torch.utils.data.Dataset.
load_from_checkpoint
(checkpoint_path[, ...])Primary way of loading a datamodule from a checkpoint.
load_state_dict
(state_dict)Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
on_after_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_exception
(exception)Called when the trainer execution is interrupted by an exception.
predict_dataloader
()An iterable or collection of iterables specifying prediction samples.
prepare_data
()Use this to download and prepare data.
save_hyperparameters
(*args[, ignore, frame, ...])Save arguments to
hparams
attribute.setup
(stage)Called at the beginning of fit (train + validate), validate, test, or predict.
state_dict
()Called when saving a checkpoint, implement to generate and save datamodule state.
teardown
(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader
()An iterable or collection of iterables specifying test samples.
An iterable or collection of iterables specifying training samples.
transfer_batch_to_device
(batch, device, ...)Override this hook if your
DataLoader
returns tensors wrapped in a custom data structure.An iterable or collection of iterables specifying validation samples.
Attributes
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
hparams
The collection of hyperparameters saved with
save_hyperparameters()
.hparams_initial
The collection of hyperparameters saved with
save_hyperparameters()
.name