Skip to content

Commit

Permalink
Updated RIFE-CUDA to support new 3.2-3.5 models with fallback for 3.0…
Browse files Browse the repository at this point in the history
…-3.1
  • Loading branch information
n00mkrad committed Jun 15, 2021
1 parent 532d556 commit 7abf45f
Show file tree
Hide file tree
Showing 14 changed files with 751 additions and 283 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ bld/
[Ll]ogs/
Flowframes*.7z
FF*.7z
Build/WebInstaller

# NMKD Python Redist Pkg
[Pp]y*/
Expand Down
11 changes: 5 additions & 6 deletions Pkgs/rife-cuda/model/IFNet_HD.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,9 @@ def __init__(self):
self.block2 = IFBlock(8, scale=2, c=96)
self.block3 = IFBlock(8, scale=1, c=48)

def forward(self, x, UHD=False):
if UHD:
x = F.interpolate(x, scale_factor=0.25, mode="bilinear", align_corners=False)
else:
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False)
def forward(self, x, scale=1.0):
x = F.interpolate(x, scale_factor=0.5 * scale, mode="bilinear",
align_corners=False)
flow0 = self.block0(x)
F1 = flow0
warped_img0 = warp(x[:, :3], F1)
Expand All @@ -111,6 +108,8 @@ def forward(self, x, UHD=False):
warped_img1 = warp(x[:, 3:], -F3)
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1))
F4 = (flow0 + flow1 + flow2 + flow3)
F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear",
align_corners=False) / scale
return F4, [F1, F2, F3, F4]

if __name__ == '__main__':
Expand Down
14 changes: 8 additions & 6 deletions Pkgs/rife-cuda/model/IFNet_HDv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,28 @@ def __init__(self):
self.block2 = IFBlock(10, scale=2, c=96)
self.block3 = IFBlock(10, scale=1, c=48)

def forward(self, x, UHD=False):
if UHD:
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
def forward(self, x, scale=1.0):
if scale != 1.0:
x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
flow0 = self.block0(x)
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2)
F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F3_large[:, :2])
warped_img1 = warp(x[:, 3:], F3_large[:, 2:4])
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1))
F4 = (flow0 + flow1 + flow2 + flow3)
if scale != 1.0:
F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale
return F4, [F1, F2, F3, F4]

if __name__ == '__main__':
Expand Down
138 changes: 87 additions & 51 deletions Pkgs/rife-cuda/model/IFNet_HDv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,80 +2,116 @@
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
from model.refine import *

def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes),
nn.PReLU(out_planes)
)

class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 2, 1),
conv(c, 2*c, 3, 2, 1),
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock0 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
conv(c, c),
conv(c, c)
)
self.convblock1 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
conv(c, c),
conv(c, c)
)
self.convblock2 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
conv(c, c),
conv(c, c)
)
self.convblock3 = nn.Sequential(
conv(c, c),
conv(c, c)
)
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(c, 4, 4, 2, 1),
)
self.conv1 = nn.ConvTranspose2d(2*c, 4, 4, 2, 1)
self.conv2 = nn.ConvTranspose2d(c, 1, 4, 2, 1)

def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * (1. / scale)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock0(x) + x
x = self.convblock1(x) + x
x = self.convblock2(x) + x
x = self.conv1(x)
flow = x
if scale != 1:
flow = F.interpolate(flow, scale_factor= scale, mode="bilinear", align_corners=False) * scale
return flow

def forward(self, x, flow, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
feat = self.conv0(torch.cat((x, flow), 1))
feat = self.convblock0(feat) + feat
feat = self.convblock1(feat) + feat
feat = self.convblock2(feat) + feat
feat = self.convblock3(feat) + feat
flow = self.conv1(feat)
mask = self.conv2(feat)
flow = F.interpolate(flow, scale_factor=scale*2, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale*2
mask = F.interpolate(mask, scale_factor=scale*2, mode="bilinear", align_corners=False, recompute_scale_factor=False)
return flow, mask

class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, c=80)
self.block1 = IFBlock(10, c=80)
self.block2 = IFBlock(10, c=80)
self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90)
self.block2 = IFBlock(7+4, c=90)
self.block_tea = IFBlock(10+4, c=90)
# self.contextnet = Contextnet()
# self.unet = Unet()

def forward(self, x, scale_list=[4,2,1]):
flow0 = self.block0(x, scale=scale_list[0])
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1), 1), F1_large, scale=scale_list[1])
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1), 1), F2_large, scale=scale_list[2])
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]
def forward(self, x, scale_list=[4, 2, 1], scale=1.0, training=False):
x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
if training == False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = torch.zeros_like(x[:, :4]).to(device)
mask = torch.zeros_like(x[:, :1]).to(device)
loss_cons = 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
mask = mask + (m0 + (-m1)) / 2
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
if scale != 1.0:
flow = F.interpolate(flow, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale
mask_list[2] = F.interpolate(mask_list[2], scale_factor=1 / scale, mode="bilinear", align_corners=False)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged[2] = (warped_img0, warped_img1)
'''
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, 1:4] * 2 - 1
'''
for i in range(3):
mask_list[i] = torch.sigmoid(mask_list[i])
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
# merged[i] = torch.clamp(merged[i] + res, 0, 1)
return flow_list, mask_list[2], merged
12 changes: 5 additions & 7 deletions Pkgs/rife-cuda/model/RIFE_HD.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, local_rank=-1):
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
Expand Down Expand Up @@ -188,11 +188,9 @@ def save_model(self, path, rank):
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))

def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow)
c1 = self.contextnet(img1, -flow)
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
Expand All @@ -209,10 +207,10 @@ def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
else:
return pred

def inference(self, img0, img1, UHD=False):
def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs, UHD)
return self.predict(imgs, flow, training=False, UHD=UHD)
flow, _ = self.flownet(imgs, scale)
return self.predict(imgs, flow, training=False)

def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
Expand Down
12 changes: 5 additions & 7 deletions Pkgs/rife-cuda/model/RIFE_HDv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, local_rank=-1):
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
Expand Down Expand Up @@ -173,11 +173,9 @@ def save_model(self, path, rank):
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))

def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
Expand All @@ -194,10 +192,10 @@ def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
else:
return pred

def inference(self, img0, img1, UHD=False):
def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs, UHD)
return self.predict(imgs, flow, training=False, UHD=UHD)
flow, _ = self.flownet(imgs, scale)
return self.predict(imgs, flow, training=False)

def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
Expand Down
Loading

0 comments on commit 7abf45f

Please sign in to comment.