Skip to content

sgnts.transforms.converter

Converter dataclass

Bases: TSTransform

Change the data type or the device of the data.

Parameters:

Name Type Description Default
backend str

str, the backend to convert the data to. Supported backends: ['numpy'|'torch']

'numpy'
dtype str

str, the data type to convert the data to. Supported dtypes: ['float32'|'float16']

'float32'
device str

str, the device to convert the data to. Suppored devices: if backend = 'numpy', only supports device = 'cpu', if backend = 'torch', supports device = ['cpu'|'cuda'|'cuda:'] where is the GPU device number.

'cpu'
Source code in sgnts/transforms/converter.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@dataclass
class Converter(TSTransform):
    """Change the data type or the device of the data.

    Args:
        backend:
            str, the backend to convert the data to. Supported backends:
            ['numpy'|'torch']
        dtype:
            str, the data type to convert the data to. Supported dtypes:
            ['float32'|'float16']
        device:
            str, the device to convert the data to. Suppored devices:
            if backend = 'numpy', only supports device = 'cpu', if backend = 'torch',
            supports device = ['cpu'|'cuda'|'cuda:<GPU number>'] where <GPU number> is
            the GPU device number.
    """

    backend: str = "numpy"
    dtype: str = "float32"
    device: str = "cpu"

    def __post_init__(self):
        assert set(self.source_pad_names) == set(self.sink_pad_names), (
            f"Source and sink pad names must match. "
            f"Source: {self.source_pad_names}, Sink: {self.sink_pad_names}"
        )
        super().__post_init__()

        if self.backend == "numpy":
            if self.device != "cpu":
                raise ValueError("Converting to numpy only supports device as cpu")
        elif self.backend == "torch":
            if not TORCH_AVAILABLE:
                raise ImportError(
                    "PyTorch is not installed. Install it with 'pip install "
                    "sgn-ts[torch]'"
                )

            if isinstance(self.dtype, str):
                if self.dtype == "float64":
                    self.dtype = torch.float64
                elif self.dtype == "float32":
                    self.dtype = torch.float32
                elif self.dtype == "float16":
                    self.dtype = torch.float16
                else:
                    raise ValueError(
                        "Supported torch data types: float64, float32, float16"
                    )
            elif isinstance(self.dtype, torch.dtype):
                pass
            else:
                raise ValueError("Unknown dtype")
        else:
            raise ValueError("Supported backends: 'numpy' or 'torch'")

        self.pad_map = {
            p: self.sink_pad_dict["%s:snk:%s" % (self.name, p.name.split(":")[-1])]
            for p in self.source_pads
        }

    def new(self, pad: SourcePad) -> TSFrame:
        frame = self.preparedframes[self.pad_map[pad]]
        self.preparedframes[self.pad_map[pad]] = None

        outbufs = []
        out: None | np.ndarray | torch.Tensor
        for buf in frame:
            if buf.is_gap:
                out = None
            else:
                data = buf.data
                if self.backend == "numpy":
                    if isinstance(data, np.ndarray):
                        # numpy to numpy
                        out = data.astype(self.dtype, copy=False)
                    elif isinstance(data, torch.Tensor):
                        # torch to numpy
                        out = data.detach().cpu().numpy().astype(self.dtype, copy=False)
                    else:
                        raise ValueError("Unsupported data type")
                else:
                    if not TORCH_AVAILABLE:
                        raise ImportError(
                            "PyTorch is not installed. Install it with 'pip "
                            "install sgn-ts[torch]'"
                        )

                    if isinstance(data, np.ndarray):
                        # numpy to torch
                        out = torch.from_numpy(data).to(self.dtype).to(self.device)
                    elif hasattr(torch, "Tensor") and isinstance(data, torch.Tensor):
                        # torch to torch
                        out = data.to(self.dtype).to(self.device)
                    else:
                        raise ValueError("Unsupported data type")

            outbufs.append(
                SeriesBuffer(
                    offset=buf.offset,
                    sample_rate=buf.sample_rate,
                    data=out,
                    shape=buf.shape,
                )
            )

        return TSFrame(
            buffers=outbufs,
            metadata=frame.metadata,
            EOS=frame.EOS,
        )