caffe2 - easy - utils to set argument of operator (#15022)
authorDuc Ngo <duc@fb.com>
Fri, 14 Dec 2018 04:43:00 +0000 (20:43 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 04:45:50 +0000 (20:45 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15022

Add setArgument testing utils to make it easy to set argument for an operator

Reviewed By: yinghai

Differential Revision: D13405225

fbshipit-source-id: b5c1859c6819d53c1a44718e2868e3137067df36

caffe2/core/test_utils.cc
caffe2/core/test_utils.h

index 070ec34..3da7397 100644 (file)
@@ -86,7 +86,7 @@ NetMutator& NetMutator::newOp(
     const std::string& type,
     const std::vector<string>& inputs,
     const std::vector<string>& outputs) {
-  createOperator(type, inputs, outputs, net_);
+  lastCreatedOp_ = createOperator(type, inputs, outputs, net_);
   return *this;
 }
 
index 6754684..53094d1 100644 (file)
@@ -3,6 +3,7 @@
 
 #include "caffe2/core/tensor.h"
 #include "caffe2/core/workspace.h"
+#include "caffe2/utils/proto_utils.h"
 
 #include <cmath>
 #include <vector>
@@ -121,7 +122,7 @@ caffe2::Tensor* createTensorAndConstantFill(
   return tensor;
 }
 
-// Coincise util class to mutate a net in a chaining fashion.
+// Concise util class to mutate a net in a chaining fashion.
 class NetMutator {
  public:
   explicit NetMutator(caffe2::NetDef* net) : net_(net) {}
@@ -131,11 +132,20 @@ class NetMutator {
       const std::vector<string>& inputs,
       const std::vector<string>& outputs);
 
+  // Add argument to the last created op.
+  template <typename T>
+  NetMutator& addArgument(const string& name, const T& value) {
+    CAFFE_ENFORCE(lastCreatedOp_ != nullptr);
+    AddArgument(name, value, lastCreatedOp_);
+    return *this;
+  }
+
  private:
   caffe2::NetDef* net_;
+  caffe2::OperatorDef* lastCreatedOp_;
 };
 
-// Coincise util class to mutate a workspace in a chaining fashion.
+// Concise util class to mutate a workspace in a chaining fashion.
 class WorkspaceMutator {
  public:
   explicit WorkspaceMutator(caffe2::Workspace* workspace)