Skip to content

Commit 28ab694

Browse files
committed
add tensorboard docs
1 parent e4147e2 commit 28ab694

File tree

5 files changed

+184
-11
lines changed

5 files changed

+184
-11
lines changed

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@ Intel® Low Precision Optimization Tool is an open-source python library which i
88
>
99
> GPU support is under development.
1010
11-
Currently supported Intel optimized DL frameworks are:
11+
Supported Intel optimized DL frameworks are:
1212
* [Tensorflow\*](https://www.tensorflow.org)
1313
* [PyTorch\*](https://pytorch.org/)
1414
* [Apache\* MXNet](https://mxnet.apache.org)
1515

16-
Currently supported tuning strategies are:
16+
Supported tuning strategies are:
1717
* [Basic](docs/introduction.md#basic-strategy)
1818
* [Bayesian](docs/introduction.md#bayesian-strategy)
1919
* [MSE](docs/introduction.md#mse-strategy)
2020
* [Exhaustive](docs/introduction.md#exhaustive-strategy)
2121
* [Random](docs/introduction.md#random-strategy)
22+
* [TPE](docs/tuning_strategy.md#TPE-strategy)
23+
24+
Mixed precision support:
25+
* [int8](docs/mixed_precision.md#int8)
26+
* [BFP16](docs/mixed_precision.md#BFP16)
2227

2328

2429
# Introduction
@@ -27,8 +32,8 @@ Currently supported tuning strategies are:
2732

2833
# Tutorials
2934
* [Hello World](examples/helloworld/README.md) demonstrates the simple steps to utilize Intel® Low Precision Optimization Tool for quanitzation, which can help you quick start with the tool.
30-
* [Tutorials](docs/README.md) provides
31-
comprehensive instructions of how to utilize diffrennt features of Intel® Low Precision Optimization Tool.
35+
* [Tutorials](docs/README.md) provides comprehensive instructions of how to utilize diffrennt features of Intel® Low Precision Optimization Tool.
36+
* [Features](docs/index.md) provides the introduction of features such as tuning strategy, QAT, pruning and so on.
3237
* [Examples](examples) is a tuning zoo to demonstrate the usage of Intel® Low Precision Optimization Tool in TensorFlow, PyTorch and MxNet for industry models of diffrent categories.
3338

3439
# Install from source
169 KB
Loading
237 KB
Loading
94.5 KB
Loading

docs/tensorboard.md

Lines changed: 175 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,193 @@
11
# Introduction
2-
(Introduce the concept and objective of the feature)
32

4-
PyTorch Tensorboard
3+
TensorBoard is a suite of web applications for inspecting and understanding your topology runs and graphs (see [TensorFlow TensorBoard](https://github.com/tensorflow/tensorboard) and [PyTorch TensorBoard](https://github.com/pytorch/pytorch/tree/master/torch/utils/tensorboard)). Intel® Low Precision Optimization Tool performs accuracy driven quantization, the tuning process will quantize the tensor, do graph transformation and optimization to achieve optimal performance under accuracy requirement. If you want to observe the behaviors of the optimizations, or you may want to find the reason why an accuracy target cannot be met, TensorBoard can provide you some valuable information.You can inspect the graph and tensor after each run of tuning and if a model cannot meet accuracy requirement user can analyze through the comparison of FP32 and int8 tensor histogram.
4+
5+
We collect the TensorBoard event summary during evaluation, the first time is on baseline FP32 model and later on at the end of each tuning runs based on quantized model. The TensorBoard log directory is named baseline_acc_<accuracy> and tune_<runs>_acc_<accuracy>, to indicate the stage and accuracy of the data is generated. User can select the data he or she has interest to observe with TensorBoard.
6+
7+
8+
PyTorch TensorBoard
59
================================
610
# Design
7-
(Introduce design philosophy and API)
11+
12+
The implementation of PyTorch TensorBoard basically have 3 steps:
13+
1. before evaluation in the _pre_eval_hook() instruments observers in the model;
14+
2. during evaluation the observers will collect tensor information in a dict data structure;
15+
3. after evaluation dump the graph and tensor information with TensorBoard summary writer in _post_eval_hook().
16+
17+
18+
The detailed algorithm can be described by the Pseudo code:
19+
```
20+
21+
def evaluate(self, model, dataloader, postprocess=None, \
22+
metric=None, measurer=None, iteration=-1, tensorboard=False):
23+
# The tensorboard summary is collected in the evaluation funciton of adapter
24+
25+
if tensorboard:
26+
model = self._pre_eval_hook(model)
27+
#evaluation code
28+
....
29+
acc = metric.result()
30+
if tensorboard:
31+
self._post_eval_hook(model, accuracy=acc, input=input)
32+
33+
def _pre_eval_hook(self, model):
34+
# Insert observer submodule into each module in whitelist in order to collect tensor information
35+
36+
class _RecordingObserver(ABC, torch.nn.Module):
37+
# Define the Observer class
38+
39+
def forward(self, x):
40+
# Record the tensor inforamtion in a dict structure
41+
self.output_tensors_dict[self.current_iter] = x.to("cpu")
42+
43+
@torch.jit.export
44+
def get_tensor_value(self):
45+
return self.output_tensors_dict
46+
47+
def _observer_forward_hook(module, input, output):
48+
#Forward hook that calls observer on the output
49+
return module.activation_post_process(output)
50+
51+
def _add_observer_(module, op_list=None, prefix=""):
52+
53+
#Add observer for each child module
54+
for name, child in module.named_children():
55+
_add_observer_(child, op_list, op_name)
56+
57+
if module is a leaf:
58+
module.add_module(
59+
'activation_post_process',
60+
module.qconfig.activation())
61+
module.register_forward_hook(_observer_forward_hook)
62+
63+
def _post_eval_hook(self, model, **args):
64+
# Dump tensor and graph information with TensorBoard summary writer
65+
if self.dump_times == 0:
66+
writer = SummaryWriter('runs/eval/baseline' +
67+
'_acc' + str(accuracy), model)
68+
else:
69+
writer = SummaryWriter('runs/eval/tune_' +
70+
str(self.dump_times) +
71+
'_acc' + str(accuracy), model)
72+
73+
if args is not None and 'input' in args and self.dump_times == 0:
74+
writer.add_graph(model, args['input'])
75+
76+
from torch.quantization import get_observer_dict
77+
get_observer_dict(model, observer_dict)
78+
for key in observer_dict:
79+
......
80+
op_name = key.strip(".activation_post_process")
81+
summary[op_name + ".output"] = observer_dict[key].get_tensor_value()
82+
83+
for iter in summary[op_name + ".output"]:
84+
#Record output tensor, for fused op only record the parent op output
85+
......
86+
if summary[op_name + ".output"][iter].is_quantized:
87+
writer.add_histogram(
88+
op + "/Output/int8",
89+
torch.dequantize(summary[op_name +
90+
".output"][iter]))
91+
else:
92+
writer.add_histogram(
93+
op + "/Output/fp32",
94+
summary[op_name + ".output"][iter])
95+
96+
state_dict = model.state_dict()
97+
for key in state_dict:
98+
# Record weight tensor, fused child tensorBoard tag will be merge
99+
if state_dict[key].is_quantized:
100+
writer.add_histogram(op + "/int8",
101+
torch.dequantize(state_dict[key]))
102+
else:
103+
writer.add_histogram(op + "/fp32", state_dict[key])
104+
105+
```
106+
8107

9108
# Usage
10109
(Introduce the usage method of the feature)
110+
1. Add "tensorboard: true" in yaml file.
111+
2. Run quantization tuning, a "./runs" folder will be generated in working folder.
112+
3. Start tensorboard:
113+
```
114+
tensorboard --bind_all --logdir_spec baseline:./runs/eval/tune_0_acc0.80,tune_1:././runs/eval/tune_1_acc0.79
115+
```
11116

12117
# Examples
13-
(Link to the example code)
118+
119+
```
120+
examples/pytorch/image_recognition/imagenet/cpu/ptq/run_tuning_dump_tensor.sh
121+
```
14122

15123
TensorFlow Tensorboard
16124
================================
17125
# Design
18-
(Introduce design philosophy and API)
126+
The implementation of TensorFlow TensorBoard basically have 4 steps:
127+
1. before evaluation we create the TensorBoard summary write and write graph, collect fp32 and node name for inspection and dump the histogram of weights and bias tensor directly from graph_def.
128+
2. Run get_tensor_by_name_with_import() to get data output tensors.
129+
3. Run session.run() to predict and get the inference result of the output tensor list collected in 2.
130+
4. Enumerate the output tensor and write histogram.
131+
132+
See ilit/adaptor/tensorflow.py evaluate() function for details.
19133

20134
# Usage
21-
(Introduce the usage method of the feature)
135+
136+
1. Add "tensorboard: true" in yaml file.
137+
2. Run quantization tuning, a "./runs" folder will be generated in working folder. For example:
138+
```
139+
ls ./runs/eval
140+
baseline_acc_0.776 tune_1_acc_0.095
141+
```
142+
The baseline_acc_0.776 folder contains the FP32 event log and 0.776 is the FP32 accuracy. tune_1_acc_0.095 contains the evaluation event log of the first run of tuning.
143+
3. Start tensorboard:
144+
```
145+
tensorboard --bind_all --logdir_spec baseline:./runs_v3/eval/baseline_acc_0.776/,tune_1:./runs_v3/eval/tune_1_acc_0.095/
146+
```
22147

23148
# Examples
24-
(Link to the example code)
149+
150+
151+
1. Add "tensorboard: true" into examples/tensorflow/image_recognition/inceptionv3.yaml. In order to demonstrate the usage of TensorBoard, pleae remove the following lines which is added to skip the quantization of 'v0/cg/conv0/conv2d/Conv2D' to avoid a known limitation.
152+
```
153+
op_wise: {
154+
'v0/cg/conv0/conv2d/Conv2D': {
155+
'activation': {'dtype': ['fp32']},
156+
}
157+
}
158+
```
159+
2. Run tuning:
160+
```
161+
bash run_tuning.sh --topology=inception_v3 --dataset_location=<imagenet> \
162+
--input_model=./inceptionv3_fp32_pretrained_model.pb --output_model=./ilit_inceptionv3.pb --config=./inceptionv3_dump_tensor.yaml
163+
```
164+
3. Start TensorBoard
165+
```
166+
tensorboard --bind_all --logdir_spec baseline:./runs_v3/eval/baseline_acc_0.776/,tune_1:./runs_v3/eval/tune_1_acc_0.095/
167+
```
168+
169+
4. In order to find the reason why tune_1 got so poor an accuracy, we can observe the TensorBoard.
170+
1). On the Graphs tab, select "baseline/." in "Run" box, find the first 'Conv2d' op after 'input' op, the op name is "v0/cg/conv0/Relu".
171+
172+
173+
<div align="left">
174+
<img src="imgs/tensorboard_baseline_v0_cg_conv0.png" width="700px" />
175+
</div>
176+
177+
2). On the Graphs tab, select "tune_1/." in "Run" box, find the first 'Cond2d' op after 'input' op, the tensor name is 'v0/cg/conv0/conv2d/Conv2D_eightbit_requantize'.
178+
179+
180+
<div align="left">
181+
<img src="imgs/tensorboard_tune_1_v0_cg_conv0.png" width="700px" />
182+
</div>
183+
184+
3). Switch to the Histograms tab, click op name 'v0/cg/conv0' in the search box, the TensorBoard will group the tensors with the same op name together, you can compare the tensor of baseline 'v0/cg/conv0/Relu' with the tensor of tune_1 'v0/cg/conv0/conv2d/Conv2D_eightbit_requantize_int8.output'. Please note the tensor name could be changed after quantization, so please group the tensor by op name and compare. From the chart we can see the histogram of the first conv2d output tensor are different. The issue is due to a known issue of TensorFlow. After filter the op 'v0/cg/conv0/conv2d/Conv2D' by adding "op_wise" in yaml file, the issue will disappear.
185+
186+
187+
<div align="left">
188+
<img src="imgs/tensorboard_v0_cg_conv0_histogram.png" width="700px" />
189+
</div>
190+
191+
192+
25193

0 commit comments

Comments
 (0)