SimCLR 代码分析

Author Avatar
patrickcty 11月 10, 2020

介绍

这篇文章分析 PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations。使用自监督的方法来生成特征表示。其中 loss 部分实现得非常巧妙,因此特地拿出来分析。

SimCLR/data_aug/dataset_wrapper.py

import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import datasets

np.random.seed(0)


class DataSetWrapper(object):

    def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.valid_size = valid_size
        self.s = s
        self.input_shape = eval(input_shape)

    def get_data_loaders(self):
        # 数据增强
        data_augment = self._get_simclr_pipeline_transform()

        # 使用官方的 STL 数据集
        train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True,
                                       transform=SimCLRDataTransform(data_augment))

        # 随机划分训练集与验证集
        train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
        return train_loader, valid_loader

    def _get_simclr_pipeline_transform(self):
        # 运用了很多数据增强的方法
        # get a set of data augmentation transformations as described in the SimCLR paper.
        color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
                                              transforms.ToTensor()])
        return data_transforms

    def get_train_validation_data_loaders(self, train_dataset):
        # 随机划分训练集和验证集 
        # obtain training indices that will be used for validation
        num_train = len(train_dataset)
        indices = list(range(num_train))
        np.random.shuffle(indices)

        split = int(np.floor(self.valid_size * num_train))
        train_idx, valid_idx = indices[split:], indices[:split]

        # define samplers for obtaining training and validation batches
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
                                  num_workers=self.num_workers, drop_last=True, shuffle=False)

        valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
                                  num_workers=self.num_workers, drop_last=True)
        return train_loader, valid_loader


class SimCLRDataTransform(object):
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, sample):
        # 同一个 sample 增强两次得到两个输入
        xi = self.transform(sample)
        xj = self.transform(sample)
        return xi, xj

这部分要注意的主要就是对于一个数据数据得到两个增强的数据。

SimCLR/models/resnet_simclr.py

import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
                            "resnet50": models.resnet50(pretrained=False)}

        resnet = self._get_basemodel(base_model)
        num_ftrs = resnet.fc.in_features

        self.features = nn.Sequential(*list(resnet.children())[:-1])

        # projection MLP
        self.l1 = nn.Linear(num_ftrs, num_ftrs)
        self.l2 = nn.Linear(num_ftrs, out_dim)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
            print("Feature extractor:", model_name)
            return model
        except:
            raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

    def forward(self, x):
        h = self.features(x)
        h = h.squeeze()

        # 用了两个 FC 层对特征进行映射
        # 这个实现中设置最终得到 256 维向量
        x = self.l1(h)
        x = F.relu(x)
        x = self.l2(x)
        return h, x

SimCLR/loss/nt_xent.py

import torch
import numpy as np


class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        # 标注负样本所在的位置
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        # 每个图像都经过两次增强,得到 2 * self.batch_size 个输入图像
        # 两两求相似度之后得到 2N * 2N 矩阵
        # 其中对角线以及上 N 下 N 对角线是同一个图像,也就是正样本
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        # 取反就是负样本
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (M, 1, C)
        # y shape: (1, C, N)
        # v shape: (M, N)
        # 因为有 batch size,因此得到的相似度是一个矩阵
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (M, 1, C)
        # y shape: (1, N, C)
        # v shape: (M, N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        # 组合成 2N 维向量,其中第 i 和 i + N 是同一个样本增强结果
        representations = torch.cat([zjs, zis], dim=0)

        # 计算相似度,得到 2N * 2N 矩阵
        similarity_matrix = self.similarity_function(representations, representations)

        # 找到正样本,其中中间对角线是自身相乘的,可以无视掉
        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        # view 是将向量 reshape,得到一个 2N 维向量
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)

        # 负样本,同样也是 reshape,得到 2N * (2N - 2) 矩阵
        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        # 串联起来之后维度是 2N * (2N - 1),其中每一列中第一个都是正样本,其他为负样本
        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature

        # 构造一个维度为 2N 的标签,0 对应相面的正样本
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        # 交叉熵展开之后就是文中 loss 的形式
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)

这是代码中最关键的部分,这里在一个 batch 中构造正负样本,通过相似矩阵的形似很巧妙地挖掘出了正负样本,最终也巧妙用交叉熵实现了以下 loss 的形式。

Contrastive loss