What is equivalent to pytorch lstm num_layers?

I’m a beginner in PyTorch. From the lstm description, I learned that I can create a stacked lstm with 3 layers by:

layer = torch.nn.LSTM(128, 512, num_layers=3)

Then in the forward function, I can do:

def forward(x, state):
    x, state = layer(x, state)
    return x, (state[0].detach(), state[1].detach())

And I can pass state from batch to batch.
But if I create 3 lstm layers, what is the equivalent to that if I want to implement the same stacked layers myself?

layer1 = torch.nn.LSTM(128, 512, num_layers=1)
layer2 = torch.nn.LSTM(128, 512, num_layers=1)
layer3 = torch.nn.LSTM(128, 512, num_layers=1)

In this case, what should go into the forward function and get the returned state?
I also tried to look at the source code of pytorch lstm, but in the forward function it calls a _VF module which I cannot find where it is defined.


If you define state as a list of the 3 layers’ states, then

def forward(x, state):
    x, s0 = layer1(x, state[0])
    x, s1 = layer2(x, state[1])
    x, s2 = layer3(x, state[2])
    return x, [s0.detach(), s1.detach(), s2.detach()]