caffe2 - Expose tensor filler util to Python (#18886)
authorDuc Ngo <duc@fb.com>
Mon, 8 Apr 2019 18:48:42 +0000 (11:48 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 8 Apr 2019 18:54:10 +0000 (11:54 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18886

Expose tensor filler util to Python and add a unit test (both C++/Python)

Reviewed By: salexspb

Differential Revision: D14784470

fbshipit-source-id: bb8e013d1755c27c166e87d5a8491a97c65d3d8d

caffe2/CMakeLists.txt
caffe2/core/test_utils.h
caffe2/predictor/emulator/CMakeLists.txt [new file with mode: 0644]
caffe2/predictor/emulator/data_filler.cc
caffe2/predictor/emulator/data_filler.h
caffe2/predictor/emulator/data_filler_test.cc [new file with mode: 0644]
caffe2/python/filler_test.py [new file with mode: 0644]
caffe2/python/pybind_state.cc
caffe2/python/workspace.py

index 31b842b..4e3f036 100644 (file)
@@ -69,6 +69,7 @@ if(NOT BUILD_ATEN_ONLY)
   add_subdirectory(core)
   add_subdirectory(utils)
   add_subdirectory(predictor)
+  add_subdirectory(predictor/emulator)
   add_subdirectory(core/nomnigraph)
   add_subdirectory(serialize)
   if (USE_NVRTC)
index fcf069d..7e286e1 100644 (file)
@@ -56,7 +56,7 @@ void assertTensorListEquals(
     const Workspace& workspace2);
 
 // Read a tensor from the workspace.
-const caffe2::Tensor& getTensor(
+CAFFE2_API const caffe2::Tensor& getTensor(
     const caffe2::Workspace& workspace,
     const std::string& name);
 
diff --git a/caffe2/predictor/emulator/CMakeLists.txt b/caffe2/predictor/emulator/CMakeLists.txt
new file mode 100644 (file)
index 0000000..6906990
--- /dev/null
@@ -0,0 +1,13 @@
+set(Caffe2_EMULATOR_CPU_SRC
+    "${CMAKE_CURRENT_SOURCE_DIR}/data_filler.h"
+    "${CMAKE_CURRENT_SOURCE_DIR}/data_filler.cc"
+)
+set(Caffe2_EMULATOR_CPU_TEST_SRC
+  "${CMAKE_CURRENT_SOURCE_DIR}/data_filler_test.cc")
+
+# Common files that are always going to be included.
+list(APPEND Caffe2_CPU_SRCS ${Caffe2_EMULATOR_CPU_SRC})
+list(APPEND Caffe2_CPU_TEST_SRCS ${Caffe2_EMULATOR_CPU_TEST_SRC})
+
+set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
+set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
index e4e64a3..1979025 100644 (file)
@@ -245,5 +245,14 @@ void TestDataRandomFiller::fillInputToWorkspace(Workspace* workspace) const {
   }
 }
 
+void fillRandomNetworkInputs(
+    const NetDef& net,
+    const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
+    const std::vector<std::vector<std::string>>& inputTypes,
+    Workspace* workspace) {
+  TestDataRandomFiller(net, inputDims, inputTypes)
+      .fillInputToWorkspace(workspace);
+}
+
 } // namespace emulator
 } // namespace caffe2
index a540e4a..78692ac 100644 (file)
@@ -138,5 +138,12 @@ class TestDataRandomFiller : public DataRandomFiller {
   void fillInputToWorkspace(Workspace* workspace) const;
 };
 
+// Convenient helpers to fill data to workspace.
+CAFFE2_API void fillRandomNetworkInputs(
+    const NetDef& net,
+    const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
+    const std::vector<std::vector<std::string>>& inputTypes,
+    Workspace* workspace);
+
 } // namespace emulator
 } // namespace caffe2
diff --git a/caffe2/predictor/emulator/data_filler_test.cc b/caffe2/predictor/emulator/data_filler_test.cc
new file mode 100644 (file)
index 0000000..b29bcec
--- /dev/null
@@ -0,0 +1,25 @@
+#include "caffe2/core/common.h"
+#include "caffe2/core/test_utils.h"
+#include "caffe2/predictor/emulator/data_filler.h"
+
+#include <gtest/gtest.h>
+
+TEST(DataFiller, FillNetInputTest) {
+  using namespace caffe2::testing;
+  using namespace caffe2::emulator;
+  caffe2::NetDef net;
+  NetMutator(&net)
+      .newOp("Concat", {"X0", "X1", "X2"}, {"concat_out", "split_info"})
+      .addArgument("axis", 1);
+
+  std::vector<int64_t> input_dim = {30, 20};
+  std::vector<std::vector<std::vector<int64_t>>> input_dims = {
+      {/* X0 */ input_dim, /* X1 */ input_dim, /* X2 */ input_dim}};
+  std::vector<std::vector<std::string>> input_types = {
+      {"float", "float", "float"}};
+  caffe2::Workspace workspace;
+  EXPECT_FALSE(workspace.HasBlob("X0"));
+  fillRandomNetworkInputs(net, input_dims, input_types, &workspace);
+  EXPECT_TRUE(workspace.HasBlob("X0"));
+  EXPECT_EQ(getTensor(workspace, "X0").sizes(), input_dim);
+}
diff --git a/caffe2/python/filler_test.py b/caffe2/python/filler_test.py
new file mode 100644 (file)
index 0000000..52ea756
--- /dev/null
@@ -0,0 +1,20 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from caffe2.python import core, test_util, workspace
+
+
+class TestFiller(test_util.TestCase):
+    def test_filler(self):
+        net = core.Net("test_filler")
+        net.Concat(["X0", "X1", "X2"], ["concat_out", "split_info"])
+        self.assertFalse(workspace.HasBlob("X0"))
+        input_dim = (30, 20)
+        workspace.FillRandomNetworkInputs(net, [[input_dim, input_dim, input_dim]], [["float", "float", "float"]])
+        self.assertTrue(workspace.HasBlob("X0"))
+        self.assertEqual(workspace.FetchBlob("X0").shape, input_dim)
+
+        with self.assertRaises(RuntimeError):
+            # Filler should throw if number of input dims/types is mismatched.
+            workspace.FillRandomNetworkInputs(net, [[input_dim]], [["float"]])
index 18f9097..f437b4e 100644 (file)
@@ -25,6 +25,7 @@
 #include "caffe2/opt/onnxifi_transformer.h"
 #include "caffe2/opt/optimize_ideep.h"
 #include "caffe2/opt/passes.h"
+#include "caffe2/predictor/emulator/data_filler.h"
 #include "caffe2/predictor/predictor.h"
 #include "caffe2/python/pybind_state_registry.h"
 #include "caffe2/utils/cpuid.h"
@@ -1146,6 +1147,19 @@ void addGlobalMethods(py::module& m) {
     return gWorkspace->HasBlob(name);
   });
   m.def(
+      "fill_random_network_inputs",
+      [](const py::bytes& net_def,
+         const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
+         const std::vector<std::vector<std::string>>& inputTypes) {
+        CAFFE_ENFORCE(gWorkspace);
+        py::gil_scoped_release g;
+        NetDef net;
+        CAFFE_ENFORCE(
+            ParseProtoFromLargeString(net_def.cast<std::string>(), &net));
+        caffe2::emulator::fillRandomNetworkInputs(
+            net, inputDims, inputTypes, gWorkspace);
+      });
+  m.def(
       "create_net",
       [](py::bytes net_def, bool overwrite) {
         CAFFE_ENFORCE(gWorkspace);
index 18fcd9b..c288650 100644 (file)
@@ -83,6 +83,11 @@ GetNumNUMANodes = C.get_num_numa_nodes
 GetBlobNUMANode = C.get_blob_numa_node
 GetBlobSizeBytes = C.get_blob_size_bytes
 
+
+def FillRandomNetworkInputs(net, input_dims, input_types):
+    C.fill_random_network_inputs(net.Proto().SerializeToString(), input_dims, input_types)
+
+
 def _GetFreeFlaskPort():
     """Get a free flask port."""
     # We will prefer to use 5000. If not, we will then pick a random port.