generated from dogeplusplus/python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
70 lines (55 loc) · 1.97 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import cv2
import jax
import json
import pickle
import haiku as hk
import tensorflow as tf
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from pathlib import Path
from einops import rearrange
from argparse import ArgumentParser
from tempfile import TemporaryDirectory
from mlflow.tracking import MlflowClient
from models import VisionTransformer
def parse_arguments():
parser = ArgumentParser("Load Vision Transformer and perform inference.")
parser.add_argument("--experiment", type=str, default="cifar_100", help="Experiment name to extract")
parser.add_argument("--run", type=str, help="Run to load")
parser.add_argument("--image", type=str, help="Path to image")
args = parser.parse_args()
return args
def main():
args = parse_arguments()
client = MlflowClient()
with TemporaryDirectory() as temp_dir:
client.download_artifacts(args.run, "config.json", temp_dir)
temp_file = Path(temp_dir, "config.json")
with open(temp_file, "r") as f:
config = json.load(f)
def create_transformer(x):
return VisionTransformer(
config["k"],
config["heads"],
config["depth"],
config["num_classes"],
config["patch_size"],
config["image_size"],
config["dropout"],
)(x)
transformer = hk.transform(create_transformer)
with open("weights.pkl", "rb") as f:
params = pickle.load(f)
key = random.PRNGKey(0)
image = cv2.imread(args.image)
image = cv2.resize(image, config["image_size"])
image = jnp.array(rearrange(image, "h w c -> 1 h w c"))
image = tf.image.per_image_standardization(image).numpy()
_, rollout = transformer.apply(params, key, image)
rollout_resized = jax.image.resize(rollout[0], config["image_size"], method="linear")
plt.imshow(image[0])
plt.imshow(rollout_resized, cmap="jet", alpha=0.3)
plt.show()
if __name__ == "__main__":
main()