Making HDF5 blob data non-mutable for copy (minor)
authorSergey Karayev <sergeykarayev@gmail.com>
Mon, 17 Mar 2014 21:05:32 +0000 (14:05 -0700)
committerSergey Karayev <sergeykarayev@gmail.com>
Mon, 17 Mar 2014 21:05:32 +0000 (14:05 -0700)
src/caffe/layers/hdf5_data_layer.cpp
src/caffe/layers/hdf5_data_layer.cu

index 5b568a8..98873cb 100644 (file)
@@ -6,6 +6,8 @@ Contributors:
 TODO:
 - load file in a separate thread ("prefetch")
 - can be smarter about the memcpy call instead of doing it row-by-row
+  :: use util functions caffe_copy, and Blob->offset()
+  :: don't forget to update hdf5_daa_layer.cu accordingly
 */
 #include <stdint.h>
 #include <string>
@@ -110,11 +112,11 @@ void HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     }
 
     memcpy(&(*top)[0]->mutable_cpu_data()[i * data_count],
-            &data_blob_.mutable_cpu_data()[current_row_ * data_count],
-            sizeof(Dtype) * data_count);
+           &data_blob_.cpu_data()[current_row_ * data_count],
+           sizeof(Dtype) * data_count);
 
     memcpy(&(*top)[1]->mutable_cpu_data()[i * label_data_count],
-            &label_blob_.mutable_cpu_data()[current_row_ * label_data_count],
+            &label_blob_.cpu_data()[current_row_ * label_data_count],
             sizeof(Dtype) * label_data_count);
   }
 }
index f1a6434..bed7f35 100644 (file)
@@ -43,13 +43,13 @@ void HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
     CUDA_CHECK(cudaMemcpy(
             &(*top)[0]->mutable_gpu_data()[i * data_count],
-            &data_blob_.mutable_cpu_data()[current_row_ * data_count],
+            &data_blob_.cpu_data()[current_row_ * data_count],
             sizeof(Dtype) * data_count,
             cudaMemcpyHostToDevice));
 
     CUDA_CHECK(cudaMemcpy(
             &(*top)[1]->mutable_gpu_data()[i * label_data_count],
-            &label_blob_.mutable_cpu_data()[current_row_ * label_data_count],
+            &label_blob_.cpu_data()[current_row_ * label_data_count],
             sizeof(Dtype) * label_data_count,
             cudaMemcpyHostToDevice));
   }