aiaccel.torch.lightning.datamodules.SingleDataModule#

class aiaccel.torch.lightning.datamodules.SingleDataModule(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#

A PyTorch Lightning DataModule designed to handle training and validation datasets with support for caching and dataset scattering. .. attribute:: train_dataset_fn

A callable function to create the training dataset.

type:

Callable[…, Dataset[str]]

val_dataset_fn#

A callable function to create the validation dataset.

Type:

Callable[…, Dataset[str]]

batch_size#

The batch size for the DataLoader.

Type:

int

use_cache#

Whether to cache the datasets. Defaults to False.

Type:

bool

use_scatter#

Whether to scatter the datasets. Defaults to True.

Type:

bool

num_workers#

Number of workers for the DataLoader. Defaults to 10.

Type:

int

common_args#

Common arguments to pass to the dataset functions. Defaults to None.

Type:

dict[str, Any] | None

setup(stage

str | None) -> None: Prepares the datasets for training and validation. Only supports the “fit” stage. Raises a ValueError if the stage is not “fit”.

train_dataloader() DataLoader[source]#

Returns the DataLoader for the training dataset.

val_dataloader() DataLoader[source]#

Returns the DataLoader for the validation dataset.

_create_dataloader(dataset, **kwargs

Any) -> DataLoader: Internal method to create a DataLoader for a given dataset with specified configurations.

__init__(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#
prepare_data_per_node#

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices#

If True, dataloader with zero length within local rank is allowed. Default value is False.

Methods

__init__(train_dataset_fn, val_dataset_fn, ...)

from_datasets([train_dataset, val_dataset, ...])

Create an instance from torch.utils.data.Dataset.

load_from_checkpoint(checkpoint_path[, ...])

Primary way of loading a datamodule from a checkpoint.

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_exception(exception)

Called when the trainer execution is interrupted by an exception.

predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

prepare_data()

Use this to download and prepare data.

save_hyperparameters(*args[, ignore, frame, ...])

Save arguments to hparams attribute.

setup(stage)

Called at the beginning of fit (train + validate), validate, test, or predict.

state_dict()

Called when saving a checkpoint, implement to generate and save datamodule state.

teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader()

An iterable or collection of iterables specifying test samples.

train_dataloader()

An iterable or collection of iterables specifying training samples.

transfer_batch_to_device(batch, device, ...)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

val_dataloader()

An iterable or collection of iterables specifying validation samples.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

name