Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions solvers/u-net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@
import deepinv as dinv


class UnetSR(torch.nn.Module):
def __init__(self, scale_factor):
super(UnetSR, self).__init__()
self.scale_factor = scale_factor
self.unet = dinv.models.UNet(
in_channels=3, out_channels=3*scale_factor**2,
scales=3, batch_norm=False
)
self.pixel_shuffle = torch.nn.PixelShuffle(scale_factor)
self.conv = torch.nn.Conv2d(
3, 3, kernel_size=3, padding=1, stride=1
)

def forward(self, x_input, physics, **kwargs):
x = self.unet(x_input, physics, **kwargs)
x = self.pixel_shuffle(x)
x = self.conv(x)
return x


class Solver(BaseSolver):

name = 'UNet'
Expand All @@ -24,6 +44,16 @@ def set_objective(self, train_dataset, physics):
self.train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=False
)

# Determine scale factor for super-resolution tasks
hr, lr = next(iter(self.train_dataloader))
_, _, h_lr, _ = lr.shape
_, _, h_hr, _ = hr.shape
if h_lr != h_hr:
self.scale_factor = h_hr // h_lr
else:
self.scale_factor = None

self.device = (
dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
)
Expand All @@ -32,9 +62,12 @@ def set_objective(self, train_dataset, physics):
def run(self, n_iter):
epochs = 4

model = dinv.models.UNet(
in_channels=3, out_channels=3, scales=3, batch_norm=False
).to(self.device)
if self.scale_factor is not None:
model = UnetSR(self.scale_factor).to(self.device)
else:
model = dinv.models.UNet(
in_channels=3, out_channels=3, scales=3, batch_norm=False
).to(self.device)

verbose = True # print training information
wandb_vis = False # plot curves and images in Weight&Bias
Expand Down
Loading