TSN(ECCV2016)论文复现(2)

2020/08/04 ActionRecognition 共 12839 字,约 37 分钟

Temporal Segment Network (TSN, ECCV2016) 代码笔记 —— Dataloader 与 TSN Model 定义。

Dataloader

# main.py
if args.modality == 'RGB':
    data_length = 1
elif args.modality in ['Flow', 'RGBDiff']:
    data_length = 5

train_loader = torch.utils.data.DataLoader(
    TSNDataSet("", args.train_list, num_segments=args.num_segments,
               new_length=data_length, # 1 for RGB,2 for RGBdiff, 5 for Optical Flow
               modality=args.modality,
               image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
               transform=torchvision.transforms.Compose([
                   train_augmentation,
                   Stack(roll=args.arch == 'BNInception'),
                   ToTorchFormatTensor(div=args.arch != 'BNInception'),
                   normalize,
               ])),
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)

转到 TSNDataSet 的定义

# dataset.py
class VideoRecord(object):
    def __init__(self, row):
        self._data = row

    @property
    def path(self):
        return self._data[0]

    @property
    def num_frames(self):
        return int(self._data[1])

    @property
    def label(self):
        return int(self._data[2])
    
class TSNDataSet(torch.utils.data.Dataset);
class TSNDataSet(data.Dataset):
    def __init__(self, root_path, list_file,
                 num_segments=3, new_length=1, modality='RGB',
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 force_grayscale=False, random_shift=True, test_mode=False):

        self.root_path = root_path
        self.list_file = list_file
        self.num_segments = num_segments
        self.new_length = new_length
        self.modality = modality
        self.image_tmpl = image_tmpl
        self.transform = transform
        self.random_shift = random_shift
        self.test_mode = test_mode

        if self.modality == 'RGBDiff':
            self.new_length += 1 # Diff needs one more image to calculate diff

        self._parse_list()

    def _load_image(self, directory, idx):
        if self.modality == 'RGB' or self.modality == 'RGBDiff':
            return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')]
        elif self.modality == 'Flow':
            x_img = Image.open(os.path.join(directory, self.image_tmpl.format('x', idx))).convert('L')
            y_img = Image.open(os.path.join(directory, self.image_tmpl.format('y', idx))).convert('L')

            return [x_img, y_img]

    def _parse_list(self):
        self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)]

    def _sample_indices(self, record):
        """
        :param record: VideoRecord
        :return: list
        """

        average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
        if average_duration > 0:
            offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
        elif record.num_frames > self.num_segments:
            offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
        else:
            offsets = np.zeros((self.num_segments,))
        return offsets + 1

    def _get_val_indices(self, record):
        if record.num_frames > self.num_segments + self.new_length - 1:
            tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
            offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
        else:
            offsets = np.zeros((self.num_segments,))
        return offsets + 1

    def _get_test_indices(self, record):

        tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)

        offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])

        return offsets + 1

    """
    !!! __getitem__ !!!
    """
    def __getitem__(self, index):
        record = self.video_list[index]

        if not self.test_mode:
            segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
        else:
            segment_indices = self._get_test_indices(record)

        return self.get(record, segment_indices)

    def get(self, record, indices):

        images = list()
        for seg_ind in indices:
            p = int(seg_ind)
            for i in range(self.new_length):
                seg_imgs = self._load_image(record.path, p)
                images.extend(seg_imgs)
                if p < record.num_frames:
                    p += 1

        process_data = self.transform(images)
        return process_data, record.label

    def __len__(self):
        return len(self.video_list)

Model

# models.py 
from torch import nn
from ops.basic_ops import ConsensusModule, Identity
from transforms import *
from torch.nn.init import normal, constant

于是转到 ./ops/basic_ops.py

class Identity(torch.nn.Module):
    def forward(self, input):
        return input

Pytorch是利用Variable与Function来构建计算图的。回顾下Variable,Variable就像是计算图中的节点,保存计算结果(包括前向传播的激活值,反向传播的梯度),而Function就像计算图中的边,实现Variable的计算,并输出新的Variable。Function简单说就是对Variable的运算,如加减乘除,relu,pool等。但它不仅仅是简单的运算。与普通Python或者numpy的运算不同,Function是针对计算图,需要计算反向传播的梯度。因此他不仅需要进行该运算(forward过程),还需要保留前向传播的输入(为计算梯度),并支持反向传播计算梯度。

ConsensusModule

class SegmentConsensus(torch.autograd.Function):

    def __init__(self, consensus_type, dim=1):
        self.consensus_type = consensus_type
        self.dim = dim
        self.shape = None

    def forward(self, input_tensor):
        self.shape = input_tensor.size()
        if self.consensus_type == 'avg':
            output = input_tensor.mean(dim=self.dim, keepdim=True)
        elif self.consensus_type == 'identity':
            output = input_tensor
        else:
            output = None

        return output

    def backward(self, grad_output):
        if self.consensus_type == 'avg':
            grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim])
        elif self.consensus_type == 'identity':
            grad_in = grad_output
        else:
            grad_in = None

        return grad_in
class ConsensusModule(torch.nn.Module):

    def __init__(self, consensus_type, dim=1):
        super(ConsensusModule, self).__init__()
        self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
        self.dim = dim

    def forward(self, input):
        return SegmentConsensus(self.consensus_type, self.dim)(input)

再回到 models.py

class TSN(nn.Module):
    def __init__(self, num_class, num_segments, modality,
                 base_model='resnet101', new_length=None,
                 consensus_type='avg', before_softmax=True,
                 dropout=0.8,
                 crop_num=1, partial_bn=True);
    def _prepare_tsn(self, num_class);
    def _prepare_base_model(self, base_model);
    def train(self, mode=True);
    def partialBN(self, enable);
    def get_optim_policies(self);
    def forward(self, input);
    def _get_diff(self, input, keep_rgb=False);
    def _construct_flow_model(self, base_model);
    def _construct_diff_model(self, base_model, keep_rgb=False);
    
    @property
    def crop_size(self):
        return self.input_size

    @property
    def scale_size(self):
        return self.input_size * 256 // 224
    
    def get_augmentation(self);

TSN 定义

python 自省函数 getattr()

class Test(object):
     val = 1
 
>>> Test.val
1
>>> getattr(Test, 'val')
1
>>> getattr(Test, 'va', 5)
5
def __init__(self, num_class, num_segments, modality,
             base_model='resnet101', new_length=None,
             consensus_type='avg', before_softmax=True,
             dropout=0.8,
             crop_num=1, partial_bn=True):
    
    super(TSN, self).__init__()
    self.modality = modality
    self.num_segments = num_segments
    self.reshape = True
    self.before_softmax = before_softmax
    self.dropout = dropout
    self.crop_num = crop_num
    self.consensus_type = consensus_type
    if not before_softmax and consensus_type != 'avg':
        raise ValueError("Only avg consensus can be used after Softmax")

    if new_length is None:
        self.new_length = 1 if modality == "RGB" else 5
    else:
        self.new_length = new_length

    print(("""
		Initializing TSN with base model: {}.
		TSN Configurations:
		input_modality:     {}
		num_segments:       {}
		new_length:         {}
		consensus_module:   {}
		dropout_ratio:      {}
    """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout)))

    self._prepare_base_model(base_model)

    feature_dim = self._prepare_tsn(num_class)

    if self.modality == 'Flow':
        print("Converting the ImageNet model to a flow init model")
        self.base_model = self._construct_flow_model(self.base_model)
        print("Done. Flow model ready...")
    elif self.modality == 'RGBDiff':
        print("Converting the ImageNet model to RGB+Diff init model")
        self.base_model = self._construct_diff_model(self.base_model)
        print("Done. RGBDiff model ready.")

    self.consensus = ConsensusModule(consensus_type) # default dim=1

    if not self.before_softmax:
        self.softmax = nn.Softmax()

    self._enable_pbn = partial_bn
    if partial_bn:
        self.partialBN(True)

获取 Backbone 模型

其中 self._prepare_base_model(base_model) 这一行函数的定义:

def _prepare_base_model(self, base_model):

    if 'resnet' in base_model or 'vgg' in base_model: # default base_model='resnet101'
        self.base_model = getattr(torchvision.models, base_model)(True) # download=True
        self.base_model.last_layer_name = 'fc'
        self.input_size = 224
        
        # 默认是RGB, 如果是光流或者RGBdiff则进if/elif循环
        self.input_mean = [0.485, 0.456, 0.406]
        self.input_std = [0.229, 0.224, 0.225]
        if self.modality == 'Flow':
            self.input_mean = [0.5]
            self.input_std = [np.mean(self.input_std)]
        elif self.modality == 'RGBDiff':
            self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
            self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
    
    elif base_model == 'BNInception':
        import tf_model_zoo
        self.base_model = getattr(tf_model_zoo, base_model)()
        self.base_model.last_layer_name = 'fc'
        self.input_size = 224
        self.input_mean = [104, 117, 128]
        self.input_std = [1]

        if self.modality == 'Flow':
            self.input_mean = [128]
        elif self.modality == 'RGBDiff':
            self.input_mean = self.input_mean * (1 + self.new_length)

    elif 'inception' in base_model:
        import tf_model_zoo
        self.base_model = getattr(tf_model_zoo, base_model)()
        self.base_model.last_layer_name = 'classif'
        self.input_size = 299
        self.input_mean = [0.5]
        self.input_std = [0.5]
    else:
        raise ValueError('Unknown base model: {}'.format(base_model))

下一行 feature_dim = self._prepare_tsn(num_class) 是为了定义最后的FC分类层,定义在:

def _prepare_tsn(self, num_class):
    feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
    if self.dropout == 0:
        setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
        self.new_fc = None
    else:
        setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
        self.new_fc = nn.Linear(feature_dim, num_class)

    std = 0.001
    if self.new_fc is None:
        normal(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
        constant(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
    else:
        normal(self.new_fc.weight, 0, std)
        constant(self.new_fc.bias, 0)
    return feature_dim

倒数第三行 self._enable_pbn = partial_bn; if partial_bn: self.partialBN(True); 其中 partialBN() 的定义为:

def partialBN(self, enable):
    self._enable_pbn = enable

所以这个函数没用……

Optical Flow 模型

def _construct_flow_model(self, base_model):
    # modify the convolution layers
    # Torch models are usually defined in a hierarchical way.
    # nn.modules.children() return all sub modules in a DFS manner
    modules = list(self.base_model.modules())
    first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
    conv_layer = modules[first_conv_idx]
    container = modules[first_conv_idx - 1]

    # modify parameters, assume the first blob contains the convolution kernels
    params = [x.clone() for x in conv_layer.parameters()]
    kernel_size = params[0].size()
    new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:]
    new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()

    new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels,
                         conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
                         bias=True if len(params) == 2 else False)
    new_conv.weight.data = new_kernels
    if len(params) == 2:
        new_conv.bias.data = params[1].data # add bias if neccessary
    layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name

    # replace the first convlution layer
    setattr(container, layer_name, new_conv)
    return base_model

RGBdiff 模型

def _construct_diff_model(self, base_model, keep_rgb=False):
    # modify the convolution layers
    # Torch models are usually defined in a hierarchical way.
    # nn.modules.children() return all sub modules in a DFS manner
    modules = list(self.base_model.modules())
    first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
    conv_layer = modules[first_conv_idx]
    container = modules[first_conv_idx - 1]

    # modify parameters, assume the first blob contains the convolution kernels
    params = [x.clone() for x in conv_layer.parameters()]
    kernel_size = params[0].size()
    if not keep_rgb:
        new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
        new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
    else:
        new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
        new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
                                1)
        new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]

    new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
                         conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
                         bias=True if len(params) == 2 else False)
    new_conv.weight.data = new_kernels
    if len(params) == 2:
        new_conv.bias.data = params[1].data  # add bias if neccessary
    layer_name = list(container.state_dict().keys())[0][:-7]  # remove .weight suffix to get the layer name

    # replace the first convolution layer
    setattr(container, layer_name, new_conv)
    return base_model

模型 forward

def forward(self, input):
    
    # self.new_length: 1 for RGB,2 for RGBdiff, 5 for Optical Flow
    # RGB: sample_len=3*1; RGBdiff: sample_len=2*2; OpticalFlow: sample_len=2*5
    sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 
    if self.modality == 'RGBDiff':
        sample_len = 3 * self.new_length # sample_len=2*2*3=12
        input = self._get_diff(input)

    base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:]))

    if self.dropout > 0:
        base_out = self.new_fc(base_out)

    if not self.before_softmax:
        base_out = self.softmax(base_out)
    if self.reshape:
        base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])

    output = self.consensus(base_out)
    return output.squeeze(1)

其中 input = self._get_diff(input) 的定义为:

def _get_diff(self, input, keep_rgb=False):
    input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2
    input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:])
    if keep_rgb:
        new_data = input_view.clone()
    else:
        new_data = input_view[:, :, 1:, :, :, :].clone()

    for x in reversed(list(range(1, self.new_length + 1))):
        if keep_rgb:
            new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
        else:
            new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]

    return new_data

文档信息

Search

    Table of Contents