From 66b26806fc6e9a0a2db24d31cfff4f88c1d51ae6 Mon Sep 17 00:00:00 2001 From: Duc Ngo Date: Thu, 13 Dec 2018 20:43:00 -0800 Subject: [PATCH] caffe2 - easy - utils to set argument of operator (#15022) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15022 Add setArgument testing utils to make it easy to set argument for an operator Reviewed By: yinghai Differential Revision: D13405225 fbshipit-source-id: b5c1859c6819d53c1a44718e2868e3137067df36 --- caffe2/core/test_utils.cc | 2 +- caffe2/core/test_utils.h | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/caffe2/core/test_utils.cc b/caffe2/core/test_utils.cc index 070ec34..3da7397 100644 --- a/caffe2/core/test_utils.cc +++ b/caffe2/core/test_utils.cc @@ -86,7 +86,7 @@ NetMutator& NetMutator::newOp( const std::string& type, const std::vector& inputs, const std::vector& outputs) { - createOperator(type, inputs, outputs, net_); + lastCreatedOp_ = createOperator(type, inputs, outputs, net_); return *this; } diff --git a/caffe2/core/test_utils.h b/caffe2/core/test_utils.h index 6754684..53094d1 100644 --- a/caffe2/core/test_utils.h +++ b/caffe2/core/test_utils.h @@ -3,6 +3,7 @@ #include "caffe2/core/tensor.h" #include "caffe2/core/workspace.h" +#include "caffe2/utils/proto_utils.h" #include #include @@ -121,7 +122,7 @@ caffe2::Tensor* createTensorAndConstantFill( return tensor; } -// Coincise util class to mutate a net in a chaining fashion. +// Concise util class to mutate a net in a chaining fashion. class NetMutator { public: explicit NetMutator(caffe2::NetDef* net) : net_(net) {} @@ -131,11 +132,20 @@ class NetMutator { const std::vector& inputs, const std::vector& outputs); + // Add argument to the last created op. + template + NetMutator& addArgument(const string& name, const T& value) { + CAFFE_ENFORCE(lastCreatedOp_ != nullptr); + AddArgument(name, value, lastCreatedOp_); + return *this; + } + private: caffe2::NetDef* net_; + caffe2::OperatorDef* lastCreatedOp_; }; -// Coincise util class to mutate a workspace in a chaining fashion. +// Concise util class to mutate a workspace in a chaining fashion. class WorkspaceMutator { public: explicit WorkspaceMutator(caffe2::Workspace* workspace) -- 2.7.4