Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#643 from evgeny-izutov/feature/ei/…
Browse files Browse the repository at this point in the history
…aslnet

Add ASL Recognition models description and Python Demo
  • Loading branch information
Roman Donchenko authored Dec 16, 2019
2 parents d9ee546 + 975d410 commit 955ec18
Show file tree
Hide file tree
Showing 25 changed files with 1,345 additions and 2 deletions.
84 changes: 84 additions & 0 deletions demos/python_demos/asl_recognition_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# ASL Recognition Python* Demo

This demo demonstrates how to run ASL (American Sign Language) Recognition models using OpenVINO™ toolkit.

## How It Works

The demo application expects an ASL recognition model in the Intermediate Representation (IR) format.

As input, the demo application takes:
* a path to a video file or a device node of a web-camera specified with a command line argument `--input`
* a path to a file in JSON format with ASL class names `--class_map`

The demo workflow is the following:

1. The demo application reads video frames one by one, runs person detector that extracts ROI, tracks the ROI of very first person. Additional process is used to prepare the batch of frames with constant framerate.
2. Batch of frames and extracted ROI are passed to artificial neural network that predicts the ASL gesture.
3. The app visualizes results of its work as graphical window where following objects are shown:
- Input frame with detected ROI.
- Last recognized ASL gesture.
- Performance characteristics.

> **NOTE**: By default, Open Model Zoo demos expect input with BGR channels order. If you trained your model to work with RGB order, you need to manually rearrange the default channels order in the demo application or reconvert your model using the Model Optimizer tool with `--reverse_input_channels` argument specified. For more information about the argument, refer to **When to Reverse Input Channels** section of [Converting a Model Using General Conversion Parameters](https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_Converting_Model_General.html).
## Running

Run the application with the `-h` option to see the following usage message:

```
usage: asl_recognition_demo.py [-h] -m_a ACTION_MODEL -m_d DETECTION_MODEL -i
INPUT -c CLASS_MAP [-d DEVICE]
[-l CPU_EXTENSION] [--no_show]
Options:
-h, --help Show this help message and exit.
-m_a ACTION_MODEL, --action_model ACTION_MODEL
Required. Path to an .xml file with a trained asl
recognition model.
-m_d DETECTION_MODEL, --detection_model DETECTION_MODEL
Required. Path to an .xml file with a trained person
detector model.
-i INPUT, --input INPUT
Required. Path to a video file or a device node of a
web-camera.
-c CLASS_MAP, --class_map CLASS_MAP
Required. Path to a file with ASL classes.
-d DEVICE, --device DEVICE
Optional. Specify the target device to infer on: CPU,
GPU, FPGA, HDDL or MYRIAD. The demo will look for a
suitable plugin for device specified (by default, it
is CPU).
-l CPU_EXTENSION, --cpu_extension CPU_EXTENSION
Optional. Required for CPU custom layers. Absolute
path to a shared library with the kernels
implementations.
--no_show Optional. Do not visualize inference results.
```

Running the application with an empty list of options yields the short version of the usage message and an error message.

To run the demo, you can use public or pre-trained models. To download the pre-trained models, use the OpenVINO [Model Downloader](../../../tools/downloader/README.md) or go to [https://download.01.org/opencv/](https://download.01.org/opencv/).

> **NOTE**: Before running the demo with a trained model, make sure the model is converted to the Inference Engine format (`*.xml` + `*.bin`) using the [Model Optimizer tool](https://docs.openvinotoolkit.org/latest/_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html).
To run the demo, please provide paths to the ASL recognition and person detection models in the IR format, to a file with class names, and to an input video:
```bash
python asl_recognition_demo.py \
-m_a /home/user/asl-recognition-0003.xml \
-m_d /home/user/person-detection-asl-0001.xml \
-i 0 \
-c ./classes.json
```

An example of file with class names can be found [here](./classes.json).

## Demo Output

The application uses OpenCV to display ASL gesture recognition result and current inference performance.

![](./asl_recognition_demo.jpg)

## See Also
* [Using Open Model Zoo demos](../../README.md)
* [Model Optimizer](https://docs.openvinotoolkit.org/latest/_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html)
* [Model Downloader](../../../tools/downloader/README.md)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
163 changes: 163 additions & 0 deletions demos/python_demos/asl_recognition_demo/asl_recognition_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#!/usr/bin/env python
"""
Copyright (c) 2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging as log
import sys
import time
import json
from os.path import exists
from argparse import ArgumentParser, SUPPRESS

import cv2
import numpy as np

from asl_recognition_demo.common import load_ie_core
from asl_recognition_demo.video_stream import VideoStream
from asl_recognition_demo.person_detector import PersonDetector
from asl_recognition_demo.person_tracker import PersonTracker
from asl_recognition_demo.action_recognizer import ActionRecognizer

DETECTOR_OUTPUT_NAME = "12688/Split.0"
TRACKER_SCORE_THRESHOLD = 0.5
TRACKER_IOU_THRESHOLD = 0.5
ACTION_NET_INPUT_FPS = 15
ACTION_NUM_CLASSES = 100
ACTION_IMAGE_SCALE = 256
ACTION_SCORE_THRESHOLD = 0.8


def build_argparser():
""" Returns argument parser. """

parser = ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', default=SUPPRESS,
help='Show this help message and exit.')
args.add_argument('-m_a', '--action_model',
help='Required. Path to an .xml file with a trained asl recognition model.',
required=True, type=str)
args.add_argument('-m_d', '--detection_model',
help='Required. Path to an .xml file with a trained person detector model.',
required=True, type=str)
args.add_argument('-i', '--input',
help='Required. Path to a video file or a device node of a web-camera.',
required=True, type=str)
args.add_argument('-c', '--class_map',
help='Required. Path to a file with ASL classes.',
required=True, type=str)
args.add_argument('-d', '--device',
help='Optional. Specify the target device to infer on: CPU, GPU, FPGA, HDDL '
'or MYRIAD. The demo will look for a suitable plugin for device '
'specified (by default, it is CPU).',
default='CPU', type=str)
args.add_argument("-l", "--cpu_extension",
help="Optional. Required for CPU custom layers. Absolute path to "
"a shared library with the kernels implementations.", type=str,
default=None)
args.add_argument('--no_show', action='store_true',
help='Optional. Do not visualize inference results.')

return parser


def load_class_map(file_path):
""" Returns class names map. """

if file_path is not None and exists(file_path):
with open(file_path, 'r') as input_stream:
data = json.load(input_stream)
class_map = dict(enumerate(data))
else:
class_map = None

return class_map


def main():
""" Main function. """

log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()

ie_core = load_ie_core(args.device, args.cpu_extension)

person_detector = PersonDetector(args.detection_model, args.device, ie_core,
num_requests=2, output_name=DETECTOR_OUTPUT_NAME)
action_recognizer = ActionRecognizer(args.action_model, args.device, ie_core,
num_requests=2, img_scale=ACTION_IMAGE_SCALE,
num_classes=ACTION_NUM_CLASSES)

video_stream = VideoStream(args.input, ACTION_NET_INPUT_FPS, action_recognizer.input_length)
video_stream.start()

person_tracker = PersonTracker(person_detector, TRACKER_SCORE_THRESHOLD, TRACKER_IOU_THRESHOLD)

class_map = load_class_map(args.class_map)
assert class_map is not None

last_caption = None
person_roi = None

start_time = time.perf_counter()
while True:
frame = video_stream.get_live_frame()
batch = video_stream.get_batch()
if frame is None or batch is None:
break

person_roi = person_tracker.get_roi(frame)
if person_roi is not None:
recognizer_result = action_recognizer(batch, person_roi)
if recognizer_result is not None:
action_class_id = np.argmax(recognizer_result)
action_class_label = \
class_map[action_class_id] if class_map is not None else action_class_id

action_class_score = np.max(recognizer_result)
if action_class_score > ACTION_SCORE_THRESHOLD:
last_caption = 'Last: {} '.format(action_class_label)

end_time = time.perf_counter()
elapsed_time = end_time - start_time
start_time = end_time
current_fps = 1.0 / elapsed_time
cv2.putText(frame, 'FPS: {:.2f}'.format(current_fps), (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)

if last_caption is not None:
cv2.putText(frame, last_caption, (10, frame.shape[0] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

if person_roi is not None:
cv2.rectangle(frame, (person_roi[0], person_roi[1]),
(person_roi[2], person_roi[3]), (128, 128, 128), 1)

if args.no_show:
continue

cv2.imshow('Demo', frame)

key = cv2.waitKey(1)
if key == 27:
break

cv2.destroyAllWindows()
video_stream.release()


if __name__ == '__main__':
sys.exit(main() or 0)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright (c) 2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import cv2
import numpy as np

from asl_recognition_demo.common import IEModel


class ActionRecognizer(IEModel):
""" Class that is used to work with action recognition model. """

def __init__(self, model_path, device, ie_core, num_requests, img_scale, num_classes):
"""Constructor"""

super().__init__(model_path, device, ie_core, num_requests)

_, _, t, h, w = self.input_size
self.input_height = h
self.input_width = w
self.input_length = t

self.img_scale = img_scale
self.num_test_classes = num_classes

@staticmethod
def _convert_to_central_roi(src_roi, input_height, input_width, img_scale):
"""Extracts from the input ROI the central square part with specified side size"""

src_roi_height, src_roi_width = src_roi[3] - src_roi[1], src_roi[2] - src_roi[0]
src_roi_center_x = 0.5 * (src_roi[0] + src_roi[2])
src_roi_center_y = 0.5 * (src_roi[1] + src_roi[3])

height_scale = float(input_height) / float(img_scale)
width_scale = float(input_width) / float(img_scale)
assert height_scale < 1.0
assert width_scale < 1.0

min_roi_size = min(src_roi_height, src_roi_width)
trg_roi_height = int(height_scale * min_roi_size)
trg_roi_width = int(width_scale * min_roi_size)

trg_roi = [int(src_roi_center_x - 0.5 * trg_roi_width),
int(src_roi_center_y - 0.5 * trg_roi_height),
int(src_roi_center_x + 0.5 * trg_roi_width),
int(src_roi_center_y + 0.5 * trg_roi_height)]

return trg_roi

def _process_image(self, input_image, roi):
"""Converts input image according to model requirements"""

cropped_image = input_image[roi[1]:roi[3], roi[0]:roi[2]]
resized_image = cv2.resize(cropped_image, (self.input_width, self.input_height))
out_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
return out_image.transpose(2, 0, 1).astype(np.float32)

def _prepare_net_input(self, images, roi):
"""Converts input sequence of images into blob of data"""

data = np.stack([self._process_image(img, roi) for img in images], axis=0)
data = data.reshape((1,) + data.shape)
data = np.transpose(data, (0, 2, 1, 3, 4))
return data

def async_infer(self, frame_buffer, person_roi, req_id):
"""Requests model inference for the specified batch of images"""

central_roi = self._convert_to_central_roi(person_roi,
self.input_height, self.input_width,
self.img_scale)

clip_data = self._prepare_net_input(frame_buffer, central_roi)

super().async_infer(clip_data, req_id)

def wait_request(self, req_id):
"""Waits for the model output"""

result = super().wait_request(req_id)
if result is None:
return None
else:
return result[:self.num_test_classes]

def __call__(self, frame_buffer, person_roi):
"""Runs model on the specified input"""

central_roi = self._convert_to_central_roi(person_roi,
self.input_height, self.input_width,
self.img_scale)
clip_data = self._prepare_net_input(frame_buffer, central_roi)

result = self.infer(clip_data)

return result[:self.num_test_classes]
Loading

0 comments on commit 955ec18

Please sign in to comment.