aiaccel.torch.pipelines.BasePipeline¶
- class aiaccel.torch.pipelines.BasePipeline(allow_tf32: bool = False, unk_args: list[str] = NOTHING)[source]¶
Base class for inference pipelines.
Note
Note that this class is an experimental feature and may change in the future.
Basic usage:
separate.py¶from typing import Any from argparse import ArgumentParser from pathlib import Path from omegaconf import OmegaConf as oc # noqa: N813 import torch import attrs import soundfile as sf from aiaccel.torch.lightning import load_checkpoint from aiaccel.torch.pipelines import BasePipeline, reorder_fields @attrs.define(slots=False, field_transformer=reorder_fields) class SeparationPipeline(BasePipeline): checkpoint_path: Path device: str = "cuda" src_ext: str = "wav" dst_ext: str = "wav" overwrite_config: dict[str, Any] | None = None def setup(self) -> None: self.model, self.config = load_checkpoint( self.checkpoint_path, device=self.device, overwrite_config=self.overwrite_config, ) self.model.eval() def __call__(self, wav: torch.Tensor) -> torch.Tensor: return self.model(wav) @torch.inference_mode() def process_one(self, src_filename: Path, dst_filename: Path) -> None: wav_mix, sr = sf.load(src_filename, dtype="float32") assert sr == self.config.sr, f"Sample rate mismatch: {sr} != {self.config.sr}" wav_mix = torch.from_numpy(wav_mix).unsqueeze(0).to(self.device) wav_sep = self(wav_mix).squeeze(0).cpu().numpy() sf.write(dst_filename, wav_sep, sr) @classmethod def _prepare_parser(cls, fields: list[attrs.Attribute]) -> ArgumentParser: return super()._prepare_parser( list(filter(lambda f: f.name != "overwrite_config", fields)) ) @classmethod def _process_unk_args( cls, unk_args: list[str], kwargs: dict[str, Any], parser: ArgumentParser ) -> dict[str, Any]: return kwargs | {"overwrite_config": oc.from_cli(unk_args)} if __name__ == "__main__": SeparationPipeline.main()
python separate.py one --helpusage: test.py one [-h] [--device DEVICE] [--src_ext SRC_EXT] [--dst_ext DST_EXT] [--allow_tf32] checkpoint_path src_filename dst_filename positional arguments: checkpoint_path src_filename dst_filename options: -h, --help show this help message and exit --device DEVICE --src_ext SRC_EXT --dst_ext DST_EXT --allow_tf32
# run inference for one file python separate.py one ./mixture.wav ./result.wav --checkpoint_path=./sepformer/ # run inference for all files in a directory python separate.py batch ./mixtures/ ./results/ --checkpoint_path=./sepformer/
- __init__(allow_tf32: bool = False, unk_args: list[str] = NOTHING) None¶
Method generated by attrs for class BasePipeline.
Methods
__init__([allow_tf32, unk_args])Method generated by attrs for class BasePipeline.
main()process_one(src_filename, dst_filename)setup()Attributes
src_extdst_extallow_tf32unk_args