
下载链接:需要在:EchoNet Dynamic 进行申请,源码也是由斯坦福公布的






Args:num_epochs (int, optional): Number of epochs during trainingDefaults to 50.迭代次数modelname (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'',  ``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101'' 模型名称,主要为pytorch内置模型pretrained (bool, optional): Whether to use pretrained weights for modelDefaults to False.加载预训练权重output (str or None, optional): 输出文件位置device (str or None, optional): 设备(CPU/GPU))n_train_patients (str or None, optional): 提前停止策略num_workers (int, optional): 线程数batch_size (int, optional): 一个批次的大小seed (int, optional):随机种子lr_step_period (int or None, optional): 学习率衰减save_segmentation (bool, optional): 保存分割结果block_size (int, optional): 视频太长的话,需不需要考虑分块保存run_test (bool, optional): 进行测试


  • 选择随机种子、输出文件保存位置、以及设备(CPU/GPU)
  • 模型:从pytorch中加载模型,其中输出层需要进行修改,输出需要改为1,即即心房/非心房,以预测概率值
  • 选择优化器,求出图像的均值和标准差,以方便对图像进行归一化处理
  • 读取数据,读取视频数据,同时标签文件是包含视频名称/帧、多边形mask、数据集划分等组成,需要生成mask。

  • 然后开始训练,并统计相关指标
def run_epoch(model, dataloader, train, optim, device):"""Run one epoch of training/evaluation for segmentation.Args:model (torch.nn.Module): Model to train/evaulate.dataloder (torch.utils.data.DataLoader): Dataloader for dataset.train (bool): Whether or not to train model.optim (torch.optim.Optimizer): Optimizerdevice (torch.device): Device to run on"""total = 0.n = 0pos = 0neg = 0pos_pix = 0neg_pix = 0model.train(train)large_inter = 0large_union = 0small_inter = 0small_union = 0large_inter_list = []large_union_list = []small_inter_list = []small_union_list = []with torch.set_grad_enabled(train):with tqdm.tqdm(total=len(dataloader)) as pbar:for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader:# Count number of pixels in/out of human segmentationpos += (large_trace == 1).sum().item()pos += (small_trace == 1).sum().item()neg += (large_trace == 0).sum().item()neg += (small_trace == 0).sum().item()# Count number of pixels in/out of computer segmentationpos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy()pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy()neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy()neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy()# Run prediction for diastolic frames and compute losslarge_frame = large_frame.to(device)large_trace = large_trace.to(device)y_large = model(large_frame)["out"]loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum")# Compute pixel intersection and union between human and computer segmentationslarge_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))# Run prediction for systolic frames and compute losssmall_frame = small_frame.to(device)small_trace = small_trace.to(device)y_small = model(small_frame)["out"]loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum")# Compute pixel intersection and union between human and computer segmentationssmall_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))# Take gradient step if trainingloss = (loss_large + loss_small) / 2if train:optim.zero_grad()loss.backward()optim.step()# Accumulate losses and compute baselinestotal += loss.item()n += large_trace.size(0)p = pos / (pos + neg)p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2)# Show info on process barpbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter)))pbar.update()large_inter_list = np.array(large_inter_list)large_union_list = np.array(large_union_list)small_inter_list = np.array(small_inter_list)small_union_list = np.array(small_union_list)return (total / n / 112 / 112,large_inter_list,large_union_list,small_inter_list,small_union_list,)


