微信公众号:OpenCV学堂
关注获取更多计算机视觉与深度学习知识
引言
Pytorch刚刚发布的最新版本1.10上面支持使用STN网络,帮助CNN网络获取旋转不变性特征。而且只需要在原来的CNN网络中改动十行左右代码即可获得加持,从而让训练生成的分类或者对象检测网络具有更好的稳定性。
STN网络
STN(Spatial Transformer Network)网络主要分为两个部分组成,一个是CNN部分、另外一个FC(全链接)部分。可以在图象分类、对象检测等视觉任务中使用。
上图中左侧a是输入图象、b是定位预测、c是空间变换之后、d是预测值。其中空间变换网络的结构如下:
网络输出参数theta值取决于变换的种类、对正常的几何变换是2x3的矩阵就是输出6个参数。表示如下:
支持STN的CNN图象分类演示
代码在原来CNN-Mnist的演示的基础上,添加STN网络支持即可,STN网络主要分为三个部分,第一个部分是CNN定位、第二个部分生成变换矩阵,最后一个部分使用输出的变换矩阵对输入完成变换之后再输入到正常的CNN网络部分就可以了。请注意!涉及的两个函数grid_affine跟grid_sample都是pytorch1.10版本才有的。最终实现的网络模型代码如下:
class MnistCNNNet(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def forward(self, x):
# transform the input
x = self.stn(x)
# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net().to(device)
模型的训练代码如下
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 500 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
测试代码如下:
def test():
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
STN网络运行结果:
参考:
https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
https://arxiv.org/pdf/1506.02025.pdf
扫码查看OpenCV+Pytorch系统化学习路线图
推荐阅读
CV全栈开发者说 - 从传统算法到深度学习怎么修炼
2021入坑Pytorch框架学习,我该从哪开始...
Pytorch轻松实现经典视觉任务
教程推荐 | Pytorch框架CV开发-从入门到实战
OpenCV4 C++学习 必备基础语法知识三
OpenCV4 C++学习 必备基础语法知识二
OpenCV4.5.4 人脸检测+五点landmark新功能测试
OpenCV4.5.4人脸识别详解与代码演示