activation should be relu/gelu, not (param1)
Package:
torch
50580

Exception Class:
RuntimeError
Raise code
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
Links to the raise (1)
https://github.com/pytorch/pytorch/blob/e56d3b023818f54553f2dc5d30b6b7aaf6b6a325/torch/nn/modules/transformer.py#L428Ways to fix
When initializing TransformerDecoderLayer the parameter activation should be given a valid value.
Reproducing the error:
pipenv install torch
import torch
from torch import nn
decoder_layer = nn.TransformerDecoderLayer(d_model=512,
nhead=8,
activation="rel") # the valid value are either relu or gelu
memory = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
out = decoder_layer(tgt, memory)
print(out.shape)
The error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-7-36493082f32e> in <module>()
4 decoder_layer = nn.TransformerDecoderLayer(d_model=512,
5 nhead=8,
----> 6 activation="rel")
7 memory = torch.rand(10, 32, 512)
8 tgt = torch.rand(20, 32, 512)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/transformer.py in __init__(self, d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, device, dtype)
380 self.dropout3 = Dropout(dropout)
381
--> 382 self.activation = _get_activation_fn(activation)
383
384 def __setstate__(self, state):
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/transformer.py in _get_activation_fn(activation)
426 return F.gelu
427
--> 428 raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
RuntimeError: activation should be relu/gelu, not rel
Fixed:
import torch
from torch import nn
decoder_layer = nn.TransformerDecoderLayer(d_model=512,
nhead=8,
activation="relu")
memory = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
out = decoder_layer(tgt, memory)
print(out.shape)
Output:
torch.Size([20, 32, 512])
Add a possible fix
Please authorize to post fix