-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgen_nn_graphviz_dot.py
84 lines (71 loc) · 2.44 KB
/
gen_nn_graphviz_dot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
architecture = (4,8,5,3)
#architecture = (3,5,2)
classifier = True
connection_lines = []
prev_layer_nodes = []
print r"digraph G{"
print r"rankdir=LR;"
print r"ranksep=1.0;"
print r"splines=line;"
print r"node [shape=circle,style=filled,color=black,fontcolor=white];"
# virtual input
this_layer_nodes = []
for i in range(architecture[0]):
node_str = "vi%d"%(i+1)
print "%s [shape=point, style=invis];"%node_str
this_layer_nodes.append(node_str)
print " {rank=source "+" ".join(this_layer_nodes)+"};"
prev_layer_nodes = this_layer_nodes
#input
this_layer_nodes = []
print r"subgraph cluster_input {"
print " style=invis;"
print " label=\"input\";"
for i in range(architecture[0]):
node_str = "i%d"%(i+1)
connection_lines.append("%s->%s"%(prev_layer_nodes[i], node_str))
print " "+node_str+";"
this_layer_nodes.append(node_str)
print " {rank=same "+" ".join(this_layer_nodes)+"};"
print r"}"
prev_layer_nodes = this_layer_nodes
# hidden
symbols = "abcdefghjklmn"
for hidden_layer_i, hidden_size in enumerate(architecture[1:-1]):
this_layer_nodes = []
print r"subgraph cluster_hidden%d {"%hidden_layer_i
print " style=invis;"
print " label=\"hidden%d\";"%hidden_layer_i
for i in range(hidden_size):
node_str = "%s%d"%(symbols[hidden_layer_i],i+1)
print " "+node_str+";"
for prev_l_node_str in prev_layer_nodes:
connection_lines.append( prev_l_node_str+"->"+node_str+";")
this_layer_nodes.append(node_str)
print " {rank=same "+" ".join(this_layer_nodes)+"};"
print r"}"
prev_layer_nodes = this_layer_nodes
#output
print r"subgraph cluster_output {"
print " style=invis;"
print " label=\"output\";"
this_layer_nodes = []
for i in range(architecture[-1]):
node_str = "o%d"%(i+1)
print " "+node_str+";"
for prev_l_node_str in prev_layer_nodes:
connection_lines.append( prev_l_node_str+"->"+node_str+";")
this_layer_nodes.append(node_str)
print " {rank=same "+" ".join(this_layer_nodes)+"};"
print r"}"
prev_layer_nodes = this_layer_nodes
if classifier:
print('argmax [label="arg-\nmax",shape=rect,style=solid,fontcolor=black];')
for out_node_str in prev_layer_nodes:
connection_lines.append( out_node_str+"->argmax;")
print("hidden [style=invis, shape=point];")
print(r'argmax->hidden [label="clas-\nsifi-\ncation"];')
# connections
for connection in connection_lines:
print connection
print r"}"