1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| import torch import torch.nn as nn import torch.nn.functional as F
class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv = nn.Conv2d(2, 1, 1)
net = CNN()
class Test(nn.Module): def __init__(self, cnn): super(Test, self).__init__() self.conv = cnn.conv
def forward(self, x): x = self.conv(x) x = F.relu(x) return x
a = torch.randn([1, 2, 3, 4]) test = Test(net) res = test(a) print(res)
|
Test类直接接受是CNN()实例化后的net,相当于把CNN()的网络结构作为参数传给了Test(),这样在CNN()中就不用写forward(),即使写了也不起作用,因为对于test对象来说,它的网络结构是Test()定义的,其前向传播必须在Test()中进行定义。