diff --git a/.gitignore b/.gitignore index f6b4615..723fe68 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,6 @@ benchopt.ini .DS_Store coverage.xml + +tmp +data/ \ No newline at end of file diff --git a/benchmark_utils/denoiser_2c.py b/benchmark_utils/denoiser_2c.py new file mode 100644 index 0000000..82ec22f --- /dev/null +++ b/benchmark_utils/denoiser_2c.py @@ -0,0 +1,26 @@ +import torch +from deepinv.models import DRUNet +from deepinv.models import Denoiser + + +class Denoiser_2c(Denoiser): + def __init__(self, device): + super(Denoiser_2c, self).__init__() + self.model_c1 = DRUNet( + in_channels=1, out_channels=1, + pretrained="download", device=device + ) + self.model_c2 = DRUNet( + in_channels=1, out_channels=1, + pretrained="download", device=device + ) + + def forward(self, y, sigma): + y1, y2 = torch.split(y, 1, dim=1) + + x_hat_1 = self.model_c1(y1, sigma=sigma) + x_hat_2 = self.model_c2(y2, sigma=sigma) + + x_hat = torch.cat([x_hat_1, x_hat_2], dim=1) + + return x_hat diff --git a/benchmark_utils/helper.py b/benchmark_utils/helper.py new file mode 100644 index 0000000..5b5b388 --- /dev/null +++ b/benchmark_utils/helper.py @@ -0,0 +1,78 @@ +from benchopt import safe_import_context + +with safe_import_context() as import_ctx: + import deepinv as dinv + from deepinv.physics import ( + Denoising, + GaussianNoise, + Downsampling, + Demosaicing, + BlurFFT, + Inpainting + ) + from deepinv.physics.generator import MotionBlurGenerator + + +DEVICE = None + + +def get_device(): + global DEVICE + if DEVICE is not None: + return DEVICE + if dinv.torch.cuda.is_available(): + DEVICE = dinv.utils.get_freer_gpu() + else: + DEVICE = "cpu" + return DEVICE + + +def get_task_physic(task, img_size, device): + if task == "denoising": + noise_level_img = 0.1 + physics = Denoising(GaussianNoise(sigma=noise_level_img)) + elif task == "gaussian-debluring": + filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) + noise_level_img = 0.03 + + physics = BlurFFT( + img_size=img_size, + filter=filter_torch, + noise_model=GaussianNoise(sigma=noise_level_img), + device=device + ) + elif task == "motion-debluring": + psf_size = 31 + motion_generator = MotionBlurGenerator( + (psf_size, psf_size), + device=device + ) + + filters = motion_generator.step(batch_size=1) + + physics = BlurFFT( + img_size=img_size, + filter=filters["filter"], + device=device + ) + elif task == "SRx4": + physics = Downsampling( + img_size=img_size, + filter="bicubic", + factor=4, + device=device + ) + elif task == "inpainting": + physics = Inpainting( + img_size, + mask=0.7, + device=device + ) + elif task == "demosaicing": + physics = Demosaicing( + img_size=img_size, + device=device + ) + else: + raise Exception("Unknown task") + return physics diff --git a/benchmark_utils/hugging_face_torch_dataset.py b/benchmark_utils/hugging_face_torch_dataset.py index 95edc10..9f251b9 100644 --- a/benchmark_utils/hugging_face_torch_dataset.py +++ b/benchmark_utils/hugging_face_torch_dataset.py @@ -2,19 +2,26 @@ class HuggingFaceTorchDataset(torch.utils.data.Dataset): - def __init__(self, hf_dataset, key, transform=None): + def __init__(self, hf_dataset, key, physics, device, transform=None): self.hf_dataset = hf_dataset self.transform = transform self.key = key + self.device = device + self.physics = physics def __len__(self): return len(self.hf_dataset) def __getitem__(self, idx): sample = self.hf_dataset[idx] - image = sample[self.key] # Image PIL + x = sample[self.key] # Image PIL if self.transform: - image = self.transform(image) + x = self.transform(x) - return image + x = x.to(self.device) + + y = self.physics(x.unsqueeze(0)) + y = y.squeeze(0) + + return x, y diff --git a/datasets/bsd500_cbsd68.py b/datasets/bsd500_cbsd68.py index bdc5e8a..4e3301f 100644 --- a/datasets/bsd500_cbsd68.py +++ b/datasets/bsd500_cbsd68.py @@ -3,14 +3,12 @@ with safe_import_context() as import_ctx: import deepinv as dinv - import torch from torchvision import transforms from datasets import load_dataset from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Denoising, GaussianNoise, Downsampling - from deepinv.physics.generator import MotionBlurGenerator + from benchmark_utils.helper import get_task_physic, get_device class Dataset(BaseDataset): @@ -18,59 +16,27 @@ class Dataset(BaseDataset): name = "BSD500_CBSD68" parameters = { - 'task': ['denoising', - 'gaussian-debluring', - 'motion-debluring', - 'SRx4'], + 'task': [ + 'denoising', + 'gaussian-debluring', + 'motion-debluring', + 'SRx4', + 'inpainting', + 'demosaicing' + ], 'img_size': [256], + 'batch_size': [2] } requirements = ["datasets"] def get_data(self): - # TODO: Remove - device = ( - dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu" + device = get_device() - if self.task == "denoising": - noise_level_img = 0.03 - physics = Denoising(GaussianNoise(sigma=noise_level_img)) - elif self.task == "gaussian-debluring": - filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) - noise_level_img = 0.03 - n_channels = 3 + n_channels = 3 + img_size = (n_channels, self.img_size, self.img_size) - physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), - filter=filter_torch, - noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), - device=device - ) - elif self.task == "motion-debluring": - psf_size = 31 - n_channels = 3 - motion_generator = MotionBlurGenerator( - (psf_size, psf_size), - device=device - ) - - filters = motion_generator.step(batch_size=1) - - physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), - filter=filters["filter"], - device=device - ) - elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), - filter="bicubic", - factor=4, - device=device) - else: - raise Exception("Unknown task") + physics = get_task_physic(self.task, img_size, device) transform = transforms.Compose([ transforms.Resize((self.img_size, self.img_size)), @@ -78,41 +44,32 @@ def get_data(self): ]) path = get_data_path("BSD500") - train_dataset = dinv.datasets.BSDS500( + bsd500_dataset = dinv.datasets.BSDS500( path, download=True, transform=transform ) + train_dataset = HuggingFaceTorchDataset( + bsd500_dataset, + key=..., + physics=physics, + device=device, + transform=transforms.Resize((self.img_size, self.img_size)) + ) dataset_cbsd68 = load_dataset("deepinv/CBSD68") test_dataset = HuggingFaceTorchDataset( - dataset_cbsd68["train"], key="png", transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, + dataset_cbsd68["train"], + key="png", physics=physics, - save_dir=get_data_path("bsd500_cbsd68"), - dataset_filename=self.task, - device=device - ) - - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=True + device=device, + transform=transform ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=False - ) - - x, y = train_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - - x, y = test_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) return dict( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, dataset_name="BSD68", - task_name=self.task + task_name=self.task, + image_size=img_size, + batch_size=self.batch_size ) diff --git a/datasets/bsd500_imnet100.py b/datasets/bsd500_imnet100.py index d0f6959..abcb77b 100644 --- a/datasets/bsd500_imnet100.py +++ b/datasets/bsd500_imnet100.py @@ -3,13 +3,11 @@ with safe_import_context() as import_ctx: import deepinv as dinv - import torch from torchvision import transforms from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Downsampling, Denoising, GaussianNoise - from deepinv.physics.generator import MotionBlurGenerator + from benchmark_utils.helper import get_task_physic, get_device from datasets import load_dataset @@ -18,59 +16,27 @@ class Dataset(BaseDataset): name = "BSD500_imnet100" parameters = { - 'task': ['denoising', - 'gaussian-debluring', - 'motion-debluring', - 'SRx4'], + 'task': [ + 'denoising', + 'gaussian-debluring', + 'motion-debluring', + 'SRx4', + 'inpainting', + 'demosaicing' + ], 'img_size': [256], + 'batch_size': [2] } requirements = ["datasets"] def get_data(self): - # TODO: Remove - device = ( - dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu" + device = get_device() - if self.task == "denoising": - noise_level_img = 0.03 - physics = Denoising(GaussianNoise(sigma=noise_level_img)) - elif self.task == "gaussian-debluring": - filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) - noise_level_img = 0.03 - n_channels = 3 + n_channels = 3 + img_size = (n_channels, self.img_size, self.img_size) - physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), - filter=filter_torch, - noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), - device=device - ) - elif self.task == "motion-debluring": - psf_size = 31 - n_channels = 3 - motion_generator = MotionBlurGenerator( - (psf_size, psf_size), - device=device - ) - - filters = motion_generator.step(batch_size=1) - - physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), - filter=filters["filter"], - device=device - ) - elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), - filter="bicubic", - factor=4, - device=device) - else: - raise Exception("Unknown task") + physics = get_task_physic(self.task, img_size, device) transform = transforms.Compose([ transforms.Resize((self.img_size, self.img_size)), @@ -78,43 +44,32 @@ def get_data(self): ]) path = get_data_path("BSD500") - train_dataset = dinv.datasets.BSDS500( + bsd500_dataset = dinv.datasets.BSDS500( path, download=True, transform=transform ) + train_dataset = HuggingFaceTorchDataset( + bsd500_dataset, + key=..., + physics=physics, + device=device, + transform=transforms.Resize((self.img_size, self.img_size)) + ) dataset_miniImnet100 = load_dataset("mterris/miniImnet100") test_dataset = HuggingFaceTorchDataset( dataset_miniImnet100["validation"], key="image", - transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, physics=physics, - save_dir=get_data_path("bsd500_imnet100"), - dataset_filename=self.task, - device=device - ) - - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=True - ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=False + device=device, + transform=transform ) - x, y = train_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - - x, y = test_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - return dict( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, dataset_name="BSD68", - task_name=self.task + task_name=self.task, + image_size=img_size, + batch_size=self.batch_size ) diff --git a/datasets/cbsd68_set3c.py b/datasets/cbsd68_set3c.py index fadd716..970ba96 100644 --- a/datasets/cbsd68_set3c.py +++ b/datasets/cbsd68_set3c.py @@ -1,16 +1,12 @@ from benchopt import BaseDataset, safe_import_context -from benchopt.config import get_data_path with safe_import_context() as import_ctx: - import deepinv as dinv - import torch from torchvision import transforms from datasets import load_dataset from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Denoising, GaussianNoise, Downsampling - from deepinv.physics.generator import MotionBlurGenerator + from benchmark_utils.helper import get_task_physic, get_device class Dataset(BaseDataset): @@ -18,60 +14,27 @@ class Dataset(BaseDataset): name = "CBSD68_Set3c" parameters = { - 'task': ['denoising', - 'gaussian-debluring', - 'motion-debluring', - 'SRx4'], + 'task': [ + 'denoising', + 'gaussian-debluring', + 'motion-debluring', + 'SRx4', + 'inpainting', + 'demosaicing' + ], 'img_size': [256], + 'batch_size': [2] } requirements = ["datasets"] def get_data(self): - # TODO: Remove - device = ( - dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" - ) - - if self.task == "denoising": - noise_level_img = 0.03 - physics = Denoising(GaussianNoise(sigma=noise_level_img)) - elif self.task == "gaussian-debluring": - filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) - noise_level_img = 0.03 - n_channels = 3 + device = get_device() - physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), - filter=filter_torch, - noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), - device=device - ) - elif self.task == "motion-debluring": - psf_size = 31 - n_channels = 3 - motion_generator = MotionBlurGenerator( - (psf_size, psf_size), - device=device - ) + n_channels = 3 + image_size = (n_channels, self.img_size, self.img_size) - filters = motion_generator.step(batch_size=1) - - physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), - filter=filters["filter"], - device=device - ) - elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), - filter="bicubic", - factor=4, - device=device) - else: - raise Exception("Unknown task") + physics = get_task_physic(self.task, image_size, device) transform = transforms.Compose([ transforms.Resize((self.img_size, self.img_size)), @@ -80,42 +43,28 @@ def get_data(self): dataset_CBSD68 = load_dataset("deepinv/CBSD68") train_dataset = HuggingFaceTorchDataset( - dataset_CBSD68["train"], key="png", transform=transform + dataset_CBSD68["train"], + key="png", + physics=physics, + device=device, + transform=transform ) dataset_Set3c = load_dataset("deepinv/set3c") test_dataset = HuggingFaceTorchDataset( - dataset_Set3c["train"], key="image", transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, + dataset_Set3c["train"], + key="image", physics=physics, - save_dir=get_data_path("cbsd68_set3c"), - dataset_filename=self.task, - device=device - ) - - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, - train=True + device=device, + transform=transform ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, - train=False - ) - - x, y = train_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - - x, y = test_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) return dict( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, dataset_name="Set3c", - task_name=self.task + task_name=self.task, + image_size=image_size, + batch_size=self.batch_size ) diff --git a/datasets/simulated.py b/datasets/simulated.py index 27befd3..87c78de 100644 --- a/datasets/simulated.py +++ b/datasets/simulated.py @@ -25,5 +25,7 @@ def get_data(self): test_dataset=test_dataset, physics=Denoising(GaussianNoise(sigma=0.03)), dataset_name="simulated", - task_name="test" + task_name="test", + image_size=(3, 32, 32), + batch_size=4 ) diff --git a/objective.py b/objective.py index 5ed8c3b..ffdcf3b 100644 --- a/objective.py +++ b/objective.py @@ -7,6 +7,9 @@ import torch from torch.utils.data import DataLoader import deepinv as dinv + import torch.nn.functional as F + from tqdm import tqdm + import time # The benchmark objective must be named `Objective` and @@ -43,7 +46,9 @@ def set_data(self, test_dataset, physics, dataset_name, - task_name): + task_name, + image_size, + batch_size): # The keyword arguments of this function are the keys of the dictionary # returned by `Dataset.get_data`. This defines the benchmark's # API to pass data. This is customizable for each benchmark. @@ -52,6 +57,8 @@ def set_data(self, self.physics = physics self.dataset_name = dataset_name self.task_name = task_name + self.image_size = image_size + self.batch_size = batch_size def evaluate_result(self, model, model_name, device): # The keyword arguments of this function are the keys of the @@ -59,49 +66,74 @@ def evaluate_result(self, model, model_name, device): # benchmark's API to pass solvers' result. This is customizable for # each benchmark. - batch_size = 2 test_dataloader = DataLoader( - self.test_dataset, batch_size=batch_size, shuffle=False + self.test_dataset, batch_size=self.batch_size, shuffle=False ) - if isinstance(model, dinv.models.DeepImagePrior): - psnr = [] - ssim = [] - lpips = [] + # DeepImagePrior use images one by one, thus we can't use dinv.test + # if isinstance(model, dinv.models.DeepImagePrior): + psnr = [] + ssim = [] + lpips = [] + times = [] - for x, y in test_dataloader: - x, y = x.to(device), y.to(device) - x_hat = torch.cat([ + for x, y in tqdm(test_dataloader, desc=f"Evaluating {model_name}"): + x, y = x.to(device), y.to(device) + + if isinstance(model, dinv.models.DeepImagePrior): + start = time.time() + x_hat = [ model(y_i[None], self.physics) for y_i in y - ]) - psnr.append(dinv.metric.PSNR()(x_hat, x)) - ssim.append(dinv.metric.SSIM()(x_hat, x)) - lpips.append(dinv.metric.LPIPS(device=device)(x_hat, x)) - - psnr = torch.mean(torch.cat(psnr)).item() - ssim = torch.mean(torch.cat(ssim)).item() - lpips = torch.mean(torch.cat(lpips)).item() - - results = dict(PSNR=psnr, SSIM=ssim, LPIPS=lpips) - else: - results = dinv.test( - model, - test_dataloader, - self.physics, - metrics=[dinv.metric.PSNR(), - dinv.metric.SSIM(), - dinv.metric.LPIPS(device=device)], - device=device - ) - - # This method can return many metrics in a dictionary. One of these - # metrics needs to be `value` for convergence detection purposes. - return dict( + ] + exec_time = time.time() - start + x_hat = torch.cat(x_hat) + else: + if ( + isinstance(self.physics, dinv.physics.blur.Downsampling) + and model_name == 'U-Net' + ): + _, _, x_h, x_w = x.shape + _, _, y_h, y_w = y.shape + + diff_h = x_h - y_h + diff_w = x_w - y_w + + pad_top = diff_h // 2 + pad_bottom = diff_h - pad_top + pad_left = diff_w // 2 + pad_right = diff_w - pad_left + + y = F.pad( + y, + pad=(pad_left, pad_right, pad_top, pad_bottom), + value=0 + ) + + start = time.time() + x_hat = model(y, self.physics) + exec_time = time.time() - start + + times.append(exec_time) + + psnr.append(dinv.metric.PSNR()(x_hat, x)) + ssim.append(dinv.metric.SSIM()(x_hat, x)) + lpips.append(dinv.metric.LPIPS(device=device)(x_hat, x)) + + psnr = torch.mean(torch.cat(psnr)).item() + times = torch.mean(torch.tensor(times)).item() + + results = dict(PSNR=psnr) + + results['Time'] = times + + values = dict( value=results["PSNR"], - ssim=results["SSIM"], - lpips=results["LPIPS"] ) + values['time'] = results["Time"] + + return values + def get_one_result(self): # Return one solution. The return value should be an object compatible # with `self.evaluate_result`. This is mainly for testing purposes. @@ -114,4 +146,10 @@ def get_objective(self): # for `Solver.set_objective`. This defines the # benchmark's API for passing the objective to the solver. # It is customizable for each benchmark. - return dict(train_dataset=self.train_dataset, physics=self.physics) + + return dict( + train_dataset=self.train_dataset, + physics=self.physics, + image_size=self.image_size, + batch_size=self.batch_size + ) diff --git a/solvers/diffpir.py b/solvers/diffpir.py index a3bcaf2..55b573d 100644 --- a/solvers/diffpir.py +++ b/solvers/diffpir.py @@ -1,9 +1,10 @@ from benchopt import BaseSolver, safe_import_context with safe_import_context() as import_ctx: - import torch from torch.utils.data import DataLoader import deepinv as dinv + from benchmark_utils.denoiser_2c import Denoiser_2c + from benchmark_utils.helper import get_device class Solver(BaseSolver): @@ -15,24 +16,30 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): - batch_size = 2 + def set_objective(self, train_dataset, physics, image_size, batch_size): self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) - self.device = ( - dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" - ) + self.device = get_device() self.physics = physics + self.image_size = image_size + def run(self, n_iter): - denoiser = dinv.models.DRUNet(pretrained="download").to(self.device) + if self.image_size[0] == 2: + denoiser = Denoiser_2c(device=self.device) + else: + denoiser = dinv.models.DRUNet( + pretrained="download", + device=self.device + ) self.model = dinv.sampling.DiffPIR( model=denoiser, data_fidelity=dinv.optim.data_fidelity.L2(), device=self.device ) + self.model.eval() def get_result(self): diff --git a/solvers/dip.py b/solvers/dip.py index 3a41c91..63ebd61 100644 --- a/solvers/dip.py +++ b/solvers/dip.py @@ -5,6 +5,7 @@ from torch.utils.data import DataLoader import deepinv as dinv import optuna + from benchmark_utils.helper import get_device class Solver(BaseSolver): @@ -16,34 +17,31 @@ class Solver(BaseSolver): requirements = ["optuna"] - def set_objective(self, train_dataset, physics): + def set_objective(self, train_dataset, physics, image_size, batch_size): self.train_dataset = train_dataset - batch_size = 32 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) - self.device = ( - dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" - ) + self.device = get_device() self.physics = physics.to(self.device) + self.image_size = image_size def run(self, n_iter): def objective(trial): lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True) iterations = trial.suggest_int('iterations', 50, 500, log=True) - # TODO: Remove - # iterations = 5 - model = self.get_model(lr, iterations) psnr = [] for x, y in self.train_dataloader: x, y = x.to(self.device), y.to(self.device) + x_hat = torch.cat([ model(y_i[None], self.physics) for y_i in y ]) + psnr.append(dinv.metric.PSNR()(x_hat, x)) psnr = torch.mean(torch.cat(psnr)).item() @@ -51,13 +49,11 @@ def objective(trial): return psnr study = optuna.create_study(direction='maximize') - study.optimize(objective, n_trials=1) + study.optimize(objective, n_trials=3) best_trial = study.best_trial best_params = best_trial.params - # TODO : replace 5 by best_params['iterations']) - # self.model = self.get_model(best_params['lr'], 5) self.model = self.get_model( best_params['lr'], best_params['iterations'] diff --git a/solvers/dpir.py b/solvers/dpir.py index 5669eed..3010eb0 100644 --- a/solvers/dpir.py +++ b/solvers/dpir.py @@ -5,6 +5,8 @@ from torch.utils.data import DataLoader import deepinv as dinv import numpy as np + from tqdm import tqdm + from benchmark_utils.helper import get_device class Solver(BaseSolver): @@ -16,35 +18,43 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): - batch_size = 2 + def set_objective(self, train_dataset, physics, image_size, batch_size): self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) - self.device = ( - dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" - ) + self.device = get_device() self.physics = physics + self.image_size = image_size def run(self, n_iter): best_sigma = 0 best_psnr = 0 + + # If the number of channels is different from 1 or 3 + # then we can't use pretrained DRUNet for sigma in np.linspace(0.01, 0.1, 10): model = dinv.optim.DPIR(sigma=sigma, device=self.device) - results = dinv.test( - model, + psnr = [] + + bar = tqdm( self.train_dataloader, - self.physics, - metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()], - device=self.device + desc="DPIR : Looking for the best sigma" ) + for x, y in bar: + x, y = x.to(self.device), y.to(self.device) + + x_hat = model(y, self.physics) + + psnr.append(dinv.metric.PSNR()(x_hat, x)) + + psnr = torch.mean(torch.cat(psnr)).item() - if results["PSNR"] > best_psnr: + if psnr > best_psnr: best_sigma = sigma - best_psnr = results["PSNR"] + best_psnr = psnr - self.model = dinv.optim.DPIR(sigma=best_sigma, device=self.device) + self.model = dinv.optim.DPIR(sigma=best_sigma, device=self.device) self.model.eval() def get_result(self): diff --git a/solvers/u-net.py b/solvers/u-net.py index b3e9534..84ba825 100644 --- a/solvers/u-net.py +++ b/solvers/u-net.py @@ -5,6 +5,7 @@ import torch from torch.utils.data import DataLoader import deepinv as dinv + from benchmark_utils.helper import get_device class Solver(BaseSolver): @@ -19,45 +20,40 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): - batch_size = 2 + def set_objective(self, train_dataset, physics, image_size, batch_size): self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) - self.device = ( - dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" - ) + self.device = get_device() self.physics = physics.to(self.device) + self.image_size = image_size def run(self, n_iter): epochs = 4 model = dinv.models.UNet( - in_channels=3, out_channels=3, scales=3, batch_norm=False + in_channels=3, out_channels=3, scales=4, + batch_norm=False ).to(self.device) - verbose = True # print training information - wandb_vis = False # plot curves and images in Weight&Bias - - # choose training losses - losses = dinv.loss.SupLoss(metric=dinv.metric.MSE()) - - # choose optimizer and scheduler optimizer = torch.optim.Adam( model.parameters(), lr=self.lr, weight_decay=1e-8 ) scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, step_size=int(epochs * 0.8) + optimizer, step_size=int(epochs * 0.7) ) + + criterion = dinv.loss.SupLoss(metric=dinv.metric.MSE()) + trainer = dinv.Trainer( model, device=self.device, - verbose=verbose, - wandb_vis=wandb_vis, + verbose=True, + wandb_vis=False, physics=self.physics, epochs=epochs, scheduler=scheduler, - losses=losses, + losses=criterion, optimizer=optimizer, show_progress_bar=True, train_dataloader=self.train_dataloader, @@ -67,4 +63,8 @@ def run(self, n_iter): self.model.eval() def get_result(self): - return dict(model=self.model, model_name="U-Net", device=self.device) + return dict( + model=self.model, + model_name=f"U-Net_{self.lr}", + device=self.device + )