[BYOC] Support Tuple Output in C/DNNL Codegen (#5701)
authorCody Yu <comaniac0422@gmail.com>
Sat, 30 May 2020 02:11:24 +0000 (19:11 -0700)
committerGitHub <noreply@github.com>
Sat, 30 May 2020 02:11:24 +0000 (11:11 +0900)
* Support tuple output runtime

* fix unit test

src/relay/backend/contrib/codegen_c/codegen.cc
src/relay/backend/contrib/codegen_c/codegen_c.h
src/relay/backend/contrib/dnnl/codegen.cc
tests/python/relay/test_pass_partition_graph.py

index b8803d4..2968966 100644 (file)
@@ -56,6 +56,25 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
     return {output};
   }
 
+  std::vector<Output> VisitExpr_(const TupleNode* node) final {
+    std::vector<Output> outs;
+    for (auto field : node->fields) {
+      auto res = VisitExpr(field);
+      CHECK_EQ(res.size(), 1U) << "Do not support tuple nest";
+      outs.push_back(res[0]);
+    }
+    return outs;
+  }
+
+  std::vector<Output> VisitExpr_(const TupleGetItemNode* op) final {
+    auto res = VisitExpr(op->tuple);
+    CHECK_GT(res.size(), static_cast<size_t>(op->index));
+
+    // Only keep the item we want for the child node.
+    // FIXME(@comaniac): The other items should still be requried for the primary outputs.
+    return {res[op->index]};
+  }
+
   std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
     // Note this is for demonstration purpose. ConstantNode doesn't necessarily
     // belong to calls. We need to revisit this when tuples come into play.
index 2ee68ce..3a3c486 100644 (file)
@@ -125,7 +125,7 @@ class CodegenCBase {
    * \endcode
    */
   void GenerateBackendCFunc(const std::string& func_name, const Array<Var>& args,
-                            const Output& out) {
+                            const std::vector<Output>& outs) {
     // Print signature
     code_stream_ << "\n";
     code_stream_ << "extern \"C\" int " << func_name << "_wrapper_(";
@@ -133,9 +133,11 @@ class CodegenCBase {
       code_stream_ << "DLTensor* arg" << i << ",\n";
       code_stream_ << "\t";
     }
-    if (args.size() > 0) {
-      code_stream_ << "DLTensor* arg" << args.size() << ") {\n";
+    for (size_t i = 0; i < outs.size() - 1; i++) {
+      code_stream_ << "DLTensor* out" << i << ",\n";
+      code_stream_ << "\t";
     }
+    code_stream_ << "DLTensor* out" << outs.size() - 1 << ") {\n";
 
     EnterScope();
 
@@ -147,10 +149,12 @@ class CodegenCBase {
       code_stream_ << "static_cast<" << dtype_str << "*>(arg" << i << "->data),\n";
       PrintIndents();
     }
-    if (args.size() > 0) {
-      code_stream_ << "static_cast<" << out.dtype << "*>(arg" << args.size() << "->data)";
+    for (size_t i = 0; i < outs.size() - 1; i++) {
+      code_stream_ << "static_cast<" << outs[i].dtype << "*>(out" << i << "->data),\n";
+      PrintIndents();
     }
-    code_stream_ << ");\n";
+    code_stream_ << "static_cast<" << outs.back().dtype << "*>(out" << outs.size() - 1
+                 << "->data));\n";
     PrintIndents();
     code_stream_ << "return 0;\n";
     ExitScope();
@@ -186,18 +190,19 @@ class CodegenCBase {
    */
   std::string JitImpl(const std::string& ext_func_id, const Array<Var>& args,
                       const std::vector<std::string>& buf_decl,
-                      const std::vector<std::string>& body, const std::vector<Output>& out) {
+                      const std::vector<std::string>& body, const std::vector<Output>& outs) {
     // Create the signature. For example, it could be:
-    // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
+    // extern "C" void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {}
     code_stream_ << "extern \"C\" void " << ext_func_id << "_(";
 
-    CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support.";
-
     for (const auto& arg : args) {
       const auto& dtype_str = GetDtypeString(arg);
       code_stream_ << dtype_str << "* " << arg->name_hint() << ", ";
     }
-    code_stream_ << out[0].dtype << "* out) {\n";
+    for (size_t i = 0; i < outs.size() - 1; ++i) {
+      code_stream_ << outs[i].dtype << "* out" << i << ", ";
+    }
+    code_stream_ << outs.back().dtype << "* out" << outs.size() - 1 << ") {\n";
     this->EnterScope();
 
     // Function body
@@ -212,22 +217,26 @@ class CodegenCBase {
     }
 
     // Copy output
-    if (out[0].need_copy) {
+    for (size_t i = 0; i < outs.size(); ++i) {
+      if (!outs[i].need_copy) {
+        continue;
+      }
       this->PrintIndents();
-      code_stream_ << "std::memcpy(out, " << out[0].name << ", 4 * " << out[0].size << ");\n";
+      code_stream_ << "std::memcpy(out" << i << ", " << outs[i].name << ", 4 * " << outs[i].size
+                   << ");\n";
+    }
 
-      // Free buffers
-      for (size_t i = 0; i < buf_decl.size(); i++) {
-        this->PrintIndents();
-        code_stream_ << "std::free(buf_" << i << ");\n";
-      }
+    // Free buffers
+    for (size_t i = 0; i < buf_decl.size(); i++) {
+      this->PrintIndents();
+      code_stream_ << "std::free(buf_" << i << ");\n";
     }
 
     this->ExitScope();
     code_stream_ << "}\n";
 
     // Create the wrapper to call the ext_func
-    this->GenerateBackendCFunc(ext_func_id, args, out[0]);
+    this->GenerateBackendCFunc(ext_func_id, args, outs);
     return code_stream_.str();
   }
 
index 3db5dc4..3f9ad7c 100644 (file)
@@ -144,6 +144,16 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
     return {output};
   }
 
+  std::vector<Output> VisitExpr_(const TupleNode* node) final {
+    std::vector<Output> outs;
+    for (auto field : node->fields) {
+      auto res = VisitExpr(field);
+      CHECK_EQ(res.size(), 1U) << "Do not support tuple nest";
+      outs.push_back(res[0]);
+    }
+    return outs;
+  }
+
   std::vector<Output> VisitExpr_(const TupleGetItemNode* op) final {
     auto res = VisitExpr(op->tuple);
     CHECK_GT(res.size(), static_cast<size_t>(op->index));
@@ -347,8 +357,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
   // Create a corresponding DNNL function for the given relay Function.
   void GenDNNLFunc(const Function& func) {
     CHECK(func.defined()) << "Input error: expect a Relay function.";
-    const auto* call = func->body.as<CallNode>();
-    CHECK(call) << "DNNL expects a single convolution or dense op";
 
     // Record the external symbol for runtime lookup.
     auto sid = GetExtSymbol(func);
index fd76285..354b616 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Unit tests for graph partitioning."""
+# pylint: disable=not-callable
 import os
 import sys
 
@@ -201,8 +202,11 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         exe = runtime.vm.Executable.load_exec(code, lib)
         vm = runtime.vm.VirtualMachine(exe)
         vm.init(ctx)
-        out = vm.run(**map_inputs)
-        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+        outs = vm.run(**map_inputs)
+        outs = outs if isinstance(outs, runtime.container.ADT) else [outs]
+        results = result if isinstance(result, list) else [result]
+        for out, ref in zip(outs, results):
+            tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=tol, atol=tol)
 
     def check_graph_runtime_result():
         compile_engine.get().clear()
@@ -215,10 +219,14 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
             rt_mod.set_input(name, data)
         rt_mod.set_input(**param)
         rt_mod.run()
-        out = tvm.nd.empty(out_shape, ctx=ctx)
-        out = rt_mod.get_output(0, out)
 
-        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+        out_shapes = out_shape if isinstance(out_shape, list) else [out_shape]
+        results = result if isinstance(result, list) else [result]
+
+        for idx, shape in enumerate(out_shapes):
+            out = tvm.nd.empty(shape, ctx=ctx)
+            out = rt_mod.get_output(idx, out)
+            tvm.testing.assert_allclose(out.asnumpy(), results[idx], rtol=tol, atol=tol)
 
     check_vm_result()
     check_graph_runtime_result()
@@ -1082,11 +1090,11 @@ def test_duplicate_merge_and_tuplegetitem():
     target = "test_duplicate_merge_and_tuplegetitem"
 
     @reg.register("nn.batch_norm", "target." + target)
-    def abs(attrs, args): # pylint: disable=unused-variable
+    def batch_norm(attrs, args): # pylint: disable=unused-variable
         return True
 
     @reg.register("nn.relu", "target." + target)
-    def abs(attrs, args): # pylint: disable=unused-variable
+    def relu(attrs, args): # pylint: disable=unused-variable
         return True
 
     def create_graph():
@@ -1195,11 +1203,11 @@ def test_flatten_tuple_output():
     target = "test_flatten_tuple_output"
 
     @reg.register("split", "target." + target)
-    def foo(attrs, args): # pylint: disable=unused-variable
+    def split(attrs, args): # pylint: disable=unused-variable
         return True
 
     @reg.register("abs", "target." + target)
-    def foo(attrs, args): # pylint: disable=unused-variable
+    def abs(attrs, args): # pylint: disable=unused-variable
         return True
 
     def create_graph():
@@ -1259,6 +1267,27 @@ def test_flatten_tuple_output():
     partitioned = seq(create_graph())
     assert tvm.ir.structural_equal(partitioned, expected(), map_free_vars=True)
 
+def test_tuple_output_exec():
+    """Test C codegen and runtime for a subgraph with a tuple output"""
+    a = relay.var('a', shape=(10, 10), dtype='float32')
+    b = relay.var('b', shape=(10, 10), dtype='float32')
+    ba = relay.annotation.compiler_begin(a, 'ccompiler')
+    bb = relay.annotation.compiler_begin(b, 'ccompiler')
+    add = relay.add(ba, bb)
+    sub = relay.subtract(ba, bb)
+    out = relay.Tuple((add, sub))
+    eout = relay.annotation.compiler_end(out, 'ccompiler')
+    func=relay.Function([a, b], eout)
+    mod = tvm.IRModule()
+    mod["main"] = func
+    mod = transform.PartitionGraph()(mod)
+
+    a_data = np.random.rand(10, 10).astype('float32')
+    b_data = np.random.rand(10, 10).astype('float32')
+
+    check_result(mod, {'a': a_data, 'b': b_data},
+                 [(10, 10), (10, 10)],
+                 [(a_data + b_data), (a_data - b_data)])
 
 if __name__ == "__main__":
     test_multi_node_compiler()
@@ -1278,3 +1307,4 @@ if __name__ == "__main__":
     test_duplicate_merge_and_tuplegetitem()
     test_constant_tuples()
     test_flatten_tuple_output()
+    test_tuple_output_exec()