add double_data, double_diff to BlobProto for weights/snapshots saved
authorJeff Donahue <jeff.donahue@gmail.com>
Thu, 30 Jul 2015 00:27:04 +0000 (17:27 -0700)
committerEric Tzeng <etzeng@eecs.berkeley.edu>
Fri, 7 Aug 2015 20:48:42 +0000 (13:48 -0700)
when using Dtype == double

src/caffe/blob.cpp
src/caffe/proto/caffe.proto

index 94fdcc3..8450aa1 100644 (file)
@@ -456,10 +456,25 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
   }
   // copy data
   Dtype* data_vec = mutable_cpu_data();
-  for (int i = 0; i < count_; ++i) {
-    data_vec[i] = proto.data(i);
+  if (proto.double_data_size() > 0) {
+    CHECK_EQ(count_, proto.double_data_size());
+    for (int i = 0; i < count_; ++i) {
+      data_vec[i] = proto.double_data(i);
+    }
+  } else {
+    CHECK_EQ(count_, proto.data_size());
+    for (int i = 0; i < count_; ++i) {
+      data_vec[i] = proto.data(i);
+    }
   }
-  if (proto.diff_size() > 0) {
+  if (proto.double_diff_size() > 0) {
+    CHECK_EQ(count_, proto.double_diff_size());
+    Dtype* diff_vec = mutable_cpu_diff();
+    for (int i = 0; i < count_; ++i) {
+      diff_vec[i] = proto.double_diff(i);
+    }
+  } else if (proto.diff_size() > 0) {
+    CHECK_EQ(count_, proto.diff_size());
     Dtype* diff_vec = mutable_cpu_diff();
     for (int i = 0; i < count_; ++i) {
       diff_vec[i] = proto.diff(i);
@@ -467,20 +482,40 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
   }
 }
 
-template <typename Dtype>
-void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
+template <>
+void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
+  proto->clear_shape();
+  for (int i = 0; i < shape_.size(); ++i) {
+    proto->mutable_shape()->add_dim(shape_[i]);
+  }
+  proto->clear_double_data();
+  proto->clear_double_diff();
+  const double* data_vec = cpu_data();
+  for (int i = 0; i < count_; ++i) {
+    proto->add_double_data(data_vec[i]);
+  }
+  if (write_diff) {
+    const double* diff_vec = cpu_diff();
+    for (int i = 0; i < count_; ++i) {
+      proto->add_double_diff(diff_vec[i]);
+    }
+  }
+}
+
+template <>
+void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const {
   proto->clear_shape();
   for (int i = 0; i < shape_.size(); ++i) {
     proto->mutable_shape()->add_dim(shape_[i]);
   }
   proto->clear_data();
   proto->clear_diff();
-  const Dtype* data_vec = cpu_data();
+  const float* data_vec = cpu_data();
   for (int i = 0; i < count_; ++i) {
     proto->add_data(data_vec[i]);
   }
   if (write_diff) {
-    const Dtype* diff_vec = cpu_diff();
+    const float* diff_vec = cpu_diff();
     for (int i = 0; i < count_; ++i) {
       proto->add_diff(diff_vec[i]);
     }
index adcf4e2..03daa80 100644 (file)
@@ -11,6 +11,8 @@ message BlobProto {
   optional BlobShape shape = 7;
   repeated float data = 5 [packed = true];
   repeated float diff = 6 [packed = true];
+  repeated double double_data = 8 [packed = true];
+  repeated double double_diff = 9 [packed = true];
 
   // 4D dimensions -- deprecated.  Use "shape" instead.
   optional int32 num = 1 [default = 0];