Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / benchmark_app / infer_request_wrap.hpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <map>
9 #include <string>
10 #include <chrono>
11
12 #include "inference_engine.hpp"
13
14 typedef std::chrono::high_resolution_clock Time;
15 typedef std::chrono::nanoseconds ns;
16
17 /// @brief Wrapper class for InferenceEngine::InferRequest. Handles asynchronous callbacks and calculates execution time.
18 class InferReqWrap {
19 public:
20     using Ptr = std::shared_ptr<InferReqWrap>;
21
22     explicit InferReqWrap(InferenceEngine::ExecutableNetwork& net) : _request(net.CreateInferRequest()) {
23         _request.SetCompletionCallback(
24                 [&]() {
25                     _endTime = Time::now();
26                 });
27     }
28
29     void startAsync() {
30         _startTime = Time::now();
31         _request.StartAsync();
32     }
33
34     void infer() {
35         _startTime = Time::now();
36         _request.Infer();
37         _endTime = Time::now();
38     }
39
40     std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> getPerformanceCounts() {
41         return _request.GetPerformanceCounts();
42     }
43
44     void wait() {
45         InferenceEngine::StatusCode code = _request.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
46         if (code != InferenceEngine::StatusCode::OK) {
47             throw std::logic_error("Wait");
48         }
49     }
50
51     InferenceEngine::Blob::Ptr getBlob(const std::string &name) {
52         return _request.GetBlob(name);
53     }
54
55     double getExecTime() const {
56         auto execTime = std::chrono::duration_cast<ns>(_endTime - _startTime);
57         return static_cast<double>(execTime.count()) * 0.000001;
58     }
59
60 private:
61     InferenceEngine::InferRequest _request;
62     Time::time_point _startTime;
63     Time::time_point _endTime;
64 };