github地址:https://github.com/PeterL1n/RobustVideoMatting
一、项目介绍:
二、环境安装
- 平台:windows 10
- 编译器:pycharm
- cuda 11.3
- cudnn 8.2.0.53
conda create -n RobustVideoMatting python=3.8
conda activate RobustVideoMatting
pip install av==8.0.3
pip install torch-1.11.0+cu113-cp38-cp38-win_amd64.whl
pip install torchvision-0.12.0+cu113-cp38-cp38-win_amd64.whl
pip install tqdm==4.61.1
pip install pims==0.5
三、执行代码
摄像头实时抠像
创建main.py
代码如下:
import cv2
import time
from torchvision import transforms
from typing import Optional, Tuple
import torch
from model import MattingNetwork
def auto_downsample_ratio(h, w):
"""
Automatically find a downsample ratio so that the largest side of the resolution be 512px.
"""
return min(512 / max(h, w), 1)
def get_frame(num):
cap = cv2.VideoCapture(num)
while True:
ret, frame = cap.read()
fps= cap.get(cv2.CAP_PROP_FPS)
print("摄像头帧速:", fps)
yield frame
def convert_video(model,
input_resize: Optional[Tuple[int, int]] = None,
downsample_ratio: Optional[float] = None,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):
"""
Args:
input_resize: If provided, the input are first resized to (w, h).
downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
device: Only need to manually provide if model is a TorchScript freezed model.
dtype: Only need to manually provide if model is a TorchScript freezed model.
"""
assert downsample_ratio is None or (
downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
# Initialize transform
if input_resize is not None:
transform = transforms.Compose([
transforms.Resize(input_resize[::-1]),
transforms.ToTensor()
])
else:
transform = transforms.ToTensor()
# Inference
print("------------------------------------------------------------>")
model = model.eval()
if device is None or dtype is None:
param = next(model.parameters())
dtype = param.dtype
device = param.device
bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)
with torch.no_grad():
rec = [None] * 4
for src in get_frame(0):
src = transform(src)
src = src.unsqueeze(0)
if downsample_ratio is None:
downsample_ratio = auto_downsample_ratio(*src.shape[2:])
src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W]
t1 = time.time()
fgr, pha, *rec = model(src, *rec, downsample_ratio)
print("frame_cost:", (time.time() - t1) / src.shape[1])
print("推理帧率:{:.2f}".format(1/((time.time() - t1) / src.shape[1])))
com = fgr * pha + bgr * (1 - pha)
frames = com[0]
if frames.size(1) == 1:
frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB
frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy() # [1, 480, 640, 3]
yield frames[0]
def show_frame(frames):
for frame in frames:
cv2.imshow("capture", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
exit()
if __name__ == '__main__':
# #-------测试摄像头是否可用------------#
# for frame in get_frame(0):
# cv2.imshow("capture", frame)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# #----------------------------------#
#加载模型
# model1 = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
# model1.load_state_dict(torch.load('Models/rvm_mobilenetv3.pth'))
model1 = MattingNetwork('resnet50').eval().cuda() # or "resnet50"
model1.load_state_dict(torch.load('Models/rvm_resnet50.pth'))
# 返回测试结果
frames = convert_video(model1)
# 展示推理结果
show_frame(frames)