caffe2 - easy - test utils to compare tensors in two workspaces (#15181)
authorDuc Ngo <duc@fb.com>
Fri, 14 Dec 2018 04:42:59 +0000 (20:42 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 04:45:46 +0000 (20:45 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15181

Add test utils to compare tensors in two workspaces

Reviewed By: ZolotukhinM

Differential Revision: D13387212

fbshipit-source-id: e19d932a1ecc696bd0a08ea14d9a7485cce67bb2

caffe2/core/test_utils.cc
caffe2/core/test_utils.h

index eafe5e2..6b58240 100644 (file)
@@ -3,9 +3,49 @@
 
 #include "test_utils.h"
 
+namespace {
+template <typename T>
+void assertTensorEqualsWithType(
+    const caffe2::TensorCPU& tensor1,
+    const caffe2::TensorCPU& tensor2) {
+  CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
+  for (auto idx = 0; idx < tensor1.numel(); ++idx) {
+    CAFFE_ENFORCE_EQ(tensor1.data<T>()[idx], tensor2.data<T>()[idx]);
+  }
+}
+} // namespace
+
 namespace caffe2 {
 namespace testing {
 
+void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2) {
+  CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
+  if (tensor1.IsType<float>()) {
+    CAFFE_ENFORCE(tensor2.IsType<float>());
+    assertTensorEqualsWithType<float>(tensor1, tensor2);
+  } else if (tensor1.IsType<int>()) {
+    CAFFE_ENFORCE(tensor2.IsType<int>());
+    assertTensorEqualsWithType<int>(tensor1, tensor2);
+  } else if (tensor1.IsType<int64_t>()) {
+    CAFFE_ENFORCE(tensor2.IsType<int64_t>());
+    assertTensorEqualsWithType<int64_t>(tensor1, tensor2);
+  }
+  // Add more types if needed.
+}
+
+void assertTensorListEquals(
+    const std::vector<std::string>& tensorNames,
+    const Workspace& workspace1,
+    const Workspace& workspace2) {
+  for (const string& tensorName : tensorNames) {
+    CAFFE_ENFORCE(workspace1.HasBlob(tensorName));
+    CAFFE_ENFORCE(workspace2.HasBlob(tensorName));
+    auto& tensor1 = getTensor(workspace1, tensorName);
+    auto& tensor2 = getTensor(workspace2, tensorName);
+    assertTensorEquals(tensor1, tensor2);
+  }
+}
+
 const caffe2::Tensor& getTensor(
     const caffe2::Workspace& workspace,
     const std::string& name) {
index f95f981..3b3af64 100644 (file)
 namespace caffe2 {
 namespace testing {
 
-// Asserts that the numeric values of two tensors are the same.
-template <typename T>
-void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2) {
-  CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
-  for (auto idx = 0; idx < tensor1.numel(); ++idx) {
-    CAFFE_ENFORCE_EQ(tensor1.data<T>()[idx], tensor2.data<T>()[idx]);
-  }
-}
+// Asserts that the values of two tensors are the same.
+void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2);
+
+// Asserts a list of tensors presented in two workspaces are equal.
+void assertTensorListEquals(
+    const std::vector<std::string>& tensorNames,
+    const Workspace& workspace1,
+    const Workspace& workspace2);
 
 // Read a tensor from the workspace.
 const caffe2::Tensor& getTensor(