Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / ie_bridges / python / sample / greengrass_samples / greengrass_object_detection_sample_ssd.py
1 """
2 BSD 3-clause "New" or "Revised" license
3
4 Copyright (C) 2018 Intel Corporation.
5
6 Redistribution and use in source and binary forms, with or without
7 modification, are permitted provided that the following conditions are met:
8
9 * Redistributions of source code must retain the above copyright notice, this
10   list of conditions and the following disclaimer.
11
12 * Redistributions in binary form must reproduce the above copyright notice,
13   this list of conditions and the following disclaimer in the documentation
14   and/or other materials provided with the distribution.
15
16 * Neither the name of the copyright holder nor the names of its
17   contributors may be used to endorse or promote products derived from
18   this software without specific prior written permission.
19
20 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 """
31
32 import sys
33 import os
34 import cv2
35 import numpy as np
36 import greengrasssdk
37 import boto3
38 import timeit
39 import datetime
40 import json
41 from collections import OrderedDict
42
43 from openvino.inference_engine import IENetwork, IEPlugin
44
45 # Specify the delta in seconds between each report
46 reporting_interval = 1.0
47
48 # Parameters for IoT Cloud
49 enable_iot_cloud_output = True
50
51 # Parameters for Kinesis
52 enable_kinesis_output = False
53 kinesis_stream_name = ""
54 kinesis_partition_key = ""
55 kinesis_region = ""
56
57 # Parameters for S3
58 enable_s3_jpeg_output = False
59 s3_bucket_name = "ssd_test"
60
61 # Parameters for jpeg output on local disk
62 enable_local_jpeg_output = False
63
64 # Create a Greengrass Core SDK client for publishing messages to AWS Cloud
65 client = greengrasssdk.client("iot-data")
66
67 # Create an S3 client for uploading files to S3
68 if enable_s3_jpeg_output:
69     s3_client = boto3.client("s3")
70
71 # Create a Kinesis client for putting records to streams
72 if enable_kinesis_output:
73     kinesis_client = boto3.client("kinesis", "us-west-2")
74
75 # Read environment variables set by Lambda function configuration
76 PARAM_MODEL_XML = os.environ.get("PARAM_MODEL_XML")
77 PARAM_INPUT_SOURCE = os.environ.get("PARAM_INPUT_SOURCE")
78 PARAM_DEVICE = os.environ.get("PARAM_DEVICE")
79 PARAM_OUTPUT_DIRECTORY = os.environ.get("PARAM_OUTPUT_DIRECTORY")
80 PARAM_CPU_EXTENSION_PATH = os.environ.get("PARAM_CPU_EXTENSION_PATH")
81 PARAM_LABELMAP_FILE = os.environ.get("PARAM_LABELMAP_FILE")
82 PARAM_TOPIC_NAME = os.environ.get("PARAM_TOPIC_NAME", "intel/faas/ssd")
83
84
85 def report(res_json, frame):
86     now = datetime.datetime.now()
87     date_prefix = str(now).replace(" ", "_")
88     if enable_iot_cloud_output:
89         data = json.dumps(res_json)
90         client.publish(topic=PARAM_TOPIC_NAME, payload=data)
91     if enable_kinesis_output:
92         kinesis_client.put_record(StreamName=kinesis_stream_name, Data=json.dumps(res_json),
93                                   PartitionKey=kinesis_partition_key)
94     if enable_s3_jpeg_output:
95         temp_image = os.path.join(PARAM_OUTPUT_DIRECTORY, "inference_result.jpeg")
96         cv2.imwrite(temp_image, frame)
97         with open(temp_image) as file:
98             image_contents = file.read()
99             s3_client.put_object(Body=image_contents, Bucket=s3_bucket_name, Key=date_prefix + ".jpeg")
100     if enable_local_jpeg_output:
101         cv2.imwrite(os.path.join(PARAM_OUTPUT_DIRECTORY, date_prefix + ".jpeg"), frame)
102
103
104 def greengrass_object_detection_sample_ssd_run():
105     client.publish(topic=PARAM_TOPIC_NAME, payload="OpenVINO: Initializing...")
106     model_bin = os.path.splitext(PARAM_MODEL_XML)[0] + ".bin"
107
108     # Plugin initialization for specified device and load extensions library if specified
109     plugin = IEPlugin(device=PARAM_DEVICE, plugin_dirs="")
110     if "CPU" in PARAM_DEVICE:
111         plugin.add_cpu_extension(PARAM_CPU_EXTENSION_PATH)
112     # Read IR
113     net = IENetwork(model=PARAM_MODEL_XML, weights=model_bin)
114     assert len(net.inputs.keys()) == 1, "Sample supports only single input topologies"
115     assert len(net.outputs) == 1, "Sample supports only single output topologies"
116     input_blob = next(iter(net.inputs))
117     out_blob = next(iter(net.outputs))
118     # Read and pre-process input image
119     n, c, h, w = net.inputs[input_blob]
120     cap = cv2.VideoCapture(PARAM_INPUT_SOURCE)
121     exec_net = plugin.load(network=net)
122     del net
123     client.publish(topic=PARAM_TOPIC_NAME, payload="Starting inference on %s" % PARAM_INPUT_SOURCE)
124     start_time = timeit.default_timer()
125     inf_seconds = 0.0
126     frame_count = 0
127     labeldata = None
128     if PARAM_LABELMAP_FILE is not None:
129         with open(PARAM_LABELMAP_FILE) as labelmap_file:
130             labeldata = json.load(labelmap_file)
131
132     while (cap.isOpened()):
133         ret, frame = cap.read()
134         if not ret:
135             break
136         frameid = cap.get(cv2.CAP_PROP_POS_FRAMES)
137         initial_w = cap.get(3)
138         initial_h = cap.get(4)
139         in_frame = cv2.resize(frame, (w, h))
140         in_frame = in_frame.transpose((2, 0, 1))  # Change data layout from HWC to CHW
141         in_frame = in_frame.reshape((n, c, h, w))
142         # Start synchronous inference
143         inf_start_time = timeit.default_timer()
144         res = exec_net.infer(inputs={input_blob: in_frame})
145         inf_seconds += timeit.default_timer() - inf_start_time
146         # Parse detection results of the current request
147         res_json = OrderedDict()
148         frame_timestamp = datetime.datetime.now()
149         object_id = 0
150         for obj in res[out_blob][0][0]:
151             if obj[2] > 0.5:
152                 xmin = int(obj[3] * initial_w)
153                 ymin = int(obj[4] * initial_h)
154                 xmax = int(obj[5] * initial_w)
155                 ymax = int(obj[6] * initial_h)
156                 cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (255, 165, 20), 4)
157                 obj_id = "Object" + str(object_id)
158                 classlabel = labeldata[str(int(obj[1]))] if labeldata else ""
159                 res_json[obj_id] = {"label": int(obj[1]), "class": classlabel, "confidence": round(obj[2], 2), "xmin": round(
160                     obj[3], 2), "ymin": round(obj[4], 2), "xmax": round(obj[5], 2), "ymax": round(obj[6], 2)}
161                 object_id += 1
162         frame_count += 1
163         # Measure elapsed seconds since the last report
164         seconds_elapsed = timeit.default_timer() - start_time
165         if seconds_elapsed >= reporting_interval:
166             res_json["timestamp"] = frame_timestamp.isoformat()
167             res_json["frame_id"] = int(frameid)
168             res_json["inference_fps"] = frame_count / inf_seconds
169             start_time = timeit.default_timer()
170             report(res_json, frame)
171             frame_count = 0
172             inf_seconds = 0.0
173
174     client.publish(topic=PARAM_TOPIC_NAME, payload="End of the input, exiting...")
175     del exec_net
176     del plugin
177
178
179 greengrass_object_detection_sample_ssd_run()
180
181
182 def function_handler(event, context):
183     client.publish(topic=PARAM_TOPIC_NAME, payload='HANDLER_CALLED!')
184     return