Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / gna / matchers / pwl_matcher.hpp
1 /*
2  * INTEL CONFIDENTIAL
3  * Copyright (C) 2018-2019 Intel Corporation.
4  *
5  * The source code contained or described herein and all documents
6  * related to the source code ("Material") are owned by Intel Corporation
7  * or its suppliers or licensors. Title to the Material remains with
8  * Intel Corporation or its suppliers and licensors. The Material may
9  * contain trade secrets and proprietary and confidential information
10  * of Intel Corporation and its suppliers and licensors, and is protected
11  * by worldwide copyright and trade secret laws and treaty provisions.
12  * No part of the Material may be used, copied, reproduced, modified,
13  * published, uploaded, posted, transmitted, distributed, or disclosed
14  * in any way without Intel's prior express written permission.
15  *
16  * No license under any patent, copyright, trade secret or other
17  * intellectual property right is granted to or conferred upon you by
18  * disclosure or delivery of the Materials, either expressly, by implication,
19  * inducement, estoppel or otherwise. Any license under such intellectual
20  * property rights must be express and approved by Intel in writing.
21  *
22  * Include any supplier copyright notices as supplier requires Intel to use.
23  *
24  * Include supplier trademarks or logos as supplier requires Intel to use,
25  * preceded by an asterisk. An asterisked footnote can be added as follows:
26  * *Third Party trademarks are the property of their respective owners.
27  *
28  * Unless otherwise agreed by Intel in writing, you may not remove or alter
29  * this notice or any other notice embedded in Materials by Intel or Intel's
30  * suppliers or licensors in any way.
31  */
32
33 #pragma once
34 #include "nnet_base_matcher.hpp"
35
36 class PWLMatcher : public ::testing::MatcherInterface<const intel_nnet_type_t*> {
37     bool matchInserted;
38     const int matchQuantity;
39     mutable int timesInserted = 0;
40  public:
41     PWLMatcher(bool inserted, int matchQuantity) : matchInserted(inserted), matchQuantity(matchQuantity) {}
42
43     bool MatchAndExplain(const intel_nnet_type_t *foo, ::testing::MatchResultListener *listener) const override {
44         if (foo == nullptr)
45             return false;
46         timesInserted = 0;
47         for(int i = 0; i < foo->nLayers; i++) {
48             if (foo->pLayers[i].nLayerKind != INTEL_AFFINE &&
49                 foo->pLayers[i].nLayerKind != INTEL_AFFINE_DIAGONAL &&
50                 foo->pLayers[i].nLayerKind != INTEL_CONVOLUTIONAL) continue;
51             auto affine = reinterpret_cast<intel_affine_layer_t*>(foo->pLayers[i].pLayerStruct);
52             if (affine == nullptr) continue;
53
54             bool hasPwl = affine->pwl.nSegments != 0 && affine->pwl.pSegments != nullptr;
55
56             if (hasPwl) {
57                 if (matchQuantity == -1)
58                     return matchInserted;
59                 else
60                     timesInserted ++;
61             }
62         }
63         if (matchInserted) {
64             if (matchQuantity != -1) {
65                 return timesInserted == matchQuantity;
66             }
67             return timesInserted != 0;
68         }
69
70         return timesInserted == 0;
71     };
72     void DescribeTo(::std::ostream *os) const override {
73         if (!matchInserted ) {
74             *os << "should not have PWL layer as part of nnet structure, but was found " << timesInserted <<" times" ;
75         } else {
76             if (matchQuantity == -1) {
77                 *os << "should have PWL layer as part of nnet structure, but it was not found " ;
78             } else {
79                 *os << "should have PWL layer as part of nnet structure, for " << matchQuantity <<" times, but was found only " << timesInserted ;
80             }
81         }
82     }
83 };
84
85 inline ::testing::Matcher<const intel_nnet_type_t*> HasPwlLayer(bool inserted = true, int matchQuantity = -1) {
86     std::unique_ptr<NNetComponentMatcher> c (new NNetComponentMatcher());
87     c->add(new PWLMatcher(inserted, matchQuantity));
88     return ::testing::MakeMatcher(c.release());
89 }