阅前说明
前面已经出现的代码用 … 代替。
本文仅解析train部分的代码(inference的部分会后续更新)。
不对网络结构做过多解释,默认已经熟悉 mrcnn 的结构以及读过这篇论文了。

另:inference 部分已更新,见:siris 显著性排序网络代码解读(inference过程)

文章目录

  • 第一部分 训练mrcnn网络
    • obj_sal_seg_branch/train.py
    • obj_sal_seg_branch.SOSNet
      • model.load_weights
      • model.train 及该方法中调用的其他方法
  • 第二部分 预处理 提取特征
    • pre_process/pre_process_obj_feat_GT.py
    • pre_process.PreProcNet
      • model.detect
    • pre_process.Dataset
    • pre_process.DataGenerator
  • 第三部分 训练显著性排序网络
  • 最后的一点思考

第一部分 训练mrcnn网络

obj_sal_seg_branch/train.py

根据 README.md ,首先运行 obj_sal_seg_branch/train.py 。下面来看这个py文件的内容:

DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your locationif __name__ == '__main__':command = "train"config = ObjSegMaskConfig()config.display()log_path = "logs/"model = SOSNet(mode="training", config=config, model_dir=log_path)

首先获得一个 SOSNet 的实例对象 model

if __name__ == '__main__':...# Start from pre-trained weights# Load weightsmodel_weights = "../weights/mask_rcnn_coco.h5"  # Make sure this is correct or change to location of weight path# Exclude layers - since we change the number of classesexclude_layers = ["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"]print("Exclude Layers: ", exclude_layers)print("Loading weights ", model_weights)model.load_weights(model_weights, by_name=True, exclude=exclude_layers)

设置不参与训练的层,放在 exclude_layers 中。这里,mrcnn 相关的输出是不要的。(通过后面的代码提示,因为显著性排序不需要关注显著性物体的类别,所以只分为两类:显著性物体或背景。所以注释中写“since we change the number of classes”)

随后调用了 modelload_weights 方法,这个方法不载入被 exclude 的层的权重。

if __name__ == '__main__':...if command == "train":print("Start Training...")# Train Datasetdataset_train = Obj_Sal_Seg_Dataset(DATASET_ROOT, "train")# Val Datasetdataset_val = Obj_Sal_Seg_Dataset(DATASET_ROOT, "val")# ********** Training  **********# Image Augmentation# Right/Left flip 50% of the timeaugmentation = imgaug.augmenters.Fliplr(0.5)# Training - Stage 1print("Training network heads")model.train(dataset_train, dataset_val,learning_rate=config.LEARNING_RATE,epochs=40,layers='heads',augmentation=augmentation)# Training - Stage 2# Fine tune all layersprint("Fine tune all layers")model.train(dataset_train, dataset_val,learning_rate=config.LEARNING_RATE / 10,epochs=200,layers='all',augmentation=augmentation)

command == "train" 是必然的,因为一开始就赋值了。
然后调用了 obj_sal_seg_branch/Obj_Sal_Seg_Dataset ,得到训练和验证数据。数据增广后开始训练。
训练分为两个阶段,第一阶段训练网络头部,第二阶段微调所有层。

接下来看看 obj_sal_seg_branch/train.py 涉及到的两个重要类:obj_sal_seg_branch.SOSNetobj_sal_seg_branch/Obj_Sal_Seg_Dataset

obj_sal_seg_branch.SOSNet

首先看 __init__ 函数:

class SOSNet():def __init__(self, mode, config, model_dir):self.mode = modeself.config = configself.model_dir = model_dirself.set_log_dir()self.keras_model = Model_Sal_Seg.build_saliency_seg_model(config, mode)

前几句平平无奇,然后搞了一个 self.keras_model ,这是一个 Model 类型的对象,再点进去看,在 training 模式下,这个 model 的 input 和 output 如下:

inputs = [input_image, input_image_meta,input_rpn_match, input_rpn_bbox, input_gt_class_ids, input_gt_boxes, input_gt_masks]
if not config.USE_RPN_ROIS:inputs.append(input_rois)outputs = [rpn_class_logits, rpn_class, rpn_bbox,feat_pyr_net_class_logits, feat_pyr_net_class, feat_pyr_net_bbox, obj_seg_masks,rpn_rois, output_rois,rpn_class_loss, rpn_bbox_loss,obj_sal_seg_class_loss, obj_sal_seg_bbox_loss,obj_sal_seg_mask_loss]

没错,就是 mrcnn 的那一套,把中间结果(包括各种rois)和 loss 都当做output输出了。

这个 SOSNet 其实就是这个 model 的一个包装类,SOSNet 中的方法,一部分是为了方便训练这个model的其中一些层,包括:

  • train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
    augmentation=None, custom_callbacks=None)
  • set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1)
  • compile(self, learning_rate, momentum)
  • load_weights(self, filepath, by_name=False, exclude=None)

还有些方法是为了方便利用训练结果:

  • detect(self, images, verbose=0)
  • mold_inputs(self, images)
  • unmold_detections(self, detections, mrcnn_mask, original_image_shape, image_shape, window)
  • get_anchors(self, image_shape)

model.load_weights

还是从最主要的开始看,据上文所述 obj_sal_seg_branch/train 首先调用了 model.load_weights 方法。

调用时传参:

model.load_weights(model_weights, by_name=True, exclude=exclude_layers)

结合 load_weights 的源码。
首先导包:

def load_weights(self, filepath, by_name=False, exclude=None):"""Modified version of the corresponding Keras function withthe addition of multi-GPU support and the ability to excludesome layers from loading.exclude: list of layer names to exclude"""import h5py# Conditional import to support versions of Keras before 2.2# TODO: remove in about 6 months (end of 2018)try:from keras.engine import savingexcept ImportError:# Keras before 2.2 used the 'topology' namespace.from keras.engine import topology as savingif h5py is None:raise ImportError('`load_weights` requires h5py.')

然后判断是否有要剔除的层。而传参时 exclude 不为空,所以将 by_name 置为 True。(事实上在传参的时候这个参数也是True,这句话只是为了代码的稳健性。不写也没事)

def load_weights(self, filepath, by_name=False, exclude=None):...if exclude:by_name = True

然后根据路径获取一个h5py文件。(对 h5py 文件不熟悉的可以参考这个博客,想知道更多的关于这个文件的操作方法请参考这个博客。)

def load_weights(self, filepath, by_name=False, exclude=None):...f = h5py.File(filepath, mode='r')if 'layer_names' not in f.attrs and 'model_weights' in f:f = f['model_weights']

这句是为了多GUP训练

def load_weights(self, filepath, by_name=False, exclude=None):...# In multi-GPU training, we wrap the model. Get layers# of the inner model because they have the weights.keras_model = self.keras_modellayers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model") \else keras_model.layers

然后过滤掉被排除的 layers,其中 filter 是python的内置函数,第一个参数是函数,第二个参数是可迭代, filter 会把后一个参数的每个数据输入函数中判断,将其中True的结果返回。

def load_weights(self, filepath, by_name=False, exclude=None):...# Exclude some layersif exclude:layers = filter(lambda l: l.name not in exclude, layers)

然后调用 keras 的 api 载入权重、关闭 f 文件、更新日志

def load_weights(self, filepath, by_name=False, exclude=None):...if by_name:saving.load_weights_from_hdf5_group_by_name(f, layers)else:saving.load_weights_from_hdf5_group(f, layers)if hasattr(f, 'close'):f.close()# Update the log directoryself.set_log_dir(filepath)

model.train 及该方法中调用的其他方法

载入权重之后,有两轮训练操作。来看看两次调用时候的传参:

# Training - Stage 1print("Training network heads")model.train(dataset_train, dataset_val,learning_rate=config.LEARNING_RATE,epochs=40,layers='heads',augmentation=augmentation)# Training - Stage 2# Fine tune all layersprint("Fine tune all layers")model.train(dataset_train, dataset_val,learning_rate=config.LEARNING_RATE / 10,epochs=200,layers='all',augmentation=augmentation)

来看这个方法的源码:

首先预定义了 layer 的正则表达式,放在一个字典里。然后利用传进来的 layer 作为 key 值,获取字典的 value。

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,augmentation=None, custom_callbacks=None):assert self.mode == "training", "Create model in training mode."# TODO: Update# Pre-defined layer regular expressionslayer_regex = {# Only Heads"heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",# From a specific Res-Net stage and up"3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)","4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)","5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",# All layers"all": ".*",}if layers in layer_regex.keys():  # 根据 keys 获得 valuelayers = layer_regex[layers]

然后把训练数据和测试数据作为参数传入 DataGenerator.data_generator ,获得 train_generatorval_generator

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,augmentation=None, custom_callbacks=None):...# Data generatorstrain_generator = DataGenerator.data_generator(train_dataset, self.config, shuffle=True,augmentation=augmentation,batch_size=self.config.BATCH_SIZE)val_generator = DataGenerator.data_generator(val_dataset, self.config, shuffle=True,batch_size=self.config.BATCH_SIZE)logs_path = self.log_dir + "/training.log"

然后设置回调函数,对回调函数不了解的可以参考 《deep learning with python》,keras作者写的那本,第七章。

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,augmentation=None, custom_callbacks=None):...# Callbackscallbacks = [keras.callbacks.TensorBoard(log_dir=self.log_dir,histogram_freq=0, write_graph=True, write_images=False),keras.callbacks.ModelCheckpoint(self.checkpoint_path,verbose=1, save_weights_only=True),keras.callbacks.CSVLogger(logs_path, separator=",", append=True),]# Add custom callbacks to the listif custom_callbacks:callbacks += custom_callbacks

正式开始训练。先将目标层的参数设置为可训练,然后 compile 模型,然后调用 fit 训练模型。以上都是基本流程。在这里用的 set_trainablecompile 是自己定义的,fit_generator 是keras的API。

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,augmentation=None, custom_callbacks=None):...# Trainlog("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))log("Checkpoint Path: {}".format(self.checkpoint_path))self.set_trainable(layers)self.compile(learning_rate, self.config.LEARNING_MOMENTUM)# Work-around for Windows: Keras fails on Windows when using# multiprocessing workers. See discussion here:# https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009if os.name is 'nt':workers = 0else:workers = multiprocessing.cpu_count()self.keras_model.fit_generator(train_generator,initial_epoch=self.epoch,epochs=epochs,steps_per_epoch=self.config.STEPS_PER_EPOCH,callbacks=callbacks,validation_data=val_generator,validation_steps=self.config.VALIDATION_STEPS,max_queue_size=100,workers=workers,use_multiprocessing=True,)self.epoch = max(self.epoch, epochs)

然后我们来看看自定义的两个方法:

def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1):"""Sets model layers as trainable if their names matchthe given regular expression."""# Print message on the first call (but not on recursive calls)if verbose > 0 and keras_model is None:log("Selecting layers to train")keras_model = keras_model or self.keras_model# In multi-GPU training, we wrap the model. Get layers# of the inner model because they have the weights.layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model") \else keras_model.layersfor layer in layers:# Is the layer a model?if layer.__class__.__name__ == 'Model':print("In model: ", layer.name)self.set_trainable(layer_regex, keras_model=layer, indent=indent + 4)continueif not layer.weights:continue# Is it trainable?trainable = bool(re.fullmatch(layer_regex, layer.name))# Update layer. If layer is a container, update inner layer.if layer.__class__.__name__ == 'TimeDistributed':layer.layer.trainable = trainableelse:layer.trainable = trainable# Print trainable layer namesif trainable and verbose > 0:log("{}{:20}   ({})".format(" " * indent, layer.name,layer.__class__.__name__))
def compile(self, learning_rate, momentum):"""Gets the model ready for training. Adds losses, regularization, andmetrics. Then calls the Keras compile() function."""# Compile SGDoptimizer = keras.optimizers.SGD(lr=learning_rate, momentum=momentum,clipnorm=self.config.GRADIENT_CLIP_NORM)# Add Losses# First, clear previously set losses to avoid duplicationself.keras_model._losses = []self.keras_model._per_input_losses = {}loss_names = ["rpn_class_loss", "rpn_bbox_loss","obj_sal_seg_class_loss", "obj_sal_seg_bbox_loss", "obj_sal_seg_mask_loss"]for name in loss_names:layer = self.keras_model.get_layer(name)if layer.output in self.keras_model.losses:continueloss = (tf.reduce_mean(layer.output, keepdims=True)* self.config.LOSS_WEIGHTS.get(name, 1.))self.keras_model.add_loss(loss)# Add L2 Regularization# Skip gamma and beta weights of batch normalization layers.reg_losses = [keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)for w in self.keras_model.trainable_weightsif 'gamma' not in w.name and 'beta' not in w.name]self.keras_model.add_loss(tf.add_n(reg_losses))# Compileself.keras_model.compile(optimizer=optimizer,loss=[None] * len(self.keras_model.outputs))# Add metrics for lossesfor name in loss_names:if name in self.keras_model.metrics_names:continuelayer = self.keras_model.get_layer(name)self.keras_model.metrics_names.append(name)loss = (tf.reduce_mean(layer.output, keepdims=True)* self.config.LOSS_WEIGHTS.get(name, 1.))self.keras_model.metrics_tensors.append(loss)

好了,到这里为止,关于 README.md 中首先要运行的 obj_sal_seg_branch/train.py 部分的代码就全部解读完毕。

总结来说完成的任务是:对原 mask r-cnn 去 mrcnn_ 头,在新数据集上先对个别层训练,然后对所有层微调。

第二部分 预处理 提取特征

pre_process/pre_process_obj_feat_GT.py

根据 README.md,第二步是运行 pre_process/pre_process_obj_feat_GT.py 文件。那么来看看它的源码吧:

首先是一堆路径的设置。这个类需要运行两次,第二次运行的时候把 data_split = "train" 注释掉,换成 data_split = "val"

DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/"    # Change to your locationif __name__ == '__main__':# add pre-trained weight path - backbone pre-trained on salient objects (binary, no rank)weight_path = ""# Run Script twice to generate pre-processed object features of GT objects for "train" and "val" data_splitsdata_split = "train"# data_split = "val"out_path = PRE_PROC_DATA_ROOT + "pre_process_feat/" + data_split + "/"if not os.path.exists(out_path):os.makedirs(out_path)mode = "inference"config = RankModelConfig()log_path = "logs/"

然后通过 Model_Obj_Feat.build_obj_feat_model(config) 获得一个 Model 类型的实例

if __name__ == '__main__':...keras_model = Model_Obj_Feat.build_obj_feat_model(config)model_name = "Obj_Feat_Net"

简单看看这个 Model 类型对象实例( keras_model )的结构:

def build_obj_feat_model(config):# *********************** INPUTS ***********************input_image = Input(shape=(config.NET_IMAGE_SIZE, config.NET_IMAGE_SIZE, 3), name="input_image")input_image_meta = Input(shape=[config.IMAGE_META_SIZE], name="input_image_meta")input_obj_rois = Input(shape=(config.SAL_OBJ_NUM, 4), name="input_obj_rois")# Normalize coordinatesobj_rois = Lambda(lambda x: fpn_model_utils.norm_boxes_graph(x, K.shape(input_image)[1:3]))(input_obj_rois)# *********************** BACKBONE FEATURES ***********************# Generate Backbone features# backbone_feat = [P2, P3, P4, P5]# rpn_features = [P2, P3, P4, P5, P6]# P2: (?, 256, 256, 256)# P3: (?, 128, 128, 256)# P4: (?, 64, 64, 256)# P5: (?, 32, 32, 256)backbone_feat = generate_backbone_features(input_image, config)P2, P3, P4, P5 = backbone_feat# *********************** SALIENT OBJECT MASK BRANCH ***********************# Produce Object Segment Masksobj_seg_masks = ObjectSegmentationMaskBranch.build_fpn_mask_graph(obj_rois, backbone_feat,input_image_meta,config.MASK_POOL_SIZE,config.NUM_CLASSES,train_bn=config.TRAIN_BN)# ROIAlign ed Object Featuresobj_features = pyr_roi_align_graph(obj_rois, backbone_feat, input_image_meta,config.POOL_SIZE,train_bn=config.TRAIN_BN,fc_layers_size=config.FPN_CLASSIF_FC_LAYERS_SIZE)# Modelinputs = [input_image, input_image_meta, input_obj_rois]outputs = [obj_seg_masks, obj_features, P5]model = Model(inputs=inputs, outputs=outputs, name="obj_feat_model")return model

这部分先从基础网络中获得 [P2, P3, P4, P5] 特征集合,然后将特征:

  • 输入 build_fpn_mask_graph 得到 mask(在代码中存放在 obj_seg_masks 变量中)。
  • 输入 pyr_roi_align_graph 得到 ROIAlign 后的特征(在代码中存放在 obj_features 变量中 )。

所以输出为: outputs = [obj_seg_masks, obj_features, P5] 。为啥要输出这仨?因为在后面的显著性排序网络中,需要用到这些特征。

然后我们再回到 pre_process/pre_process_obj_feat_GT.py 文件中,接着上面看代码:

if __name__ == '__main__':...model = PreProcNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)

刚看完一个 keras_model ,马上又来一个 model ,跟第一部分中 SOSNetbuild_saliency_seg_model build 的模型的一个包装类一样,这个 model 对应的类 PreProcNet 是 keras_model(即 build_obj_feat_model build 的模型)的一个包装类,同样里面也装了一些方便 detect 时候调用的方法。

我们先接着把 pre_process/pre_process_obj_feat_GT.py 文件中剩余的看完,再来解读 PreProcNet 类中那些方法(不用看也可以根据前面 SOSNet 的例子猜出来,肯定也是有载入权重的方法、detect方法)

继续 pre_process/pre_process_obj_feat_GT.py :
调用model中的载入权重的方法。

if __name__ == '__main__':...# Load weightsprint("Loading weights ", weight_path)model.load_weights(weight_path, by_name=True)

开头就给 mode 赋值了,肯定会进入这个分支。
调用 pre_process.Dataset 获得数据,然后把 dataset 传入 pre_process.DataGenerator 。后文将对这俩类做详解。

if __name__ == '__main__':...if mode == "inference":# ********** Create Datasets# Train/Val Datasetdataset = Dataset(DATASET_ROOT, data_split)predictions = []num = len(dataset.img_ids)for i in range(num):image_id = dataset.img_ids[i]print(i + 1, " / ", num, " - ", image_id)input_data, gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list = DataGenerator.load_inference_data_obj_feat_gt(dataset, image_id, config)

调用 model 中的 detect 方法(对啊,这也是为啥设置成 ‘inference’ 模式,因为这部分是预处理,得到特征方便后面的显著性排序网络。而特征的得到,是靠 detect )

if __name__ == '__main__':...if mode == "inference":...for i in range(num):...result = model.detect(input_data, verbose=1)

把之前从 DataGenerator 中获得的其它数据都放进 result 里面。

if __name__ == '__main__':...if mode == "inference":...for i in range(num):...result["gt_ranks"] = gt_ranksresult["sel_not_sal_obj_idx_list"] = sel_not_sal_obj_idx_listresult["shuffled_indices"] = shuffled_indicesresult["chosen_obj_idx_order_list"] = chosen_obj_idx_order_list

最后把 result 存入本地文件。其中 pickle 是 python 中一个的工具,它能够实现任意对象与文本之间的相互转化,也可以实现任意对象与二进制之间的相互转化。也就是说,pickle 可以实现 Python 对象的存储及恢复。

if __name__ == '__main__':...if mode == "inference":...for i in range(num):...o_p = out_path + image_idwith open(o_p, "wb") as f:pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)

到这里 pre_process_obj_feat_GT.py 的内容就结束了,总结来说,就是获得了提取特征的model对象。

然后通过Dataset 和 DataGenerator 获得一部分数据(gt_ranks、 sel_not_sal_obj_idx_list、shuffled_indices、chosen_obj_idx_order_list),

然后调用 detect 获得另一部分数据,或者说特征(obj_masks、obj_feat、P5)。最后把这些数据都用 pickle 存到本地。

接下来返回去看看在这个过程中涉及到的 detect 方法的具体代码、以及 Dataset 和 DataGenerator 的具体代码。

pre_process.PreProcNet

load_weights 和前面基本一样,不重复说了。
看下 detect(self, input_data, verbose=0)

model.detect

源码很短,真的很短。
返回一个字典,携带了预测结果,也就是:

  • obj_masks:[batch, roi_count, height, width, num_classes]
  • obj_feat:pooled features
  • P5:[batch, 32, 32, 256]
# Detection performed per single image
def detect(self, input_data, verbose=0):assert self.mode == "inference", "Create model in inference mode."if verbose:log("Processing image")log("image", input_data[0])detections = self.keras_model.predict(input_data, verbose=0)# Process detectionobj_masks, obj_feat, P5 = detectionsresult = {}result["obj_masks"] = obj_masksresult["obj_feat"] = obj_featresult["P5"] = P5return result

pre_process.Dataset

先看 __init__

class Dataset(object):def __init__(self, dataset_root, data_split):self.dataset_root = dataset_root                    # Root folder of Datasetself.data_split = data_splitself.load_dataset()

__init__ 中调用了 self.load_dataset() ,来看看这个方法的源码:

代码也很短,简洁明快。首先导入图片 id ,存入 self.img_ids ,然后导入排序的 ground_truth,存入 self.gt_rank_orders

最后导入显著性物体的分割数据,这个数据是存在 jason 文件中的。这里面可以得到 obj_bbox , obj_seg , _sal_obj_idx_list , _not_sal_obj_idx_list

def load_dataset(self):print("\nLoading Dataset...")image_file = self.data_split + "_images.txt"# Get list of image idsimage_path = os.path.join(self.dataset_root, image_file)with open(image_path, "r") as f:image_names = [line.strip() for line in f.readlines()]self.img_ids = image_namesprint(self.img_ids)# Load Rank Orderrank_order_root = self.dataset_root + "rank_order/" + self.data_split + "/"self.gt_rank_orders = self.load_rank_order_data(rank_order_root)# Load Object Dataobj_seg_data_path = self.dataset_root + "obj_seg_data_" + self.data_split + ".json"self.obj_bboxes, self.obj_seg, self.sal_obj_idx_list, self.not_sal_obj_idx_list = self.load_object_seg_data(obj_seg_data_path)

其中详细的调用方法,比如 self.load_rank_order_data 到底是怎么 load 的就不在此详论了。

pre_process.DataGenerator

这里面有两个方法,分别是:

  • load_inference_data_obj_feat
  • load_inference_data_obj_feat_gt

在 pre_process/pre_process_obj_feat_GT.py 中调用的是带 _gt 的那个,所以暂时先只讲 load_inference_data_obj_feat_gt
看源码:

先通过 dataset 获得一系列数据,包括:①image、②gt_ranks、③sel_not_sal_obj_idx_list、④shuffled_indices、⑤chosen_obj_idx_order_list、⑥object_roi_masks

def load_inference_data_obj_feat_gt(dataset, image_id, config):image = dataset.load_image(image_id)gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list = dataset.load_gt_rank_order(image_id)object_roi_masks = dataset.load_object_roi_masks(image_id, sel_not_sal_obj_idx_list)

随后调用 fpn_network.utils 中的工具方法,进行一系列处理:

  • 对图片进行 corp、resize 处理
  • 对 mask 也进行 resize 处理
  • 根据 obj_mask 获得 obj_bbox
def load_inference_data_obj_feat_gt(dataset, image_id, config):image = dataset.load_image(image_id)...original_shape = image.shapeimage, window, scale, padding, crop = utils.resize_image(image,min_dim=config.IMAGE_MIN_DIM,min_scale=config.IMAGE_MIN_SCALE,max_dim=config.IMAGE_MAX_DIM,mode=config.IMAGE_RESIZE_MODE)obj_mask = utils.resize_mask(object_roi_masks, scale, padding, crop)# bbox: [num_instances, (y1, x1, y2, x2)]obj_bbox = utils.extract_bboxes(obj_mask)

chosen_obj_idx_order_list 是 dataset 返回的一个数据,这是根据 ground truth 的显著性排序选出固定个(由config决定是多少个)个体,下面这段代码把这些被选中的 sal_obj 合成 batch

def load_inference_data_obj_feat_gt(dataset, image_id, config):image = dataset.load_image(image_id)...# *********************** FILL REST, SHUFFLE ORDER ***********************# order is in salient objects then non-salient objectsbatch_obj_roi = np.zeros(shape=(config.SAL_OBJ_NUM, 4), dtype=np.int32)for i in range(len(chosen_obj_idx_order_list)):_idx = chosen_obj_idx_order_list[i]batch_obj_roi[_idx] = obj_bbox[i]

然后是对图片做标准化、生成 image_meta,这跟mask r-cnn里面的操作差不多。

然后就是合成 batch,返回这些数据。

def load_inference_data_obj_feat_gt(dataset, image_id, config):image = dataset.load_image(image_id)...# Normalize imageimage = model_utils.mold_image(image.astype(np.float32), config)# Active classesactive_class_ids = np.ones([config.NUM_CLASSES], dtype=np.int32)img_id = image_idimg_id = int(img_id[-12:])# Image meta dataimage_meta = model_utils.compose_image_meta(img_id, original_shape, image.shape,window, scale, active_class_ids)# Expand input dimensions to consider batchimage = np.expand_dims(image, axis=0)image_meta = np.expand_dims(image_meta, axis=0)batch_obj_roi = np.expand_dims(batch_obj_roi, axis=0)return [image, image_meta, batch_obj_roi], gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list

到此为止,第二部分的代码就解读完成了。总结来说,第二部分的任务是提取特征以及对图片预处理,目的是方便之后输入显著性排序网络。

第三部分 训练显著性排序网络

training 过程的最后一步是运行 train.py 文件

来看看这个文件的源码;
前面是一些设置,然后调用 Model_SAM_SMM.build_saliency_rank_model 获得 keras_model ,后文将对这个 model 的结构进行解析。然后将这个 keras_model 作为参数获得 ASRNet 的对象。显然跟前面套路一样,ASRNetModel_SAM_SMM.build_saliency_rank_model 的包装类,里面还封装了一些载入权重、训练、检测之类的方法。跟前面差不多,这里就不细讲了。

# Path to dataset
DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location# Path to pre-processed data - object features
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/"    # Change to your locationif __name__ == '__main__':weight_path = ""    # add pre-trained weight pathcommand = "train"config = RankModelConfig()log_path = "logs/"mode = "training"print("Loading Rank Model")keras_model = Model_SAM_SMM.build_saliency_rank_model(config, mode)model_name = "Rank_Model_SAM_SMM"model = ASRNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)

然后就是调用 ASRNet 里面的方法,载入权重。获得 dataset 和 dataGenerator 之后训练。逻辑都差不多。

if __name__ == '__main__':...# Load weightsprint("Loading weights ", weight_path)model.load_weights(weight_path, by_name=True)# Train/Evaluate Modelif command == "train":print("Start Training...")# ********** Create Datasets# Train Datasettrain_dataset = Dataset(DATASET_ROOT, PRE_PROC_DATA_ROOT, "train")# Val Datasetval_dataset = Dataset(DATASET_ROOT, PRE_PROC_DATA_ROOT, "val")# ********** Parameters# Image Augmentation# Right/Left flip 50% of the time# augmentation = imgaug.augmenters.Fliplr(0.5)augmentation = None# ********** Create Data generatorstrain_generator = DataGenerator.data_generator(train_dataset, config, shuffle=True,augmentation=augmentation,batch_size=config.BATCH_NUM)val_generator = DataGenerator.data_generator(val_dataset, config, shuffle=True,batch_size=config.BATCH_NUM)# ********** Training  **********model.train(train_generator, val_generator,learning_rate=config.LEARNING_RATE,epochs=40,layers='all')

下面放一下 Model_SAM_SMM.build_saliency_rank_model 得到的 model 的结构(唉,看不清的话,可以下载原图(提取码1111))
或者自己调用下面这句也生成同样的效果:

utils.plot_model(model, 'model.png', show_shapes=True)


最后的分类器的结构:

然后这部分到这里也结束了,具体的显著性分类网络的代码比较简单,就不解析了。(或许之后有时间会更新)

最后的一点思考

感觉整个网络用的 Dense 层非常多,但是一个分成 6 类(5个显著性等级和1个背景)的分类器,感觉没必要?但是我也没做实验,不晓得具体情况。

另外神经网络做排序,其实是不太合适的(参考《deep learning with python》一书的观点),在这里是把排序问题转换成了一个分类任务。虽然也可以实现目的。

以上。欢迎评论区讨论。(如果有人看的话。。。)

inference 部分已更新,见:siris 显著性排序网络代码解读(inference过程)

siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency相关推荐

  1. 装逼一步到位!GauGAN代码解读来了

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...

  2. ResNet及其变种的结构梳理、有效性分析与代码解读(PyTorch)

    点击我爱计算机视觉标星,更快获取CVML新技术 本文来自知乎,作者费敬敬,现为同济大学计算机科学与技术硕士. https://zhuanlan.zhihu.com/p/54289848 温故而知新,理 ...

  3. dlib人脸识别代码解读

    文章目录 一 人脸关键点检测器的训练 1.1 原理 1.1.1 级联回归公式 1.1.2 回归方程求解 1.1.3 分裂点 1.2 源代码 1.3 代码解读 1.3.1 预处理阶段 1.3.2 训练阶 ...

  4. 基于实例分割方法的端到端车道线检测 论文+代码解读

    Towards End-to-End Lane Detection: an Instance Segmentation Approach 论文原文 https://arxiv.org/pdf/1802 ...

  5. DAMO-YOLO全流程代码解读

    一.数据集相关代码解读 创建dataloader(damo/dataset/build.py) 在damo/apis/detector_trainer.py的158行,及174-203行中,DAMO- ...

  6. BigGAN代码解读(gpt3.5的帮助)——谱正则化部分

    BigGAN代码解读(gpt4.0的帮助)--谱正则化部分 作者个人记录学习 BigGAN中使用谱归一化对训练过程进行优化,在github中的代码中,使用了自己编写的谱归一化对卷积层.线性层以及Emb ...

  7. mask rcnn 超详细代码解读(一)

    mask r-cnn 代码解读(一) 文章目录 1 代码架构 2 model.py 的结构 3 train过程代码解析 3.1 Resnet Graph 3.2 Region Proposal Net ...

  8. 类ChatGPT逐行代码解读(2/2):从零起步实现ChatLLaMA和ColossalChat

    本文为<类ChatGPT逐行代码解读>系列的第二篇,上一篇是:如何从零起步实现Transformer.ChatGLM 本文两个模型的特点是加了RLHF 第六部分 LLaMA的RLHF版:C ...

  9. 图像分割套件PaddleSeg全面解析(一)train.py代码解读

    首先祝贺百度团队百度斩获NeurIPS2020挑战赛冠军,https://www.jiqizhixin.com/articles/2020-12-09-2. 在此次比赛中使用的是基于飞桨深度学习框架开 ...

最新文章

  1. 理解什么是MyBatis?
  2. 模板 - 2 - SAT问题
  3. RPi 2B UART作为调试口或者普通串口
  4. python 动画场景_Python GUI教程(十五):在PyQt5中使用动画
  5. 8.1 概述-机器学习笔记-斯坦福吴恩达教授
  6. mysql 命令 g_MySQL命令行的几个用法
  7. 「机械」4大传动方式优劣对比:机械、电气、气压、液压
  8. windows 获取系统CPU和进程CPU 内存等信息
  9. bat文件获取当前时间并格式化输出
  10. 由乱序播放说开了去-数组的打乱算法Fisher–Yates Shuffle
  11. winserver 08 64位安装sql05 64位提示asp版本注册
  12. linux下tomcat缓存磁盘文件,Linux环境下清理Tomcat缓存
  13. python traceback报错_怎么屏蔽Python Traceback错误信息
  14. word vba设置表格样式
  15. Java Icon图标的使用
  16. “无法为保留分区分配驱动器号”的解决
  17. 今日头条推荐算法原理全文详解之三
  18. 我们整理了20个Python项目,送给正在求职的你
  19. js关系图库:aworkflow
  20. 思绪,飘在青山绿水间

热门文章

  1. 如何理解并掌握 Java 数据结构
  2. 初中数学老师计算机培训反思,初中数学老师培训心得体会优秀范文五篇
  3. CTF-show-爆破
  4. suse linux关机命令行,suse linux 关机命令
  5. 企业IT信息化的方案设计
  6. 909.在线mp3音量调整
  7. 面试7轮,结果对接的HR离职了……
  8. CRISP-DM模型
  9. 单片机python教程_如何入门单片机/嵌入式
  10. PSP1000/2000/3000 PSPgo全主机介绍(1)