PyTorch/Lightning Toolkit

Datasets

CachedDataset(dataset)

A dataset wrapper that caches the samples to improve performance.

FileCachedDataset(dataset, cache_path)

A dataset wrapper that caches samples to disk to reduce memory usage.

HDF5Dataset(dataset_path[, grp_list])

A dataset class for loading data from an HDF5 file.

RawHDF5Dataset(dataset_path[, grp_list])

A dataset class for reading data from HDF5 files.

scatter_dataset(dataset[, permute_fn])

Splits a dataset into subsets and returns the subset corresponding to the current process rank.

Functional

linear_sum_assignment(cost_matrix[, maximize])

Solve the linear sum assignment problem for a batch of cost matrices.

Learning Rate Schedulers

SequentialLR(optimizer, schedulers_fn, ...)

A wrapper of torch.optim.lr_scheduler.SequentialLR to use list of functions to create schedulers.

Inference Pipeline Helpers

BasePipeline([allow_tf32, unk_args])

Base class for inference pipelines.

reorder_fields(cls, fields)

Reorder attrs fields such that fields without default values come first, then fields with default values.

Lightning Utilities

OptimizerLightningModule(optimizer_config)

LightningModule subclass for models that use custom optimizers and schedulers.

OptimizerConfig(optimizer_generator[, ...])

Configuration for the optimizer and scheduler in a LightningModule.

build_param_groups(named_params, groups)

Build parameter groups for the optimizer based on the provided patterns.

load_checkpoint(model_path[, config_name, ...])

Load a PyTorch Lightning model from a pre-trained checkpoint.

ABCIEnvironment()

Environment class for ABCI.

Lightning Datamodules

SingleDataModule(train_dataset_fn, ...[, ...])

A PyTorch Lightning DataModule designed to handle training and validation datasets with support for caching and dataset scattering.

H5py Utilities

HDF5Writer()

Abstract base class for writing data to an HDF5 file.