First we define a compiler from a single Relay expression to the
graph langauge. We require the expression to be a function.
-The function's parameters correpond to the placeholder/inputs
+The function's parameters correspond to the placeholder/inputs
and model parameters found in the computation graph representation.
The body of the function represents the computation graph.
To connect to the graph runtime, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
-contrib.graph_runtime or any other TVM runtime comptatible system.
+contrib.graph_runtime or any other TVM runtime compatible systems.
"""
-
from __future__ import absolute_import
-import json
-from collections import defaultdict, OrderedDict
-import attr
-from . import _backend
-from . import compile_engine
-from ..op import Op
-from ..expr import Function, GlobalVar
-from ..expr_functor import ExprFunctor
-from ..ty import TupleType, TensorType
-from ... import target as _target
-
-
-@attr.s
-class NodeRef(object):
- """A reference to a node, used for constructing the graph."""
- ident = attr.ib()
- index = attr.ib(default=0)
- version = attr.ib(default=0)
-
- def to_json(self):
- return [self.ident, self.index, self.version]
-
-
-@attr.s
-class Node(object):
- """The base class for nodes in the TVM runtime system graph input."""
- name = attr.ib()
- attrs = attr.ib()
-
- def to_json(self):
- raise Exception("Abstract method, please implement me.")
-
-
-@attr.s
-class InputNode(Node):
- """An input node in the TVM runtime system graph input."""
- name = attr.ib()
- attrs = attr.ib()
-
- def to_json(self):
- return {
- "op": "null",
- "name": self.name,
- "inputs": []
- }
+from tvm.ndarray import empty
+from tvm._ffi.function import _init_api
-@attr.s
-class OpNode(Node):
- """An operator node in the TVM runtime system"s graph input."""
- op_name = attr.ib()
- inputs = attr.ib()
- op_attrs = attr.ib()
- num_outputs = attr.ib(default=1)
+from tvm.relay import build_module
+from tvm import target as _target
- def to_json(self):
- attrs = dict.copy(self.op_attrs)
- # Extend ops with extra info.
- attrs["func_name"] = self.op_name
- attrs["flatten_data"] = "0"
- attrs["num_inputs"] = str(len(self.inputs))
- attrs["num_outputs"] = str(self.num_outputs)
+_init_api("tvm.relay.build_module")
- return {
- "op": "tvm_op",
- "name": self.name,
- "attrs": attrs,
- "inputs": self.inputs
- }
-
-
-def shape_to_json(shape):
- """Convert symbolic shape to json compatible forma."""
- return [sh.value for sh in shape]
-
-
-class GraphRuntimeCodegen(ExprFunctor):
+class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system."""
- nodes = attr.ib()
- var_map = attr.ib()
def __init__(self, mod, target):
- ExprFunctor.__init__(self)
- self.mod = mod
- self.target = target
- self.nodes = []
- self.var_map = {}
- self.params = {}
- self.storage_device_map = None
- self.compile_engine = compile_engine.get()
- self.lowered_funcs = defaultdict(set)
- self._name_map = {}
-
- def add_node(self, node, expr):
- """
- Add a node to the graph.
-
- Parameters
- ----------
- node: Node
- The node to add to the graph.
-
- expr: tvm.relay.Expr
- The corresponding expression.
-
- Returns
- -------
- node_ref: Union[NodeRef, List[NodeRef]]
- A reference to the node.
- """
- checked_type = expr.checked_type
- # setup storage ids
- assert expr in self.storage_device_map
- storage_device_info = self.storage_device_map[expr]
- assert len(storage_device_info) == 2
- node.attrs["storage_id"] = [x.value for x in storage_device_info[0]]
- device_types = [x.value for x in storage_device_info[1]]
- num_unknown_devices = device_types.count(0)
- if num_unknown_devices != 0 and num_unknown_devices != len(device_types):
- raise RuntimeError("The graph contains not annotated nodes for "
- "heterogeneous execution. All nodes must be "
- "annotated.")
-
- # Add the `device_index` attribute when the graph is annotated.
- if num_unknown_devices == 0:
- node.attrs["device_index"] = device_types
-
- node_id = len(self.nodes)
- self.nodes.append(node)
- # Tuple return value, flatten as tuple
- if isinstance(checked_type, TupleType):
- ret = []
- shape = []
- dtype = []
- for i, typ in enumerate(checked_type.fields):
- if not isinstance(typ, TensorType):
- raise RuntimeError("type %s not supported" % typ)
- ret.append(NodeRef(node_id, i))
- shape.append(shape_to_json(typ.shape))
- dtype.append(typ.dtype)
- node.attrs["shape"] = shape
- node.attrs["dtype"] = dtype
- assert isinstance(node, OpNode)
- node.num_outputs = len(checked_type.fields)
- return tuple(ret)
- # Normal tensor return type
- if not isinstance(checked_type, TensorType):
- raise RuntimeError("type %s not supported" % checked_type)
- node.attrs["shape"] = [shape_to_json(checked_type.shape)]
- node.attrs["dtype"] = [checked_type.dtype]
- node.num_outputs = 1
- return NodeRef(node_id, 0)
-
- def visit_tuple(self, vtuple):
- fields = []
- for field in vtuple.fields:
- ref = self.visit(field)
- assert isinstance(ref, NodeRef)
- fields.append(ref)
- return tuple(fields)
-
- def visit_tuple_getitem(self, op):
- vtuple = self.visit(op.tuple_value)
- assert isinstance(vtuple, tuple)
- return vtuple[op.index]
-
- def visit_constant(self, op):
- index = len(self.params)
- name = "p%d" % index
- self.params[name] = op.data
- node = InputNode(name, {})
- return self.add_node(node, op)
-
- def visit_function(self, _):
- raise RuntimeError("function not supported")
-
- def visit_if(self, _):
- raise RuntimeError("if not supported")
-
- def visit_global_var(self, _):
- raise RuntimeError()
-
- def visit_let(self, let):
- """
- Visit the let binding, by first traversing its value,
- then setting the metadata on the returned NodeRef.
-
- Finally visit the body, and return the NodeRef corresponding
- to it.
-
- Parameters
- ----------
- let: tvm.relay.Expr
- The let binding to transform.
-
- Returns
- -------
- ref: NodeRef
- The node reference to the body.
- """
- assert let.var not in self.var_map
- self.var_map[let.var] = self.visit(let.value)
- return self.visit(let.body)
-
- def visit_var(self, rvar):
- return self.var_map[rvar]
-
- def visit_call(self, call):
- """Transform a ::tvm.relay.Call into an operator in the TVM graph."""
- if isinstance(call.op, Op):
- raise Exception(
- "Operators should be transformed away; try applying" +
- "the fuse_ops transformation to the expression.")
- elif isinstance(call.op, GlobalVar):
- func = self.mod[call.op]
- elif isinstance(call.op, Function):
- func = call.op
- else:
- raise Exception(
- "TVM runtime does not support calls to {0}".format(type(call.op)))
- if int(func.attrs.Primitive) != 1:
- raise Exception(
- "TVM only support calls to primitive functions " +
- "(i.e functions composed of fusable operator invocations)")
-
- assert call in self.storage_device_map
- device_types = self.storage_device_map[call][1]
- call_dev_type = device_types[0].value
- if isinstance(self.target, (str, _target.Target)):
- # homogeneous execution.
- cached_func = self.compile_engine.lower(func, self.target)
- self.target = {0: str(self.target)}
- elif isinstance(self.target, dict):
- # heterogeneous execution.
- if call_dev_type not in self.target:
- raise Exception("No target is provided for device " +
- "{0}".format(call_dev_type))
- cached_func = self.compile_engine.lower(func,
- self.target[call_dev_type])
- else:
- raise ValueError("self.target must be the type of str," +
- "tvm.target.Target, or dict of int to str")
- for loweredf in cached_func.funcs:
- self.lowered_funcs[self.target[call_dev_type]].add(loweredf)
-
- inputs = []
- # flatten tuple in the call.
- for arg in call.args:
- res = self.visit(arg)
- if isinstance(arg.checked_type, TupleType):
- assert isinstance(res, tuple)
- inputs += res
- else:
- inputs.append(res)
-
- inputs = [x.to_json() for x in inputs]
- op_name = cached_func.func_name
- op_node = OpNode(self._get_unique_name(op_name), {},
- op_name, inputs, {})
- return self.add_node(op_node, call)
-
- def visit_op(self, _):
- raise Exception("can not compile op in non-eta expanded form")
-
- def visit_ref_create(self, _):
- raise RuntimeError("reference not supported")
-
- def visit_ref_read(self, _):
- raise RuntimeError("reference not supported")
-
- def visit_ref_write(self, _):
- raise RuntimeError("reference not supported")
-
- def visit_constructor(self, _):
- raise Exception("ADT constructor case not yet implemented")
-
- def visit_match(self, _):
- raise Exception("match case not yet implemented")
-
- def _get_json(self):
- """
- Convert the sequence of nodes stored by the compiler into the
- TVM graph runtime format.
-
- Returns
- -------
- graph_json : str
- The generated JSON as a string.
- """
- nodes = []
- # First we compute "nodes" field.
- for node in self.nodes:
- nodes.append(node.to_json())
-
- arg_nodes = []
- # Compute "arg_nodes" and "heads" fields.
- for i, node in enumerate(self.nodes):
- if isinstance(node, InputNode):
- arg_nodes.append(i)
-
- heads = self.heads
- heads = heads if isinstance(heads, tuple) else [heads]
- heads = [x.to_json() for x in heads]
-
- # Compute "node_row_ptr" and entry attributes.
- num_entry = 0
- shapes = []
- storage_ids = []
- device_types = []
- dltypes = []
- node_row_ptr = [0]
- for node in self.nodes:
- assert node.num_outputs == len(node.attrs["shape"])
- shapes += node.attrs["shape"]
- dltypes += node.attrs["dtype"]
- storage_ids += node.attrs["storage_id"]
- if "device_index" in node.attrs:
- device_types += node.attrs["device_index"]
- num_entry += node.num_outputs
- node_row_ptr.append(num_entry)
-
- # Compute "attrs" field.
- attrs = {}
- attrs["shape"] = ["list_shape", shapes]
- attrs["storage_id"] = ["list_int", storage_ids]
- if device_types:
- attrs["device_index"] = ["list_int", device_types]
- attrs["dltype"] = ["list_str", dltypes]
-
- # Metadata definitions
- def nested_defaultdict():
- return defaultdict(nested_defaultdict)
- metadata = nested_defaultdict()
- for node_id in arg_nodes:
- node_name = nodes[node_id]['name']
- if node_name not in self.params:
- metadata['signatures']['default']['inputs'][node_name]['id'] = node_id
- metadata['signatures']['default']['inputs'][node_name]['dtype'] = dltypes[node_id]
- metadata['signatures']['default']['inputs'][node_name]['shape'] = shapes[node_id]
- for node_id in heads:
- node_name = nodes[node_id[0]]['name']
- metadata['signatures']['default']['outputs'][node_name]['id'] = node_id[0]
- metadata['signatures']['default']['outputs'][node_name]['dtype'] = dltypes[node_id[0]]
- metadata['signatures']['default']['outputs'][node_name]['shape'] = shapes[node_id[0]]
-
- # Keep 'metadata' always at end
- json_dict = OrderedDict([
- ("nodes", nodes),
- ("arg_nodes", arg_nodes),
- ("heads", heads),
- ("attrs", attrs),
- ("node_row_ptr", node_row_ptr),
- ("metadata", metadata),
- ])
-
- return json.dumps(json_dict, indent=2)
-
- def debug_dump_memory_plan(self, func):
- """Debug function to dump memory plan."""
- def _annotate(expr):
- if expr in self.storage_device_map:
- storage_device_info = self.storage_device_map[expr]
- assert len(storage_device_info) == 2
- return str(storage_device_info[0])
- return ""
- return func.astext(show_meta_data=False, annotate=_annotate)
-
- def debug_dump_device_annotation(self, func):
- """Debug function to dump device annotation result."""
- def _annotate(expr):
- if expr in self.storage_device_map:
- storage_device_info = self.storage_device_map[expr]
- assert len(storage_device_info) == 2
- return str(storage_device_info[1])
- return ""
- return func.astext(show_meta_data=False, annotate=_annotate)
-
+ self._mod = build_module._GraphRuntimeCodegen()
+ self._init = self._mod["init"]
+ self._codegen = self._mod["codegen"]
+ self._get_graph_json = self._mod["get_graph_json"]
+ self._list_params_name = self._mod["list_params_name"]
+ self._get_param_by_name = self._mod["get_param_by_name"]
+ self._get_lowered_funcs = self._mod["get_lowered_funcs"]
+ self._setup(mod, target)
+
+ def _setup(self, mod, target):
+ tgts = []
+ if isinstance(target, dict):
+ for kv in target.items():
+ tgts.append(kv[0])
+ if isinstance(kv[1], (str, _target.Target)):
+ tgts.append(str(kv[1]))
+ else:
+ raise Exception("Unknown target type")
+ elif isinstance(target, (str, _target.Target)):
+ tgts.append("0")
+ tgts.append(str(target))
+ self._init(mod, tgts)
def codegen(self, func):
"""Compile a single function into a graph.
-------
graph_json : str
The graph json that can be consumed by runtime.
-
lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The lowered functions.
-
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
- self.storage_device_map = _backend.GraphPlanMemory(func)
- # First we convert all the parameters into input nodes.
- for param in func.params:
- node = InputNode(param.name_hint, {})
- self.var_map[param] = self.add_node(node, param)
-
- # Then we compile the body into a graph which can depend
- # on input variables.
- self.heads = self.visit(func.body)
- graph_json = self._get_json()
-
- # Return the lowered functions as a list for homogeneous compilation.
- # Otherwise, for heterogeneous compilation, a dictionary containing
- # the device id to a list of lowered functions is returned. Both forms
- # are acceptable to tvm.build.
- if not isinstance(self.target, dict):
- lowered_funcs = list(list(self.lowered_funcs.values())[0])
- else:
- lowered_funcs = {k: list(v) for k, v in self.lowered_funcs.items()}
- return graph_json, lowered_funcs, self.params
-
- def _get_unique_name(self, name):
- if name not in self._name_map:
- self._name_map[name] = 1
- return name
- index = self._name_map[name]
- self._name_map[name] += 1
- return self._get_unique_name(name + str(index))
+ self._codegen(func)
+ graph_json = self._get_graph_json()
+ lowered_func = self._get_lowered_funcs()
+ param_names = self._list_params_name()
+ params = {}
+ for name in param_names:
+ key = name.value
+ arr = self._get_param_by_name(key)
+ param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx)
+ arr.copyto(param)
+ params[key] = param
+ return graph_json, lowered_func, params
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file relay/backend/graph_codegen.cc
+ * \brief Graph runtime codegen
+ */
+
+#include <dmlc/any.h>
+#include <dmlc/json.h>
+#include <tvm/node/ir_functor.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/device_api.h>
+
+
+#include <list>
+#include <string>
+#include <vector>
+
+#include "utils.h"
+#include "compile_engine.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+class GraphNode;
+class GraphInputNode;
+class GraphOpNode;
+
+using IntegerArray = Array<Integer>;
+using ShapeVector = std::vector<std::vector<int64_t> >;
+using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
+using GraphNodePtr = std::shared_ptr<GraphNode>;
+using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
+using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
+using TargetsMap = std::unordered_map<std::string, Target>;
+
+/*! \brief Lowered outputs */
+struct LoweredOutput {
+ std::string graph_json;
+ Map<std::string, Array<LoweredFunc> > lowered_funcs;
+ std::unordered_map<std::string, tvm::runtime::NDArray> params;
+};
+
+/*! \brief Node types */
+enum GraphNodeType {
+ kGraphNop,
+ kGraphInputNode,
+ kGraphOpNode,
+};
+
+class GraphNodeRef {
+ public:
+ GraphNodeRef() {}
+ GraphNodeRef(int ident, int index, int version = 0)
+ : ident_(ident), index_(index), version_(version) {}
+
+
+ inline void Save(dmlc::JSONWriter* writer) const {
+ writer->BeginArray();
+ writer->WriteArrayItem(ident_);
+ writer->WriteArrayItem(index_);
+ writer->WriteArrayItem(version_);
+ writer->EndArray();
+ }
+
+ inline void Load(dmlc::JSONReader* reader) {
+ LOG(FATAL) << "Not implemented.";
+ }
+
+ protected:
+ int ident_;
+ int index_{0};
+ int version_{0};
+};
+
+/*! \brief Base Node class */
+class GraphNode {
+ public:
+ GraphNode() {}
+ virtual void Save(dmlc::JSONWriter* writer) const {}
+ virtual void Load(dmlc::JSONReader* reader) {}
+ virtual GraphNodeType Type() const { return kGraphNop; }
+ virtual ~GraphNode() {}
+
+ public:
+ int num_outputs_{1};
+ std::string name_;
+ GraphAttrs attrs_;
+};
+
+/*! \brief Input Node */
+class GraphInputNode : public GraphNode {
+ public:
+ GraphInputNode() {}
+ GraphInputNode(const std::string& name, const GraphAttrs& attrs) {
+ name_ = name;
+ attrs_ = attrs;
+ }
+
+ GraphNodeType Type() const override { return kGraphInputNode; }
+
+ void Save(dmlc::JSONWriter* writer) const override {
+ const std::string op_name{"null"};
+ writer->BeginObject();
+ writer->WriteObjectKeyValue("op", op_name);
+ writer->WriteObjectKeyValue("name", this->name_);
+ writer->WriteObjectKeyValue("inputs", std::list<int>());
+ writer->EndObject();
+ }
+ static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
+ const GraphAttrs& attrs) {
+ auto ptr = std::make_shared<GraphInputNode>(name, attrs);
+ return std::dynamic_pointer_cast<GraphNode>(ptr);
+ }
+};
+
+/*! \brief Op Node */
+class GraphOpNode : public GraphNode {
+ public:
+ GraphOpNode() {}
+ GraphOpNode(const std::string& name,
+ const GraphAttrs& nd_attrs,
+ const std::string& op_name,
+ const std::vector<GraphNodeRef>& inputs,
+ const GraphAttrs& attrs,
+ size_t num_outputs = 1) {
+ name_ = name;
+ attrs_ = nd_attrs;
+ op_name_ = op_name;
+ inputs_ = inputs;
+ op_attrs_ = attrs_;
+ num_outputs_ = num_outputs;
+ op_attrs_["func_name"] = op_name_;
+ op_attrs_["flatten_data"] = std::string("0");
+ op_attrs_["num_inputs"] = std::to_string(inputs_.size());
+ op_attrs_["num_outputs"] = std::to_string(num_outputs_);
+ }
+
+ GraphNodeType Type() const override { return kGraphOpNode; }
+
+ void Save(dmlc::JSONWriter* writer) const override {
+ GraphAttrs attrs = op_attrs_;
+ attrs["func_name"] = this->op_name_;
+ attrs["flatten_data"] = std::string("0");
+ attrs["num_inputs"] = std::to_string(this->inputs_.size());
+ attrs["num_outputs"] = std::to_string(this->num_outputs_);
+ writer->BeginObject();
+ writer->WriteObjectKeyValue("op", op_type_name_);
+ writer->WriteObjectKeyValue("name", name_);
+ writer->WriteObjectKeyValue("attrs", attrs);
+ writer->WriteObjectKeyValue("inputs", this->inputs_);
+ writer->EndObject();
+ }
+ static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
+ const GraphAttrs& nd_attrs,
+ const std::string& op_name,
+ const std::vector<GraphNodeRef>& inputs,
+ const GraphAttrs& attrs,
+ size_t num_outputs = 1) {
+ auto ptr = std::make_shared<GraphOpNode>(name, nd_attrs, op_name, inputs, attrs, num_outputs);
+ return std::dynamic_pointer_cast<GraphNode>(ptr);
+ }
+
+ public:
+ std::string op_name_;
+ std::vector<GraphNodeRef> inputs_;
+ GraphAttrs op_attrs_;
+
+ private:
+ const std::string op_type_name_{"tvm_op"};
+};
+
+/*! \brief Code generator for graph runtime */
+class GraphRuntimeCodegen
+ : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
+ public:
+ GraphRuntimeCodegen(runtime::Module* mod,
+ const std::unordered_map<std::string, std::string>& targets) : mod_(mod) {
+ compile_engine_ = CompileEngine::Global();
+ for (auto &kv : targets) {
+ targets_[kv.first] = Target::create(kv.second);
+ }
+ }
+
+ LoweredOutput Codegen(relay::Function func) {
+ auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
+ storage_device_map_ = (*pf)(func);
+ // First we convert all the parameters into input nodes.
+ for (auto param : func->params) {
+ auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
+ var_map_[param.get()] = AddNode(node_ptr, param);
+ }
+ heads_ = VisitExpr(func->body);
+ std::ostringstream os;
+ dmlc::JSONWriter writer(&os);
+ GetJSON(&writer);
+ LoweredOutput ret;
+ ret.graph_json = os.str();
+ ret.params = params_;
+ for (auto& kv : lowered_funcs_) {
+ if (ret.lowered_funcs.count(kv.first) == 0) {
+ ret.lowered_funcs.Set(kv.first, Array<LoweredFunc>());
+ }
+ auto& vec = ret.lowered_funcs[kv.first];
+ Array<LoweredFunc> tmp;
+ for (auto f : kv.second) {
+ tmp.push_back(f);
+ }
+ for (auto f : vec) {
+ tmp.push_back(f);
+ }
+ ret.lowered_funcs.Set(kv.first, tmp);
+ }
+ return ret;
+ }
+
+ protected:
+ /*!
+ * \brief Extract shape from expr to vector<int64_t>
+ *
+ * \param shape
+ * \return std::vector<int64_t>
+ */
+ std::vector<int64_t> _ShapeToJSON(tvm::Array<HalideIR::Expr> shape) {
+ std::vector<int64_t> ret;
+ for (IndexExpr dim : shape) {
+ const int64_t* pval = as_const_int(dim);
+ ret.push_back(*pval);
+ }
+ return ret;
+ }
+
+ /*!
+ * \brief Add node to graph
+ *
+ * \param node
+ * \param expr
+ * \return std::vector<_NodeRef>
+ */
+ std::vector<GraphNodeRef> AddNode(GraphNodePtr node, Expr expr) {
+ auto checked_type = expr->checked_type();
+ size_t count = storage_device_map_.count(expr);
+ CHECK_GT(count, 0) << "Expr is not existing in storage plan";
+ auto storage_device_info = storage_device_map_[expr];
+ CHECK_EQ(storage_device_info.size(), 2);
+ // storage
+ std::vector<int64_t> storage_info;
+ for (auto& v : storage_device_info[0]) {
+ storage_info.push_back(v->value);
+ }
+ node->attrs_["storage_id"] = std::move(storage_info);
+ // type
+ std::vector<int64_t> device_types;
+ for (auto& v : storage_device_info[1]) {
+ device_types.push_back(v->value);
+ }
+ size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0);
+ if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) {
+ LOG(FATAL) << "The graph contains not annotated nodes for "
+ << "heterogeneous execution. All nodes must be "
+ << "annotated.";
+ }
+ if (num_unknown_devices == 0) {
+ node->attrs_["device_index"] = device_types;
+ }
+ auto node_id = nodes_.size();
+ nodes_.push_back(node);
+ // Tuple return value, flatten as tuple
+ if (const auto* tuple_type = checked_type.as<TupleTypeNode>()) {
+ std::vector<GraphNodeRef> ret;
+ ShapeVector shape;
+ std::vector<std::string> dtype;
+ for (size_t i = 0; i < tuple_type->fields.size(); ++i) {
+ if (const auto* typ = tuple_type->fields[i].as<TensorTypeNode>()) {
+ ret.push_back(GraphNodeRef(node_id, i));
+ shape.emplace_back(_ShapeToJSON(typ->shape));
+ dtype.emplace_back(DType2String(typ->dtype));
+ } else {
+ LOG(FATAL) << "type " << checked_type->type_key() << " not supported";
+ }
+ }
+ CHECK_EQ(node->Type(), kGraphOpNode);
+ auto op_nd = std::dynamic_pointer_cast<GraphOpNode>(node);
+ op_nd->attrs_["shape"] = shape;
+ op_nd->attrs_["dtype"] = dtype;
+ op_nd->num_outputs_ = tuple_type->fields.size();
+ return ret;
+ }
+ // Normal tensor return type
+ if (const auto* tensor_type = checked_type.as<TensorTypeNode>()) {
+ ShapeVector shape;
+ std::vector<std::string> dtype;
+ shape.emplace_back(_ShapeToJSON(tensor_type->shape));
+ dtype.emplace_back(DType2String(tensor_type->dtype));
+ node->attrs_["shape"] = shape;
+ node->attrs_["dtype"] = dtype;
+ } else {
+ LOG(FATAL) << "type " << checked_type->type_key() << " not supported";
+ }
+ return {GraphNodeRef(node_id, 0)};
+ }
+
+ /*! \brief Visitors */
+ std::unordered_map<Expr, std::vector<GraphNodeRef>, NodeHash, NodeEqual> visitor_cache_;
+
+ std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
+ if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
+ std::vector<GraphNodeRef> res;
+ if (expr.as<ConstantNode>()) {
+ res = VisitExpr_(expr.as<ConstantNode>());
+ } else if (expr.as<TupleNode>()) {
+ res = VisitExpr_(expr.as<TupleNode>());
+ } else if (expr.as<VarNode>()) {
+ res = VisitExpr_(expr.as<VarNode>());
+ } else if (expr.as<GlobalVarNode>()) {
+ res = VisitExpr_(expr.as<GlobalVarNode>());
+ } else if (expr.as<FunctionNode>()) {
+ res = VisitExpr_(expr.as<FunctionNode>());
+ } else if (expr.as<CallNode>()) {
+ res = VisitExpr_(expr.as<CallNode>());
+ } else if (expr.as<LetNode>()) {
+ res = VisitExpr_(expr.as<LetNode>());
+ } else if (expr.as<IfNode>()) {
+ res = VisitExpr_(expr.as<IfNode>());
+ } else if (expr.as<OpNode>()) {
+ res = VisitExpr_(expr.as<OpNode>());
+ } else if (expr.as<TupleGetItemNode>()) {
+ res = VisitExpr_(expr.as<TupleGetItemNode>());
+ } else if (expr.as<RefCreateNode>()) {
+ res = VisitExpr_(expr.as<RefCreateNode>());
+ } else if (expr.as<RefReadNode>()) {
+ res = VisitExpr_(expr.as<RefReadNode>());
+ } else if (expr.as<RefWriteNode>()) {
+ res = VisitExpr_(expr.as<RefWriteNode>());
+ } else if (expr.as<ConstructorNode>()) {
+ res = VisitExpr_(expr.as<ConstructorNode>());
+ } else if (expr.as<MatchNode>()) {
+ res = VisitExpr_(expr.as<MatchNode>());
+ }
+ visitor_cache_[expr] = res;
+ return res;
+ }
+
+ std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
+ Expr expr = GetRef<Expr>(op);
+ return var_map_[expr.get()];
+ }
+
+ std::vector<GraphNodeRef> VisitExpr_(const ConstantNode* op) override {
+ Expr expr = GetRef<Expr>(op);
+ size_t index = params_.size();
+ std::string name = "p" + std::to_string(index);
+ params_[name] = op->data;
+ auto node = GraphInputNode::make_node_ptr(name, GraphAttrs());
+ return AddNode(node, expr);
+ }
+
+ std::vector<GraphNodeRef> VisitExpr_(const TupleNode* op) override {
+ std::vector<GraphNodeRef> fields;
+ for (auto field : op->fields) {
+ auto ref_vec = VisitExpr(field);
+ for (auto ref : ref_vec) {
+ fields.push_back(ref);
+ }
+ }
+ return fields;
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
+ Expr expr = GetRef<Expr>(op);
+ Function func;
+ if (op->op.as<OpNode>()) {
+ LOG(FATAL) << "Operators should be transformed away; try applying"
+ << "the fuse_ops transformation to the expression.";
+ } else if (op->op.as<GlobalVarNode>()) {
+ LOG(FATAL) << "Not implemented";
+ } else if (op->op.as<FunctionNode>()) {
+ func = GetRef<Function>(op->op.as<FunctionNode>());
+ } else {
+ LOG(FATAL) << "TVM runtime does not support calls to " << op->op->type_key();
+ }
+ if (!func->IsPrimitive()) {
+ LOG(FATAL) << "TVM only support calls to primitive functions "
+ << "(i.e functions composed of fusable operator invocations)";
+ }
+
+ CHECK_GE(storage_device_map_.count(expr), 0);
+ auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
+ auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
+ auto &device_type = storage_device_map_[expr][1];
+ auto call_dev_type = device_type[0]->value; //-> int to string
+ Target target;
+ if (targets_.size() == 1) {
+ // homogeneous execution.
+ for (auto kv : targets_) {
+ target = kv.second;
+ }
+ } else {
+ // heterogeneous execution.
+ const auto call_dev_key = std::to_string(call_dev_type);
+ const auto call_dev_name = runtime::DeviceName(call_dev_type);
+ if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
+ LOG(FATAL) << "No target is provided for device "
+ << call_dev_name;
+ }
+ if (targets_.count(call_dev_key)) {
+ target = targets_[call_dev_key];
+ } else {
+ target = targets_[call_dev_name];
+ }
+ }
+ CCacheKey key = (*pf0)(func, target);
+ CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
+ if (!lowered_funcs_.count(target->target_name)) {
+ lowered_funcs_[target->target_name] = {};
+ }
+ for (auto f : lowerd_func->funcs) {
+ lowered_funcs_[target->target_name].insert(f);
+ }
+
+ std::vector<GraphNodeRef> inputs;
+ for (auto arg : op->args) {
+ auto res = VisitExpr(arg);
+ for (auto nr : res) {
+ inputs.push_back(nr);
+ }
+ }
+ auto& op_name = lowerd_func->func_name;
+ auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name),
+ GraphAttrs(),
+ op_name,
+ inputs,
+ GraphAttrs());
+ return AddNode(node, expr);
+ }
+
+ std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
+ CHECK_EQ(var_map_.count(op->var.get()), 0);
+ var_map_[op->var.get()] = VisitExpr(op->value);
+ return VisitExpr(op->body);
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const TupleGetItemNode* op) override {
+ auto vtuple = VisitExpr(op->tuple);
+ return {vtuple[op->index]};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const OpNode* op) override {
+ throw std::runtime_error("can not compile op in non-eta expanded form");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const GlobalVarNode* op) override {
+ throw std::runtime_error("");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const IfNode* op) override {
+ throw std::invalid_argument("if not supported");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
+ throw std::invalid_argument("function not supported");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
+ throw std::invalid_argument("reference not supported");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const RefReadNode* op) override {
+ throw std::invalid_argument("reference not supported");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const RefWriteNode* op) override {
+ throw std::invalid_argument("reference not supported");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const ConstructorNode* op) override {
+ throw std::invalid_argument("ADT constructor case not yet implemented");
+ return {};
+ }
+ std::vector<GraphNodeRef> VisitExpr_(const MatchNode* op) override {
+ throw std::invalid_argument("match case not yet implemented");
+ return {};
+ }
+ /*!
+ * \brief Generate Graph JSON
+ *
+ * \param writer json writer
+ */
+ void GetJSON(dmlc::JSONWriter* writer) {
+ std::vector<size_t> arg_nodes;
+ for (size_t i = 0; i < nodes_.size(); ++i) {
+ auto node = nodes_[i];
+ if (node->Type() == kGraphInputNode) {
+ arg_nodes.push_back(i);
+ }
+ }
+ size_t num_entry = 0;
+ ShapeVector shapes;
+ std::vector<size_t> storage_ids;
+ std::vector<size_t> device_types;
+ std::vector<std::string> dltypes;
+ std::vector<size_t> node_row_ptr{0};
+ for (auto node : nodes_) {
+ const auto& shape_vec = dmlc::get<ShapeVector>(node->attrs_["shape"]);
+ const auto& storage_id = dmlc::get<std::vector<int64_t>>(node->attrs_["storage_id"]);
+ const auto& dtype_vec = dmlc::get<std::vector<std::string>>(node->attrs_["dtype"]);
+
+ CHECK_EQ(node->num_outputs_, shape_vec.size());
+ num_entry += node->num_outputs_;
+
+ shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end());
+ dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end());
+ storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end());
+ if (node->attrs_.count("device_index")) {
+ const auto& dev_types = dmlc::get<std::vector<int64_t>>(node->attrs_["device_index"]);
+ device_types.insert(device_types.end(), dev_types.begin(), dev_types.end());
+ }
+ node_row_ptr.push_back(num_entry);
+ }
+ writer->BeginObject();
+ writer->WriteObjectKeyValue("nodes", nodes_);
+ writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
+ writer->WriteObjectKeyValue("heads", heads_);
+ std::unordered_map<std::string, std::vector<dmlc::any>> attrs;
+ attrs["shape"].emplace_back(std::string("list_shape"));
+ attrs["shape"].emplace_back(shapes);
+ attrs["storage_id"].emplace_back(std::string("list_int"));
+ attrs["storage_id"].emplace_back(storage_ids);
+ if (device_types.size()) {
+ attrs["device_index"].emplace_back(std::string("list_int"));
+ attrs["device_index"].emplace_back(device_types);
+ }
+ attrs["dltype"].emplace_back(std::string("list_str"));
+ attrs["dltype"].emplace_back(dltypes);
+ writer->WriteObjectKeyValue("attrs", attrs);
+ writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
+ writer->EndObject();
+ }
+
+ /*!
+ * \brief Get unique name for func
+ *
+ * \param name
+ * \return std::string
+ */
+ std::string _GetUniqueName(const std::string& name) {
+ if (!name_map_.count(name)) {
+ name_map_[name] = 1;
+ return name;
+ }
+ auto index = name_map_[name];
+ name_map_[name] += 1;
+ return _GetUniqueName(name + std::to_string(index));
+ }
+
+ protected:
+ /*! \brief nodes */
+ std::vector<GraphNodePtr> nodes_;
+ /*! \brief output of graph */
+ std::vector<GraphNodeRef> heads_;
+ /*! \brief mod */
+ runtime::Module* mod_;
+ /*! \brief variable map */
+ std::unordered_map<const Node*, std::vector<GraphNodeRef>> var_map_;
+ /*! \brief target device */
+ TargetsMap targets_;
+ /*! \brief params */
+ std::unordered_map<std::string, runtime::NDArray> params_;
+ /*! \brief plan memory of device result */
+ Map<Expr, Array<IntegerArray>> storage_device_map_;
+ /*! \brief lowered funcs */
+ std::unordered_map<std::string, std::unordered_set<LoweredFunc, NodeHash, NodeEqual>>
+ lowered_funcs_;
+ /*! \brief name map */
+ std::unordered_map<std::string, size_t> name_map_;
+ /*! \brief compile engine */
+ CompileEngine compile_engine_;
+};
+
+class GraphRuntimeCodegenModule : public runtime::ModuleNode {
+ public:
+ GraphRuntimeCodegenModule() {}
+ virtual PackedFunc GetFunction(const std::string& name,
+ const std::shared_ptr<ModuleNode>& sptr_to_self) {
+ if (name == "init") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
+ << "runtime::Module mod and Map<str, StringImm> targets";
+ void* mod = args[0];
+ auto& sptr = args[1].node_sptr();
+ auto* node = static_cast<const ArrayNode*>(sptr.get());
+ auto& tmp_targets = node->data;
+ std::unordered_map<std::string, std::string> targets;
+ for (size_t i = 0; i < tmp_targets.size(); i += 2) {
+ std::string key;
+ auto sk = Expr(tmp_targets[i]).as<ir::StringImm>();
+ auto ik = Expr(tmp_targets[i]).as<ir::IntImm>();
+ if (sk) {
+ key = sk->value;
+ }
+ if (ik) {
+ key = std::to_string(ik->value);
+ }
+ auto v = Expr(tmp_targets[i + 1]).as<ir::StringImm>();
+ targets[key] = v->value;
+ }
+ codegen_ = std::make_shared<GraphRuntimeCodegen>(
+ reinterpret_cast<runtime::Module*>(mod), targets);
+ });
+ } else if (name == "codegen") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ Function func = args[0];
+ this->output_ = this->codegen_->Codegen(func);
+ });
+ } else if (name == "get_graph_json") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ *rv = this->output_.graph_json;
+ });
+ } else if (name == "list_params_name") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ Array<HalideIR::Expr> ret;
+ for (const auto &kv : this->output_.params) {
+ HalideIR::Expr name = ir::StringImm::make(kv.first);
+ ret.push_back(name);
+ }
+ *rv = ret;
+ });
+
+ } else if (name == "get_param_by_name") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ std::string key = args[0];
+ CHECK_GT(this->output_.params.count(key), 0);
+ *rv = this->output_.params[key];
+ });
+ } else if (name == "get_lowered_funcs") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ *rv = this->output_.lowered_funcs;
+ });
+ } else {
+ return PackedFunc([](TVMArgs args, TVMRetValue* rv) {});
+ }
+ }
+
+ const char* type_key() const final {
+ return "RelayGraphRuntimeCodegenModule";
+ }
+
+ private:
+ std::shared_ptr<GraphRuntimeCodegen> codegen_;
+ LoweredOutput output_;
+};
+
+runtime::Module CreateGraphCodegenMod() {
+ std::shared_ptr<GraphRuntimeCodegenModule> ptr =
+ std::make_shared<GraphRuntimeCodegenModule>();
+ return runtime::Module(ptr);
+}
+
+TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = CreateGraphCodegenMod();
+});
+
+} // namespace backend
+} // namespace relay
+} // namespace tvm
+
+namespace dmlc {
+namespace json {
+// JSON utils
+template <typename T>
+inline bool SameType(const dmlc::any& data) {
+ return std::type_index(data.type()) == std::type_index(typeid(T));
+}
+
+template <>
+struct Handler<std::shared_ptr<tvm::relay::backend::GraphNode>> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const std::shared_ptr<tvm::relay::backend::GraphNode>& data) {
+ data->Save(writer);
+ }
+ inline static void Read(dmlc::JSONReader* reader,
+ std::shared_ptr<tvm::relay::backend::GraphNode>* data) {
+ LOG(FATAL) << "Not implemented.";
+ }
+};
+
+template <>
+struct Handler<std::unordered_map<std::string, dmlc::any>> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const std::unordered_map<std::string, dmlc::any>& data) {
+ writer->BeginObject();
+ for (const auto& kv : data) {
+ auto k = kv.first;
+ const dmlc::any& v = kv.second;
+ if (SameType<std::string>(v)) {
+ writer->WriteObjectKeyValue(k, dmlc::get<std::string>(v));
+ } else if (SameType<int>(v)) {
+ writer->WriteObjectKeyValue(k, dmlc::get<int>(v));
+ } else if (SameType<std::vector<size_t>>(v)) {
+ writer->WriteObjectKeyValue(k, dmlc::get<std::vector<size_t>>(v));
+ } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
+ writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::vector<int64_t>>>(v));
+ } else if (SameType<std::vector<std::string>>(v)) {
+ writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::string>>(v));
+ } else {
+ LOG(FATAL) << "Not supported";
+ }
+ }
+ writer->EndObject();
+ }
+ inline static void Read(dmlc::JSONReader* reader,
+ std::unordered_map<std::string, dmlc::any>* data) {
+ LOG(FATAL) << "Not implemented.";
+ }
+};
+
+template <>
+struct Handler<std::vector<dmlc::any>> {
+ inline static void Write(dmlc::JSONWriter* writer, const std::vector<dmlc::any>& data) {
+ writer->BeginArray();
+ for (const auto& v : data) {
+ if (SameType<std::string>(v)) {
+ writer->WriteArrayItem(dmlc::get<std::string>(v));
+ } else if (SameType<int>(v)) {
+ writer->WriteArrayItem(dmlc::get<int>(v));
+ } else if (SameType<std::vector<size_t>>(v)) {
+ writer->WriteArrayItem(dmlc::get<std::vector<size_t>>(v));
+ } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
+ writer->WriteArrayItem(dmlc::get<std::vector<std::vector<int64_t>>>(v));
+ } else if (SameType<std::vector<std::string>>(v)) {
+ writer->WriteArrayItem(dmlc::get<std::vector<std::string>>(v));
+ } else {
+ LOG(FATAL) << "Not supported";
+ }
+ }
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader, std::vector<dmlc::any>* data) {
+ LOG(FATAL) << "Not implemented.";
+ }
+};
+} // namespace json
+} // namespace dmlc