# coding=utf-8 import torch.nn as nn class TestModule(nn.Module): def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs): super().__init__(*args, **kwargs) self.model = DecodeLayer() def forward(self, x): for module in self.model: x = module(x) return x class DecodeLayer(nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.layers = nn.ModuleList() for i in range(10): self.layers.append(nn.Linear(10, 10)) if __name__ == "__main__": test_module = TestModule(0, 3) for x in test_module.named_parameters(): print(x[0])