aiaccel.torch.lightning.load_checkpoint

aiaccel.torch.lightning.load_checkpoint(model_path: str | Path, config_name: str = 'merged_config.yaml', device: str = 'cuda', overwrite_config: DictConfig | ListConfig | dict[Any, Any] | list[Any] | None = None) tuple[Module, DictConfig | ListConfig][source]

Load a PyTorch Lightning model from a pre-trained checkpoint.

This function loads a model from a specified path, which can be a local directory or a Hugging Face repository. It also loads the associated configuration file and allows for optional configuration overrides. The model can be set to evaluation mode if specified.

Parameters:
  • model_path (str | Path) – The path to the model directory or Hugging Face repo. For local paths, use the format “file://<absolute_path>” or just the path (str | Path). For Hugging Face, use the format “hf://<repo_id>”.

  • config_name (str) – The name of the configuration file to load. Default is “merged_config.yaml”.

  • device (str) – The device to map the model to. Default is “cuda”.

  • overwrite_config (DictConfig | ListConfig | dict | list | None) – Optional configuration overrides.