class DiagLayerMatcher : public ::testing::MatcherInterface<const intel_nnet_type_t*> {
bool matchInserted;
- int matchQuantity;
+ int matchQuantity;
+ mutable int actualQuantity;
public:
DiagLayerMatcher(bool matchInserted, int matchQuantity) : matchInserted(matchInserted), matchQuantity(matchQuantity) {}
bool MatchAndExplain(const intel_nnet_type_t *foo, ::testing::MatchResultListener *listener) const override {
if (foo == nullptr)
return false;
+ actualQuantity = 0;
for(int i = 0; i < foo->nLayers; i++) {
if (foo->pLayers[i].nLayerKind != INTEL_AFFINE_DIAGONAL) continue;
// diagonal layer has to have 1 for weights and 0 for biases
// if all weights are zero, or zero value doesn't look like padding
if (!bWeightsOK && beforePadding == -1) continue;
-
- return matchInserted;
+ actualQuantity ++;
+ }
+ // means any quantity > 0
+ if (matchQuantity == -1) {
+ if (actualQuantity > 0)
+ return matchInserted;
+ else
+ return !matchInserted;
}
- return !matchInserted;
+ if (actualQuantity == matchQuantity)
+ return matchInserted;
+ else
+ return !matchInserted;
+
};
void DescribeTo(::std::ostream *os) const override {
- *os << "should "<< (matchInserted ? "" : "not ") << "have Identity Diagonal Primitive primitive as part of nnet structure";
+ *os << "should "<< (matchInserted ? "" : "not ") << "have "
+ << (matchQuantity == -1 ? "any" : std::to_string(matchQuantity))
+ << " Identity Diagonal Primitive primitive as part of nnet structure, but was " << actualQuantity;
}
};