forked from XPixelGroup/HAT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
50 lines (44 loc) · 1.44 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import tempfile
import shutil
import os
from PIL import Image
import subprocess
from cog import BasePredictor, Input, Path
class Predictor(BasePredictor):
def predict(
self,
image: Path = Input(
description="Input Image.",
),
) -> Path:
input_dir = "input_dir"
output_path = Path(tempfile.mkdtemp()) / "output.png"
try:
for d in [input_dir, "results"]:
if os.path.exists(input_dir):
shutil.rmtree(input_dir)
os.makedirs(input_dir, exist_ok=False)
input_path = os.path.join(input_dir, os.path.basename(image))
shutil.copy(str(image), input_path)
subprocess.call(
[
"python",
"hat/test.py",
"-opt",
"options/test/HAT_SRx4_ImageNet-LR.yml",
]
)
res_dir = os.path.join(
"results", "HAT_SRx4_ImageNet-LR", "visualization", "custom"
)
assert (
len(os.listdir(res_dir)) == 1
), "Should contain only one result for Single prediction."
res = Image.open(os.path.join(res_dir, os.listdir(res_dir)[0]))
res.save(str(output_path))
finally:
pass
shutil.rmtree(input_dir)
shutil.rmtree("results")
return output_path