diff --git a/solvers/u-net.py b/solvers/u-net.py index b3e9534..e665bd1 100644 --- a/solvers/u-net.py +++ b/solvers/u-net.py @@ -7,6 +7,26 @@ import deepinv as dinv +class UnetSR(torch.nn.Module): + def __init__(self, scale_factor): + super(UnetSR, self).__init__() + self.scale_factor = scale_factor + self.unet = dinv.models.UNet( + in_channels=3, out_channels=3*scale_factor**2, + scales=3, batch_norm=False + ) + self.pixel_shuffle = torch.nn.PixelShuffle(scale_factor) + self.conv = torch.nn.Conv2d( + 3, 3, kernel_size=3, padding=1, stride=1 + ) + + def forward(self, x_input, physics, **kwargs): + x = self.unet(x_input, physics, **kwargs) + x = self.pixel_shuffle(x) + x = self.conv(x) + return x + + class Solver(BaseSolver): name = 'UNet' @@ -24,6 +44,16 @@ def set_objective(self, train_dataset, physics): self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) + + # Determine scale factor for super-resolution tasks + hr, lr = next(iter(self.train_dataloader)) + _, _, h_lr, _ = lr.shape + _, _, h_hr, _ = hr.shape + if h_lr != h_hr: + self.scale_factor = h_hr // h_lr + else: + self.scale_factor = None + self.device = ( dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) @@ -32,9 +62,12 @@ def set_objective(self, train_dataset, physics): def run(self, n_iter): epochs = 4 - model = dinv.models.UNet( - in_channels=3, out_channels=3, scales=3, batch_norm=False - ).to(self.device) + if self.scale_factor is not None: + model = UnetSR(self.scale_factor).to(self.device) + else: + model = dinv.models.UNet( + in_channels=3, out_channels=3, scales=3, batch_norm=False + ).to(self.device) verbose = True # print training information wandb_vis = False # plot curves and images in Weight&Bias