How to translate this small part of TensorFlow code into pyTorch?
def transforms(x):
# stft returns spectogram for each sample and each eeg
# input X contains 3 signals, apply stft for each
# and get array with shape [samples, num_of_eeg, time_stamps, freq]
# change dims and return [samples, time_stamps, freq, num_of_eeg]
spectrograms = tf.signal.stft(x, frame_length=32, frame_step=4, fft_length=64)
spectrograms = tf.abs(spectrograms)
return tf.einsum("...ijk->...jki", spectrograms)
CodePudding user response:
You can find the doc for STFT pytorch implementation here. The rest is fast-forward. It should be:
def transforms(x: torch.Tensor) -> torch.Tensor:
"""Return Fourrier spectrogram."""
spectrograms = torch.stft(x, win_length=32, n_fft=4, hop_length=64)
spectrograms = torch.abs(spectrograms)
return torch.einsum("...ijk->...jki", spectrograms)