对多层神经网络连接的理解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
import torch.nn.functional as F


class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = nn.Conv2d(2, 1, 1)


net = CNN()


class Test(nn.Module):
def __init__(self, cnn):
super(Test, self).__init__()
self.conv = cnn.conv

def forward(self, x):
x = self.conv(x)
x = F.relu(x)
return x


a = torch.randn([1, 2, 3, 4])
test = Test(net)
res = test(a)
print(res)

Test类直接接受是CNN()实例化后的net,相当于把CNN()的网络结构作为参数传给了Test(),这样在CNN()中就不用写forward(),即使写了也不起作用,因为对于test对象来说,它的网络结构是Test()定义的,其前向传播必须在Test()中进行定义。