I’m a beginner in PyTorch. From the lstm description, I learned that I can create a stacked lstm with 3 layers by:
layer = torch.nn.LSTM(128, 512, num_layers=3)
Then in the
forward function, I can do:
def forward(x, state): x, state = layer(x, state) return x, (state.detach(), state.detach())
And I can pass
state from batch to batch.
But if I create 3 lstm layers, what is the equivalent to that if I want to implement the same stacked layers myself?
layer1 = torch.nn.LSTM(128, 512, num_layers=1) layer2 = torch.nn.LSTM(128, 512, num_layers=1) layer3 = torch.nn.LSTM(128, 512, num_layers=1)
In this case, what should go into the
forward function and get the returned
I also tried to look at the source code of pytorch lstm, but in the
forward function it calls a
_VF module which I cannot find where it is defined.
If you define
state as a list of the 3 layers’ states, then
def forward(x, state): x, s0 = layer1(x, state) x, s1 = layer2(x, state) x, s2 = layer3(x, state) return x, [s0.detach(), s1.detach(), s2.detach()]