GitHub:sczhou/ProPainter

一、介绍:

Segment-Anything|万物分割-openAI维基百科

分割任何物体(SA)项目提出了Segment Anything Model (SAM),它可以通过输入提示生成高质量的对象掩码,用于图像中所有对象的分割任务。该模型在包含1100万张图像和110亿个掩码的数据集上训练,展现出强大的零样本性能。研究人员还提出了数据引擎和SAM模型,用于训练和使用该项目。该研究提供了一种全新的思路和方法,具有广泛的应用前景,并且可以有效地解决图像分割问题。最终,他们将该数据集和模型公开发布,以促进计算机视觉基础模型的研究。该模型有望取代传统的OpenCV,成为未来图像分割抠图领域的主流模型。

Segment-Anything|万物分割-openAI维基百科

该项目用于将视频中的人物或水印移除,总的来说效果还是不错的,但美中不足的是,需要自己对视频需要处理的主体进行抠像,如果将视频转换成序列帧,然后在Photoshop中逐帧来处理显示是不现实的,工程量极大。

二、配置环境

conda create -n propainter python=3.8
pip install opencv-python -i https://mirror.baidu.com/pypi/simple
pip install pycocotools -i https://mirror.baidu.com/pypi/simple
pip install matplotlib -i https://mirror.baidu.com/pypi/simple
pip install onnxruntime -i https://mirror.baidu.com/pypi/simple
pip install onnx -i https://mirror.baidu.com/pypi/simple

三、模型下载

官网下载

四、使用教程

全局分割

方法说明

使用 SAM 自动生成对象掩码,由于SAM可以有效地处理提示,因此可以通过在图像上采样大量提示来生成整个图像的掩码。这种方法被用来生成SA-1B数据集。

SamAutomaticMaskGenerator实现了这种能力。它通过在图像上的网格中对单点输入提示进行采样工作,从每个采样点中 SAM 可以预测多个掩码。然后,对掩码进行质量过滤和非最大值抑制去重。其他选项允许进一步提高掩码的质量和数量,例如在图像的多个裁剪区域上运行预测,或对掩码进行后处理以消除小的断开区域和孔洞。

效果展示

本站所提供的脚本内置三种输出方式,得到的效果分别如下:

Segment-Anything|万物分割-openAI维基百科
原图

修改配置

  • 第12行
    • 您所下载的模型路径,如果您下载本站所提供的模型则无需修改。
  • 第21行
    • 您想使用的模型名称,默认h模型效果最好,建议不要修改。
  • 第22行
    • 您需要进行全局自动分割的原图文件路径
  • 第23行
    • 如果您需要使用OutMask方法进行分割,则需要填写输出路径,也可以保持默认自动生成。
  • 第173行
    • 您想以哪种方式进行自动分割?直接修改为您想分割的方法名称即可。

代码示例

# 此脚本由openai.wiki提供转载请注明出处

# 导入相关模块
import os,time,cv2,torch,torch.nn
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# 模型路径
model_Path = './checkpoint'

# 定义模型路径
models = {'b':['%s/sam_vit_b_01ec64.pth' % model_Path,'vit_b'],
          'h':['%s/sam_vit_h_4b8939.pth' % model_Path,'vit_h'],
          'l':['%s/sam_vit_l_0b3195.pth' % model_Path,'vit_l'] 
}

# 必要参数
model = models['h']                           # 直接输入想使用的模型字母即可可用修改参数为'b'|'h'|'l'
image_path = './notebooks/images/dog.jpg'     # 输入的图片路径
output_folder = './Mask_Folder'               # 输出 Mask 的文件夹

# 窗口定义
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

# 通用定义
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

sam_checkpoint = model[0]
model_type = model[1]

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# 自动掩码生成
def mask_generator():
    mask_generator = SamAutomaticMaskGenerator(sam)

    masks = mask_generator.generate(image)

    print('mask总数:%s' % len(masks))
    # 掩码生成返回一个掩码列表其中每个掩码都是一个包含有关掩码的各种数据的字典这些键是

    # 分割面具
    # area掩码的面积以像素为单位
    # bbox : XYWH 格式的掩码边界框
    # predicted_iou模型自己对掩码质量的预测
    # point_coords生成此掩码的采样输入点
    # stability_score掩模质量的附加度量
    # crop_box用于生成 XYWH 格式蒙版的图像裁剪
    
    print(masks[0].keys())

    plt.figure(figsize=(20,20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.show()

# 自动掩码生成选项
# 自动掩码生成中有几个可调参数用于控制采样点的密度以及去除低质量或重复掩码的阈值
# 此外生成可以在图像的裁剪上自动运行以提高较小对象的性能并且后处理可以去除杂散像素和孔洞
# 以下是对更多掩码进行采样的示例配置
def mask_generator_2():
    mask_generator_2 = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,           # 需要 opencv 运行后处理
    )

    masks2 = mask_generator_2.generate(image)

    plt.figure(figsize=(20,20))
    plt.imshow(image)
    show_anns(masks2)
    plt.axis('off')
    plt.show() 

def OutMask():
    # 检测输出文件夹是否存在不存在则自动创建
    os.makedirs(output_folder, exist_ok=True)

    # 载入模型
    print("%s 模型正在载入" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    sam = sam_model_registry["vit_h"](checkpoint=model)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sam = sam.to(device)

    # 输出模型加载完成的当前时间
    print("%s 模型载入完成" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

    # 这里是加载图片
    image = cv2.imread(image_path)
    # 输出图片加载完成的current时间
    print("%s 图片加载完成" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

    # 这里是加载图片这里的image_path是图片的路径
    print("%s 正在分割图像" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    # 这里是预测,不用提示词,进行全图分割
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)

    # 使用提示词,进行局部分割
    # predictor = SamPredictor(sam)
    # predictor.set_image(image)
    # masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=None, mask_input=None, multimask_output=True, return_logits=False)

    print('%s 图像分割完成' % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

    # 循环保存mask
    # 遍历 masks 列表并保存每个掩码
    for i, mask in enumerate(masks):
        mask_array = mask['segmentation']
        mask_uint8 = (mask_array * 255).astype(np.uint8)

        # 为每个掩码生成一个唯一的文件名
        output_file = os.path.join(output_folder, f"Mask_{i+1}.png")

        # 保存掩码
        cv2.imwrite(output_file, mask_uint8)


    # 输出完整的mask
    # 获取输入图像的尺寸
    height, width, _ = image.shape

    # 创建一个全零数组用于合并掩码
    merged_mask = np.zeros((height, width), dtype=np.uint8)

    # 遍历 masks 列表并合并每个掩码
    for i, mask in enumerate(masks):
        mask_array = mask['segmentation']
        mask_uint8 = (mask_array * 255).astype(np.uint8)

        # 为每个掩码生成一个唯一的文件名
        output_file = os.path.join(output_folder, f"Mask_{i+1}.png")

        # 保存掩码
        cv2.imwrite(output_file, mask_uint8)

        # 将当前掩码添加到合并掩码上
        merged_mask = np.maximum(merged_mask, mask_uint8)

    # 保存合并后的掩码
    merged_output_file = os.path.join(output_folder, "Mask_All.png")
    cv2.imwrite(merged_output_file, merged_mask)


# 本站提供三种输方式分别为mask_generator|mask_generator_2|OutMask三种请根据自身情况修改
run_function = mask_generator()               # 运行模式

局部分割

方法说明

使用SAM从提示中获取对象掩码

Segment Anything Model (SAM) 是一种能够根据指定对象的提示,预测对象掩码的强大模型。该模型首先将图像转换为一个图像嵌入,使得可以从提示中高效地生成高质量的掩码。

SamPredictor类提供了一个易于使用的接口,用于与该模型进行交互。它允许用户首先使用set_image方法设置一张图像,并计算出必要的图像嵌入。然后,通过predict方法提供提示,从这些提示中高效地预测出掩码。该模型可以接受点提示、框提示、以及上一次预测的掩码作为输入。

效果展示

本站所提供的脚本内置五种输出方式,得到的效果分别如下:

Segment-Anything|万物分割-openAI维基百科
SOWS
Segment-Anything|万物分割-openAI维基百科
ETEBI

修改配置

  • 第13行
    • 您所下载的模型路径,如果您下载本站所提供的模型则无需修改。
  • 第22行
    • 您想使用的模型名称,默认h模型效果最好,建议不要修改。
  • 第23行
    • 您需要进行全局自动分割的原图文件路径
  • 第251行
    • 您想以哪种方式进行自动分割?直接修改为您想分割的方法名称即可。
# 此脚本由openai.wiki提供转载请注明出处

# 导入相关模块
import cv2,torch,torch.nn,sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

# 模型路径
model_Path = './checkpoint'

# 定义模型路径
models = {'b':['%s/sam_vit_b_01ec64.pth' % model_Path,'vit_b'],
          'h':['%s/sam_vit_h_4b8939.pth' % model_Path,'vit_h'],
          'l':['%s/sam_vit_l_0b3195.pth' % model_Path,'vit_l'] 
}

# 必要参数
model = models['h']                             # 直接输入想使用的模型字母即可可用修改参数为'b'|'h'|'l'
image_path = './notebooks/images/truck.jpg'     # 输入的图片路径

# 窗口定义
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  

# 通用定义
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

sam_checkpoint = model[0]
model_type = model[1]

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)
predictor.set_image(image)


input_point = np.array([[500, 375]])
input_label = np.array([1])

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

# 示例图像 | Example image
def EI():
    masks.shape  # (number_of_masks) x H x W

    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10,10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show() 
    
# 选择对象的 SAM | Selecting objects with SAM
# 首先加载SAM模型和预测器将下面的路径更改为指向SAM检查点
# 建议在CUDA上运行并使用默认模型以获得最佳结果
def SOWS():
    input_point = np.array([[500, 375], [1125, 625]])
    input_label = np.array([1, 1])

    mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
    )
    masks.shape
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show() 
    input_point = np.array([[500, 375], [1125, 625]])
    input_label = np.array([1, 0])

    mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

# 使用框指定特定对象 | Specifying a specific object with a box
# 该模型还可以接受框作为输入框以xyxy格式提供
def SASOWAB():
    input_box = np.array([425, 600, 700, 875])
    masks, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    plt.axis('off')
    plt.show()

# 结合点和框 | Combining points and boxes
# 可以通过将两种类型的提示都包含在预测器中来结合点和框这可以用于选择卡车的轮胎而不是整个车轮
def CPAB():
    input_box = np.array([425, 600, 700, 875])
    input_point = np.array([[575, 750]])
    input_label = np.array([0])
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box,
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

# 批量提示输入 | Batched prompt inputs
# SamPredictor可以使用predict_torch方法为同一图像输入多个提示该方法假定输入点已经是torch张量并且已经转换为输入帧例如假设我们有来自目标检测器的多个框输出
def BPI():
    input_boxes = torch.tensor([
        [75, 275, 1725, 850],
        [425, 600, 700, 875],
        [1375, 550, 1650, 800],
        [1240, 675, 1400, 750],
    ], device=predictor.device)

    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    masks.shape  # (batch_size) x (num_predicted_masks_per_input) x H x W
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box in input_boxes:
        show_box(box.cpu().numpy(), plt.gca())
    plt.axis('off')
    plt.show()

# 端到端的批量推断 | End-to-end batched inference
# 如果所有提示都提前可用则可以直接以端到端的方式运行SAM这也允许在图像上进行批处理
def ETEBI():
    image1 = image  # truck.jpg from above
    image1_boxes = torch.tensor([
        [75, 275, 1725, 850],
        [425, 600, 700, 875],
        [1375, 550, 1650, 800],
        [1240, 675, 1400, 750],
    ], device=sam.device)

    image2 = cv2.imread(image_path)
    image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
    image2_boxes = torch.tensor([
        [450, 170, 520, 350],
        [350, 190, 450, 350],
        [500, 170, 580, 350],
        [580, 170, 640, 350],
    ], device=sam.device)
    from segment_anything.utils.transforms import ResizeLongestSide
    resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

    def prepare_image(image, transform, device):
        image = transform.apply_image(image)
        image = torch.as_tensor(image, device=device.device) 
        return image.permute(2, 0, 1).contiguous()
    batched_input = [
        {
            'image': prepare_image(image1, resize_transform, sam),
            'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
            'original_size': image1.shape[:2]
        },
        {
            'image': prepare_image(image2, resize_transform, sam),
            'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
            'original_size': image2.shape[:2]
        }
    ]
    batched_output = sam(batched_input, multimask_output=False)
    # 输出是针对每个输入图像的结果列表其中列表元素是具有以下键的字典

    # masks预测的二进制掩码的批处理torch张量大小与原始图像相同
    # iou_predictions模型对每个掩码的质量预测
    # low_res_logits每个掩码的低分辨率logits可以在稍后的迭代中作为掩码输入传回模型
    batched_output[0].keys()
    fig, ax = plt.subplots(1, 2, figsize=(20, 20))

    ax[0].imshow(image1)
    for mask in batched_output[0]['masks']:
        show_mask(mask.cpu().numpy(), ax[0], random_color=True)
    for box in image1_boxes:
        show_box(box.cpu().numpy(), ax[0])
    ax[0].axis('off')

    ax[1].imshow(image2)
    for mask in batched_output[1]['masks']:
        show_mask(mask.cpu().numpy(), ax[1], random_color=True)
    for box in image2_boxes:
        show_box(box.cpu().numpy(), ax[1])
    ax[1].axis('off')

    plt.tight_layout()
    plt.show()

# 本站提供五种输方式分别为SOWS|SASOWAB|CPAB|BPI|ETEBI五种请根据自身情况修改
run_function = SOWS()               # 运行模式

四、总结

经本站测试,暂时并未发现较为明显的问题,但是分割效果真的非常不错。