caffe2 - easy - test utils to fill tensors (#15019)
authorDuc Ngo <duc@fb.com>
Fri, 14 Dec 2018 04:42:59 +0000 (20:42 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 04:45:44 +0000 (20:45 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15019

Put some utils to fill tensors to test_utils

Reviewed By: salexspb

Differential Revision: D13386691

fbshipit-source-id: 51d891aad1ca12dc5133c0352df65b8db4f96edb

caffe2/core/test_utils.h

index cd4763b..f95f981 100644 (file)
@@ -36,6 +36,57 @@ caffe2::OperatorDef* createOperator(
     const std::vector<string>& outputs,
     caffe2::NetDef* net);
 
+// Fill data from a vector to a tensor.
+template <typename T>
+void fillTensor(
+    const vector<int64_t>& shape,
+    const vector<T>& data,
+    TensorCPU* tensor) {
+  tensor->Resize(shape);
+  CAFFE_ENFORCE_EQ(data.size(), tensor->numel());
+  auto ptr = tensor->mutable_data<T>();
+  for (int i = 0; i < tensor->numel(); ++i) {
+    ptr[i] = data[i];
+  }
+}
+
+// Create a tensor and fill data.
+template <typename T>
+caffe2::Tensor* createTensorAndFill(
+    const string& name,
+    const vector<int64_t>& shape,
+    const vector<T>& data,
+    Workspace* workspace) {
+  auto* tensor = createTensor(name, workspace);
+  fillTensor<T>(shape, data, tensor);
+  return tensor;
+}
+
+// Fill a constant to a tensor.
+template <typename T>
+void constantFillTensor(
+    const vector<int64_t>& shape,
+    const T& data,
+    TensorCPU* tensor) {
+  tensor->Resize(shape);
+  auto ptr = tensor->mutable_data<T>();
+  for (int i = 0; i < tensor->numel(); ++i) {
+    ptr[i] = data;
+  }
+}
+
+// Create a tensor and fill a constant.
+template <typename T>
+caffe2::Tensor* createTensorAndConstantFill(
+    const string& name,
+    const vector<int64_t>& shape,
+    const T& data,
+    Workspace* workspace) {
+  auto* tensor = createTensor(name, workspace);
+  constantFillTensor<T>(shape, data, tensor);
+  return tensor;
+}
+
 // Coincise util class to mutate a net in a chaining fashion.
 class NetMutator {
  public:
@@ -50,6 +101,36 @@ class NetMutator {
   caffe2::NetDef* net_;
 };
 
+// Coincise util class to mutate a workspace in a chaining fashion.
+class WorkspaceMutator {
+ public:
+  explicit WorkspaceMutator(caffe2::Workspace* workspace)
+      : workspace_(workspace) {}
+
+  // New tensor filled by a data vector.
+  template <typename T>
+  WorkspaceMutator& newTensor(
+      const string& name,
+      const vector<int64_t>& shape,
+      const vector<T>& data) {
+    createTensorAndFill<T>(name, shape, data, workspace_);
+    return *this;
+  }
+
+  // New tensor filled by a constant.
+  template <typename T>
+  WorkspaceMutator& newTensorConst(
+      const string& name,
+      const vector<int64_t>& shape,
+      const T& data) {
+    createTensorAndConstantFill<T>(name, shape, data, workspace_);
+    return *this;
+  }
+
+ private:
+  caffe2::Workspace* workspace_;
+};
+
 } // namespace testing
 } // namespace caffe2