Skip to content

Commit 5f9fd8d

Browse files
authored
Reproduce the table 2 of paper about FiRe
https://arxiv.org/pdf/2411.18970
2 parents c3dc41f + 7d4b2f7 commit 5f9fd8d

File tree

8 files changed

+329
-15
lines changed

8 files changed

+329
-15
lines changed

benchmark_utils/image_dataset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
import os
2+
import random
3+
24
from torch.utils.data import Dataset
35
from typing import Callable
46
from PIL import Image
57

68

79
class ImageDataset(Dataset):
8-
def __init__(self, folder: str, transform: Callable = None) -> None:
10+
def __init__(self,
11+
folder: str,
12+
transform: Callable = None,
13+
num_images=None):
914
self.folder = folder
1015
self.transform = transform
1116
self.files = [f for f in os.listdir(folder) if f.endswith((
1217
'.png', '.jpg', '.jpeg'))]
1318

19+
if num_images is not None:
20+
self.files.sort()
21+
self.files = random.sample(self.files, num_images)
22+
1423
def __len__(self):
1524
return len(self.files)
1625

datasets/bsd500_bsd20.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from benchopt import BaseDataset, safe_import_context, config
2+
3+
with safe_import_context() as import_ctx:
4+
import deepinv as dinv
5+
import torch
6+
from torchvision import transforms
7+
from benchmark_utils.image_dataset import ImageDataset
8+
from deepinv.physics import Downsampling, Denoising, GaussianNoise
9+
from deepinv.physics.generator import MotionBlurGenerator
10+
11+
12+
class Dataset(BaseDataset):
13+
14+
name = "BSD500_BSD20"
15+
16+
parameters = {
17+
'task': ['denoising',
18+
'gaussian-debluring',
19+
'motion-debluring',
20+
'SRx4'],
21+
'img_size': [256],
22+
}
23+
24+
requirements = ["datasets"]
25+
26+
def get_data(self):
27+
# TODO: Remove
28+
device = (
29+
dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu"
30+
31+
n_channels = 3
32+
33+
if self.task == "denoising":
34+
noise_level_img = 0.03
35+
physics = Denoising(GaussianNoise(sigma=noise_level_img))
36+
elif self.task == "gaussian-debluring":
37+
filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3))
38+
noise_level_img = 0.03
39+
n_channels = 3
40+
41+
physics = dinv.physics.BlurFFT(
42+
img_size=(n_channels, self.img_size, self.img_size),
43+
filter=filter_torch,
44+
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
45+
device=device
46+
)
47+
elif self.task == "motion-debluring":
48+
psf_size = 31
49+
n_channels = 3
50+
motion_generator = MotionBlurGenerator(
51+
(psf_size, psf_size),
52+
device=device
53+
)
54+
55+
filters = motion_generator.step(batch_size=1)
56+
57+
physics = dinv.physics.BlurFFT(
58+
img_size=(n_channels, self.img_size, self.img_size),
59+
filter=filters["filter"],
60+
device=device
61+
)
62+
elif self.task == "SRx4":
63+
physics = Downsampling(img_size=(n_channels,
64+
self.img_size,
65+
self.img_size),
66+
filter="bicubic",
67+
factor=4,
68+
device=device)
69+
else:
70+
raise Exception("Unknown task")
71+
72+
transform = transforms.Compose([
73+
transforms.Resize((self.img_size, self.img_size)),
74+
transforms.ToTensor()
75+
])
76+
77+
train_dataset = ImageDataset(
78+
config.get_data_path("BSD500") / "train",
79+
transform=transform
80+
)
81+
82+
test_dataset = ImageDataset(
83+
config.get_data_path("BSD500") / "val",
84+
transform=transform,
85+
num_images=20
86+
)
87+
88+
dinv_dataset_path = dinv.datasets.generate_dataset(
89+
train_dataset=train_dataset,
90+
test_dataset=test_dataset,
91+
physics=physics,
92+
save_dir=config.get_data_path(
93+
key="generated_datasets"
94+
) / "bsd500_bsd20",
95+
dataset_filename=self.task,
96+
device=device
97+
)
98+
99+
train_dataset = dinv.datasets.HDF5Dataset(
100+
path=dinv_dataset_path, train=True
101+
)
102+
test_dataset = dinv.datasets.HDF5Dataset(
103+
path=dinv_dataset_path, train=False
104+
)
105+
106+
x, y = train_dataset[0]
107+
dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)])
108+
109+
x, y = test_dataset[0]
110+
dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)])
111+
112+
return dict(
113+
train_dataset=train_dataset,
114+
test_dataset=test_dataset,
115+
physics=physics,
116+
dataset_name="BSD68",
117+
task_name=self.task
118+
)

datasets/bsd500_cbsd68.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,19 @@
99
from benchmark_utils.hugging_face_torch_dataset import (
1010
HuggingFaceTorchDataset
1111
)
12-
from deepinv.physics import Denoising, GaussianNoise
12+
from deepinv.physics import Denoising, GaussianNoise, Downsampling
13+
from deepinv.physics.generator import MotionBlurGenerator
1314

1415

1516
class Dataset(BaseDataset):
1617

1718
name = "BSD500_CBSD68"
1819

1920
parameters = {
20-
'task': ['denoising', 'debluring'],
21+
'task': ['denoising',
22+
'gaussian-debluring',
23+
'motion-debluring',
24+
'SRx4'],
2125
'img_size': [256],
2226
}
2327

@@ -31,17 +35,40 @@ def get_data(self):
3135
if self.task == "denoising":
3236
noise_level_img = 0.03
3337
physics = Denoising(GaussianNoise(sigma=noise_level_img))
34-
elif self.task == "debluring":
38+
elif self.task == "gaussian-debluring":
3539
filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3))
3640
noise_level_img = 0.03
37-
n_channels = 3 # 3 for color images, 1 for gray-scale images
41+
n_channels = 3
3842

3943
physics = dinv.physics.BlurFFT(
4044
img_size=(n_channels, self.img_size, self.img_size),
4145
filter=filter_torch,
4246
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
4347
device=device
4448
)
49+
elif self.task == "motion-debluring":
50+
psf_size = 31
51+
n_channels = 3
52+
motion_generator = MotionBlurGenerator(
53+
(psf_size, psf_size),
54+
device=device
55+
)
56+
57+
filters = motion_generator.step(batch_size=1)
58+
59+
physics = dinv.physics.BlurFFT(
60+
img_size=(n_channels, self.img_size, self.img_size),
61+
filter=filters["filter"],
62+
device=device
63+
)
64+
elif self.task == "SRx4":
65+
n_channels = 3
66+
physics = Downsampling(img_size=(n_channels,
67+
self.img_size,
68+
self.img_size),
69+
filter="bicubic",
70+
factor=4,
71+
device=device)
4572
else:
4673
raise Exception("Unknown task")
4774

datasets/bsd500_imnet100.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from benchopt import BaseDataset, safe_import_context, config
2+
3+
with safe_import_context() as import_ctx:
4+
import deepinv as dinv
5+
import torch
6+
from torchvision import transforms
7+
from benchmark_utils.image_dataset import ImageDataset
8+
from benchmark_utils.hugging_face_torch_dataset import (
9+
HuggingFaceTorchDataset
10+
)
11+
from deepinv.physics import Downsampling, Denoising, GaussianNoise
12+
from deepinv.physics.generator import MotionBlurGenerator
13+
from datasets import load_dataset
14+
15+
16+
class Dataset(BaseDataset):
17+
18+
name = "BSD500_imnet100"
19+
20+
parameters = {
21+
'task': ['denoising',
22+
'gaussian-debluring',
23+
'motion-debluring',
24+
'SRx4'],
25+
'img_size': [256],
26+
}
27+
28+
requirements = ["datasets"]
29+
30+
def get_data(self):
31+
# TODO: Remove
32+
device = (
33+
dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu"
34+
35+
if self.task == "denoising":
36+
noise_level_img = 0.03
37+
physics = Denoising(GaussianNoise(sigma=noise_level_img))
38+
elif self.task == "gaussian-debluring":
39+
filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3))
40+
noise_level_img = 0.03
41+
n_channels = 3
42+
43+
physics = dinv.physics.BlurFFT(
44+
img_size=(n_channels, self.img_size, self.img_size),
45+
filter=filter_torch,
46+
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
47+
device=device
48+
)
49+
elif self.task == "motion-debluring":
50+
psf_size = 31
51+
n_channels = 3
52+
motion_generator = MotionBlurGenerator(
53+
(psf_size, psf_size),
54+
device=device
55+
)
56+
57+
filters = motion_generator.step(batch_size=1)
58+
59+
physics = dinv.physics.BlurFFT(
60+
img_size=(n_channels, self.img_size, self.img_size),
61+
filter=filters["filter"],
62+
device=device
63+
)
64+
elif self.task == "SRx4":
65+
n_channels = 3
66+
physics = Downsampling(img_size=(n_channels,
67+
self.img_size,
68+
self.img_size),
69+
filter="bicubic",
70+
factor=4,
71+
device=device)
72+
else:
73+
raise Exception("Unknown task")
74+
75+
transform = transforms.Compose([
76+
transforms.Resize((self.img_size, self.img_size)),
77+
transforms.ToTensor()
78+
])
79+
80+
train_dataset = ImageDataset(
81+
config.get_data_path("BSD500") / "train",
82+
transform=transform
83+
)
84+
85+
dataset_miniImnet100 = load_dataset("mterris/miniImnet100")
86+
test_dataset = HuggingFaceTorchDataset(
87+
dataset_miniImnet100["validation"],
88+
key="image",
89+
transform=transform
90+
)
91+
92+
dinv_dataset_path = dinv.datasets.generate_dataset(
93+
train_dataset=train_dataset,
94+
test_dataset=test_dataset,
95+
physics=physics,
96+
save_dir=config.get_data_path(
97+
key="generated_datasets"
98+
) / "bsd500_imnet100",
99+
dataset_filename=self.task,
100+
device=device
101+
)
102+
103+
train_dataset = dinv.datasets.HDF5Dataset(
104+
path=dinv_dataset_path, train=True
105+
)
106+
test_dataset = dinv.datasets.HDF5Dataset(
107+
path=dinv_dataset_path, train=False
108+
)
109+
110+
x, y = train_dataset[0]
111+
dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)])
112+
113+
x, y = test_dataset[0]
114+
dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)])
115+
116+
return dict(
117+
train_dataset=train_dataset,
118+
test_dataset=test_dataset,
119+
physics=physics,
120+
dataset_name="BSD68",
121+
task_name=self.task
122+
)

datasets/cbsd68_set3c.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
from benchmark_utils.hugging_face_torch_dataset import (
99
HuggingFaceTorchDataset
1010
)
11-
from deepinv.physics import Denoising, GaussianNoise
11+
from deepinv.physics import Denoising, GaussianNoise, Downsampling
12+
from deepinv.physics.generator import MotionBlurGenerator
1213

1314

1415
class Dataset(BaseDataset):
1516

1617
name = "CBSD68_Set3c"
1718

1819
parameters = {
19-
'task': ['denoising', 'debluring'],
20+
'task': ['denoising',
21+
'gaussian-debluring',
22+
'motion-debluring',
23+
'SRx4'],
2024
'img_size': [256],
2125
}
2226

@@ -31,17 +35,40 @@ def get_data(self):
3135
if self.task == "denoising":
3236
noise_level_img = 0.03
3337
physics = Denoising(GaussianNoise(sigma=noise_level_img))
34-
elif self.task == "debluring":
38+
elif self.task == "gaussian-debluring":
3539
filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3))
3640
noise_level_img = 0.03
37-
n_channels = 3 # 3 for color images, 1 for gray-scale images
41+
n_channels = 3
3842

3943
physics = dinv.physics.BlurFFT(
4044
img_size=(n_channels, self.img_size, self.img_size),
4145
filter=filter_torch,
4246
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
4347
device=device
4448
)
49+
elif self.task == "motion-debluring":
50+
psf_size = 31
51+
n_channels = 3
52+
motion_generator = MotionBlurGenerator(
53+
(psf_size, psf_size),
54+
device=device
55+
)
56+
57+
filters = motion_generator.step(batch_size=1)
58+
59+
physics = dinv.physics.BlurFFT(
60+
img_size=(n_channels, self.img_size, self.img_size),
61+
filter=filters["filter"],
62+
device=device
63+
)
64+
elif self.task == "SRx4":
65+
n_channels = 3
66+
physics = Downsampling(img_size=(n_channels,
67+
self.img_size,
68+
self.img_size),
69+
filter="bicubic",
70+
factor=4,
71+
device=device)
4572
else:
4673
raise Exception("Unknown task")
4774

0 commit comments

Comments
 (0)