-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata.py
55 lines (48 loc) · 1.66 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import os.path
from random import shuffle
from base import base
class DATA:
def __init__(self, N):
self.N = N
def gen(self, dataset, size):
judge = base(self.N)
judge.count = self.N**2
order = []
for i in range(self.N**2):
judge.board[i] = i % 2 + 1
order += [i]
datax = []
datay = []
for i in range(size):
shuffle(judge.board)
shuffle(order)
for k in range(self.N**2):
p = order[k]
if judge.board[p] == (self.N**2-1) % 2 + 1:
judge.unmake(p // self.N, p % self.N)
for k in range(self.N**2):
datax += [1 if judge.board[k] == 1 else 0]
datax += [1 if judge.board[k] == 2 else 0]
judge.make(p // self.N, p % self.N, (self.N**2-1) % 2 + 1)
datay += [(-judge.checkwin() + 1)/2]
break
datax = np.asarray(datax, dtype=np.int8)
datax = datax.reshape([-1, self.N, self.N, 2])
datay = np.asarray(datay, dtype=np.int8)
np.savez(dataset + "_x.npz", datax)
np.savez(dataset + "_y.npz", datay)
def load(self, dataset):
return np.load(dataset + "_x.npz")["arr_0"], np.load(dataset + "_y.npz")["arr_0"]
def load_trn(self):
return self.load("trn")
def load_vld(self):
return self.load("vld")
def load_tst(self):
return self.load("tst")
'''
data = DATA(9)
data.gen("trn", 800000)
data.gen("vld", 100000)
data.gen("tst", 100000)
'''