-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
}
return nullptr;
}
+
+
+ #define ASSERT_N_CONNECTIONS(a, b, n) \
+ ASSERT_EQ(countForwardConnections(#a, #b), n);\
+ ASSERT_EQ(countBackwardConnections(#a, #b), n);
+
#define ASSERT_CONNECTION(a, b) \
- ASSERT_TRUE(assertConnection(#a, #b));
+ ASSERT_N_CONNECTIONS(a,b,1);
+
+ #define ASSERT_2_CONNECTIONS(a, b) \
+ ASSERT_N_CONNECTIONS(a,b,2);
+
+ #define ASSERT_3_CONNECTIONS(a, b) \
+ ASSERT_N_CONNECTIONS(a,b,3);
+
+ /**
+ * @brief check connection without direction
+ */
+ #define ASSERT_NO_CONNECTION(a, b) \
+ ASSERT_EQ(countConnections(#a, #b), 0);\
+ ASSERT_EQ(countConnections(#b, #a), 0);\
void ASSERT_DIMS(int x, const SizeVector & dims) {
}
}
- bool assertConnection(std::string a, std::string b) {
+ int countForwardConnections(std::string a, std::string b) {
+ long int nForward = 0;
+ CNNLayerPtr layerExist;
+ try {
+ layerExist = wrap.getLayerByName(a.c_str());
+ if (!layerExist) {
+ return 0;
+ }
+ } catch(...) {
+ return 0;
+ }
- bool bForward = false;
- for (auto && outData : wrap.getLayerByName(a.c_str())->outData) {
+ for (auto && outData : layerExist->outData) {
auto &inputMap = outData->inputTo;
- auto i =
- std::find_if(inputMap.begin(), inputMap.end(), [&](std::map<std::string, CNNLayerPtr>::value_type &vt) {
+ nForward +=
+ std::count_if(inputMap.begin(), inputMap.end(), [&](std::map<std::string, CNNLayerPtr>::value_type &vt) {
return vt.second->name == b;
});
- if (i != inputMap.end()) {
- bForward = true;
- break;
- }
}
- if (!bForward) {
- return false;
+
+ return nForward;
+ }
+
+ int countBackwardConnections(std::string a, std::string b) {
+ CNNLayerPtr layerExist;
+ try {
+ layerExist = wrap.getLayerByName(b.c_str());
+ if (!layerExist) {
+ return 0;
+ }
+ } catch(...) {
+ return 0;
}
- auto prevData = wrap.getLayerByName(b.c_str())->insData;
+ auto prevData = layerExist->insData;
- auto j = std::find_if(prevData.begin(), prevData.end(), [&](DataWeakPtr wp) {
+ auto nBackward = std::count_if(prevData.begin(), prevData.end(), [&](DataWeakPtr wp) {
return wp.lock()->getCreatorLayer().lock()->name == a;
});
- return j != prevData.end();
+
+ return nBackward;
+ }
+
+ int countConnections(std::string a, std::string b) {
+ return countForwardConnections(a, b) + countBackwardConnections(a, b);
}
int numCreated = 0;
}
}
+ void TearDown() override {
+ // Reset shared_pointer circular dependencies to mitigate memory leaks.
+ for (auto& items : datas) {
+ for (auto& data : items) {
+ for (auto& input : data->getInputTo()) {
+ input.second.reset();
+ }
+ }
+ }
+ }
+
int ID(const CNNLayerPtr &ptr) {
for (int i = 0; i < layers.size(); i++) {
if (layers[i].get() == ptr.get())