TestBlob: test that legacy BlobProtos are correctly handled by ShapeEquals
authorJeff Donahue <jeff.donahue@gmail.com>
Sun, 30 Nov 2014 00:11:07 +0000 (16:11 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 3 Mar 2015 23:55:13 +0000 (15:55 -0800)
src/caffe/test/test_blob.cpp

index a654896..7da6423 100644 (file)
@@ -1,4 +1,5 @@
 #include <cstring>
+#include <vector>
 
 #include "gtest/gtest.h"
 
@@ -51,6 +52,59 @@ TYPED_TEST(BlobSimpleTest, TestReshape) {
   EXPECT_EQ(this->blob_->count(), 120);
 }
 
+TYPED_TEST(BlobSimpleTest, TestLegacyBlobProtoShapeEquals) {
+  BlobProto blob_proto;
+
+  // Reshape to (3 x 2).
+  vector<int> shape(2);
+  shape[0] = 3;
+  shape[1] = 2;
+  this->blob_->Reshape(shape);
+
+  // (3 x 2) blob == (1 x 1 x 3 x 2) legacy blob
+  blob_proto.set_num(1);
+  blob_proto.set_channels(1);
+  blob_proto.set_height(3);
+  blob_proto.set_width(2);
+  EXPECT_TRUE(this->blob_->ShapeEquals(blob_proto));
+
+  // (3 x 2) blob != (0 x 1 x 3 x 2) legacy blob
+  blob_proto.set_num(0);
+  blob_proto.set_channels(1);
+  blob_proto.set_height(3);
+  blob_proto.set_width(2);
+  EXPECT_FALSE(this->blob_->ShapeEquals(blob_proto));
+
+  // (3 x 2) blob != (3 x 1 x 3 x 2) legacy blob
+  blob_proto.set_num(3);
+  blob_proto.set_channels(1);
+  blob_proto.set_height(3);
+  blob_proto.set_width(2);
+  EXPECT_FALSE(this->blob_->ShapeEquals(blob_proto));
+
+  // Reshape to (1 x 3 x 2).
+  shape.insert(shape.begin(), 1);
+  this->blob_->Reshape(shape);
+
+  // (1 x 3 x 2) blob == (1 x 1 x 3 x 2) legacy blob
+  blob_proto.set_num(1);
+  blob_proto.set_channels(1);
+  blob_proto.set_height(3);
+  blob_proto.set_width(2);
+  EXPECT_TRUE(this->blob_->ShapeEquals(blob_proto));
+
+  // Reshape to (2 x 3 x 2).
+  shape[0] = 2;
+  this->blob_->Reshape(shape);
+
+  // (2 x 3 x 2) blob != (1 x 1 x 3 x 2) legacy blob
+  blob_proto.set_num(1);
+  blob_proto.set_channels(1);
+  blob_proto.set_height(3);
+  blob_proto.set_width(2);
+  EXPECT_FALSE(this->blob_->ShapeEquals(blob_proto));
+}
+
 template <typename TypeParam>
 class BlobMathTest : public MultiDeviceTest<TypeParam> {
   typedef typename TypeParam::Dtype Dtype;