Home > Mobile >  Training with threshold in PyTorch
Training with threshold in PyTorch

Time:03-10

I have a neural network, which produces a single value when excited with input. I need to use this value returned by the network to threshold another array. The result of this threshold operation is used to compute a loss function (the value of threshold is not known before hand and needs to be arrived at by training). Following is an MWE

import torch

x = torch.randn(10, 1)  # Say this is the output of the network (10 is my batch size)
data_array = torch.randn(10, 2)  # This is the data I need to threshold
ground_truth = torch.randn(10, 2)  # This is the ground truth
mse_loss = torch.nn.MSELoss()  # Loss function

# Threshold
thresholded_vals = data_array * (data_array >= x)  # Returns zero in all places where the value is less than the threshold, the value itself otherwise

# Compute loss and gradients
loss = mse_loss(thresholded_vals, ground_truth)
loss.backward()  # Throws error here

Since the operation of thresholding returns a tensor array that is devoid of any gradients the backward() operation throws error.

How does one train a network in such a case?

CodePudding user response:

Your threshold function is not differentiable in the threshold, therefore torch does not calculate the gradient for the threshold which is why your example is not working.

import torch

x = torch.randn(10, 1, requires_grad=True)  # Say this is the output of the network (10 is my batch size)
data_array = torch.randn(10, 2, requires_grad=True)  # This is the data I need to threshold
ground_truth = torch.randn(10, 2)  # This is the ground truth
mse_loss = torch.nn.MSELoss()  # Loss function

# Threshold
thresholded_vals = data_array * (data_array >= x)  # Returns zero in all places where the value is less than the threshold, the value itself otherwise

# Compute loss and gradients
loss = mse_loss(thresholded_vals, ground_truth)
loss.backward()  # Throws error here
print(x.grad)
print(data_array.grad)

Output:

None #<- for the threshold x
tensor([[ 0.1088, -0.0617],  #<- for the data_array
        [ 0.1011,  0.0000],
        [ 0.0000,  0.0000],
        [-0.0000, -0.0000],
        [ 0.2047,  0.0973],
        [-0.0000,  0.2197],
        [-0.0000,  0.0929],
        [ 0.1106,  0.2579],
        [ 0.0743,  0.0880],
        [ 0.0000,  0.1112]])
  • Related