-
Notifications
You must be signed in to change notification settings - Fork 8
add page 6.4 and image #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
quxiao
wants to merge
2
commits into
supremind:master
Choose a base branch
from
quxiao:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# 使用训练的模型运行推理服务 | ||
|
||
## 下载模型 | ||
|
||
在之前创建的训练进行过程中,经过一些 epoch 会生成模型,训练代码中的 AVA-SDK 会将这些模型上传至 AVA 平台。用户在 "模型列表 -> 训练产生模型" 的页面中,可以查看到对应的训练任务,以及训练任务产生的模型列表,如下图: | ||
|
||
 | ||
|
||
点击右侧的 "下载",即可将训练产生的模型下载至本地。 | ||
|
||
```bash | ||
➜ ll models | ||
total 6256 | ||
-rw-r--r--@ 1 quxiao staff 2.7M 7 1 10:35 5b337b3a5710da0001333594-snapshot-model_epoch_30.model.tar.gz | ||
``` | ||
|
||
解压之后,就可以看到模型对应的网络结构和权重文件了。 | ||
|
||
```bash | ||
➜ tar -zxvf 5b337b3a5710da0001333594-snapshot-model_epoch_30.model.tar.gz | ||
x snapshot-symbol.json | ||
x snapshot-0030.params | ||
➜ ll | ||
total 12464 | ||
-rw-r--r--@ 1 quxiao staff 2.7M 7 1 10:35 5b337b3a5710da0001333594-snapshot-model_epoch_30.model.tar.gz | ||
-rw-r--r--@ 1 quxiao staff 2.9M 6 27 20:24 snapshot-0030.params | ||
-rw-r--r--@ 1 quxiao staff 105K 6 27 20:24 snapshot-symbol.json | ||
``` | ||
|
||
## 编写推理代码 | ||
|
||
接着,我们可以根据 mxnet 提供的教程,对上面的推理代码进行一些修改。demo 代码如下: | ||
|
||
```python | ||
import os | ||
import base64 | ||
import mxnet as mx | ||
import numpy as np | ||
from collections import namedtuple | ||
from bottle import route, run | ||
import urllib | ||
from urlparse import urlparse | ||
|
||
Batch = namedtuple('Batch', ['data']) | ||
|
||
|
||
class ImagePredictor (object): | ||
def __init__(self): | ||
self.Batch = namedtuple('Batch', ['data']) | ||
self.model_dir = "model" | ||
self.data_dir = "data" | ||
self.mod = None | ||
self.labels = [] | ||
self.symbol_filename = 'deploy.symbol.json' | ||
self.weight_filename = 'weight.params' | ||
self.label_filename = 'labels.csv' | ||
|
||
def load_model(self, custom_model_dir=None, symbol_fn=None, weight_fn=None, label_fn=None): | ||
if custom_model_dir != None: | ||
self.model_dir = custom_model_dir | ||
if symbol_fn != None: | ||
self.symbol_filename = symbol_fn | ||
if weight_fn != None: | ||
self.weight_filename = weight_fn | ||
if label_fn != None: | ||
self.label_filename = label_fn | ||
|
||
print "set cpu/gpu model" | ||
# set the context on CPU, switch to GPU if there is one available | ||
ctx = mx.cpu() | ||
|
||
print "loading model" | ||
sym, arg_params, aux_params, labels = self._load_model() | ||
self.mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None) | ||
self.mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], | ||
label_shapes=self.mod._label_shapes) | ||
self.mod.set_params(arg_params, aux_params, allow_missing=True) | ||
self.labels = labels | ||
|
||
def _load_model(self): | ||
sym = mx.symbol.load(os.path.join(self.model_dir, self.symbol_filename)) | ||
save_dict = mx.ndarray.load(os.path.join(self.model_dir, self.weight_filename)) | ||
arg_params = {} | ||
aux_params = {} | ||
for k, v in save_dict.items(): | ||
tp, name = k.split(':', 1) | ||
if tp == 'arg': | ||
arg_params[name] = v | ||
if tp == 'aux': | ||
aux_params[name] = v | ||
|
||
labels = [] | ||
with open(os.path.join(self.model_dir, self.label_filename), 'r') as f: | ||
labels = [l.rstrip() for l in f] | ||
|
||
return sym, arg_params, aux_params, labels | ||
|
||
def get_image(self, uri): | ||
# download and show the image | ||
# fname = mx.test_utils.download(url, dirname=self.data_dir) | ||
# img = mx.image.imread(fname) | ||
content = urllib.urlopen(uri).read() | ||
img = mx.image.imdecode(content) | ||
if img is None: | ||
return None | ||
# convert into format (batch, RGB, width, height) | ||
img = mx.image.imresize(img, 224, 224) # resize | ||
img = img.transpose((2, 0, 1)) # Channel first | ||
img = img.expand_dims(axis=0) # batchify | ||
return img | ||
|
||
def predict(self, uri, topN=5): | ||
img = self.get_image(uri) | ||
# compute the predict probabilities | ||
self.mod.forward(Batch([img])) | ||
prob = self.mod.get_outputs()[0].asnumpy() | ||
# print the top-5 | ||
prob = np.squeeze(prob) | ||
a = np.argsort(prob)[::-1] | ||
ret = [] | ||
for i in a[0:topN]: | ||
print('probability=%f, class=%s' %(prob[i], self.labels[i])) | ||
ret.append({'prob': float(prob[i]), 'label': self.labels[i]}) | ||
return ret | ||
|
||
``` | ||
|
||
另外,我们还需要准备一份分类新的文件 `label.csv`,包含了 cifar10 数据集中数据分类 ID 和名称的映射关系,文件可参考:https://github.com/quxiao/cifar10-example/blob/master/image-predict/models/ava-snapshot-model/labels.csv | ||
|
||
## 查看推理结果 | ||
|
||
之后,我们运行几张图片的推理: | ||
|
||
```python | ||
def main(): | ||
worker = predictor.ImagePredictor() | ||
worker.load_model( | ||
custom_model_dir='models/ava-snapshot-model', | ||
symbol_fn='snapshot-symbol.json', | ||
weight_fn='snapshot-0030.params', | ||
label_fn='labels.csv') | ||
worker.predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true') | ||
worker.predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true') | ||
``` | ||
|
||
就可以得到推理的结果: | ||
|
||
``` | ||
➜ python local.py | ||
set cpu/gpu model | ||
loading model | ||
[12:15:51] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v1.1.0. Attempting to upgrade... | ||
[12:15:51] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded! | ||
uri: https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true | ||
probability=0.533429, class=n01491361 tiger shark, Galeocerdo cuvieri | ||
probability=0.309057, class=n01494475 hammerhead, hammerhead shark | ||
probability=0.089500, class=n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias | ||
probability=0.042574, class=n01498041 stingray | ||
probability=0.024384, class=n01496331 electric ray, crampfish, numbfish, torpedo | ||
|
||
uri: https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true | ||
probability=0.656869, class=n01491361 tiger shark, Galeocerdo cuvieri | ||
probability=0.165087, class=n01496331 electric ray, crampfish, numbfish, torpedo | ||
probability=0.063514, class=n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias | ||
probability=0.056640, class=n01494475 hammerhead, hammerhead shark | ||
probability=0.037440, class=n01498041 stingray | ||
``` | ||
|
||
PS:从结果可以看出,推理的效果与预期存在差距,这是由多方面原因造成的,例如训练数据集编码的参数、训练代码的超参等等。 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个示例能不能迁移到 qiniu-ava
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还需要把这一节加到 summary.md