主要内容包含segment-anything项目的安装、基于SamPredictor对单点输入生成mask、基于SamPredictor对多点输入生成mask、基于SamAutomaticMaskGenerator自动生成mask。

Segment Anything项目是一个可以对任何图像进行分割的项目,其论文介绍可以查看https://blog.csdn.net/a486259/article/details/131137939,其测试网站为 https://segment-anything.com

这里对Segment Anything项目的使用进行初步总结,绝大部分内容源自https://github.com/facebookresearch/segment-anything 。

注:segment-anything训练VIT模型时的输入size为1024x1024,其输出的feature size为256x64x64,进行了16倍的下采样

1、安装segment-anything

下载segment-anything项目,进入目录后执行pip install -e .安装项目。

git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything
pip install -e .

该项目依赖opencv-python pycocotools matplotlib onnxruntime onnx torch等包,安装命令如下

pip install opencv-python pycocotools matplotlib onnxruntime onnx torch

segment-anything模型是基于torch框架实现的

2. 根据提示输入生成mask

Segment Anything Model (SAM) 预测对象mask,给出所需识别出对象的提示输入(对象的粗略位置信息)。该模型首先将图像转换为图像嵌入,然后解码器根据用户输入的提示(粗略位置信息)可以生成高质量的掩模。
SamPredictor类为模型调用提供了一个简单的接口,用于提示模型的输入。它先让用户使用“set_image”方法设置图像,该方法会将图像输入转换到特征空间嵌入。然后,可以通过“predict”方法输入提示信息,以根据这些提示有效地预测掩码。predict函数支持将点和框提示以及上一次预测迭代中的mask作为输入。

2.1 前置函数库

前置库导入和函数实现

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):if random_color:color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30/255, 144/255, 255/255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):pos_points = coords[labels==1]neg_points = coords[labels==0]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   def show_box(box, ax):x0, y0 = box[0], box[1]w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

2.2 显示样图

读取图片并展示

image = cv2.imread('images/truck.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()

2.3 加载SAM模型

sam_vit_b_01ec64模型的下载地址为: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
这里需要注意要使用cuda

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictorsam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)predictor = SamPredictor(sam)
#predictor.set_image(image)

其它版本的模型下载地址为

  • default or vit_h: ViT-H SAM model.
  • vit_l: ViT-L SAM model.
  • vit_b: ViT-B SAM model.

通过调用“SamPredictor.set_image”处理图像以生成图像嵌入(特征向量)。“SamPrejector”会记住此特征向量,并将其用于后续掩码预测。

predictor.set_image(image)

2.4 单点输入生成mask

要选择卡车,可以卡车上选择一个点。点以(x,y)格式输入到模型中,并带有标签1(前景点)或0(背景点)。可以输入多个点;这里我们只使用一个。所选的点将在图像上显示为星形。
此时代码及执行效果如下:

input_point = np.array([[250, 187]])
input_label = np.array([1])plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

使用“SamPredictor.prdict”进行预测。该模型返回掩码(masks)、掩码的分数(scores)以及可传递到下一次预测迭代的低分辨率掩码(logits)。

在“multimask_output=True”(默认设置)的情况下,SAM输出3个掩码,其中“scores”给出了模型对这些掩码质量的估计。此设置用于存在不明确输入提示的时候(光凭一个点无法有效识别出用户意图是组件局部、组件还是整体),并帮助模型消除与提示一致的不同对象的歧义。当为“multimask_output=False”时,它将返回一个掩码。对于单点等不明确的提示,建议使用“multimask_output=True”,即使只需要一个掩码;可以通过选择在“分数”中返回的分数最高的一个来选择最佳的单个掩码。这通常会得到更好的mask。

masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=True,
)print(masks.shape)  # (number_of_masks) x H x W  | output (3, 600, 900)for i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10,10))plt.imshow(image)show_mask(mask, plt.gca())show_points(input_point, input_label, plt.gca())plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)plt.axis('off')plt.show()



3、多输入生成mask

3.1 多点输入生成mask

单个输入点不明确,需要让模型返回了与其一致的多个对象。要获得单个对象,可以提供多个点。如果可用,还可以将先前迭代的掩码(logits值)提供给模型以帮助预测。当使用多个提示指定单个对象时,可以通过设置“multimask_output=False”来请求获取单个掩码。

input_point = np.array([[250, 184], [562, 322]])
input_label = np.array([1, 1])mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best maskmasks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,mask_input=mask_input[None, :, :],multimask_output=False,
)
print(masks.shape) #output: (1, 600, 900)plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

输入负点明确区域

input_label与input_point相对应,为0时表示是负点

input_point = np.array([[250, 187], [561, 322]])
input_label = np.array([1, 0])#为0时表示是负点,即第二个点[561, 322]是负点mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best maskmasks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,mask_input=mask_input[None, :, :],multimask_output=False,
)plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

3.2 boxes输入生成mask

支持将xyxy格式的box作为输入,将框内的主体目标识别出来(类似于实例分割)

input_box = np.array([212, 300, 350, 437])
masks, _, _ = predictor.predict(point_coords=None,point_labels=None,box=input_box[None, :],multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

3.3 同时输入点与boxes生成mask

point和boxes可以同时输入,只需将这两种类型的提示都包括在预测器中即可。在这里,这可以用来只选择卡车的轮胎(将车轴部分设置为负点),而不是整个车轮。

input_box = np.array([215, 310, 350, 430]) #只能默认框住正类
input_point = np.array([[287, 375]])
input_label = np.array([0]) #将车轴部分设置为负点masks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,box=input_box,multimask_output=False,
)plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

3.4 同时输入多个boxes生成mask

SamPredictor可以使用predict_tarch方法对同一图像输入多个提示(points、boxes)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合。例如,假设我们有几个来自对象检测器的输出结果。
SamPredictor对象(此外也可以使用segment_anything.utils.transforms)可以将boxes信息编码为特征向量(以实现对任意数量boxes的支持,transformed_boxes),然后预测mask。

input_boxes = torch.tensor([[75, 275, 1725, 850],[425, 600, 700, 875],[1375, 550, 1650, 800],[1240, 675, 1400, 750],
], device=predictor.device) #假设这是目标检测的预测结果
input_boxes=input_boxes/2transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(point_coords=None,point_labels=None,boxes=transformed_boxes,multimask_output=False,
)print(masks.shape)  # (batch_size) x (num_predicted_masks_per_input) x H x W | output: torch.Size([4, 1, 600, 900])plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

3.5 端到端的批量推理

如果所有提示输入都已经明确的,则可以以端到端的方式直接运行SAM。这允许SAM对图像进行批处理,以下代码构建了2个image和boxes。

image1 = cv2.imread('images/truck.jpg')
image1_boxes = torch.tensor([[75, 275, 1725, 850],[425, 600, 700, 875],[1375, 550, 1650, 800],[1240, 675, 1400, 750],
], device=sam.device)image2 = cv2.imread('images/groceries.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([[450, 170, 520, 350],[350, 190, 450, 350],[500, 170, 580, 350],[580, 170, 640, 350],
], device=sam.device)

图像和提示都作为PyTorch张量输入,这些张量(图像和提示输入)已经被编码为特征向量。所有的输入数据都被封装为list,每个元素都是一个dict,它的key如下:

  • image: CHW格式的PyTorch tensor .
  • original_size: 图像原始大小, (H, W) format.
  • point_coords: 一批输入点格式.
  • point_labels: 每个输入点所对应的类型(正例或负例).
  • boxes: 一批输入的boxe(只能是正例).
  • mask_inputs: 一批输入的mask.

如果没有相应的信息,可以不进行输入,但image必须输入

from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)def prepare_image(image, transform, device):image = transform.apply_image(image)image = torch.as_tensor(image, device=device.device) return image.permute(2, 0, 1).contiguous()batched_input = [{'image': prepare_image(image1, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),'original_size': image1.shape[:2]},{'image': prepare_image(image2, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),'original_size': image2.shape[:2]}
]
batched_output = sam(batched_input, multimask_output=False)
print(batched_output[0].keys()) # output:dict_keys(['masks', 'iou_predictions', 'low_res_logits'])

输出是每个输入图像的结果列表,其中元素是字典对象,其key为:

  • masks: 一批mask,tensor张量
  • iou_predictions: 与mask相对应的iou预测值.
  • low_res_logits: 每个掩码的低分辨率logits,可以在以后的迭代中作为掩码输入再次调用模型。
fig, ax = plt.subplots(1, 2, figsize=(20, 20))ax[0].imshow(image1)
for mask in batched_output[0]['masks']:show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')ax[1].imshow(image2)
for mask in batched_output[1]['masks']:show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')plt.tight_layout()
plt.show()

4、自动生成mask

4.1 基础前置库

这里加载了一些基础库,并读取images/dog.jpg作为样例数据

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_anns(anns):if len(anns) == 0:returnsorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)ax = plt.gca()ax.set_autoscale_on(False)img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))img[:,:,3] = 0for ann in sorted_anns:m = ann['segmentation']color_mask = np.concatenate([np.random.random(3), [0.35]])img[m] = color_maskax.imshow(img)image = cv2.imread('images/dog.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

4.2 自动生成mask

要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictorsam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
#自动生成采样点对图像进行分割
mask_generator = SamAutomaticMaskGenerator(sam)masks = mask_generator.generate(image)print(len(masks))
print(masks[0].keys())
print(masks[0])plt.figure(figsize=(16,16))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

代码输出的文字信息如下:

42
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
{'segmentation': array([[False, False, False, ..., False, False, False],[False, False, False, ..., False, False, False],[False, False, False, ..., False, False, False],...,[ True,  True,  True, ..., False, False, False],[ True,  True,  True, ..., False, False, False],[False, False, False, ..., False, False, False]]), 'area': 18821, 'bbox': [0, 113, 207, 152], 'predicted_iou': 0.9937220215797424, 'point_coords': [[93.75, 146.015625]], 'stability_score': 0.9622295498847961, 'crop_box': [0, 0, 400, 267]}

所生成的图像如下

masks = mask_generator.generate(image)

Mask generation返回该图像所有的masks信息,每一个mask都是一个字典对象,mask的keys如下:

  • segmentation : np的二维数组,为二值的mask图片
  • area : mask的像素面积
  • bbox : mask的外接矩形框,为XYWH格式
  • predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
  • point_coords : 用于生成该mask的point输入
  • stability_score : mask质量的附加指标
  • crop_box : 用于以XYWH格式生成此遮罩的图像裁剪

4.3 自动mask的参数

在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,SamAutomaticMaskGenerator可以自动在图像上切片运行,以提高较小对象的性能,可以通过后处理去除杂散像素和孔洞。以下是对更多遮罩进行采样的示例配置:

mask_generator_2 = SamAutomaticMaskGenerator(model=sam,points_per_side=32,#控制采样点的间隔,值越小,采样点越密集pred_iou_thresh=0.86,#mask的iou阈值stability_score_thresh=0.92,#mask的稳定性阈值crop_n_layers=1,crop_n_points_downscale_factor=2,min_mask_region_area=50,  #最小mask面积,会使用opencv滤除掉小面积的区域
)
masks2 = mask_generator_2.generate(image)print(len(masks2)) # 69plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()

Segment Anything使用手册(交互式数据标柱|自动数据标柱)相关推荐

  1. R语言vtreat包自动处理dataframe的缺失值、使用分组的中位数来标准化数据列中每个数据的值(和中位数表连接并基于中位数进行数据标化)、计算数据列的中位数或者均值并进行数据标准化

    R语言vtreat包自动处理dataframe的缺失值.使用分组的中位数来标准化数据列中每个数据的值(和中位数表连接并基于中位数进行数据标化).计算数据列的中位数或者均值并基于中位数或者均值进行数据标 ...

  2. web数据交互_通过体育运动使用定制的交互式Web应用程序数据科学探索任何数据...

    web数据交互 Most good data projects start with the analyst doing something to get a feel for the data th ...

  3. 数据产品经理修炼手册pdf_【尼读书】数据产品经理修炼手册(附思维导图)

    前言:进入一个行业,除了要多在工作中实践和思考之外,还需要多读书.这样能够站在一个更高的角度去看问题,往往会对问题有更全面的掌握和新的认知.在[尼读书]这个栏目中,尼同学通过自己读书后的理解和整理与大 ...

  4. 数据探查平台-元数据对标专利 -- 普帝

    专利申请步骤 申请(专利权)人 讯飞智元信息科技有限公司 发明人 石金普;王慧敏;冯小凯;姚素雅 地址 安徽省合肥市高新区望江西路666号讯飞大厦8层-10层 邮编 230088 案例分享 1.一种确 ...

  5. 再谈数据标准落标,论数据模型设计工具

    工欲善其事必先利其器.工具是用来提高生产效率,其次才是管理属性. 一个工具顺不顺手极大影响生产效率和管理效果.工具用不起来,管理制度也落不下去.管理自说自话,下面各干各的,最终两张皮.Datablau ...

  6. 数值型数据和标称型数据

    在学习机器学习的工程中,发现有一种名为标称型的数据,具体如下: 标称型:一般在有限的数据中取,而且只存在'是'和'否'两种不同的结果(一般用于分类) 数值型:可以在无限的数据中取,而且数值比较具体化, ...

  7. excel柱状图负值柱下数据标签移到坐标轴上方

    excel做出的柱状图,如何将负值柱的数据标签如何移到横坐标上方,而同时正值柱数据标签保持位置不变呢? 首选需要删除影响显示效果的刻度线标签 设置数据标志数据标签为显示状态 然后针对在坐标轴以下的数据 ...

  8. 直通车实战手册:如何利用直通车开好标品类目

    金九银十是行业的旺季,同时也是各类目产品蓄力爆发的时候.双11不到两个月,现在疯狂掘金的模式已经开启,你还在等待吗? 现在做淘宝,产品技术都要占优势,如果你只有优质产品,没有相对应的资源技术相匹配,是 ...

  9. php 数据中心,数据层 · Thinkphp 独立数据中心使用手册 · 看云

    [TOC] ## 概述 数据层是用来直接操作数据表的,数据层的方法都是原子的操作,应避免在数据层中处理具体业务流程,具体业务流程应在逻辑层进行处理. ## 数据层类定义 数据层类通常需要继承核心的\t ...

最新文章

  1. php 学习笔记 数组1
  2. qt android 开发之wifi开发篇
  3. java多态调用优先级_关于java的多态方法调用顺序的问题
  4. Python-线程的生命周期
  5. Oracle 1204 RAC failover 测试 (五)
  6. beautifulsoup find函数返回值_再端一碗BeautifulSoup
  7. 怎么求平均数_EXCEL怎么求企业连续几年业绩的平均增长率
  8. Java中怎么样检查一个字符串是不是数字呢
  9. Linux下如何杀死终端
  10. centeros下安装python
  11. 一起谈.NET技术,Visual Studio对程序集签名时一个很不好用的地方
  12. 对软件研发项目管理的深入探讨
  13. Jasmine基础API
  14. 加密对冲基金究竟靠谱吗?全球第一份行业报告揭开秘密
  15. 某个蝰蛇音效的卡刷包代码分析
  16. 云呐智能运维工具,智能运维工具怎么用
  17. java if经典程序_java经典程序题15道(另附自己做的答案)
  18. arcgis小班编号问题 工具箱来喽
  19. js 将json数据自动绑定到 html table 表格中
  20. 二叉树前中后序遍历的非递归实现以及层次遍历、zig-zag型遍历详解

热门文章

  1. 文本处理三剑客之sed命令详解
  2. 【转】隐函数绘制并提取数据
  3. 可靠的UDP (RUDP)
  4. 这些岗位要注意啦,不想被迫离职,要早做打算!
  5. Lambda架构简介
  6. HTML+CSS+JavaScript网页制作案例教程-黑马程序员-第四章课后习题(播放器图标)
  7. 公安部消防局全面推进“智慧消防”建设
  8. 说说我认为的网络推广专员
  9. Matlab的多维数组操作
  10. 【经验】使用Java控制kiftd网盘服务器并实现定时导入文件功能