Blob: add SyncedMemory shape accessor for GPU shape access
authorJeff Donahue <jeff.donahue@gmail.com>
Thu, 5 Mar 2015 05:31:34 +0000 (21:31 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 19 Sep 2015 00:53:19 +0000 (17:53 -0700)
include/caffe/blob.hpp
src/caffe/blob.cpp

index dda7b1f..fea5117 100644 (file)
@@ -219,6 +219,7 @@ class Blob {
 
   const Dtype* cpu_data() const;
   void set_cpu_data(Dtype* data);
+  const int* gpu_shape() const;
   const Dtype* gpu_data() const;
   const Dtype* cpu_diff() const;
   const Dtype* gpu_diff() const;
@@ -268,6 +269,7 @@ class Blob {
  protected:
   shared_ptr<SyncedMemory> data_;
   shared_ptr<SyncedMemory> diff_;
+  shared_ptr<SyncedMemory> shape_data_;
   vector<int> shape_;
   int count_;
   int capacity_;
index 8450aa1..c86fd5d 100644 (file)
@@ -24,11 +24,16 @@ void Blob<Dtype>::Reshape(const vector<int>& shape) {
   CHECK_LE(shape.size(), kMaxBlobAxes);
   count_ = 1;
   shape_.resize(shape.size());
+  if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
+    shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
+  }
+  int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
   for (int i = 0; i < shape.size(); ++i) {
     CHECK_GE(shape[i], 0);
     CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
     count_ *= shape[i];
     shape_[i] = shape[i];
+    shape_data[i] = shape[i];
   }
   if (count_ > capacity_) {
     capacity_ = count_;
@@ -68,6 +73,12 @@ Blob<Dtype>::Blob(const vector<int>& shape)
 }
 
 template <typename Dtype>
+const int* Blob<Dtype>::gpu_shape() const {
+  CHECK(shape_data_);
+  return (const int*)shape_data_->gpu_data();
+}
+
+template <typename Dtype>
 const Dtype* Blob<Dtype>::cpu_data() const {
   CHECK(data_);
   return (const Dtype*)data_->cpu_data();