const std::vector<string>& outputs,
caffe2::NetDef* net);
+// Fill data from a vector to a tensor.
+template <typename T>
+void fillTensor(
+ const vector<int64_t>& shape,
+ const vector<T>& data,
+ TensorCPU* tensor) {
+ tensor->Resize(shape);
+ CAFFE_ENFORCE_EQ(data.size(), tensor->numel());
+ auto ptr = tensor->mutable_data<T>();
+ for (int i = 0; i < tensor->numel(); ++i) {
+ ptr[i] = data[i];
+ }
+}
+
+// Create a tensor and fill data.
+template <typename T>
+caffe2::Tensor* createTensorAndFill(
+ const string& name,
+ const vector<int64_t>& shape,
+ const vector<T>& data,
+ Workspace* workspace) {
+ auto* tensor = createTensor(name, workspace);
+ fillTensor<T>(shape, data, tensor);
+ return tensor;
+}
+
+// Fill a constant to a tensor.
+template <typename T>
+void constantFillTensor(
+ const vector<int64_t>& shape,
+ const T& data,
+ TensorCPU* tensor) {
+ tensor->Resize(shape);
+ auto ptr = tensor->mutable_data<T>();
+ for (int i = 0; i < tensor->numel(); ++i) {
+ ptr[i] = data;
+ }
+}
+
+// Create a tensor and fill a constant.
+template <typename T>
+caffe2::Tensor* createTensorAndConstantFill(
+ const string& name,
+ const vector<int64_t>& shape,
+ const T& data,
+ Workspace* workspace) {
+ auto* tensor = createTensor(name, workspace);
+ constantFillTensor<T>(shape, data, tensor);
+ return tensor;
+}
+
// Coincise util class to mutate a net in a chaining fashion.
class NetMutator {
public:
caffe2::NetDef* net_;
};
+// Coincise util class to mutate a workspace in a chaining fashion.
+class WorkspaceMutator {
+ public:
+ explicit WorkspaceMutator(caffe2::Workspace* workspace)
+ : workspace_(workspace) {}
+
+ // New tensor filled by a data vector.
+ template <typename T>
+ WorkspaceMutator& newTensor(
+ const string& name,
+ const vector<int64_t>& shape,
+ const vector<T>& data) {
+ createTensorAndFill<T>(name, shape, data, workspace_);
+ return *this;
+ }
+
+ // New tensor filled by a constant.
+ template <typename T>
+ WorkspaceMutator& newTensorConst(
+ const string& name,
+ const vector<int64_t>& shape,
+ const T& data) {
+ createTensorAndConstantFill<T>(name, shape, data, workspace_);
+ return *this;
+ }
+
+ private:
+ caffe2::Workspace* workspace_;
+};
+
} // namespace testing
} // namespace caffe2