I'm trying to upsample an RGB image in the frequency domain, using Pytorch. I'm using
The upscaled image:
Another interesting thing to note is the maximum and minimum values of the image pixels after performing IFFT: they are 2.2729
and -1.8376
respectively. Ideally, they should be 1.0 and 0.0.
Can someone please explain what's wrong here?
CodePudding user response:
The usual convention for the DFT is to treat the first sample as 0Hz component. But you need to have the 0Hz component in the center in order for padding to make sense. Most FFT tools provide a shift function to circularly shift your result so that the 0Hz component is in the center. In pytorch you need to perform torch.fft.fftshift
after the FFT and torch.fft.ifftshift
right before taking the inverse FFT to put the 0Hz component back in the upper left corner.
import torch
import torch.nn.functional as F
import cv2
import numpy as np
img = src = cv2.imread('orig.png')
torch_img = torch.from_numpy(img).to(torch.float32).permute(2, 0, 1) / 255.
# note the fftshift
fft = torch.fft.fftshift(torch.fft.fft2(torch_img, norm="forward"))
fr = fft.real
fi = fft.imag
fr = F.pad(fr, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)
fi = F.pad(fi, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)
# note the ifftshift
fft_hires = torch.fft.ifftshift(torch.complex(fr, fi))
inv = torch.fft.ifft2(fft_hires, norm="forward").real
print(inv.max(), inv.min())
img = (inv.permute(1, 2, 0).detach()).clamp(0, 1)
img = (255 * img).numpy().astype(np.uint8)
cv2.imwrite('hires.png', img)
which produces the following hires.png