Segment Anything使用手册(交互式数据标柱|自动数据标柱)
主要内容包含segment-anything项目的安装、基于SamPredictor对单点输入生成mask、基于SamPredictor对多点输入生成mask、基于SamAutomaticMaskGenerator自动生成mask。
Segment Anything项目是一个可以对任何图像进行分割的项目,其论文介绍可以查看https://blog.csdn.net/a486259/article/details/131137939,其测试网站为 https://segment-anything.com
这里对Segment Anything项目的使用进行初步总结,绝大部分内容源自https://github.com/facebookresearch/segment-anything 。
注:segment-anything训练VIT模型时的输入size为1024x1024,其输出的feature size为256x64x64,进行了16倍的下采样
1、安装segment-anything
下载segment-anything项目,进入目录后执行pip install -e .
安装项目。
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
该项目依赖opencv-python pycocotools matplotlib onnxruntime onnx torch等包,安装命令如下
pip install opencv-python pycocotools matplotlib onnxruntime onnx torch
segment-anything模型是基于torch框架实现的
2. 根据提示输入生成mask
Segment Anything Model (SAM) 预测对象mask,给出所需识别出对象的提示输入(对象的粗略位置信息)。该模型首先将图像转换为图像嵌入,然后解码器根据用户输入的提示(粗略位置信息)可以生成高质量的掩模。
SamPredictor类为模型调用提供了一个简单的接口,用于提示模型的输入。它先让用户使用“set_image”方法设置图像,该方法会将图像输入转换到特征空间嵌入。然后,可以通过“predict”方法输入提示信息,以根据这些提示有效地预测掩码。predict函数支持将点和框提示以及上一次预测迭代中的mask作为输入。
2.1 前置函数库
前置库导入和函数实现
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):if random_color:color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30/255, 144/255, 255/255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):pos_points = coords[labels==1]neg_points = coords[labels==0]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax):x0, y0 = box[0], box[1]w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
2.2 显示样图
读取图片并展示
image = cv2.imread('images/truck.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()
2.3 加载SAM模型
sam_vit_b_01ec64模型的下载地址为: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
这里需要注意要使用cuda
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictorsam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)predictor = SamPredictor(sam)
#predictor.set_image(image)
其它版本的模型下载地址为
default
orvit_h
: ViT-H SAM model.vit_l
: ViT-L SAM model.vit_b
: ViT-B SAM model.
通过调用“SamPredictor.set_image”处理图像以生成图像嵌入(特征向量)。“SamPrejector”会记住此特征向量,并将其用于后续掩码预测。
predictor.set_image(image)
2.4 单点输入生成mask
要选择卡车,可以卡车上选择一个点。点以(x,y)格式输入到模型中,并带有标签1(前景点)或0(背景点)。可以输入多个点;这里我们只使用一个。所选的点将在图像上显示为星形。
此时代码及执行效果如下:
input_point = np.array([[250, 187]])
input_label = np.array([1])plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
使用“SamPredictor.prdict”进行预测。该模型返回掩码(masks)、掩码的分数(scores)以及可传递到下一次预测迭代的低分辨率掩码(logits)。
在“multimask_output=True”(默认设置)的情况下,SAM输出3个掩码,其中“scores”给出了模型对这些掩码质量的估计。此设置用于存在不明确输入提示的时候(光凭一个点无法有效识别出用户意图是组件局部、组件还是整体),并帮助模型消除与提示一致的不同对象的歧义。当为“multimask_output=False”时,它将返回一个掩码。对于单点等不明确的提示,建议使用“multimask_output=True”,即使只需要一个掩码;可以通过选择在“分数”中返回的分数最高的一个来选择最佳的单个掩码。这通常会得到更好的mask。
masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=True,
)print(masks.shape) # (number_of_masks) x H x W | output (3, 600, 900)for i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10,10))plt.imshow(image)show_mask(mask, plt.gca())show_points(input_point, input_label, plt.gca())plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)plt.axis('off')plt.show()
3、多输入生成mask
3.1 多点输入生成mask
单个输入点不明确,需要让模型返回了与其一致的多个对象。要获得单个对象,可以提供多个点。如果可用,还可以将先前迭代的掩码(logits值)提供给模型以帮助预测。当使用多个提示指定单个对象时,可以通过设置“multimask_output=False”来请求获取单个掩码。
input_point = np.array([[250, 184], [562, 322]])
input_label = np.array([1, 1])mask_input = logits[np.argmax(scores), :, :] # Choose the model's best maskmasks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,mask_input=mask_input[None, :, :],multimask_output=False,
)
print(masks.shape) #output: (1, 600, 900)plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
输入负点明确区域
input_label与input_point相对应,为0时表示是负点
input_point = np.array([[250, 187], [561, 322]])
input_label = np.array([1, 0])#为0时表示是负点,即第二个点[561, 322]是负点mask_input = logits[np.argmax(scores), :, :] # Choose the model's best maskmasks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,mask_input=mask_input[None, :, :],multimask_output=False,
)plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
3.2 boxes输入生成mask
支持将xyxy格式的box作为输入,将框内的主体目标识别出来(类似于实例分割)
input_box = np.array([212, 300, 350, 437])
masks, _, _ = predictor.predict(point_coords=None,point_labels=None,box=input_box[None, :],multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
3.3 同时输入点与boxes生成mask
point和boxes可以同时输入,只需将这两种类型的提示都包括在预测器中即可。在这里,这可以用来只选择卡车的轮胎(将车轴部分设置为负点),而不是整个车轮。
input_box = np.array([215, 310, 350, 430]) #只能默认框住正类
input_point = np.array([[287, 375]])
input_label = np.array([0]) #将车轴部分设置为负点masks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,box=input_box,multimask_output=False,
)plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
3.4 同时输入多个boxes生成mask
SamPredictor可以使用predict_tarch方法对同一图像输入多个提示(points、boxes)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合。例如,假设我们有几个来自对象检测器的输出结果。
SamPredictor对象(此外也可以使用segment_anything.utils.transforms)可以将boxes信息编码为特征向量(以实现对任意数量boxes的支持,transformed_boxes),然后预测mask。
input_boxes = torch.tensor([[75, 275, 1725, 850],[425, 600, 700, 875],[1375, 550, 1650, 800],[1240, 675, 1400, 750],
], device=predictor.device) #假设这是目标检测的预测结果
input_boxes=input_boxes/2transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(point_coords=None,point_labels=None,boxes=transformed_boxes,multimask_output=False,
)print(masks.shape) # (batch_size) x (num_predicted_masks_per_input) x H x W | output: torch.Size([4, 1, 600, 900])plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
3.5 端到端的批量推理
如果所有提示输入都已经明确的,则可以以端到端的方式直接运行SAM。这允许SAM对图像进行批处理,以下代码构建了2个image和boxes。
image1 = cv2.imread('images/truck.jpg')
image1_boxes = torch.tensor([[75, 275, 1725, 850],[425, 600, 700, 875],[1375, 550, 1650, 800],[1240, 675, 1400, 750],
], device=sam.device)image2 = cv2.imread('images/groceries.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([[450, 170, 520, 350],[350, 190, 450, 350],[500, 170, 580, 350],[580, 170, 640, 350],
], device=sam.device)
图像和提示都作为PyTorch张量输入,这些张量(图像和提示输入)已经被编码为特征向量。所有的输入数据都被封装为list,每个元素都是一个dict,它的key如下:
image
: CHW格式的PyTorch tensor .original_size
: 图像原始大小, (H, W) format.point_coords
: 一批输入点格式.point_labels
: 每个输入点所对应的类型(正例或负例).boxes
: 一批输入的boxe(只能是正例).mask_inputs
: 一批输入的mask.
如果没有相应的信息,可以不进行输入,但image必须输入
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)def prepare_image(image, transform, device):image = transform.apply_image(image)image = torch.as_tensor(image, device=device.device) return image.permute(2, 0, 1).contiguous()batched_input = [{'image': prepare_image(image1, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),'original_size': image1.shape[:2]},{'image': prepare_image(image2, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),'original_size': image2.shape[:2]}
]
batched_output = sam(batched_input, multimask_output=False)
print(batched_output[0].keys()) # output:dict_keys(['masks', 'iou_predictions', 'low_res_logits'])
输出是每个输入图像的结果列表,其中元素是字典对象,其key为:
masks
: 一批mask,tensor张量iou_predictions
: 与mask相对应的iou预测值.low_res_logits
: 每个掩码的低分辨率logits,可以在以后的迭代中作为掩码输入再次调用模型。
fig, ax = plt.subplots(1, 2, figsize=(20, 20))ax[0].imshow(image1)
for mask in batched_output[0]['masks']:show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')ax[1].imshow(image2)
for mask in batched_output[1]['masks']:show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')plt.tight_layout()
plt.show()
4、自动生成mask
4.1 基础前置库
这里加载了一些基础库,并读取images/dog.jpg作为样例数据
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_anns(anns):if len(anns) == 0:returnsorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)ax = plt.gca()ax.set_autoscale_on(False)img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))img[:,:,3] = 0for ann in sorted_anns:m = ann['segmentation']color_mask = np.concatenate([np.random.random(3), [0.35]])img[m] = color_maskax.imshow(img)image = cv2.imread('images/dog.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()
4.2 自动生成mask
要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictorsam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
#自动生成采样点对图像进行分割
mask_generator = SamAutomaticMaskGenerator(sam)masks = mask_generator.generate(image)print(len(masks))
print(masks[0].keys())
print(masks[0])plt.figure(figsize=(16,16))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
代码输出的文字信息如下:
42
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
{'segmentation': array([[False, False, False, ..., False, False, False],[False, False, False, ..., False, False, False],[False, False, False, ..., False, False, False],...,[ True, True, True, ..., False, False, False],[ True, True, True, ..., False, False, False],[False, False, False, ..., False, False, False]]), 'area': 18821, 'bbox': [0, 113, 207, 152], 'predicted_iou': 0.9937220215797424, 'point_coords': [[93.75, 146.015625]], 'stability_score': 0.9622295498847961, 'crop_box': [0, 0, 400, 267]}
所生成的图像如下
masks = mask_generator.generate(image)
Mask generation返回该图像所有的masks信息,每一个mask都是一个字典对象,mask的keys如下:
segmentation
: np的二维数组,为二值的mask图片area
: mask的像素面积bbox
: mask的外接矩形框,为XYWH格式predicted_iou
: 该mask的质量(模型预测出的与真实框的iou)point_coords
: 用于生成该mask的point输入stability_score
: mask质量的附加指标crop_box
: 用于以XYWH格式生成此遮罩的图像裁剪
4.3 自动mask的参数
在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,SamAutomaticMaskGenerator可以自动在图像上切片运行,以提高较小对象的性能,可以通过后处理去除杂散像素和孔洞。以下是对更多遮罩进行采样的示例配置:
mask_generator_2 = SamAutomaticMaskGenerator(model=sam,points_per_side=32,#控制采样点的间隔,值越小,采样点越密集pred_iou_thresh=0.86,#mask的iou阈值stability_score_thresh=0.92,#mask的稳定性阈值crop_n_layers=1,crop_n_points_downscale_factor=2,min_mask_region_area=50, #最小mask面积,会使用opencv滤除掉小面积的区域
)
masks2 = mask_generator_2.generate(image)print(len(masks2)) # 69plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()
Segment Anything使用手册(交互式数据标柱|自动数据标柱)相关推荐
- R语言vtreat包自动处理dataframe的缺失值、使用分组的中位数来标准化数据列中每个数据的值(和中位数表连接并基于中位数进行数据标化)、计算数据列的中位数或者均值并进行数据标准化
R语言vtreat包自动处理dataframe的缺失值.使用分组的中位数来标准化数据列中每个数据的值(和中位数表连接并基于中位数进行数据标化).计算数据列的中位数或者均值并基于中位数或者均值进行数据标 ...
- web数据交互_通过体育运动使用定制的交互式Web应用程序数据科学探索任何数据...
web数据交互 Most good data projects start with the analyst doing something to get a feel for the data th ...
- 数据产品经理修炼手册pdf_【尼读书】数据产品经理修炼手册(附思维导图)
前言:进入一个行业,除了要多在工作中实践和思考之外,还需要多读书.这样能够站在一个更高的角度去看问题,往往会对问题有更全面的掌握和新的认知.在[尼读书]这个栏目中,尼同学通过自己读书后的理解和整理与大 ...
- 数据探查平台-元数据对标专利 -- 普帝
专利申请步骤 申请(专利权)人 讯飞智元信息科技有限公司 发明人 石金普;王慧敏;冯小凯;姚素雅 地址 安徽省合肥市高新区望江西路666号讯飞大厦8层-10层 邮编 230088 案例分享 1.一种确 ...
- 再谈数据标准落标,论数据模型设计工具
工欲善其事必先利其器.工具是用来提高生产效率,其次才是管理属性. 一个工具顺不顺手极大影响生产效率和管理效果.工具用不起来,管理制度也落不下去.管理自说自话,下面各干各的,最终两张皮.Datablau ...
- 数值型数据和标称型数据
在学习机器学习的工程中,发现有一种名为标称型的数据,具体如下: 标称型:一般在有限的数据中取,而且只存在'是'和'否'两种不同的结果(一般用于分类) 数值型:可以在无限的数据中取,而且数值比较具体化, ...
- excel柱状图负值柱下数据标签移到坐标轴上方
excel做出的柱状图,如何将负值柱的数据标签如何移到横坐标上方,而同时正值柱数据标签保持位置不变呢? 首选需要删除影响显示效果的刻度线标签 设置数据标志数据标签为显示状态 然后针对在坐标轴以下的数据 ...
- 直通车实战手册:如何利用直通车开好标品类目
金九银十是行业的旺季,同时也是各类目产品蓄力爆发的时候.双11不到两个月,现在疯狂掘金的模式已经开启,你还在等待吗? 现在做淘宝,产品技术都要占优势,如果你只有优质产品,没有相对应的资源技术相匹配,是 ...
- php 数据中心,数据层 · Thinkphp 独立数据中心使用手册 · 看云
[TOC] ## 概述 数据层是用来直接操作数据表的,数据层的方法都是原子的操作,应避免在数据层中处理具体业务流程,具体业务流程应在逻辑层进行处理. ## 数据层类定义 数据层类通常需要继承核心的\t ...
最新文章
- php 学习笔记 数组1
- qt android 开发之wifi开发篇
- java多态调用优先级_关于java的多态方法调用顺序的问题
- Python-线程的生命周期
- Oracle 1204 RAC failover 测试 (五)
- beautifulsoup find函数返回值_再端一碗BeautifulSoup
- 怎么求平均数_EXCEL怎么求企业连续几年业绩的平均增长率
- Java中怎么样检查一个字符串是不是数字呢
- Linux下如何杀死终端
- centeros下安装python
- 一起谈.NET技术,Visual Studio对程序集签名时一个很不好用的地方
- 对软件研发项目管理的深入探讨
- Jasmine基础API
- 加密对冲基金究竟靠谱吗?全球第一份行业报告揭开秘密
- 某个蝰蛇音效的卡刷包代码分析
- 云呐智能运维工具,智能运维工具怎么用
- java if经典程序_java经典程序题15道(另附自己做的答案)
- arcgis小班编号问题 工具箱来喽
- js 将json数据自动绑定到 html table 表格中
- 二叉树前中后序遍历的非递归实现以及层次遍历、zig-zag型遍历详解