29 lines
697 B
Python
29 lines
697 B
Python
# coding=utf-8
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class TestModule(nn.Module):
|
|
def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.model = DecodeLayer()
|
|
|
|
def forward(self, x):
|
|
for module in self.model:
|
|
x = module(x)
|
|
return x
|
|
|
|
|
|
class DecodeLayer(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.layers = nn.ModuleList()
|
|
for i in range(10):
|
|
self.layers.append(nn.Linear(10, 10))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_module = TestModule(0, 3)
|
|
for x in test_module.named_parameters():
|
|
print(x[0])
|