Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / shape_infer / adult_test_utils.hpp
1 #include <utility>
2
3 #include <utility>
4
5 // Copyright (C) 2018-2019 Intel Corporation
6 // SPDX-License-Identifier: Apache-2.0
7 //
8
9 #pragma once
10
11 #include <gtest/gtest.h>
12 #include <inference_engine/shape_infer/const_infer/ie_const_infer_holder.hpp>
13 #include "built_in_shape_infer_general_test.hpp"
14
15 namespace IE = InferenceEngine;
16
17 struct InOutData {
18     testing::InOutShapes inOutShapes;
19     std::vector<std::vector<float>> inData;
20     std::vector<std::vector<float>> outData;
21 };
22
23 using FloatMap = std::map<std::string, std::vector<float>>;
24 using InitBlobsFunc = std::function<IE::BlobMap(const FloatMap& inOutData)>;
25
26 struct ASIConfig {
27     InOutData inOutData;
28     std::string type;
29     FloatMap floatBlobData;
30     std::map<std::string, std::string> strParams;
31     InitBlobsFunc initBlobs;
32     std::vector<IE::Precision> inPrecisions;
33     std::vector<IE::Precision> outPrecisions;
34 };
35
36 class BaseMatcher {
37 public:
38     explicit BaseMatcher(ASIConfig config) : config(std::move(config)) {}
39
40 protected:
41     void compareWithRef(const std::vector<IE::Blob::Ptr>& outBlobs,
42                         const std::vector<std::vector<float>>& refData,
43                         float tolerance = 0.0001);
44
45     std::vector<IE::Blob::Ptr>
46     createBlobs(const std::vector<IE::SizeVector>& shapes, const std::vector<IE::Precision>& precisions);
47
48     void fillBlobs(const std::vector<IE::Blob::Ptr>& blobs, const std::vector<std::vector<float>>& data);
49
50     ASIConfig config;
51 };
52
53 class ConstInferMatcher : public BaseMatcher {
54 public:
55     explicit ConstInferMatcher(const ASIConfig& config) : BaseMatcher(config) {}
56
57     void toData(const std::vector<std::vector<float>>& refData);
58
59 private:
60     std::shared_ptr<IE::ShapeInfer::ConstInferHolder> holder;
61 };
62
63 class ShapeInferMatcher : public BaseMatcher {
64 public:
65     explicit ShapeInferMatcher(const ASIConfig& config) : BaseMatcher(config) {}
66
67     void toShapes(const std::vector<IE::SizeVector>& refShape);
68
69 private:
70     std::unique_ptr<IE::ShapeInfer::BuiltInShapeInferHolder> siHolder;
71     IE::StatusCode sts;
72     IE::ResponseDesc desc;
73 };
74
75 template<typename M>
76 class MatcherConfigurator {
77 public:
78     explicit MatcherConfigurator(ASIConfig config) : config(std::move(config)) {}
79
80     MatcherConfigurator& withParams(const std::map<std::string, std::string>& params) {
81         config.strParams = params;
82         return *this;
83     }
84
85     MatcherConfigurator& withInputPrecisions(const std::vector<IE::Precision>& inputPrecisions) {
86         config.inPrecisions = inputPrecisions;
87         return *this;
88     }
89
90     MatcherConfigurator& withOutputPrecisions(const std::vector<IE::Precision>& outputPrecisions) {
91         config.outPrecisions = outputPrecisions;
92         return *this;
93     }
94
95     MatcherConfigurator& withBlobs(const FloatMap& blobDataMap) {
96         config.floatBlobData = blobDataMap;
97         return *this;
98     }
99
100     M equals() {
101         return M(config);
102     }
103
104 private:
105     ASIConfig config;
106 };
107
108 class ASITestBuilder {
109     ASIConfig config;
110 public:
111     ASITestBuilder() {
112         config.initBlobs = defaultBlobInit();
113     }
114
115     ASITestBuilder& withData(const InOutData& data) {
116         config.inOutData = data;
117         config.inPrecisions = {data.inOutShapes.inDims.size(), IE::Precision::FP32};
118         config.outPrecisions = {data.inOutShapes.outDims.size(), IE::Precision::FP32};
119         return *this;
120     }
121
122     ASITestBuilder& withType(const std::string& type) {
123         config.type = type;
124         return *this;
125     }
126
127     MatcherConfigurator<ConstInferMatcher> constInferResultFor();
128
129     MatcherConfigurator<ShapeInferMatcher> shapeInferResultFor();
130
131 private:
132     InitBlobsFunc defaultBlobInit();
133 };
134
135 PRETTY_PARAM(BlobsParam, FloatMap)
136
137 PRETTY_PARAM(InOutDataParam, InOutData)