Skip to content

Commit

Permalink
Update Dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiacheng-WU committed Apr 5, 2021
1 parent 93f8217 commit 599dcc8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 28 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

## 安装环境

python 3 + torch + numpy
python 3 + torch + numpy + einops

```bash
pip install torch numpy
pip install torch numpy einops
```

## 使用方法
Expand Down
4 changes: 2 additions & 2 deletions compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def decompose(input_query, attr_num=10):
def compose(input_res, attr_num=10):
# the input result are 2**attr_num and each is related to the above generated queries
# input_res = np.random.random(2**attr_num)
res = 0.0
res = np.array([0.0])
for i in range(0, 2**attr_num):
count_one = np.sum(convert_int_to_bool_list(i))
sign = (-1)**count_one
res = res + sign * input_res[i]
# Due to precision error, sometimes res while be small negative, we fix it by max
return max(res, 0)
return np.maximum(res, np.array([0.0]))
7 changes: 2 additions & 5 deletions net.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
from shuffle import shuffle
from einops import rearrange


class AQPNet(nn.Module):
Expand Down Expand Up @@ -32,7 +29,7 @@ def forward(self, x):
x = self.conv2(x)
x = F.relu(x)
# the flatten should not participate on the dim 0, which is batch dim
x = torch.flatten(x, start_dim=1)
x = rearrange(x, 'b c h w -> b (c h w)')
x = self.fc1(x)
x = self.do1(x)
x = F.relu(x)
Expand Down
14 changes: 5 additions & 9 deletions query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
from einops import repeat
from net import AQPNet
from shuffle import shuffle, shuffle_batch
from shuffle import shuffle_batch
from compose import compose, decompose
from globals import *

Expand All @@ -25,12 +22,11 @@ def do_query(query):
# print(output_queries)
shuffle_output_queries = shuffle_batch(output_queries, ATTR_NUM, SHUFFLE_TIME)
tensor_queries = torch.from_numpy(np.array(shuffle_output_queries)).to(device=device, dtype=torch.float)
queries_size = list(tensor_queries.size())
queries_size.insert(1, 1)
tensor_queries = torch.reshape(tensor_queries, queries_size)
# Same Reason in train.py
tensor_queries = repeat(tensor_queries, 'b w h -> b c w h', c=1)
output_tensors = model(tensor_queries)
output_array = output_tensors.data.cpu().numpy()
output_array = np.reshape(output_array, output_array.size)
# output_array = np.reshape(output_array, output_array.size)
res = compose(output_array)
return res

Expand Down
21 changes: 11 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
import pandas as pd
from net import AQPNet
from shuffle import shuffle, shuffle_batch
from shuffle import shuffle_batch
from einops import repeat
from globals import *
# We assume the attribute is labeled from 0 or we need normalize
# We have 10 attributes base on Professor Wang Ying
Expand All @@ -18,13 +16,16 @@ def train(model, device, data, target, optimizer, batch_size):
for i in range(len(target) // batch_size):
batch_data = data[i*batch_size:(i+1)*batch_size]
batch_target = target[i*batch_size:(i+1)*batch_size]
batch_data, batch_target = torch.from_numpy(batch_data).to(device=device, dtype=torch.float), torch.from_numpy(batch_target).to(device=device, dtype=torch.float)
batch_data = torch.from_numpy(batch_data).to(device=device, dtype=torch.float)
batch_target = torch.from_numpy(batch_target).to(device=device, dtype=torch.float)
# In fact, we consider the input as the collection of 2D graphs, which is already 3D tensor
# However, when the Conv2D requires the following format of input [batch, channels, length, width]
# However, when the Conv2D requires the following format of input [batch, channels, width, height]
# We only have one channel, such we should add extra dim here.
batch_data_size = list(batch_data.size())
batch_data_size.insert(1, 1)
batch_data = torch.reshape(batch_data, batch_data_size)
# We use einops optimize the format to replace the follwing code
# batch_data_size = list(batch_data.size())
# batch_data_size.insert(1, 1)
# batch_data = torch.reshape(batch_data, batch_data_size)
batch_data = repeat(batch_data, 'b w h -> b c w h', c=1)

optimizer.zero_grad()
output = model(batch_data)
Expand All @@ -37,7 +38,7 @@ def train(model, device, data, target, optimizer, batch_size):

def process_train_set(train_sets, attr_num, shuffle_time):
train_sets = np.array(train_sets)
targets = train_sets[:, 0:1].T[0]
targets = train_sets[:, 0:1]
datas = train_sets[:, 1:]
datas = np.array(shuffle_batch(datas, attr_num, shuffle_time))
return datas, targets
Expand Down

0 comments on commit 599dcc8

Please sign in to comment.