#include <cstring>
+#include <vector>
#include "gtest/gtest.h"
EXPECT_EQ(this->blob_->count(), 120);
}
+TYPED_TEST(BlobSimpleTest, TestLegacyBlobProtoShapeEquals) {
+ BlobProto blob_proto;
+
+ // Reshape to (3 x 2).
+ vector<int> shape(2);
+ shape[0] = 3;
+ shape[1] = 2;
+ this->blob_->Reshape(shape);
+
+ // (3 x 2) blob == (1 x 1 x 3 x 2) legacy blob
+ blob_proto.set_num(1);
+ blob_proto.set_channels(1);
+ blob_proto.set_height(3);
+ blob_proto.set_width(2);
+ EXPECT_TRUE(this->blob_->ShapeEquals(blob_proto));
+
+ // (3 x 2) blob != (0 x 1 x 3 x 2) legacy blob
+ blob_proto.set_num(0);
+ blob_proto.set_channels(1);
+ blob_proto.set_height(3);
+ blob_proto.set_width(2);
+ EXPECT_FALSE(this->blob_->ShapeEquals(blob_proto));
+
+ // (3 x 2) blob != (3 x 1 x 3 x 2) legacy blob
+ blob_proto.set_num(3);
+ blob_proto.set_channels(1);
+ blob_proto.set_height(3);
+ blob_proto.set_width(2);
+ EXPECT_FALSE(this->blob_->ShapeEquals(blob_proto));
+
+ // Reshape to (1 x 3 x 2).
+ shape.insert(shape.begin(), 1);
+ this->blob_->Reshape(shape);
+
+ // (1 x 3 x 2) blob == (1 x 1 x 3 x 2) legacy blob
+ blob_proto.set_num(1);
+ blob_proto.set_channels(1);
+ blob_proto.set_height(3);
+ blob_proto.set_width(2);
+ EXPECT_TRUE(this->blob_->ShapeEquals(blob_proto));
+
+ // Reshape to (2 x 3 x 2).
+ shape[0] = 2;
+ this->blob_->Reshape(shape);
+
+ // (2 x 3 x 2) blob != (1 x 1 x 3 x 2) legacy blob
+ blob_proto.set_num(1);
+ blob_proto.set_channels(1);
+ blob_proto.set_height(3);
+ blob_proto.set_width(2);
+ EXPECT_FALSE(this->blob_->ShapeEquals(blob_proto));
+}
+
template <typename TypeParam>
class BlobMathTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;