The two ways of computing 'tanh' are shown as follows. Why the computing efficiency of torch.tanh(1) is much higher than the direct expression(2)? I am confused. And where can I find the original code of torch.tanh in pytorch? Dose it written by C/C ?
import torch
import time
def tanh(x):
return (torch.exp(x) - torch.exp(-x)) / (torch.exp(x) torch.exp(-x))
class Function(torch.nn.Module):
def __init__(self):
super(Function, self).__init__()
self.Linear1 = torch.nn.Linear(3, 50)
self.Linear2 = torch.nn.Linear(50, 50)
self.Linear3 = torch.nn.Linear(50, 50)
self.Linear4 = torch.nn.Linear(50, 1)
def forward(self, x):
# (1) for torch.torch
x = torch.tanh(self.Linear1(x))
x = torch.tanh(self.Linear2(x))
x = torch.tanh(self.Linear3(x))
x = torch.tanh(self.Linear4(x))
# (2) for direct expression
# x = tanh(self.Linear1(x))
# x = tanh(self.Linear2(x))
# x = tanh(self.Linear3(x))
# x = tanh(self.Linear4(x))
return x
func = Function()
x= torch.ones(1000,3)
T1 = time.time()
for i in range(10000):
y = func(x)
T2 = time.time()
print(T2-T1)
CodePudding user response:
The mathematical functions are writen in higly optimized code, they can use advanced CPU features and multiple cores, it can even take advantage of GPUs.
in your tanh function it evaluates the exp
function four times, does 2 subtraction and one division, creating temporary tensors require memory allocation that can be slow as well, not to mention the overhead of the python interpreter, being 4 to 10 times slow is reasonable.