aiaccel.torch.lightning.datamodules package#

Submodules#

aiaccel.torch.lightning.datamodules.single_datamodule module#

class aiaccel.torch.lightning.datamodules.single_datamodule.SingleDataModule(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#

Bases: LightningDataModule

A PyTorch Lightning DataModule designed to handle training and validation datasets with support for caching and dataset scattering. .. attribute:: train_dataset_fn

A callable function to create the training dataset.

type:

Callable[…, Dataset[str]]

val_dataset_fn#

A callable function to create the validation dataset.

Type:

Callable[…, Dataset[str]]

batch_size#

The batch size for the DataLoader.

Type:

int

use_cache#

Whether to cache the datasets. Defaults to False.

Type:

bool

use_scatter#

Whether to scatter the datasets. Defaults to True.

Type:

bool

num_workers#

Number of workers for the DataLoader. Defaults to 10.

Type:

int

common_args#

Common arguments to pass to the dataset functions. Defaults to None.

Type:

dict[str, Any] | None

setup(stage

str | None) -> None: Prepares the datasets for training and validation. Only supports the “fit” stage. Raises a ValueError if the stage is not “fit”.

train_dataloader() DataLoader[source]#

Returns the DataLoader for the training dataset.

val_dataloader() DataLoader[source]#

Returns the DataLoader for the validation dataset.

_create_dataloader(dataset, **kwargs

Any) -> DataLoader: Internal method to create a DataLoader for a given dataset with specified configurations.

setup(stage: str | None) None[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
train_dataloader() DataLoader[Any][source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader() DataLoader[Any][source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Module contents#

class aiaccel.torch.lightning.datamodules.SingleDataModule(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#

Bases: LightningDataModule

A PyTorch Lightning DataModule designed to handle training and validation datasets with support for caching and dataset scattering. .. attribute:: train_dataset_fn

A callable function to create the training dataset.

type:

Callable[…, Dataset[str]]

val_dataset_fn#

A callable function to create the validation dataset.

Type:

Callable[…, Dataset[str]]

batch_size#

The batch size for the DataLoader.

Type:

int

use_cache#

Whether to cache the datasets. Defaults to False.

Type:

bool

use_scatter#

Whether to scatter the datasets. Defaults to True.

Type:

bool

num_workers#

Number of workers for the DataLoader. Defaults to 10.

Type:

int

common_args#

Common arguments to pass to the dataset functions. Defaults to None.

Type:

dict[str, Any] | None

setup(stage

str | None) -> None: Prepares the datasets for training and validation. Only supports the “fit” stage. Raises a ValueError if the stage is not “fit”.

train_dataloader() DataLoader[source]#

Returns the DataLoader for the training dataset.

val_dataloader() DataLoader[source]#

Returns the DataLoader for the validation dataset.

_create_dataloader(dataset, **kwargs

Any) -> DataLoader: Internal method to create a DataLoader for a given dataset with specified configurations.

setup(stage: str | None) None[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
train_dataloader() DataLoader[Any][source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader() DataLoader[Any][source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.