From: Duc Ngo Date: Fri, 14 Dec 2018 04:42:59 +0000 (-0800) Subject: caffe2 - easy - test utils to compare tensors in two workspaces (#15181) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2243 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d0b4ae835df09f28d91fb695b4ebc5b57650534b;p=platform%2Fupstream%2Fpytorch.git caffe2 - easy - test utils to compare tensors in two workspaces (#15181) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15181 Add test utils to compare tensors in two workspaces Reviewed By: ZolotukhinM Differential Revision: D13387212 fbshipit-source-id: e19d932a1ecc696bd0a08ea14d9a7485cce67bb2 --- diff --git a/caffe2/core/test_utils.cc b/caffe2/core/test_utils.cc index eafe5e2..6b58240 100644 --- a/caffe2/core/test_utils.cc +++ b/caffe2/core/test_utils.cc @@ -3,9 +3,49 @@ #include "test_utils.h" +namespace { +template +void assertTensorEqualsWithType( + const caffe2::TensorCPU& tensor1, + const caffe2::TensorCPU& tensor2) { + CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes()); + for (auto idx = 0; idx < tensor1.numel(); ++idx) { + CAFFE_ENFORCE_EQ(tensor1.data()[idx], tensor2.data()[idx]); + } +} +} // namespace + namespace caffe2 { namespace testing { +void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2) { + CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes()); + if (tensor1.IsType()) { + CAFFE_ENFORCE(tensor2.IsType()); + assertTensorEqualsWithType(tensor1, tensor2); + } else if (tensor1.IsType()) { + CAFFE_ENFORCE(tensor2.IsType()); + assertTensorEqualsWithType(tensor1, tensor2); + } else if (tensor1.IsType()) { + CAFFE_ENFORCE(tensor2.IsType()); + assertTensorEqualsWithType(tensor1, tensor2); + } + // Add more types if needed. +} + +void assertTensorListEquals( + const std::vector& tensorNames, + const Workspace& workspace1, + const Workspace& workspace2) { + for (const string& tensorName : tensorNames) { + CAFFE_ENFORCE(workspace1.HasBlob(tensorName)); + CAFFE_ENFORCE(workspace2.HasBlob(tensorName)); + auto& tensor1 = getTensor(workspace1, tensorName); + auto& tensor2 = getTensor(workspace2, tensorName); + assertTensorEquals(tensor1, tensor2); + } +} + const caffe2::Tensor& getTensor( const caffe2::Workspace& workspace, const std::string& name) { diff --git a/caffe2/core/test_utils.h b/caffe2/core/test_utils.h index f95f981..3b3af64 100644 --- a/caffe2/core/test_utils.h +++ b/caffe2/core/test_utils.h @@ -10,14 +10,14 @@ namespace caffe2 { namespace testing { -// Asserts that the numeric values of two tensors are the same. -template -void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2) { - CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes()); - for (auto idx = 0; idx < tensor1.numel(); ++idx) { - CAFFE_ENFORCE_EQ(tensor1.data()[idx], tensor2.data()[idx]); - } -} +// Asserts that the values of two tensors are the same. +void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2); + +// Asserts a list of tensors presented in two workspaces are equal. +void assertTensorListEquals( + const std::vector& tensorNames, + const Workspace& workspace1, + const Workspace& workspace2); // Read a tensor from the workspace. const caffe2::Tensor& getTensor(