Home > Net >  How to translate this small part of TensorFlow code into pyTorch?
How to translate this small part of TensorFlow code into pyTorch?

Time:11-23

How to translate this small part of TensorFlow code into pyTorch?

def transforms(x):
        # stft returns spectogram for each sample and each eeg
        # input X contains 3 signals, apply stft for each 
        # and get array with shape [samples, num_of_eeg, time_stamps, freq]
        # change dims and return [samples, time_stamps, freq, num_of_eeg]
        spectrograms = tf.signal.stft(x, frame_length=32, frame_step=4, fft_length=64)
        spectrograms = tf.abs(spectrograms)
        return tf.einsum("...ijk->...jki", spectrograms)

CodePudding user response:

You can find the doc for STFT pytorch implementation here. The rest is fast-forward. It should be:

def transforms(x: torch.Tensor) -> torch.Tensor:
    """Return Fourrier spectrogram."""
    spectrograms = torch.stft(x, win_length=32, n_fft=4, hop_length=64)
    spectrograms = torch.abs(spectrograms)
    return torch.einsum("...ijk->...jki", spectrograms)
  • Related