-
Notifications
You must be signed in to change notification settings - Fork 1
/
retrieve_prompt_pcl.py
122 lines (105 loc) · 4.45 KB
/
retrieve_prompt_pcl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import json
import os
import torch
from torch.utils.data import DataLoader
from pytorch_metric_learning.losses import NTXentLoss, CrossBatchMemory
from common.dataset import SHREC23_Test_TextData
from common.predict import predict
from pointcloud.dataset import SHREC23_PointCloudData_TextQuery, SHREC23_Test_PointCloudData_Objects
from pointcloud.curvenet import CurveNet
from common.models import BertExtractor, MLP
from common.test import test_loop
from common.train import train_loop
from pointcloud.pointmlp import PointMLP, PointMLPElite
from utils.plot_logs import plot_logs
'''
python retrieve_prompt_pcl.py \
--info-json exps/pcl_exp_0/args.json \
--pcl-model pointmlp \
--obj-data-path data/TextANIMAR2023/3D_Model_References/References \
--obj-csv-path data/TextANIMAR2023/3D_Model_References/References/References.csv \
--txt-csv-path data/TextANIMAR2023/Test/TextQuery_Test.csv \
--obj-weight exps/pcl_exp_0/weights/best_obj_embedder.pth \
--skt-weight pcl_exp_0/weights/best_query_embedder.pth
'''
parser = argparse.ArgumentParser()
parser.add_argument('--pcl-model', type=str,
default='curvenet', choices=['curvenet', 'pointmlp', 'pointmlpelite'], help='Model for point cloud feature extraction')
parser.add_argument('--info-json', type=str, required=True, help='Path to model infomation json')
parser.add_argument('--output-path', type=str, default='./prompt', help='Path to output folder')
parser.add_argument('--obj-data-path', type=str, required=True, help='Path to 3D objects folder')
parser.add_argument('--obj-csv-path', type=str, required=True, help='Path to CSV file of objects')
parser.add_argument('--txt-csv-path', type=str, help='Path to CSV file of prompts')
parser.add_argument('--obj-weight', type=str, required=True, help='Path to 3D object weight')
parser.add_argument('--txt-weight', type=str, required=True, help='Path to prompt weight')
args = parser.parse_args()
#Info json
with open(args.info_json) as json_file:
arg_dict = json.load(json_file)
batch_size = arg_dict['batch_size']
latent_dim = arg_dict['latent_dim']
# Initialize
## Model parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
obj_weight=args.obj_weight
query_weight=args.query_weight
## Storage
# output folder
output_path = args.output_path
if not os.path.exists(output_path):
os.makedirs(output_path)
folders = os.listdir(output_path)
new_id = 0
if len(folders) > 0:
for folder in folders:
if not folder.startswith('pcl_predict_'):
continue
new_id = max(new_id, int(folder.split('pcl_predict_')[-1]))
new_id += 1
output_path = os.path.join(output_path, f'pcl_predict_{new_id}')
os.makedirs(output_path)
# Load Model
## Get weight
### For Object Extraction
obj_state = torch.load(args.obj_weight)[0]
### For Text Extraction
query_state= torch.load(args.txt_weight)[0]
## Construct model
### For Object Extraction
if arg_dict['pcl_model'] == 'curvenet':
obj_extractor = CurveNet(device=device)
elif arg_dict['pcl_model'] == 'pointmlp':
obj_extractor = PointMLP(device=device)
elif arg_dict['pcl_model'] == 'pointmlpelite':
obj_extractor = PointMLPElite(device=device)
else:
raise NotImplementedError
## For Text Extraction
query_extractor = BertExtractor(arg_dict['text_model']) # OOM, so freeze for baseline
## Apply weights
### For Object Extraction
obj_embedder = MLP(obj_extractor, latent_dim=latent_dim)
obj_embedder.load_state_dict(obj_state)
obj_embedder = obj_embedder.to(device)
### For Text Extraction
query_embedder = MLP(query_extractor, latent_dim=latent_dim)
query_embedder.load_state_dict(query_state)
query_embedder = query_embedder.to(device)
# Load data
obj_ds = SHREC23_Test_PointCloudData_Objects(obj_data_path=args.obj_data_path,
csv_data_path=args.obj_csv_path)
txt_ds = SHREC23_Test_TextData(csv_data_path=args.txt_csv_path)
## Initialize dataloader
obj_dl = DataLoader(obj_ds, batch_size=batch_size,
shuffle=False, num_workers=arg_dict['num_workers'], collate_fn=obj_ds.collate_fn)
txt_dl = DataLoader(txt_ds, batch_size=batch_size,
shuffle=False, num_workers=arg_dict['num_workers'], collate_fn=txt_ds.collate_fn)
# Predict
predict(obj_embedder=obj_embedder, query_embedder=query_embedder,
obj_input='pointclouds', query_input='tokens',
obj_dl=obj_dl,
query_dl=txt_dl,
dimension=latent_dim,
output_path=output_path,
device=device)