[Relay] Add pass for getting calibration data from a relay module (#5997)
authorYi-Hsiang (Sean) Lai <seanlatias@users.noreply.github.com>
Mon, 13 Jul 2020 22:39:10 +0000 (18:39 -0400)
committerGitHub <noreply@github.com>
Mon, 13 Jul 2020 22:39:10 +0000 (15:39 -0700)
* add simple pass to extract outputs

* complete pass that collects all function inputs/outputs

* add analysis pass for collecting outputs

* reorganize the files

* add the first test

* update test with tuples

* clean up Python code

* merge with upstream

* clean up transform.py

* add comments for cpp files

* fix lint issues

* update submodules

* modify files according to the review

* fix style and typo

* fix lint error

* add checks for repeated function calls

* fix lint error

* merge review comments

* small simplification

* revise the code according to the review comments

* add username in TODO

* use IRModule directly

* use better APIs according to the review

* apply comments from the reviewer

* retrigger ci

include/tvm/relay/analysis.h
python/tvm/relay/analysis/analysis.py
src/relay/analysis/get_calibration_data.cc [new file with mode: 0644]
tests/python/relay/test_analysis_get_calibration_data.py [new file with mode: 0644]

index b4b1b9d..8eda7dd 100644 (file)
@@ -236,6 +236,24 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
  */
 TVM_DLL std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body);
 
+/*!
+ * \brief Get the updated module for collecting calibration data.
+ *
+ * \param mod The module to be updated.
+ *
+ * \return The updated module.
+ */
+TVM_DLL IRModule GetCalibrateModule(IRModule mod);
+
+/*!
+ * \brief Get the output map between subgrpahs and its inputs/output.
+ *
+ * \param mod The module for running calibration.
+ *
+ * \return The mapping between a subgraph name and its postition in the output tuple.
+ */
+TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);
+
 }  // namespace relay
 }  // namespace tvm
 
index c237859..632af46 100644 (file)
@@ -21,6 +21,8 @@ This file contains the set of passes for Relay, which exposes an interface for
 configuring the passes and scripting them in Python.
 """
 from tvm.ir import IRModule
+from tvm.relay import transform, build_module
+from tvm.runtime.ndarray import cpu
 
 from . import _ffi_api
 from .feature import Feature
@@ -351,3 +353,49 @@ def search_fc_transpose(expr):
     """
     ret = _ffi_api.search_fc_transpose(expr)
     return ret
+
+
+def get_calibration_data(mod, data):
+    """Get the calibration data of a given relay graph
+
+    This pass uses the graph runtime to get the calibration data of a module, which
+    includes the input and output values of each function. The returned data uses
+    the GlobalVar of each function as a key. Users can further access the inputs and
+    outputs by using `inputs` or  `outputs` as the key.
+
+    Following are some limitations:
+    1. The input module (graph) cannot have control flows.
+    2. The input arguments of each function cannot be tuples (outputs can be tuples).
+    3. We only handle top-level functions (i.e., nested function is not handled).
+    4. We only handle functions with `Compiler` attribute being set.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The input module for collecting the calibration data
+
+    data : Dict[str, NDArray]
+        The input data for running the module
+
+    Returns
+    -------
+    data : Dict[tvm.relay.GlobalVar, Dict[str, NDArray]]
+    """
+    output_map = _ffi_api.get_calibrate_output_map(mod)
+
+    mod = _ffi_api.get_calibrate_module(mod)
+    mod = transform.Inline()(mod)
+
+    ref_ex = build_module.create_executor("graph", mod=mod, ctx=cpu(0))
+    ref_res = ref_ex.evaluate()(**data)
+
+    calib_data = {}
+    for gvar, indices in output_map.items():
+        offset = int(indices[0])
+        in_len = int(indices[1])
+        out_len = int(indices[2])
+        value = {"inputs": ref_res[offset:offset + in_len],
+                 "outputs": ref_res[offset + in_len:offset + in_len + out_len]}
+        calib_data[gvar] = value
+
+    return calib_data
diff --git a/src/relay/analysis/get_calibration_data.cc b/src/relay/analysis/get_calibration_data.cc
new file mode 100644 (file)
index 0000000..34d0d00
--- /dev/null
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/get_calibration_data.cc
+ *
+ * \brief To get the calibration data, we need to perform two
+ * steps. First, we need to prepare the module that generates
+ * the tensor values (GetCalibrateModule). Second, we need to
+ * generate the mapping between the values and the functions
+ * (GetCalibrateOutputMap).
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief This function returns a module that will be used by
+ * the relay graph runtime for collecting the calibration data.
+ * To do that, we first make all inputs and outputs of each
+ * function into the final output (i.e., the final output is a
+ * tuple of tensors). Then, we change the compiler attribute of
+ * each function. Finally, we mark all function to be inlined.
+ */
+
+class Collector : public ExprRewriter {
+ public:
+  explicit Collector(const IRModule& module) : module_(module) {}
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    // check if the function implementation is available
+    // intrinsic functions are excluded for now
+    if (call->op->IsInstance<GlobalVarNode>()) {
+      auto var = Downcast<GlobalVar>(call->op);
+      CHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
+      // we only handle functions with Compiler attribute set
+      auto func = Downcast<Function>(module_->Lookup(var));
+      if (func->GetAttr<String>(attr::kCompiler)) {
+        // collect all the inputs and outputs
+        for (const auto& it : call->args) new_outputs_.push_back(it);
+        new_outputs_.push_back(post);
+      }
+    }
+    return post;
+  }
+
+  Array<Expr> GetNewOutputs() { return new_outputs_; }
+
+ private:
+  const IRModule& module_;
+  Array<Expr> new_outputs_;
+};
+
+Expr FlattenOutputTuple(const Array<Expr>& exprs) {
+  Array<Expr> fields;
+  for (const auto& it : exprs) {
+    CHECK(it->checked_type_.defined());
+    if (auto* tn = it->checked_type_.as<TupleTypeNode>()) {
+      // TODO(seanlatias): for now input argument cannot be a tuple
+      CHECK(it->IsInstance<CallNode>());
+      for (size_t i = 0; i < tn->fields.size(); i++) {
+        fields.push_back(TupleGetItem(it, i));
+      }
+    } else {
+      fields.push_back(it);
+    }
+  }
+  return Tuple(fields);
+}
+
+IRModule GetCalibrateModule(IRModule module) {
+  auto glob_funcs = module->functions;
+  // module is mutable, hence, we make a copy of it.
+  module.CopyOnWrite();
+  for (const auto& pair : glob_funcs) {
+    if (auto* fn = pair.second.as<FunctionNode>()) {
+      auto func = GetRef<Function>(fn);
+      // we only collect the outputs for main function
+      if (pair.first->name_hint == "main") {
+        Collector collector(module);
+        PostOrderRewrite(func->body, &collector);
+        auto new_outputs = collector.GetNewOutputs();
+        Expr tuple = FlattenOutputTuple(new_outputs);
+        func = Function(func->params, tuple, tuple->checked_type_, func->type_params, func->attrs);
+        module->Update(pair.first, func);
+      }
+    }
+  }
+  // reset the attribute of functions for running graph runtime
+  for (const auto& pair : glob_funcs) {
+    if (auto* fn = pair.second.as<FunctionNode>()) {
+      auto func = GetRef<Function>(fn);
+      if (func->GetAttr<String>(attr::kCompiler)) {
+        // we need to inline the functions in order to run grpah runtime
+        func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1));
+        // reset the compiler attribute to null for llvm execution
+        func = WithAttr(std::move(func), attr::kCompiler, NullValue<ObjectRef>());
+        module->Update(pair.first, func);
+      }
+    }
+  }
+  return module;
+}
+
+/*!
+ * \brief This function generates the output mapping between
+ * the calibration data and each function. The key is a
+ * GlobalVar that corresponds to each function and the value
+ * is an array of integers. The size of the array is always
+ * three. The first value is the offset the points to the start.
+ * The second value is the number of inputs. The third value
+ * is the number of outputs.
+ */
+
+class OutputMapper : public ExprRewriter {
+ public:
+  OutputMapper(Map<GlobalVar, Array<Integer>>* output_map, const IRModule& module, size_t* offset)
+      : output_map_(output_map), module_(module), offset_(offset) {}
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    if (call->op->IsInstance<GlobalVarNode>()) {
+      auto var = Downcast<GlobalVar>(call->op);
+      CHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
+      CHECK_EQ(output_map_->count(var), 0)
+          << "Repeated function call " << var << " is not supported.";
+      auto func = Downcast<Function>(module_->Lookup(var));
+      // we only handle functions with Compiler attribute set
+      if (func->GetAttr<String>(attr::kCompiler)) {
+        Array<Integer> info;
+        // the first value is the offset
+        info.push_back(Integer(*offset_));
+        // the second value is the number of inputs
+        info.push_back(Integer(call->args.size()));
+        // the third value is the number of outputs
+        // we need to check if the output is a tuple
+        size_t out_size = 1;
+        if (auto* tn = func->body.as<TupleNode>()) {
+          info.push_back(Integer(tn->fields.size()));
+          out_size = tn->fields.size();
+        } else {
+          info.push_back(Integer(1));
+        }
+        output_map_->Set(var, info);
+        // calculate the offset for the next function
+        *offset_ = *offset_ + call->args.size() + out_size;
+      }
+    }
+    return post;
+  }
+
+ private:
+  Map<GlobalVar, Array<Integer>>* output_map_;
+  const IRModule& module_;
+  size_t* offset_;
+};
+
+Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& module) {
+  Map<GlobalVar, Array<Integer>> output_map;
+  size_t offset = 0;
+  auto glob_funcs = module->functions;
+  for (const auto& pair : glob_funcs) {
+    if (auto* fn = pair.second.as<FunctionNode>()) {
+      if (pair.first->name_hint == "main") {
+        OutputMapper output_mapper(&output_map, module, &offset);
+        auto func = GetRef<Function>(fn);
+        PostOrderRewrite(func->body, &output_mapper);
+      }
+    }
+  }
+
+  return output_map;
+}
+
+TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_module").set_body_typed([](IRModule mod) {
+  return GetCalibrateModule(mod);
+});
+
+TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_output_map")
+    .set_body_typed([](const IRModule& mod) { return GetCalibrateOutputMap(mod); });
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_analysis_get_calibration_data.py b/tests/python/relay/test_analysis_get_calibration_data.py
new file mode 100644 (file)
index 0000000..9a29f2e
--- /dev/null
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+
+import tvm
+import tvm.relay.testing
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.analysis import get_calibration_data
+
+
+def check_data_size(mod, data):
+    assert len(data) == len(mod.functions) - 1
+    for key, value in mod.functions.items():
+        if key.name_hint != "main":
+            assert len(data[key]["inputs"]) == len(value.params)
+            if isinstance(value.body, relay.Tuple):
+                assert len(data[key]["outputs"]) == len(value.body.fields)
+            else:
+                assert len(data[key]["outputs"]) == 1
+
+def test_simple_graph():
+    # A module with two subgraphs
+    mod = tvm.IRModule()
+
+    x0 = relay.var('x0', shape=(8, 8))
+    y0 = relay.var('y0', shape=(8, 8))
+    z0 = x0 + y0
+    z1 = x0 - y0
+    z2 = relay.Tuple((z0, z1))
+    f0 = relay.Function([x0, y0], z2)
+    f0 = f0.with_attr("Compiler", "test_graph")
+    g0 = relay.GlobalVar("g0")
+    mod[g0] = f0
+
+    x1 = relay.var('x1', shape=(8, 8))
+    y1 = relay.var('y1', shape=(8, 8))
+    z1 = x1 - y1
+    f1 = relay.Function([x1, y1], z1)
+    f1 = f1.with_attr("Compiler", "test_graph")
+    g1 = relay.GlobalVar("g1")
+    mod[g1] = f1
+
+
+    x = relay.var('x', shape=(8, 8))
+    y = relay.var('y', shape=(8, 8))
+    z = relay.var('z', shape=(8, 8))
+    c0 = relay.Call(g0, [x, y])
+    c1 = relay.Call(g1, [relay.TupleGetItem(c0, 0), z])
+    fm = relay.Function([x, y, z], c1)
+    mod["main"] = fm
+
+    x_data = np.random.rand(8, 8).astype('float32')
+    y_data = np.random.rand(8, 8).astype('float32')
+    z_data = np.random.rand(8, 8).astype('float32')
+    data = get_calibration_data(mod, {"x": x_data, "y": y_data, "z": z_data})
+
+    # Check the number and orders
+    check_data_size(mod, data)
+    tvm.testing.assert_allclose(data[g0]["inputs"][0].asnumpy(), x_data)
+    tvm.testing.assert_allclose(data[g0]["inputs"][1].asnumpy(), y_data)
+    tvm.testing.assert_allclose(data[g0]["outputs"][0].asnumpy(), x_data + y_data)
+    tvm.testing.assert_allclose(data[g0]["outputs"][1].asnumpy(), x_data - y_data)
+    tvm.testing.assert_allclose(data[g1]["inputs"][0].asnumpy(), x_data + y_data)
+    tvm.testing.assert_allclose(data[g1]["inputs"][1].asnumpy(), z_data)
+    tvm.testing.assert_allclose(data[g1]["outputs"][0].asnumpy(), x_data + y_data - z_data)
+
+def test_mobilenet_dnnl():
+    if not tvm.get_global_func("relay.ext.dnnl", True):
+        print("skip because DNNL codegen is not available")
+        return
+
+    dtype = 'float32'
+    ishape = (1, 3, 224, 224)
+    mod, params = relay.testing.mobilenet.get_workload(
+        batch_size=1, dtype='float32')
+
+    mod = transform.AnnotateTarget(["dnnl"])(mod)
+    mod = transform.MergeCompilerRegions()(mod)
+    mod = transform.PartitionGraph()(mod)
+
+    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+    data = get_calibration_data(mod, {"data": i_data, **params})
+
+    # Check the number and orders
+    check_data_size(mod, data)
+
+if __name__ == "__main__":
+    test_simple_graph()
+    test_mobilenet_dnnl()