torch_ext/fi/test_module.py
2025-01-04 13:47:42 +08:00

29 lines
697 B
Python

# coding=utf-8
import torch.nn as nn
class TestModule(nn.Module):
def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = DecodeLayer()
def forward(self, x):
for module in self.model:
x = module(x)
return x
class DecodeLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.layers = nn.ModuleList()
for i in range(10):
self.layers.append(nn.Linear(10, 10))
if __name__ == "__main__":
test_module = TestModule(0, 3)
for x in test_module.named_parameters():
print(x[0])