aiaccel.torch.lightning.datamodules package#
Submodules#
aiaccel.torch.lightning.datamodules.single_datamodule module#
- class aiaccel.torch.lightning.datamodules.single_datamodule.SingleDataModule(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#
Bases:
LightningDataModule
A PyTorch Lightning DataModule designed to handle training and validation datasets with support for caching and dataset scattering. .. attribute:: train_dataset_fn
A callable function to create the training dataset.
- type:
Callable[…, Dataset[str]]
- val_dataset_fn#
A callable function to create the validation dataset.
- Type:
Callable[…, Dataset[str]]
- batch_size#
The batch size for the DataLoader.
- Type:
int
- use_cache#
Whether to cache the datasets. Defaults to False.
- Type:
bool
- use_scatter#
Whether to scatter the datasets. Defaults to True.
- Type:
bool
- num_workers#
Number of workers for the DataLoader. Defaults to 10.
- Type:
int
- common_args#
Common arguments to pass to the dataset functions. Defaults to None.
- Type:
dict[str, Any] | None
- setup(stage
str | None) -> None: Prepares the datasets for training and validation. Only supports the “fit” stage. Raises a ValueError if the stage is not “fit”.
- _create_dataloader(dataset, **kwargs
Any) -> DataLoader: Internal method to create a DataLoader for a given dataset with specified configurations.
- setup(stage: str | None) None [source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- train_dataloader() DataLoader[Any] [source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader() DataLoader[Any] [source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.
Module contents#
- class aiaccel.torch.lightning.datamodules.SingleDataModule(train_dataset_fn: Callable[[...], Dataset[str]], val_dataset_fn: Callable[[...], Dataset[str]], batch_size: int, use_cache: bool = False, use_scatter: bool = True, num_workers: int = 10, common_args: dict[str, Any] | None = None)[source]#
Bases:
LightningDataModule
A PyTorch Lightning DataModule designed to handle training and validation datasets with support for caching and dataset scattering. .. attribute:: train_dataset_fn
A callable function to create the training dataset.
- type:
Callable[…, Dataset[str]]
- val_dataset_fn#
A callable function to create the validation dataset.
- Type:
Callable[…, Dataset[str]]
- batch_size#
The batch size for the DataLoader.
- Type:
int
- use_cache#
Whether to cache the datasets. Defaults to False.
- Type:
bool
- use_scatter#
Whether to scatter the datasets. Defaults to True.
- Type:
bool
- num_workers#
Number of workers for the DataLoader. Defaults to 10.
- Type:
int
- common_args#
Common arguments to pass to the dataset functions. Defaults to None.
- Type:
dict[str, Any] | None
- setup(stage
str | None) -> None: Prepares the datasets for training and validation. Only supports the “fit” stage. Raises a ValueError if the stage is not “fit”.
- _create_dataloader(dataset, **kwargs
Any) -> DataLoader: Internal method to create a DataLoader for a given dataset with specified configurations.
- setup(stage: str | None) None [source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- train_dataloader() DataLoader[Any] [source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader() DataLoader[Any] [source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.