Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / gna / matchers / diag_matcher.hpp
index b39813d..2c17865 100644 (file)
@@ -9,12 +9,14 @@
 
 class DiagLayerMatcher : public ::testing::MatcherInterface<const intel_nnet_type_t*> {
     bool matchInserted;
-    int matchQuantity;
+    int  matchQuantity;
+    mutable int  actualQuantity;
 public:
     DiagLayerMatcher(bool matchInserted, int matchQuantity) : matchInserted(matchInserted), matchQuantity(matchQuantity) {}
     bool MatchAndExplain(const intel_nnet_type_t *foo, ::testing::MatchResultListener *listener) const override {
         if (foo == nullptr)
             return false;
+        actualQuantity = 0;
         for(int i = 0; i < foo->nLayers; i++) {
             if (foo->pLayers[i].nLayerKind != INTEL_AFFINE_DIAGONAL) continue;
             // diagonal layer has to have 1 for weights and 0 for biases
@@ -45,13 +47,25 @@ public:
 
             // if all weights are zero, or zero value doesn't look like padding
             if (!bWeightsOK && beforePadding == -1) continue;
-
-            return matchInserted;
+            actualQuantity ++;
+        }
+        // means any quantity > 0
+        if (matchQuantity == -1) {
+            if (actualQuantity > 0)
+                return matchInserted;
+            else
+                return !matchInserted;
         }
-        return !matchInserted;
+        if (actualQuantity == matchQuantity)
+            return matchInserted;
+        else
+            return !matchInserted;
+
     };
     void DescribeTo(::std::ostream *os) const override {
-        *os << "should "<< (matchInserted ? "" : "not ") << "have Identity Diagonal Primitive primitive as part of nnet structure";
+        *os << "should "<< (matchInserted ? "" : "not ") << "have "
+            << (matchQuantity == -1 ? "any" : std::to_string(matchQuantity))
+            << " Identity Diagonal Primitive primitive as part of nnet structure, but was " << actualQuantity;
     }
 };