SimCLR 代码分析
介绍
这篇文章分析 PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations。使用自监督的方法来生成特征表示。其中 loss 部分实现得非常巧妙,因此特地拿出来分析。
SimCLR/data_aug/dataset_wrapper.py
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import datasets
np.random.seed(0)
class DataSetWrapper(object):
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
self.batch_size = batch_size
self.num_workers = num_workers
self.valid_size = valid_size
self.s = s
self.input_shape = eval(input_shape)
def get_data_loaders(self):
# 数据增强
data_augment = self._get_simclr_pipeline_transform()
# 使用官方的 STL 数据集
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True,
transform=SimCLRDataTransform(data_augment))
# 随机划分训练集与验证集
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
return train_loader, valid_loader
def _get_simclr_pipeline_transform(self):
# 运用了很多数据增强的方法
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
transforms.ToTensor()])
return data_transforms
def get_train_validation_data_loaders(self, train_dataset):
# 随机划分训练集和验证集
# obtain training indices that will be used for validation
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(self.valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
num_workers=self.num_workers, drop_last=True, shuffle=False)
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
num_workers=self.num_workers, drop_last=True)
return train_loader, valid_loader
class SimCLRDataTransform(object):
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
# 同一个 sample 增强两次得到两个输入
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj
这部分要注意的主要就是对于一个数据数据得到两个增强的数据。
SimCLR/models/resnet_simclr.py
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class ResNetSimCLR(nn.Module):
def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
"resnet50": models.resnet50(pretrained=False)}
resnet = self._get_basemodel(base_model)
num_ftrs = resnet.fc.in_features
self.features = nn.Sequential(*list(resnet.children())[:-1])
# projection MLP
self.l1 = nn.Linear(num_ftrs, num_ftrs)
self.l2 = nn.Linear(num_ftrs, out_dim)
def _get_basemodel(self, model_name):
try:
model = self.resnet_dict[model_name]
print("Feature extractor:", model_name)
return model
except:
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
def forward(self, x):
h = self.features(x)
h = h.squeeze()
# 用了两个 FC 层对特征进行映射
# 这个实现中设置最终得到 256 维向量
x = self.l1(h)
x = F.relu(x)
x = self.l2(x)
return h, x
SimCLR/loss/nt_xent.py
import torch
import numpy as np
class NTXentLoss(torch.nn.Module):
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
super(NTXentLoss, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.device = device
self.softmax = torch.nn.Softmax(dim=-1)
# 标注负样本所在的位置
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def _get_similarity_function(self, use_cosine_similarity):
if use_cosine_similarity:
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
return self._cosine_simililarity
else:
return self._dot_simililarity
def _get_correlated_mask(self):
# 每个图像都经过两次增强,得到 2 * self.batch_size 个输入图像
# 两两求相似度之后得到 2N * 2N 矩阵
# 其中对角线以及上 N 下 N 对角线是同一个图像,也就是正样本
diag = np.eye(2 * self.batch_size)
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
mask = torch.from_numpy((diag + l1 + l2))
# 取反就是负样本
mask = (1 - mask).type(torch.bool)
return mask.to(self.device)
@staticmethod
def _dot_simililarity(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (M, 1, C)
# y shape: (1, C, N)
# v shape: (M, N)
# 因为有 batch size,因此得到的相似度是一个矩阵
return v
def _cosine_simililarity(self, x, y):
# x shape: (M, 1, C)
# y shape: (1, N, C)
# v shape: (M, N)
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def forward(self, zis, zjs):
# 组合成 2N 维向量,其中第 i 和 i + N 是同一个样本增强结果
representations = torch.cat([zjs, zis], dim=0)
# 计算相似度,得到 2N * 2N 矩阵
similarity_matrix = self.similarity_function(representations, representations)
# 找到正样本,其中中间对角线是自身相乘的,可以无视掉
# filter out the scores from the positive samples
l_pos = torch.diag(similarity_matrix, self.batch_size)
r_pos = torch.diag(similarity_matrix, -self.batch_size)
# view 是将向量 reshape,得到一个 2N 维向量
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
# 负样本,同样也是 reshape,得到 2N * (2N - 2) 矩阵
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
# 串联起来之后维度是 2N * (2N - 1),其中每一列中第一个都是正样本,其他为负样本
logits = torch.cat((positives, negatives), dim=1)
logits /= self.temperature
# 构造一个维度为 2N 的标签,0 对应相面的正样本
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
# 交叉熵展开之后就是文中 loss 的形式
loss = self.criterion(logits, labels)
return loss / (2 * self.batch_size)
这是代码中最关键的部分,这里在一个 batch 中构造正负样本,通过相似矩阵的形似很巧妙地挖掘出了正负样本,最终也巧妙用交叉熵实现了以下 loss 的形式。