5 // Copyright (C) 2018-2019 Intel Corporation
6 // SPDX-License-Identifier: Apache-2.0
11 #include <gtest/gtest.h>
12 #include <inference_engine/shape_infer/const_infer/ie_const_infer_holder.hpp>
13 #include "built_in_shape_infer_general_test.hpp"
15 namespace IE = InferenceEngine;
18 testing::InOutShapes inOutShapes;
19 std::vector<std::vector<float>> inData;
20 std::vector<std::vector<float>> outData;
23 using FloatMap = std::map<std::string, std::vector<float>>;
24 using InitBlobsFunc = std::function<IE::BlobMap(const FloatMap& inOutData)>;
29 FloatMap floatBlobData;
30 std::map<std::string, std::string> strParams;
31 InitBlobsFunc initBlobs;
32 std::vector<IE::Precision> inPrecisions;
33 std::vector<IE::Precision> outPrecisions;
38 explicit BaseMatcher(ASIConfig config) : config(std::move(config)) {}
41 void compareWithRef(const std::vector<IE::Blob::Ptr>& outBlobs,
42 const std::vector<std::vector<float>>& refData,
43 float tolerance = 0.0001);
45 std::vector<IE::Blob::Ptr>
46 createBlobs(const std::vector<IE::SizeVector>& shapes, const std::vector<IE::Precision>& precisions);
48 void fillBlobs(const std::vector<IE::Blob::Ptr>& blobs, const std::vector<std::vector<float>>& data);
53 class ConstInferMatcher : public BaseMatcher {
55 explicit ConstInferMatcher(const ASIConfig& config) : BaseMatcher(config) {}
57 void toData(const std::vector<std::vector<float>>& refData);
60 std::shared_ptr<IE::ShapeInfer::ConstInferHolder> holder;
63 class ShapeInferMatcher : public BaseMatcher {
65 explicit ShapeInferMatcher(const ASIConfig& config) : BaseMatcher(config) {}
67 void toShapes(const std::vector<IE::SizeVector>& refShape);
70 std::unique_ptr<IE::ShapeInfer::BuiltInShapeInferHolder> siHolder;
72 IE::ResponseDesc desc;
76 class MatcherConfigurator {
78 explicit MatcherConfigurator(ASIConfig config) : config(std::move(config)) {}
80 MatcherConfigurator& withParams(const std::map<std::string, std::string>& params) {
81 config.strParams = params;
85 MatcherConfigurator& withInputPrecisions(const std::vector<IE::Precision>& inputPrecisions) {
86 config.inPrecisions = inputPrecisions;
90 MatcherConfigurator& withOutputPrecisions(const std::vector<IE::Precision>& outputPrecisions) {
91 config.outPrecisions = outputPrecisions;
95 MatcherConfigurator& withBlobs(const FloatMap& blobDataMap) {
96 config.floatBlobData = blobDataMap;
108 class ASITestBuilder {
112 config.initBlobs = defaultBlobInit();
115 ASITestBuilder& withData(const InOutData& data) {
116 config.inOutData = data;
117 config.inPrecisions = {data.inOutShapes.inDims.size(), IE::Precision::FP32};
118 config.outPrecisions = {data.inOutShapes.outDims.size(), IE::Precision::FP32};
122 ASITestBuilder& withType(const std::string& type) {
127 MatcherConfigurator<ConstInferMatcher> constInferResultFor();
129 MatcherConfigurator<ShapeInferMatcher> shapeInferResultFor();
132 InitBlobsFunc defaultBlobInit();
135 PRETTY_PARAM(BlobsParam, FloatMap)
137 PRETTY_PARAM(InOutDataParam, InOutData)