const Dtype* cpu_data() const;
void set_cpu_data(Dtype* data);
+ const int* gpu_shape() const;
const Dtype* gpu_data() const;
const Dtype* cpu_diff() const;
const Dtype* gpu_diff() const;
protected:
shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
+ shared_ptr<SyncedMemory> shape_data_;
vector<int> shape_;
int count_;
int capacity_;
CHECK_LE(shape.size(), kMaxBlobAxes);
count_ = 1;
shape_.resize(shape.size());
+ if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
+ shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
+ }
+ int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
for (int i = 0; i < shape.size(); ++i) {
CHECK_GE(shape[i], 0);
CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
count_ *= shape[i];
shape_[i] = shape[i];
+ shape_data[i] = shape[i];
}
if (count_ > capacity_) {
capacity_ = count_;
}
template <typename Dtype>
+const int* Blob<Dtype>::gpu_shape() const {
+ CHECK(shape_data_);
+ return (const int*)shape_data_->gpu_data();
+}
+
+template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() const {
CHECK(data_);
return (const Dtype*)data_->cpu_data();