aiaccel.torch.pipelines.BasePipeline

class aiaccel.torch.pipelines.BasePipeline(allow_tf32: bool = False, unk_args: list[str] = NOTHING)[source]

Base class for inference pipelines.

Note

Note that this class is an experimental feature and may change in the future.

Basic usage:

separate.py
from typing import Any

from argparse import ArgumentParser
from pathlib import Path

from omegaconf import OmegaConf as oc  # noqa: N813

import torch

import attrs

import soundfile as sf

from aiaccel.torch.lightning import load_checkpoint
from aiaccel.torch.pipelines import BasePipeline, reorder_fields


@attrs.define(slots=False, field_transformer=reorder_fields)
class SeparationPipeline(BasePipeline):
    checkpoint_path: Path

    device: str = "cuda"

    src_ext: str = "wav"
    dst_ext: str = "wav"

    overwrite_config: dict[str, Any] | None = None

    def setup(self) -> None:
        self.model, self.config = load_checkpoint(
            self.checkpoint_path,
            device=self.device,
            overwrite_config=self.overwrite_config,
        )
        self.model.eval()

    def __call__(self, wav: torch.Tensor) -> torch.Tensor:
        return self.model(wav)

    @torch.inference_mode()
    def process_one(self, src_filename: Path, dst_filename: Path) -> None:
        wav_mix, sr = sf.load(src_filename, dtype="float32")
        assert sr == self.config.sr, f"Sample rate mismatch: {sr} != {self.config.sr}"

        wav_mix = torch.from_numpy(wav_mix).unsqueeze(0).to(self.device)
        wav_sep = self(wav_mix).squeeze(0).cpu().numpy()

        sf.write(dst_filename, wav_sep, sr)

    @classmethod
    def _prepare_parser(cls, fields: list[attrs.Attribute]) -> ArgumentParser:
        return super()._prepare_parser(
            list(filter(lambda f: f.name != "overwrite_config", fields))
        )

    @classmethod
    def _process_unk_args(
        cls, unk_args: list[str], kwargs: dict[str, Any], parser: ArgumentParser
    ) -> dict[str, Any]:
        return kwargs | {"overwrite_config": oc.from_cli(unk_args)}


if __name__ == "__main__":
    SeparationPipeline.main()
python separate.py one --help
usage: test.py one [-h] [--device DEVICE] [--src_ext SRC_EXT] [--dst_ext DST_EXT] [--allow_tf32] checkpoint_path src_filename dst_filename

positional arguments:
checkpoint_path
src_filename
dst_filename

options:
-h, --help         show this help message and exit
--device DEVICE
--src_ext SRC_EXT
--dst_ext DST_EXT
--allow_tf32
# run inference for one file
python separate.py one ./mixture.wav ./result.wav --checkpoint_path=./sepformer/

# run inference for all files in a directory
python separate.py batch ./mixtures/ ./results/ --checkpoint_path=./sepformer/
__init__(allow_tf32: bool = False, unk_args: list[str] = NOTHING) None

Method generated by attrs for class BasePipeline.

Methods

__init__([allow_tf32, unk_args])

Method generated by attrs for class BasePipeline.

main()

process_one(src_filename, dst_filename)

setup()

Attributes

src_ext

dst_ext

allow_tf32

unk_args