GitHub:sczhou/ProPainter
一、介绍:
分割任何物体(SA)项目提出了Segment Anything Model (SAM),它可以通过输入提示生成高质量的对象掩码,用于图像中所有对象的分割任务。该模型在包含1100万张图像和110亿个掩码的数据集上训练,展现出强大的零样本性能。研究人员还提出了数据引擎和SAM模型,用于训练和使用该项目。该研究提供了一种全新的思路和方法,具有广泛的应用前景,并且可以有效地解决图像分割问题。最终,他们将该数据集和模型公开发布,以促进计算机视觉基础模型的研究。该模型有望取代传统的OpenCV,成为未来图像分割抠图领域的主流模型。
该项目用于将视频中的人物或水印移除,总的来说效果还是不错的,但美中不足的是,需要自己对视频需要处理的主体进行抠像,如果将视频转换成序列帧,然后在Photoshop中逐帧来处理显示是不现实的,工程量极大。
二、配置环境
conda create -n propainter python=3.8
pip install opencv-python -i https://mirror.baidu.com/pypi/simple
pip install pycocotools -i https://mirror.baidu.com/pypi/simple
pip install matplotlib -i https://mirror.baidu.com/pypi/simple
pip install onnxruntime -i https://mirror.baidu.com/pypi/simple
pip install onnx -i https://mirror.baidu.com/pypi/simple
三、模型下载
官网下载
- vit_h
- vit_l
- vit_b
四、使用教程
全局分割
方法说明
使用 SAM 自动生成对象掩码,由于SAM
可以有效地处理提示,因此可以通过在图像上采样大量提示来生成整个图像的掩码。这种方法被用来生成SA-1B
数据集。
类SamAutomaticMaskGenerator
实现了这种能力。它通过在图像上的网格中对单点输入提示进行采样工作,从每个采样点中 SAM 可以预测多个掩码。然后,对掩码进行质量过滤和非最大值抑制去重。其他选项允许进一步提高掩码的质量和数量,例如在图像的多个裁剪区域上运行预测,或对掩码进行后处理以消除小的断开区域和孔洞。
效果展示
本站所提供的脚本内置三种输出方式,得到的效果分别如下:
修改配置
- 第12行
- 您所下载的模型路径,如果您下载本站所提供的模型则无需修改。
- 第21行
- 您想使用的模型名称,默认h模型效果最好,建议不要修改。
- 第22行
- 您需要进行全局自动分割的原图文件路径
- 第23行
- 如果您需要使用OutMask方法进行分割,则需要填写输出路径,也可以保持默认自动生成。
- 第173行
- 您想以哪种方式进行自动分割?直接修改为您想分割的方法名称即可。
代码示例
# 此脚本由openai.wiki提供,转载请注明出处。
# 导入相关模块
import os,time,cv2,torch,torch.nn
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# 模型路径
model_Path = './checkpoint'
# 定义模型路径
models = {'b':['%s/sam_vit_b_01ec64.pth' % model_Path,'vit_b'],
'h':['%s/sam_vit_h_4b8939.pth' % model_Path,'vit_h'],
'l':['%s/sam_vit_l_0b3195.pth' % model_Path,'vit_l']
}
# 必要参数
model = models['h'] # 直接输入想使用的模型字母即可,可用修改参数为'b'|'h'|'l'
image_path = './notebooks/images/dog.jpg' # 输入的图片路径
output_folder = './Mask_Folder' # 输出 Mask 的文件夹
# 窗口定义
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
# 通用定义
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sam_checkpoint = model[0]
model_type = model[1]
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
# 自动掩码生成
def mask_generator():
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print('mask总数:%s' % len(masks))
# 掩码生成返回一个掩码列表,其中每个掩码都是一个包含有关掩码的各种数据的字典。 这些键是:
# 分割:面具
# area :掩码的面积(以像素为单位)
# bbox : XYWH 格式的掩码边界框
# predicted_iou :模型自己对掩码质量的预测
# point_coords :生成此掩码的采样输入点
# stability_score :掩模质量的附加度量
# crop_box :用于生成 XYWH 格式蒙版的图像裁剪
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
# 自动掩码生成选项
# 自动掩码生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩码的阈值。
# 此外,生成可以在图像的裁剪上自动运行以提高较小对象的性能,并且后处理可以去除杂散像素和孔洞。
# 以下是对更多掩码进行采样的示例配置:
def mask_generator_2():
mask_generator_2 = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # 需要 opencv 运行后处理
)
masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()
def OutMask():
# 检测输出文件夹是否存在,不存在则自动创建。
os.makedirs(output_folder, exist_ok=True)
# 载入模型
print("%s 模型正在载入" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
sam = sam_model_registry["vit_h"](checkpoint=model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam = sam.to(device)
# 输出模型加载完成的当前时间
print("%s 模型载入完成" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
# 这里是加载图片
image = cv2.imread(image_path)
# 输出图片加载完成的current时间
print("%s 图片加载完成" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
# 这里是加载图片,这里的image_path是图片的路径
print("%s 正在分割图像" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
# 这里是预测,不用提示词,进行全图分割
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
# 使用提示词,进行局部分割
# predictor = SamPredictor(sam)
# predictor.set_image(image)
# masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=None, mask_input=None, multimask_output=True, return_logits=False)
print('%s 图像分割完成' % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
# 循环保存mask
# 遍历 masks 列表并保存每个掩码
for i, mask in enumerate(masks):
mask_array = mask['segmentation']
mask_uint8 = (mask_array * 255).astype(np.uint8)
# 为每个掩码生成一个唯一的文件名
output_file = os.path.join(output_folder, f"Mask_{i+1}.png")
# 保存掩码
cv2.imwrite(output_file, mask_uint8)
# 输出完整的mask
# 获取输入图像的尺寸
height, width, _ = image.shape
# 创建一个全零数组,用于合并掩码
merged_mask = np.zeros((height, width), dtype=np.uint8)
# 遍历 masks 列表并合并每个掩码
for i, mask in enumerate(masks):
mask_array = mask['segmentation']
mask_uint8 = (mask_array * 255).astype(np.uint8)
# 为每个掩码生成一个唯一的文件名
output_file = os.path.join(output_folder, f"Mask_{i+1}.png")
# 保存掩码
cv2.imwrite(output_file, mask_uint8)
# 将当前掩码添加到合并掩码上
merged_mask = np.maximum(merged_mask, mask_uint8)
# 保存合并后的掩码
merged_output_file = os.path.join(output_folder, "Mask_All.png")
cv2.imwrite(merged_output_file, merged_mask)
# 本站提供三种输方式,分别为mask_generator|mask_generator_2|OutMask三种,请根据自身情况修改。
run_function = mask_generator() # 运行模式
局部分割
方法说明
使用SAM从提示中获取对象掩码
Segment Anything Model (SAM) 是一种能够根据指定对象的提示,预测对象掩码的强大模型。该模型首先将图像转换为一个图像嵌入,使得可以从提示中高效地生成高质量的掩码。
SamPredictor类提供了一个易于使用的接口,用于与该模型进行交互。它允许用户首先使用set_image方法设置一张图像,并计算出必要的图像嵌入。然后,通过predict方法提供提示,从这些提示中高效地预测出掩码。该模型可以接受点提示、框提示、以及上一次预测的掩码作为输入。
效果展示
本站所提供的脚本内置五种输出方式,得到的效果分别如下:
修改配置
- 第13行
- 您所下载的模型路径,如果您下载本站所提供的模型则无需修改。
- 第22行
- 您想使用的模型名称,默认h模型效果最好,建议不要修改。
- 第23行
- 您需要进行全局自动分割的原图文件路径
- 第251行
- 您想以哪种方式进行自动分割?直接修改为您想分割的方法名称即可。
# 此脚本由openai.wiki提供,转载请注明出处。
# 导入相关模块
import cv2,torch,torch.nn,sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
# 模型路径
model_Path = './checkpoint'
# 定义模型路径
models = {'b':['%s/sam_vit_b_01ec64.pth' % model_Path,'vit_b'],
'h':['%s/sam_vit_h_4b8939.pth' % model_Path,'vit_h'],
'l':['%s/sam_vit_l_0b3195.pth' % model_Path,'vit_l']
}
# 必要参数
model = models['h'] # 直接输入想使用的模型字母即可,可用修改参数为'b'|'h'|'l'
image_path = './notebooks/images/truck.jpg' # 输入的图片路径
# 窗口定义
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
# 通用定义
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sam_checkpoint = model[0]
model_type = model[1]
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
# 示例图像 | Example image
def EI():
masks.shape # (number_of_masks) x H x W
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
# 选择对象的 SAM | Selecting objects with SAM
# 首先,加载SAM模型和预测器。将下面的路径更改为指向SAM检查点。
# 建议在CUDA上运行并使用默认模型以获得最佳结果。
def SOWS():
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
masks.shape
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
# 使用框指定特定对象 | Specifying a specific object with a box
# 该模型还可以接受框作为输入,框以xyxy格式提供。
def SASOWAB():
input_box = np.array([425, 600, 700, 875])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
# 结合点和框 | Combining points and boxes
# 可以通过将两种类型的提示都包含在预测器中来结合点和框。这可以用于选择卡车的轮胎而不是整个车轮。
def CPAB():
input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
# 批量提示输入 | Batched prompt inputs
# SamPredictor可以使用predict_torch方法为同一图像输入多个提示。该方法假定输入点已经是torch张量,并且已经转换为输入帧。例如,假设我们有来自目标检测器的多个框输出。
def BPI():
input_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
# 端到端的批量推断 | End-to-end batched inference
# 如果所有提示都提前可用,则可以直接以端到端的方式运行SAM。这也允许在图像上进行批处理。
def ETEBI():
image1 = image # truck.jpg from above
image1_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=sam.device)
image2 = cv2.imread(image_path)
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([
[450, 170, 520, 350],
[350, 190, 450, 350],
[500, 170, 580, 350],
[580, 170, 640, 350],
], device=sam.device)
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
def prepare_image(image, transform, device):
image = transform.apply_image(image)
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()
batched_input = [
{
'image': prepare_image(image1, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
'original_size': image1.shape[:2]
},
{
'image': prepare_image(image2, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
'original_size': image2.shape[:2]
}
]
batched_output = sam(batched_input, multimask_output=False)
# 输出是针对每个输入图像的结果列表,其中列表元素是具有以下键的字典:
# masks:预测的二进制掩码的批处理torch张量,大小与原始图像相同。
# iou_predictions:模型对每个掩码的质量预测。
# low_res_logits:每个掩码的低分辨率logits,可以在稍后的迭代中作为掩码输入传回模型。
batched_output[0].keys()
fig, ax = plt.subplots(1, 2, figsize=(20, 20))
ax[0].imshow(image1)
for mask in batched_output[0]['masks']:
show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:
show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')
ax[1].imshow(image2)
for mask in batched_output[1]['masks']:
show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:
show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')
plt.tight_layout()
plt.show()
# 本站提供五种输方式,分别为SOWS|SASOWAB|CPAB|BPI|ETEBI五种,请根据自身情况修改。
run_function = SOWS() # 运行模式
四、总结
经本站测试,暂时并未发现较为明显的问题,但是分割效果真的非常不错。