接下来将按照顺序讲解每一个文件的作用

ablation

ab_mixmatch.py

这段代码定义了额外的标志来测试MixMatch方法实现的不同部分。MixMatch算法是一种半监督学习方法,利用标记和未标记的数据来训练模型。

import functools
import osfrom absl import app
from absl import flags
from easydict import EasyDict
from libml import layers, utils, models
from libml.data_pair import DATASETS
from libml.layers import MixMode
import tensorflow as tf

这段代码是一个 Python 代码文件的一部分,它使用了一些常用的 Python 库和自定义库来实现深度学习的数据处理和模型训练。

下面是这段代码的主要作用和功能:

1、导入必要的Python库:

  • functools:用于高阶函数编程。
  • os:用于与操作系统进行交互,例如获取环境变量和文件路径等。
  • absl:一个用于 Python 应用程序的命令行参数解析器。
  • easydict:提供了一种更加方便的字典方式来访问字典对象中的元素。

2.导入自定义库:

  • libml:这是一个自定义的 Python 库,包含了一些用于深度学习的数据处理和模型训练的模块。在这段代码中,我们使用了 layersutilsmodels 模块。
  • libml.data_pair:这是一个自定义的 Python 模块,它包含了一些用于深度学习的数据处理的方法。

3.定义一个 MixMode 枚举变量,用于表示数据集混合的模式。

4.使用 TensorFlow 2.x 版本的 API 构建深度学习模型。

FLAGS = flags.FLAGS

这一行代码定义了一个全局变量 FLAGS,它是 absl.flags.FLAGS 对象的一个实例。这个实例用于存储和管理命令行参数,以便在 Python 应用程序中使用这些参数。

在使用 absl.flags 库时,首先需要创建一个 FLAGS 对象实例,然后可以使用 DEFINE_xxx() 方法来定义命令行参数。在程序中引用这些参数时,可以通过 FLAGS.xxx 的方式来访问它们的值。

class AblationMixMatch(models.MultiModel):def augment(self, x, l, beta, **kwargs):assert 0, 'Do not call.'def guess_label(self, y, classifier, T, getter, **kwargs):del kwargslogits_y = [classifier(yi, training=True, getter=getter) for yi in y]logits_y = tf.concat(logits_y, 0)# Compute predicted probability distribution py.p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])p_model_y = tf.reduce_mean(p_model_y, axis=0)# Compute the target distribution.p_target = tf.pow(p_model_y, 1. / T)p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)return EasyDict(p_target=p_target, p_model=p_model_y)def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')l_in = tf.placeholder(tf.int32, [None], 'labels')wd *= lrw_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)augment = MixMode(mixmode)classifier = functools.partial(self.classifier, **kwargs)classifier(x_in, training=True)  # Instantiate network.ema = tf.train.ExponentialMovingAverage(decay=ema)ema_op = ema.apply(utils.model_vars())ema_getter = functools.partial(utils.getter_ema, ema)y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)guess = self.guess_label(tf.split(y, nu), classifier,getter=ema_getter if use_ema_guess else None, **kwargs)ly = tf.stop_gradient(guess.p_target)lx = tf.one_hot(l_in, self.nclass)xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])x, y = xy[0], xy[1:]labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)del xy, labels_xybatches = layers.interleave([x] + y, batch)skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)logits = [classifier(batches[0], training=True)]post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]for batchi in batches[1:]:logits.append(classifier(batchi, training=True))logits = layers.interleave(logits, batch)logits_x = logits[0]logits_y = tf.concat(logits[1:], 0)loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)loss_xe = tf.reduce_mean(loss_xe)loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))loss_l2u = tf.reduce_mean(loss_l2u)tf.summary.scalar('losses/xe', loss_xe)tf.summary.scalar('losses/l2u', loss_l2u)post_ops.append(ema_op)post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)with tf.control_dependencies([train_op]):train_op = tf.group(*post_ops)# Tuning op: only retrain batch norm.skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)classifier(batches[0], training=True)train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)if v not in skip_ops])return EasyDict(x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,classify_raw=tf.nn.softmax(classifier(x_in, training=False)),classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))

这段代码实现了一个名为AblationMixMatch的模型类,继承自MultiModel类,用于训练一个混合标签的半监督学习模型。其中的函数和参数的作用如下:

  • augment()函数:用于数据增强,但是这个函数的实现是抛出异常,不会被调用,因此可以被认为是空函数。
  • guess_label()函数:用于猜测标签,将无标签数据集的输出与训练好的分类器结合,计算猜测的标签。
  • model()函数:定义了整个模型的结构和训练过程,接受一系列的超参数,包括nu(无标签数据集的数量)、w_match(混合损失的权重)、warmup_kimg(warmup的步长)、batch(batch size)、lr(学习率)、wd(权重衰减系数)、ema(指数平均的系数)、beta(数据增强的系数)等。

在model()函数中,首先定义了输入placeholder的形状,然后对分类器进行了实例化,同时对模型进行了初始化和平均操作。接下来,对数据进行了数据增强和混合,然后将增强后的数据分别送入分类器中进行训练,并计算交叉熵和L2损失。最后定义了训练和调参操作,以及输出分类器的原始输出和指数平均后的输出。

具体来说,就是

    def augment(self, x, l, beta, **kwargs):assert 0, 'Do not call.'

这个方法是一个占位符,代码中没有实际使用到。它被定义在 AblationMixMatch 类中作为一个抽象方法。如果这个方法被调用,代码会抛出一个异常,提示不应该直接调用它。这种设计方式通常是为了让子类必须实现这个方法,而不是使用父类的默认实现。在本例中,它的目的可能是为了强制子类实现一个数据增强的方法。

    def guess_label(self, y, classifier, T, getter, **kwargs):del kwargslogits_y = [classifier(yi, training=True, getter=getter) for yi in y]logits_y = tf.concat(logits_y, 0)# Compute predicted probability distribution py.p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])p_model_y = tf.reduce_mean(p_model_y, axis=0)# Compute the target distribution.p_target = tf.pow(p_model_y, 1. / T)p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)return EasyDict(p_target=p_target, p_model=p_model_y)

这段代码是实现了一个模型对给定的标签 y 进行预测,其使用了一个分类器对标签进行推断,同时还有一个温度参数 T,用于控制预测的概率分布平滑程度。

首先,对每个标签 y,使用分类器classifier得到一个输出 logits_y,将所有的logits_y在第0个维度上进原始的logits_y是一个列表,得到一个新的张量。假设原始的logits_y每个元素都是形状为[batch_size, num_classes]的张量,那么拼接后的张量形状为[(len(logits_y) * batch_size), num_classes],其中len(logits_y)表示logits_y列表的长度。(这一步得到的是所有logits_y的值,)

然后,使用 softmax 函数将其转换为概率分布 p_model_y

接着,将所有标签的概率分布 p_model_y 求平均得到整个数据集的概率分布(这里的平均就是按列先求和再平均)。

最后,使用温度参数 T 对整个数据集的概率分布进行平滑,得到目标分布 p_target。该函数的返回值包含了目标分布 p_target 和整个数据集的概率分布 p_model_y。

  • 对张量p_model_y中的每个元素取T次方根(即1/T次幂),得到一个新的张量p_target。如果原始的p_model_y是一个形状为[num_classes]的张量,那么经过tf.pow操作后,得到的新张量p_target形状仍为[num_classes],其中每个元素的值是原始张量对应元素的T次方根。
  • 是对张量p_target在第1个维度(即num_classes维度)上进行归一化,得到一个新的张量p_target。具体地说,如果原始的p_target是一个形状为[batch_size, num_classes]的张量,那么经过reduce_sum操作后,得到的是一个形状为[batch_size, 1]的张量,其中每个元素是原始张量在该维度上的和。接着,使用除法操作将原始张量中的每个元素除以对应的和,从而得到新张量p_target。

    def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')l_in = tf.placeholder(tf.int32, [None], 'labels')wd *= lrw_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)augment = MixMode(mixmode)classifier = functools.partial(self.classifier, **kwargs)classifier(x_in, training=True)  # Instantiate network.ema = tf.train.ExponentialMovingAverage(decay=ema)ema_op = ema.apply(utils.model_vars())ema_getter = functools.partial(utils.getter_ema, ema)y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)guess = self.guess_label(tf.split(y, nu), classifier,getter=ema_getter if use_ema_guess else None, **kwargs)ly = tf.stop_gradient(guess.p_target)lx = tf.one_hot(l_in, self.nclass)xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])x, y = xy[0], xy[1:]labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)del xy, labels_xybatches = layers.interleave([x] + y, batch)skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)logits = [classifier(batches[0], training=True)]post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]for batchi in batches[1:]:logits.append(classifier(batchi, training=True))logits = layers.interleave(logits, batch)logits_x = logits[0]logits_y = tf.concat(logits[1:], 0)loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)loss_xe = tf.reduce_mean(loss_xe)loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))loss_l2u = tf.reduce_mean(loss_l2u)tf.summary.scalar('losses/xe', loss_xe)tf.summary.scalar('losses/l2u', loss_l2u)post_ops.append(ema_op)post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)with tf.control_dependencies([train_op]):train_op = tf.group(*post_ops)# Tuning op: only retrain batch norm.skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)classifier(batches[0], training=True)train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)if v not in skip_ops])return EasyDict(x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,classify_raw=tf.nn.softmax(classifier(x_in, training=False)),classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))

该段代码是一个方法model,包含了训练模型的整个过程。该方法接受一些参数,如nu、w_match、warmup_kimg、batch、lr、wd、ema、beta、mixmode、use_ema_guess等,其中x_in表示输入图像,y_in表示与x_in对应的标签图像,l_in表示标签的类别。该方法的目标是在给定标签样本的情况下,使用半监督学习算法来训练分类模型。

该方法的大致流程如下:

  1. 定义输入placeholder:x_in、y_in、l_in。

  2. 对于给定的参数,进行预处理:计算wd * lr、w_match、augment、classifier等。

  3. 对输入图像x_in进行一次前向传播,以便实例化网络。同时,定义一个ExponentialMovingAverage对象ema,并应用于模型变量。

  4. 将标签图像y_in展开成一维张量,并根据guess_label方法和classifier对其进行预测。其中,guess_label方法会使用模型对标签图像进行猜测,并返回猜测后的标签,即p_target。使用tf.stop_gradient方法对p_target进行梯度截断,以防止误差反向传播。

  5. 对标签图像进行one-hot编码,得到labels_x,将x_in和y用MixMode方法进行数据增强,并将labels_x和p_target合并成labels_y。

  6. 将增强后的数据集拆分成batch,并使用分类器对每个batch进行前向传播,得到对应的logits。将logits_x和logits_y分别提取出来。

  7. 计算交叉熵损失loss_xe和l2正则化损失loss_l2u,并计算它们的平均值。

  8. 进行优化操作。首先,使用Adam优化器对loss_xe和w_match * loss_l2u进行优化。然后,将ema_op和model_vars()中所有名称带有kernel的变量进行指数滑动平均操作,再将它们乘以(1-wd)进行权重衰减。最后,将所有操作合并成train_op。

  9. 对于调参,只重新训练batch norm,即将所有除skip_ops外的其他更新操作合并为train_bn。

  10. 返回一个EasyDict对象,包含了x_in、y_in、l_in、train_op、tune_op、classify_raw、classify_op等。其中,classify_raw表示在没有应用ema时,分类模型对x_in进行前向传播的结果,classify_op表示在应用ema之后,分类模型对x_in进行前向传播的结果。

展开来讲:

    def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')l_in = tf.placeholder(tf.int32, [None], 'labels')wd *= lrw_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)augment = MixMode(mixmode)classifier = functools.partial(self.classifier, **kwargs)

这是一个模型定义函数,输入参数包括nu(无标签数据集的数量)、w_match(匹配权重)、warmup_kimg(预热步长)、batch(批大小)、lr(学习率)、wd(权重衰减)、ema(指数滑动平均系数)、beta(数据增强的beta参数)、mixmode(数据增强模式)和其他可选参数。该函数返回一个分类器。

该函数首先根据数据集的高、宽和通道数创建一个输入占位符x_in和一个标签占位符y_in。

然后,将权重衰减乘以学习率,并将匹配权重乘以一个warmup_kimg参数,以在前几个迭代中逐渐增加该权重。

接着,使用给定的数据增强模式创建一个数据增强器augment。

最后,函数返回一个分类器,该分类器使用self.classifier函数作为主要分类器,其中的参数使用了kwargs,该函数是一个偏函数,其中已经部分确定了一些参数。

        classifier(x_in, training=True)  # Instantiate network.ema = tf.train.ExponentialMovingAverage(decay=ema)ema_op = ema.apply(utils.model_vars())ema_getter = functools.partial(utils.getter_ema, ema)

这段代码用于实例化一个神经网络分类器(classifier),并使用指数移动平均(Exponential Moving Average)对其参数进行平滑处理。

首先,通过调用classifier(x_in, training=True)来实例化网络,其中x_in是输入数据,training=True表示在训练模式下运行网络。

然后,使用指数移动平均(Exponential Moving Average,简称EMA)对网络的参数进行平滑处理。具体来说,通过调用tf.train.ExponentialMovingAverage(decay=ema)来创建一个指数移动平均器,其中decay参数指定了平均的衰减率

接着,通过调用ema.apply(utils.model_vars())来将指数移动平均器应用于网络的所有参数。这将为每个参数创建一个EMA副本,并更新其值。

最后,使用functools.partial将utils.getter_ema和EMA副本绑定在一起,创建一个ema_getter函数,用于在测试模式下获取网络参数的EMA副本。这将确保在测试模式下,网络参数将始终是平滑的EMA副本,而不是训练模式下的原始参数。

y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)guess = self.guess_label(tf.split(y, nu), classifier,getter=ema_getter if use_ema_guess else None, **kwargs)ly = tf.stop_gradient(guess.p_target)lx = tf.one_hot(l_in, self.nclass)xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])x, y = xy[0], xy[1:]labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)del xy, labels_xy

这段代码用于对输入数据进行一些操作,包括将输入标签(y_in)进行转置和重塑,使用分类器对标签进行预测,然后进行数据增强和标签拼接。

首先,使用tf.transpose将输入标签y_in进行转置,以便在后面进行reshape操作。然后,使用tf.reshape将转置后的标签y_in重塑为[-1] + hwc的形状,其中hwc表示标签y_in的高度、宽度和通道数。这将y_in从一个5D张量转换为一个2D张量num_views 表示每个样本所包含的视角数量。

(详细解释)tf.transpose(y_in, [1,0,2,3,4]),将y_in的维度从(batch_size, num_views, height, width, channels)变为(num_views, batch_size, height, width, channels)。使用tf.reshape函数将转置后的标签y_in重塑为一个新的形状,即[-1] + hwc。将y_in从一个5D张量转换为一个2D张量,其中第一维度表示了所有样本和视角的总数。具体来说,将第二到第五维度平坦化,即(batch_size * num_views, height, width, channels)转换为(batch_size * num_views * height * width * channels)。这种重塑操作可以将标签变为一个长向量,方便后续操作。

接下来,使用分类器对标签进行预测。具体来说,使用self.guess_label函数对重塑后的标签y进行预测,其中guess.p_target是预测的概率,可以用于计算分类器的损失函数。如果use_ema_guess为True,则使用ema_getter获取分类器参数的EMA副本进行预测。

(详细解释)y形状为(batch_size * num_views * height * width * channels),tf.split(y, nu)y 按照 nu 的值在第一维度上进行分割,得到一个包含 nu 个张量的列表。classifier 是用于分类的网络模型,它将每个标签数据映射为一个类别,并同时输出每个类别的置信度。getter 是一个函数,用于获取模型中的参数,这里使用 ema_getter 函数来获取使用指数移动平均法(Exponential Moving Average,EMA)计算的模型参数,以提高模型的鲁棒性。**kwargs 表示其他可选的参数,这些参数会传递给 guess_label 方法。

生成标签的独热编码和停止梯度的标签分布。具体来说,l_in 是一个形状为 (batch_size, ) 的张量,表示输入的真实标签。self.nclass 是一个标量,表示标签的类别数量。因此tf.one_hot(l_in, self.nclass) 会将真实标签 l_in 编码为一个形状为 (batch_size, self.nclass) 的独热编码张量,其中每一行表示一个标签的独热编码。

guess 是通过 guess_label 方法生成的一组伪标签。在该方法中,伪标签的生成是通过预测标签的分布来实现的,即 guess.p_target 表示标签数据的估计分布。为了避免在训练时反向传播误差到伪标签,导致网络训练不稳定,这里使用 tf.stop_gradient 函数将 guess.p_target 停止梯度,生成一个形状与其相同的新张量 ly

最终,lxly 分别表示真实标签和伪标签的独热编码,它们会被用于训练网络。

然后,进行数据增强和标签拼接。具体来说,将输入数据x_in和预测的标签guess.p_target(使用tf.split对预测的标签进行分割)传递给augment函数,对它们进行数据增强(augmentation)。augment函数返回增强后的数据和标签。最后,将增强后的数据x和增强后的标签y拆分为单独的张量,并将输入标签l_in进行one-hot编码,得到labels_x和labels_y。最后,删除xy和labels_xy以释放内存。

(详细解释)augment 是一个数据增强的函数,它接受三个参数:datalabelsparams,分别表示原始数据、标签和数据增强的参数。在这里,[x_in] + tf.split(y, nu) 表示将原始数据 x_in 和伪标签数据 y 按照视角数 nu 进行拆分,拼接成一个列表传递给 augment 函数。类似地,[lx] + [ly] * nu 表示将真实标签的独热编码 lx 和伪标签的独热编码 ly 按照视角数 nu 进行拆分,并使用列表推导式生成一个长度为 nu 的列表,最后将这两个列表拼接起来。params 参数中包含了两个值,都是标量 beta。它们用于控制数据增强时两种操作的强度,具体操作是随机剪裁和随机翻转。augment 函数的返回值是一个元组,包含增强后的数据和标签。在这里,xylabels_xy 分别表示增强后的数据和标签。其中,xy[0] 表示增强后的原始数据,xy[1:] 表示增强后的伪标签数据;labels_xy[0] 表示增强后的真实标签独热编码,labels_xy[1:] 表示增强后的伪标签独热编码。最后,通过将 xy 拆分为 xy,将 labels_xy 拆分为 labels_xlabels_y,分别表示增强后的原始数据、增强后的伪标签数据、增强后的真实标签独热编码和增强后的伪标签独热编码,用于训练网络。最后通过 del xy, labels_xy 删除不再需要的变量,释放内存。

        batches = layers.interleave([x] + y, batch)skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)logits = [classifier(batches[0], training=True)]post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]for batchi in batches[1:]:logits.append(classifier(batchi, training=True))logits = layers.interleave(logits, batch)logits_x = logits[0]logits_y = tf.concat(logits[1:], 0)

这段代码主要是为了计算训练数据和伪标签数据的logits(分类器输出的未经softmax处理的概率),并计算损失函数。

首先通过调用layers.interleave函数将训练数据和伪标签数据交错分组,形成一个新的batch列表,其中第一个元素是训练数据,其余元素是伪标签数据。

然后通过循环遍历每个batch,调用分类器函数classifier计算每个batch的logits。在计算logits的过程中,通过设置training=True来启用训练模式,以便在BN层中记录训练过程中的均值和方差,并在测试过程中使用它们进行归一化。

计算完logits后,通过调用layers.interleave函数将它们重新交错分组,然后将第一个元素赋给logits_x变量,其余元素赋给logits_y变量。

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)loss_xe = tf.reduce_mean(loss_xe)loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))loss_l2u = tf.reduce_mean(loss_l2u)tf.summary.scalar('losses/xe', loss_xe)tf.summary.scalar('losses/l2u', loss_l2u)

这段代码计算了两个损失函数。第一个是 softmax 交叉熵损失函数,用来计算有标签数据的分类误差,它被赋值给了变量 loss_xe。第二个是 L2 损失函数,用于衡量无标签数据的预测结果与其平滑后的伪标签之间的差异,它被赋值给了变量 loss_l2u。这两个损失函数分别使用了 TensorFlow 中的 tf.nn.softmax_cross_entropy_with_logits_v2()tf.square() 函数进行计算,并用 tf.reduce_mean() 函数求取了它们的平均值。在这里,tf.summary.scalar() 函数被用来在 TensorBoard 中记录损失函数的值。

        post_ops.append(ema_op)post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)with tf.control_dependencies([train_op]):train_op = tf.group(*post_ops)# Tuning op: only retrain batch norm.skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)classifier(batches[0], training=True)train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)if v not in skip_ops])

这段代码是在训练模型的过程中,定义了一些后处理操作(post_ops),然后使用Adam优化器最小化交叉熵损失(loss_xe)和L2正则化损失(loss_l2u)的和。其中,L2正则化损失用于匹配有标签样本和无标签样本的特征分布,以实现半监督学习的目的。tf.summary.scalar用于记录损失的变化情况。with tf.control_dependencies([train_op])语句确保在进行后续操作之前,train_op操作先被执行。另外,还定义了一个操作train_bn,用于只重新训练BN层。

        return EasyDict(x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,classify_raw=tf.nn.softmax(classifier(x_in, training=False)),classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))

这段代码返回了一个包含各种操作和张量的EasyDict对象。它包括输入张量x_in,y_in和l_in,两个分类器的softmax输出,训练操作train_op和调整操作train_bn,以及其他一些操作。这个对象的目的是使训练和测试代码更加简洁和易于理解。

def main(argv):del argv  # Unused.dataset = DATASETS[FLAGS.dataset]()log_width = utils.ilog2(dataset.width)model = AblationMixMatch(os.path.join(FLAGS.train_dir, dataset.name),dataset,lr=FLAGS.lr,wd=FLAGS.wd,arch=FLAGS.arch,batch=FLAGS.batch,nclass=dataset.nclass,ema=FLAGS.ema,beta=FLAGS.beta,use_ema_guess=FLAGS.use_ema_guess,T=FLAGS.T,mixmode=FLAGS.mixmode,nu=FLAGS.nu,w_match=FLAGS.w_match,warmup_kimg=FLAGS.warmup_kimg,scales=FLAGS.scales or (log_width - 2),filters=FLAGS.filters,repeat=FLAGS.repeat)model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)

这段代码是一个调用AblationMixMatch类实例化一个模型,并训练的函数。首先会根据FLAGS中的数据集名称选择对应的数据集,然后通过AblationMixMatch类构造模型。FLAGS中包含了训练需要用到的超参数,例如学习率、权重衰减、卷积神经网络结构等。最后,调用模型的train方法,训练模型并输出训练结果。其中,FLAGS.train_kimg和FLAGS.report_kimg是指训练步数和结果输出步数,都需要左移10位,因为模型使用的是Mini-batch SGD,每一步的batch size是2的整数次幂

if __name__ == '__main__':utils.setup_tf()flags.DEFINE_float('wd', 0.02, 'Weight decay.')flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')flags.DEFINE_bool('use_ema_guess', False, 'Whether to use EMA parameters when guessing labels.')flags.DEFINE_float('T', 0.5, 'Softmax sharpening temperature.')flags.DEFINE_enum('mixmode', 'xxy.yxy', MixMode.MODES, 'Mixup mode')flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.')flags.DEFINE_integer('warmup_kimg', 128, 'Warmup in kimg for the matching loss.')flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')FLAGS.set_default('dataset', 'cifar10.3@250-5000')FLAGS.set_default('batch', 64)FLAGS.set_default('lr', 0.002)FLAGS.set_default('train_kimg', 1 << 16)app.run(main)

这段代码是一个 Python 脚本的主函数,会在运行时执行。该脚本提供了许多可选的命令行参数,用于指定不同的超参数设置。在此之后,脚本调用了 utils.setup_tf() 函数,该函数是一个工具函数,用于设置 TensorFlow 运行时的 GPU 环境等配置。最后,脚本调用了 app.run(main) 函数来运行 main 函数。main 函数主要是构建了一个 AblationMixMatch 模型,并调用 train 函数来训练模型。

半监督学习之MixMatch(代码解读 ablation)相关推荐

  1. 半监督学习:MixMatch

    MixMatch: A Holistic Approach to Semi-Supervised Learning 官方代码---tensorflow版本 pytorch版 论文 2.1以上的内容都是 ...

  2. 详解基于图卷积的半监督学习

    Kipf和Welling最近发表的一篇论文提出,使用谱传播规则(spectral propagation)快速近似spectral Graph Convolution. 和之前讨论的求和规则和平均规则 ...

  3. 机器学习教程 之 半监督学习 Co-training 协同训练 (论文、算法、数据集、代码)

    这篇博客介绍的是一篇用于半监督分类问题的方法: 协同训练 Co-training, A. Blum and T. Mitchell, "Combining labeled and unlab ...

  4. 机器学习教程 之 半监督学习 Tri-training方法 (论文、数据集、代码)

    最近因为项目需要研究了一下半监督学习,稍经了解以后发现当存在大量未标签数据时,这确实是一种非常好用的方法,可以很好的提升分类精度.这里介绍一下周志华教授的Tri-triaining方法,在实现上非常的 ...

  5. MixMatch:半监督学习

    MixMatch:半监督学习 1 摘要 2 介绍 3 已有相关工作 3.1 Consistency Regularization 一致性正则化 3.2 Entropy Minimization/ En ...

  6. 长文总结半监督学习(Semi-Supervised Learning)

    ©PaperWeekly 原创 · 作者|燕皖 单位|渊亭科技 研究方向|计算机视觉.CNN 在现实生活中,无标签的数据易于获取,而有标签的数据收集起来通常很困难,标注也耗时和耗力.在这种情况下,半监 ...

  7. 新技术“红”不过十年,半监督学习为什么是个例外?

    作者 | 严林 来源 | 授权转载自知乎(ID:严林) 这一波深度学习的发展,以2006年Hinton发表Deep Belief Networks的论文为起点,到今年已经超过了10年.从过往学术界和产 ...

  8. 手把手教你实现GAN半监督学习

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言  本文主要介绍如何在tensorflow上仅使用200个带标 ...

  9. [论文学习]DIVIDEMIX:带噪声标签的半监督学习LEARNING WITH NOISY LABELS AS SEMI-SUPERVISED LEARNING

      本文研究含噪声标签数据的训练,是目前该领域的SOTA.主要方法是:首先使用高斯混合模型,根据训练集每样本的损失函数值对样本进行分类,分为干净样本和噪声样本,把噪声样本作为无标签样本:然后使用半监督 ...

最新文章

  1. 项目经理的超越(一)你超越了吗?
  2. yum groupinstall安装一组软件使用示例
  3. POJ - 2289 Jamie's Contact Groups(二分图多重匹配)
  4. 使用Spring和Hibernate进行集成测试有多酷
  5. matlab建立的发动机的模型,奇瑞使用基于模型的设计实现发动机管理系统软件的自主开发...
  6. Oracle Hint(提示)与常用方法
  7. 从入门到退坑,详解数分行业的3个岗位,起薪高达40W的是哪个?
  8. 顺序存储循环队列的基本操作
  9. LINUX安装7Zip
  10. python题库刷题训练软件
  11. 【每周一爬】爬取盗版小说网的小说
  12. 高德地图 js自动定位到当前城市
  13. 简析H264编码中的GOP
  14. php绑定银行卡实现,php网站如何绑定银行卡
  15. oracle pdb监听配置,oracle 12c 监听
  16. 从零基础到斩获BAT算法岗offer,围观复旦大佬的秋招之路
  17. USB通讯入门(二)CyUSB.inf文件修改后,设备管理器可以识别出USB设备,但Cypress USB Console没有任何显示
  18. 视频教程-项目1——无线自助点餐平台-Java
  19. redis加锁、解锁
  20. android低电量提示,Android P系统低电量提醒功能 根据使用情况判断充电时间

热门文章

  1. java 千分符_java中关于千分位
  2. 数据库运维的一些操作
  3. 智能防汛系统分级式水位监测-雨量水位报警站
  4. 稳控科技水库水坝监测系统解决方案
  5. Android--通过关键字查找短消息数据库并将匹配的信息显示
  6. 数据时代,谁来舞动存储“三叉戟”?
  7. zabbix监控部署
  8. 平安WiFi牵手“黑科技“,引领WiFi行业新变革
  9. c++心形编码_使用C++描绘心形
  10. 机动车 科目一 之 标识标志 (警告标志 [黄色])