Skip to content

Commit 94f6da8

Browse files
committed
moving the model load option to global context
1 parent f8c80a9 commit 94f6da8

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

flask_pytorch_web_app/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from torchvision import models
66
from PIL import Image
77

8+
# Make sure to pass `pretrained` as `True` to use the pretrained weights:
9+
model = models.densenet121(pretrained=True)
10+
# Since we are using our model only for inference, switch to `eval` mode:
11+
model.eval()
12+
813

914
def transform_image(image_bytes):
1015
my_transforms = transforms.Compose([transforms.Resize(255),
@@ -26,11 +31,6 @@ def supported_image_type(img):
2631
def predict(image_file, class_file):
2732
try:
2833
class_file = getcwd() + '/flask_pytorch_web_app/' + class_file
29-
# Make sure to pass `pretrained` as `True` to use the pretrained weights:
30-
model = models.densenet121(pretrained=True)
31-
# Since we are using our model only for inference, switch to `eval` mode:
32-
model.eval()
33-
3434
imagenet_class_index = json.load(open(class_file))
3535
with open(image_file, 'rb') as f:
3636
image_bytes = f.read()
@@ -50,3 +50,4 @@ def test():
5050

5151
p = predict(img, cf)
5252
print(f'Given image is: {p[1]}')
53+

0 commit comments

Comments
 (0)