github地址:https://github.com/PeterL1n/RobustVideoMatting

一、项目介绍:

二、环境安装

  • 平台:windows 10
  • 编译器:pycharm
  • cuda 11.3
  • cudnn 8.2.0.53
conda create -n  RobustVideoMatting  python=3.8
conda activate RobustVideoMatting  
pip install av==8.0.3
pip install torch-1.11.0+cu113-cp38-cp38-win_amd64.whl
pip install torchvision-0.12.0+cu113-cp38-cp38-win_amd64.whl
pip install tqdm==4.61.1
pip install pims==0.5

三、执行代码

摄像头实时抠像

创建main.py

代码如下:

import cv2
import time
from torchvision import transforms
from typing import Optional, Tuple
import torch
from model import MattingNetwork
 
 
 
def auto_downsample_ratio(h, w):
    """
    Automatically find a downsample ratio so that the largest side of the resolution be 512px.
    """
    return min(512 / max(h, w), 1)
 
 
def get_frame(num):
    cap = cv2.VideoCapture(num)
    while True:
        ret, frame = cap.read()
        fps= cap.get(cv2.CAP_PROP_FPS)
        print("摄像头帧速:", fps)
        yield frame
 
 
def convert_video(model,
                  input_resize: Optional[Tuple[int, int]] = None,
                  downsample_ratio: Optional[float] = None,
                  device: Optional[str] = None,
                  dtype: Optional[torch.dtype] = None):
    """
    Args:
        input_resize: If provided, the input are first resized to (w, h).
        downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
        device: Only need to manually provide if model is a TorchScript freezed model.
        dtype: Only need to manually provide if model is a TorchScript freezed model.
    """
    assert downsample_ratio is None or (
                downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
 
 
    # Initialize transform
    if input_resize is not None:
        transform = transforms.Compose([
            transforms.Resize(input_resize[::-1]),
            transforms.ToTensor()
        ])
    else:
        transform = transforms.ToTensor()
 
    # Inference
    print("------------------------------------------------------------>")
    model = model.eval()
    if device is None or dtype is None:
        param = next(model.parameters())
        dtype = param.dtype
        device = param.device
 
    bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)
 
    with torch.no_grad():
        rec = [None] * 4
        for src in get_frame(0):
            src = transform(src)
            src = src.unsqueeze(0)
 
            if downsample_ratio is None:
                downsample_ratio = auto_downsample_ratio(*src.shape[2:])
 
            src = src.to(device, dtype, non_blocking=True).unsqueeze(0)  # [B, T, C, H, W]
            t1 = time.time()
            fgr, pha, *rec = model(src, *rec, downsample_ratio)
            print("frame_cost:", (time.time() - t1) / src.shape[1])
            print("推理帧率:{:.2f}".format(1/((time.time() - t1) / src.shape[1])))
 
            com = fgr * pha + bgr * (1 - pha)
            frames = com[0]
            if frames.size(1) == 1:
                frames = frames.repeat(1, 3, 1, 1)  # convert grayscale to RGB
            frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()  # [1, 480, 640, 3]
 
            yield frames[0]
 
 
def show_frame(frames):
    for frame in frames:
        cv2.imshow("capture", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            exit()
 
if __name__ == '__main__':
 
    # #-------测试摄像头是否可用------------#
    # for frame in get_frame(0):
    #     cv2.imshow("capture", frame)
    #     if cv2.waitKey(1) & 0xFF == ord('q'):
    #         break
    # #----------------------------------#
 
 
    #加载模型
    # model1 = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
    # model1.load_state_dict(torch.load('Models/rvm_mobilenetv3.pth'))
 
 
    model1 = MattingNetwork('resnet50').eval().cuda()  # or "resnet50"
    model1.load_state_dict(torch.load('Models/rvm_resnet50.pth'))
 
    # 返回测试结果
    frames = convert_video(model1)
 
    # 展示推理结果
    show_frame(frames)

四、效果展示

五、总结