# specific language governing permissions and limitations
# under the License.
-if(USE_DNNL_CODEGEN STREQUAL "ON")
- file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
+if((USE_DNNL_CODEGEN STREQUAL "ON") OR (USE_DNNL_CODEGEN STREQUAL "JSON"))
+ add_definitions(-DUSE_JSON_RUNTIME=1)
+ file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})
+ list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC})
find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
- file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
+ file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
- message(STATUS "Build with DNNL codegen: " ${EXTERN_LIBRARY_DNNL})
+ message(STATUS "Build with DNNL JSON runtime: " ${EXTERN_LIBRARY_DNNL})
+elseif(USE_DNNL_CODEGEN STREQUAL "C_SRC")
+ file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc)
+ list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})
+
+ find_library(EXTERN_LIBRARY_DNNL dnnl)
+ list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
+ file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl.cc)
+ list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
+ message(STATUS "Build with DNNL C source module: " ${EXTERN_LIBRARY_DNNL})
endif()
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/contrib/codegen_json.h
+ * \brief Utilities for json codegen and runtime
+ */
+#ifndef TVM_RELAY_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
+#define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
+
+#include <dmlc/any.h>
+#include <dmlc/json.h>
+#include <tvm/node/container.h>
+#include <tvm/node/reflection.h>
+#include <tvm/runtime/container.h>
+#include <tvm/tir/op.h>
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../../../../runtime/contrib/json/json_node.h"
+#include "../../../../runtime/contrib/json/json_runtime.h"
+#include "../../utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace contrib {
+
+using namespace tvm::runtime::json;
+
+using ShapeVector = std::vector<std::vector<int64_t>>;
+using TypeVector = std::vector<std::string>;
+using JSONGraphObjectPtr = std::shared_ptr<JSONGraphNode>;
+
+/*!
+ * \brief Helper class to extract all attributes of a certain op and save them
+ * into text format.
+ */
+class OpAttrExtractor : public AttrVisitor {
+ public:
+ explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {}
+
+ template <typename T = double, typename = std::enable_if_t<std::is_floating_point<T>::value>>
+ std::string Fp2String(const T value) {
+ std::ostringstream out;
+ out.precision(std::numeric_limits<T>::max_digits10);
+ out << value;
+ return out.str();
+ }
+
+ void SetNodeAttr(const char* key, const std::vector<std::string>& value) {
+ std::vector<dmlc::any> attr;
+ attr.emplace_back(value);
+ node_->SetAttr(key, attr);
+ }
+
+ void Visit(const char* key, double* value) final { SetNodeAttr(key, {Fp2String(*value)}); }
+
+ void Visit(const char* key, int64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); }
+
+ void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); }
+
+ void Visit(const char* key, int* value) final { SetNodeAttr(key, {std::to_string(*value)}); }
+
+ void Visit(const char* key, bool* value) final { SetNodeAttr(key, {std::to_string(*value)}); }
+
+ void Visit(const char* key, std::string* value) final { SetNodeAttr(key, {*value}); }
+
+ void Visit(const char* key, DataType* value) final {
+ if (!value->is_void()) {
+ SetNodeAttr(key, {runtime::DLDataType2String(*value)});
+ } else {
+ SetNodeAttr(key, {""});
+ }
+ }
+
+ void Visit(const char* key, runtime::ObjectRef* value) final {
+ if (const auto* an = (*value).as<ArrayNode>()) {
+ std::vector<std::string> attr;
+ for (size_t i = 0; i < an->size(); ++i) {
+ if (const auto* im = (*an)[i].as<IntImmNode>()) {
+ attr.push_back(std::to_string(im->value));
+ } else if (const auto* fm = (*an)[i].as<FloatImmNode>()) {
+ attr.push_back(Fp2String(fm->value));
+ } else if (const auto* str = (*an)[i].as<StringObj>()) {
+ String s = GetRef<String>(str);
+ attr.push_back(s);
+ } else {
+ LOG(FATAL) << "Not supported type: " << (*an)[i]->GetTypeKey();
+ }
+ }
+ SetNodeAttr(key, attr);
+ } else if (!(*value).defined()) { // Skip NullValue
+ SetNodeAttr(key, std::vector<std::string>{""});
+ } else if (const auto* im = (*value).as<IntImmNode>()) {
+ SetNodeAttr(key, std::vector<std::string>{std::to_string(im->value)});
+ } else if (const auto* fm = (*value).as<FloatImmNode>()) {
+ SetNodeAttr(key, std::vector<std::string>{Fp2String(fm->value)});
+ } else if (const auto* str = (*value).as<StringObj>()) {
+ String s = GetRef<String>(str);
+ SetNodeAttr(key, std::vector<std::string>{s});
+ } else {
+ LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ": " << *value;
+ }
+ }
+
+ void Visit(const char* key, runtime::NDArray* value) final {
+ LOG(FATAL) << "NDArray is not allowed in op attribute";
+ }
+
+ void Visit(const char* key, void** value) final {
+ LOG(FATAL) << "void pointer is not allowed in op attribute";
+ }
+
+ void Extract(Object* node) {
+ if (node) {
+ reflection_->VisitAttrs(node, this);
+ }
+ }
+
+ private:
+ JSONGraphObjectPtr node_;
+ ReflectionVTable* reflection_ = ReflectionVTable::Global();
+};
+
+/*! \brief Serialize a Relay expression to JSON. */
+class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEntry>> {
+ public:
+ /*!
+ * \brief Constructor
+ *
+ * \param symbol The symbol that represents the graph being converted.
+ * \param expr The Relay expression to be converted to the JSON form.
+ */
+ JSONSerializer(const std::string& symbol, const Expr& expr) : symbol_(symbol), func_(expr) {}
+
+ void serialize() {
+ relay::Function func = Downcast<relay::Function>(func_);
+ // First we convert all the parameters into input nodes.
+ for (const auto& param : func->params) {
+ auto node_ptr = std::make_shared<JSONGraphNode>(param->name_hint(), "input" /* op_type_ */);
+ memo_[param] = AddNode(node_ptr, param);
+ }
+ heads_ = VisitExpr(func->body);
+ }
+
+ /*!\brief Return the required params. */
+ Array<String> GetParams() const { return params_; }
+
+ /*!\brief Return the generated json. */
+ std::string GetJSON() {
+ std::ostringstream os;
+ dmlc::JSONWriter writer(&os);
+ Save(&writer);
+ return os.str();
+ }
+
+ protected:
+ /*!
+ * \brief Add a node to graph.
+ *
+ * \param node A graph node. It is a shared pointer. Some attributes of it
+ * will be added, i.e. shape and type. These attributes are attached to
+ * the JSON graph in the end.
+ * \param expr The relay expression.
+ * \return A list of graph entry nodes. It the relay expr is a tuple type, we
+ * will flatten it.
+ */
+ std::vector<JSONGraphNodeEntry> AddNode(JSONGraphObjectPtr node, const Expr& expr) {
+ auto checked_type = expr->checked_type();
+ auto node_id = nodes_.size();
+ nodes_.push_back(node);
+ std::vector<JSONGraphNodeEntry> ret;
+ ShapeVector shape;
+ TypeVector dtype;
+ // Flatten tuple node.
+ if (const auto* tuple_type = checked_type.as<TupleTypeNode>()) {
+ for (size_t i = 0; i < tuple_type->fields.size(); ++i) {
+ const auto* tensor_type = tuple_type->fields[i].as<TensorTypeNode>();
+ CHECK(tensor_type) << "Expect TensorType, but received: ."
+ << tuple_type->fields[i]->GetTypeKey();
+ ret.push_back(JSONGraphNodeEntry(node_id, i));
+ shape.emplace_back(GetIntShape(tensor_type->shape));
+ dtype.emplace_back(DType2String(tensor_type->dtype));
+ }
+ node->SetNumOutput(tuple_type->fields.size());
+ } else {
+ const auto* tensor_type = checked_type.as<TensorTypeNode>();
+ CHECK(tensor_type) << "Expect TensorType, but received: " << checked_type->GetTypeKey();
+ shape.emplace_back(GetIntShape(tensor_type->shape));
+ dtype.emplace_back(DType2String(tensor_type->dtype));
+ ret.push_back(JSONGraphNodeEntry(node_id, 0));
+ }
+ std::vector<dmlc::any> shape_attrs;
+ shape_attrs.emplace_back(shape);
+ node->SetAttr("shape", shape_attrs);
+
+ std::vector<dmlc::any> type_attrs;
+ type_attrs.emplace_back(dtype);
+ node->SetAttr("dtype", type_attrs);
+ return ret;
+ }
+
+ void SetCallNodeAttribute(JSONGraphObjectPtr node, const CallNode* cn) {
+ if (cn->op.as<OpNode>()) {
+ OpAttrExtractor extractor(node);
+ const Object* call_attr = cn->attrs.get();
+ extractor.Extract(const_cast<Object*>(call_attr));
+ } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+ auto pattern = fn->GetAttr<String>(attr::kPartitionedFromPattern);
+ CHECK(pattern.defined());
+ std::vector<std::string> values;
+ values.push_back(pattern.value());
+ std::vector<dmlc::any> attr;
+ attr.emplace_back(values);
+ node->SetAttr("PartitionedFromPattern", attr);
+ }
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExprDefault_(const Object* op) {
+ LOG(FATAL) << "JSON runtime currently doesn't support " << op->GetTypeKey();
+ return {};
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const VarNode* vn) {
+ CHECK(memo_.count(GetRef<Expr>(vn)));
+ return memo_[GetRef<Expr>(vn)];
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const ConstantNode* cn) {
+ std::string name = symbol_ + "_const_" + std::to_string(params_.size());
+ params_.push_back(name);
+ auto node = std::make_shared<JSONGraphNode>(name, "const" /* op_type_ */);
+ return AddNode(node, GetRef<Expr>(cn));
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const TupleNode* tn) {
+ std::vector<JSONGraphNodeEntry> fields;
+ for (const auto& field : tn->fields) {
+ auto ref = VisitExpr(field);
+ fields.insert(fields.end(), ref.begin(), ref.end());
+ }
+ return fields;
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) {
+ Expr expr = GetRef<Expr>(cn);
+ std::string name;
+ if (const auto* op_node = cn->op.as<OpNode>()) {
+ name = op_node->name;
+ } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+ auto comp = fn->GetAttr<String>(attr::kComposite);
+ CHECK(comp.defined()) << "JSON runtime only supports composite functions.";
+ name = comp.value();
+ } else {
+ LOG(FATAL) << "JSON runtime does not support calls to " << cn->op->GetTypeKey();
+ }
+
+ std::vector<JSONGraphNodeEntry> inputs;
+ for (const auto& arg : cn->args) {
+ auto res = VisitExpr(arg);
+ inputs.insert(inputs.end(), res.begin(), res.end());
+ }
+ auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
+ "kernel", /* op_type_ */
+ inputs, 1 /* num_outputs_ */);
+ SetCallNodeAttribute(node, cn);
+ return AddNode(node, GetRef<Expr>(cn));
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const LetNode* ln) {
+ CHECK_EQ(memo_.count(ln->var), 0);
+ memo_[ln->var] = VisitExpr(ln->value);
+ return VisitExpr(ln->body);
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const TupleGetItemNode* gtn) {
+ auto vtuple = VisitExpr(gtn->tuple);
+ return {vtuple[gtn->index]};
+ }
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const FunctionNode* fn) {
+ CHECK(fn->GetAttr<String>(attr::kComposite).defined())
+ << "JSON runtime only supports composite functions";
+ // FunctionNode should be handled by the caller.
+ return {};
+ }
+
+ /*!
+ * \brief Save to JSON graph
+ *
+ * \param writer A json writer
+ */
+ void Save(dmlc::JSONWriter* writer) {
+ std::vector<size_t> arg_nodes;
+ for (size_t i = 0; i < nodes_.size(); ++i) {
+ auto node = nodes_[i];
+ if (node->IsLeaf()) {
+ arg_nodes.push_back(i);
+ }
+ }
+ size_t num_entry = 0;
+ std::vector<size_t> node_row_ptr{0};
+ for (auto node : nodes_) {
+ num_entry += node->GetNumOutput();
+ node_row_ptr.push_back(num_entry);
+ }
+ writer->BeginObject();
+ writer->WriteObjectKeyValue("nodes", nodes_);
+ writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
+ writer->WriteObjectKeyValue("heads", heads_);
+ writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
+ writer->EndObject();
+ }
+
+ private:
+ /*! \brief The symbol that represents the json graph. */
+ std::string symbol_;
+ /*! \brief The function to be serialized. */
+ const Expr func_;
+ /*! \brief JSON graph nodes. */
+ std::vector<JSONGraphObjectPtr> nodes_;
+ /*! \brief Output of the JSON graph. */
+ std::vector<JSONGraphNodeEntry> heads_;
+ /*! \brief The list of required constants. */
+ Array<String> params_;
+};
+
+} // namespace contrib
+} // namespace backend
+} // namespace relay
+} // namespace tvm
+#endif // TVM_RELAY_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
#include <sstream>
#include "../../utils.h"
+
+#ifdef USE_JSON_RUNTIME
+#include "../../../../runtime/contrib/json/json_node.h"
+#include "../codegen_json/codegen_json.h"
+#else
#include "../codegen_c/codegen_c.h"
+#endif
namespace tvm {
namespace relay {
using namespace backend;
+#ifndef USE_JSON_RUNTIME // C source runtime
inline size_t GetShape1DSize(const Type& type) {
const auto shape = GetShape(type);
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
std::ostringstream code_stream_;
};
+#else // DNNL JSON runtime
+
+class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
+ using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+
+ public:
+ DNNLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) override {
+ Expr expr = GetRef<Expr>(cn);
+ std::string name;
+ const CallNode* call = cn;
+ if (const auto* op_node = cn->op.as<OpNode>()) {
+ name = op_node->name;
+ } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+ auto comp = fn->GetAttr<String>(attr::kComposite);
+ CHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions.";
+ name = comp.value();
+
+ if (name == "dnnl.conv2d_bias_relu") {
+ call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
+ } else if (name == "dnnl.conv2d_relu") {
+ call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
+ CHECK(call->op.as<OpNode>()) << "Not op node";
+ } else {
+ LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
+ }
+ } else {
+ LOG(FATAL) << "DNNL JSON runtime does not support calls to " << cn->op->GetTypeKey();
+ }
+
+ std::vector<JSONGraphNodeEntry> inputs;
+ for (const auto& arg : cn->args) {
+ auto res = VisitExpr(arg);
+ inputs.insert(inputs.end(), res.begin(), res.end());
+ }
+ auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
+ "kernel", /* op_type_ */
+ inputs, 1 /* num_outputs_ */);
+ SetCallNodeAttribute(node, call);
+ return AddNode(node, GetRef<Expr>(cn));
+ }
+};
+
+/*!
+ * \brief Get the external symbol of the Relay function name.
+ *
+ * \param func The provided function.
+ *
+ * \return An external symbol.
+ */
+std::string GetExtSymbol(const Function& func) {
+ const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
+ return std::string(name_node.value());
+}
+#endif
+
/*!
* \brief The external compiler/codegen tool. It takes a Relay expression/module and
* compile it into a runtime module.
*/
runtime::Module DNNLCompiler(const ObjectRef& ref) {
+#ifdef USE_JSON_RUNTIME
+ CHECK(ref->IsInstance<FunctionNode>());
+ auto func = Downcast<Function>(ref);
+ auto func_name = GetExtSymbol(func);
+ DNNLJSONSerializer serializer(func_name, func);
+ serializer.serialize();
+ std::string graph_json = serializer.GetJSON();
+ auto params = serializer.GetParams();
+
+ const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate");
+ CHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
+ auto mod = (*pf)(func_name, graph_json, params);
+ return mod;
+#else
DNNLModuleCodegen dnnl;
return dnnl.CreateCSourceModule(ref);
+#endif
}
TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler);
LOG(FATAL) << "Not implemented.";
}
};
-
template <>
struct Handler<std::unordered_map<std::string, dmlc::any>> {
inline static void Write(dmlc::JSONWriter* writer,
writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::vector<int64_t>>>(v));
} else if (SameType<std::vector<std::string>>(v)) {
writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::string>>(v));
+ } else if (SameType<std::vector<dmlc::any>>(v)) {
+ writer->WriteObjectKeyValue(k, dmlc::get<std::vector<dmlc::any>>(v));
} else {
LOG(FATAL) << "Not supported";
}
CHECK(pf != nullptr) << "can not find packed function";
return runtime::TypedPackedFunc<R(Args...)>(*pf);
}
+
+/*!
+ * \brief Extract shape from an IndexExpr array to std::vector<int64_t>
+ *
+ * \param shape The shape in Array
+ * \return The converted shape in std::vector<int64_t>
+ */
+inline std::vector<int64_t> GetIntShape(const Array<IndexExpr>& shape) {
+ std::vector<int64_t> ret;
+ for (const auto& dim : shape) {
+ const int64_t* pval = tir::as_const_int(dim);
+ CHECK(pval) << "Expect integer, but received: " << dim->GetTypeKey();
+ ret.push_back(*pval);
+ }
+ return ret;
+}
+
/*!
* \brief Convert type to string
*
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+ * \brief A simple JSON runtime for DNNL.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <string>
+#include <vector>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+#include "dnnl.hpp"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::json;
+
+class DNNLJSONRuntime : public JSONRuntimeBase {
+ using tag = dnnl::memory::format_tag;
+ using dt = dnnl::memory::data_type;
+
+ public:
+ DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
+ const Array<String> const_names)
+ : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+ const char* type_key() const { return "dnnl_json"; }
+
+ void Init(const Array<NDArray>& consts) override {
+ BuildEngine();
+
+ CHECK_EQ(consts.size(), const_idx_.size())
+ << "The number of input constants must match the number of required.";
+
+ // Setup constants entries for weights.
+ SetupConstants(consts);
+ }
+
+ void Run() override {
+ // Fill in the input buffers.
+ for (size_t i = 0; i < input_nodes_.size(); ++i) {
+ auto eid = EntryID(input_nodes_[i], 0);
+ // TODO(@comaniac): Support other data lengths.
+ size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+ size_t buffer_size = GetDataSize(*data_entry_[eid]);
+ write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
+ offset_in_bytes);
+ }
+
+ // Invoke the engine through intepreting the stream.
+ for (size_t i = 0; i < net_.size(); ++i) {
+ net_.at(i).execute(stream_, net_args_.at(i));
+ }
+ stream_.wait();
+
+ // Read output buffers.
+ for (size_t i = 0; i < outputs_.size(); ++i) {
+ auto eid = EntryID(outputs_[i]);
+ size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+ size_t buffer_size = GetDataSize(*data_entry_[eid]);
+ read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
+ offset_in_bytes);
+ }
+ }
+
+ private:
+ // Build up the engine based on the input graph.
+ void BuildEngine() {
+ engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0);
+ stream_ = dnnl::stream(engine_);
+
+ // Build subgraph engine.
+ for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+ const auto& node = nodes_[nid];
+ if (node.GetOpType() == "kernel") {
+ CHECK_EQ(node.GetOpType(), "kernel");
+ auto op_name = node.GetOpName();
+ if ("nn.conv2d" == op_name) {
+ Conv2d(nid);
+ } else if ("dnnl.conv2d_relu" == op_name) {
+ Conv2d(nid, true, false);
+ } else if ("dnnl.conv2d_bias_relu" == op_name) {
+ Conv2d(nid, true, true);
+ } else if ("nn.dense" == op_name) {
+ Dense(nid);
+ } else if ("nn.batch_norm" == op_name) {
+ BatchNorm(nid);
+ } else if ("nn.relu" == op_name) {
+ Relu(nid);
+ } else if ("add" == op_name) {
+ Add(nid);
+ } else {
+ LOG(FATAL) << "Unsupported op: " << op_name;
+ }
+ }
+ }
+ }
+
+ // Bind a JSON graph node entry to a DNNL memory.
+ dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory::desc mem_desc,
+ size_t offset = 0) {
+ auto eid = EntryID(entry);
+ if (entry_out_mem_.count(eid) == 0) {
+ return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset);
+ }
+ return entry_out_mem_[eid].first;
+ }
+
+ // Bind a JSON graph node entry to a given DNNL memory.
+ dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory mem,
+ size_t offset = 0) {
+ auto eid = EntryID(entry);
+ // Since the DNNL memory has been created before calling this function, we assume the entry
+ // has not yet been bound to the other DNNL memory; otherwise it may have memory leak.
+ CHECK_EQ(entry_out_mem_.count(eid), 0);
+
+ // TODO(@comanic): Support other data types (i.e., int8).
+ auto data_node = nodes_[entry.id_];
+ auto dltype = data_node.GetOpDataType()[entry.index_];
+ CHECK_EQ(dltype.bits, 32);
+
+ entry_out_mem_[eid] = {mem, offset};
+ return entry_out_mem_[eid].first;
+ }
+
+ void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) {
+ auto node = nodes_[nid];
+
+ // Setup attributes.
+ auto data_entry = node.GetInputs()[0];
+ auto weight_entry = node.GetInputs()[1];
+ dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+ dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
+ std::vector<std::string> str_strides = node.GetAttr<std::vector<std::string>>("strides");
+ std::vector<std::string> str_padding = node.GetAttr<std::vector<std::string>>("padding");
+ dnnl::memory::dim groups = std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
+
+ dnnl::memory::dim N = input_shape[0], // batch size
+ IC = input_shape[1], // input channels
+ IH = input_shape[2], // input height
+ IW = input_shape[2], // input width
+ OC = weight_shape[0], // output channels
+ KH = weight_shape[2], // weight height
+ KW = weight_shape[3], // weight width
+ PH_L = std::stoi(str_padding[1]), // height padding: left
+ PH_R = std::stoi(str_padding[3]), // height padding: right
+ PW_L = std::stoi(str_padding[0]), // width padding: left
+ PW_R = std::stoi(str_padding[2]), // width padding: right
+ SH = std::stoi(str_strides[0]), // height-wise stride
+ SW = std::stoi(str_strides[0]), // weight-wise stride
+ OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height
+ OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width
+
+ // Memory shapes.
+ dnnl::memory::dims src_dims = {N, IC, IH, IW};
+ dnnl::memory::dims weights_dims = {OC, IC, KH, KW};
+ if (groups > 1) {
+ weights_dims = {groups, 1, IC / groups, KH, KW};
+ }
+ dnnl::memory::dims bias_dims = {OC};
+ dnnl::memory::dims dst_dims = {N, OC, OH, OW};
+ dnnl::memory::dims strides_dims = {SH, SW};
+ dnnl::memory::dims padding_dims_l = {PH_L, PW_L};
+ dnnl::memory::dims padding_dims_r = {PH_R, PW_R};
+
+ // Memory descriptions.
+ auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, tag::any);
+ auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, tag::any);
+ auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
+ auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::nchw);
+
+ // Covn2d description.
+ auto conv_desc = dnnl::convolution_forward::desc(
+ dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
+ conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
+
+ // Enable ReLU
+ dnnl::primitive_attr attr;
+ if (has_relu) {
+ dnnl::post_ops ops;
+ ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
+ attr.set_post_ops(ops);
+ }
+
+ auto conv2d_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_);
+
+ // Push to the network.
+ auto conv = dnnl::convolution_forward(conv2d_prim_desc);
+ net_.push_back(conv);
+
+ // Data memory.
+ CHECK_EQ(node.GetAttr<std::vector<std::string>>("data_layout")[0], "NCHW");
+ auto conv2d_src_memory = BindDNNLMemory(data_entry, {src_dims, dt::f32, tag::nchw});
+
+ // Weight memory.
+ CHECK_EQ(node.GetAttr<std::vector<std::string>>("kernel_layout")[0], "OIHW");
+ auto conv2d_weights_memory = BindDNNLMemory(
+ weight_entry, {weights_dims, dt::f32, (groups > 1) ? tag::goihw : tag::oihw});
+
+ // Bias memory.
+ auto conv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_);
+ if (has_bias) {
+ auto bias_entry = node.GetInputs()[2];
+ BindDNNLMemory(bias_entry, conv2d_bias_memory);
+ } else {
+ float bias[OC] = {0};
+ write_to_dnnl_memory(bias, conv2d_bias_memory, OC * sizeof(float));
+ }
+
+ // Output memory.
+ JSONGraphNodeEntry out_entry(nid, 0);
+ auto conv2d_dst_memory = BindDNNLMemory(out_entry, conv2d_prim_desc.dst_desc());
+
+ // Bind memory buffers.
+ net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory},
+ {DNNL_ARG_WEIGHTS, conv2d_weights_memory},
+ {DNNL_ARG_BIAS, conv2d_bias_memory},
+ {DNNL_ARG_DST, conv2d_dst_memory}});
+ }
+
+ void Dense(const size_t& nid) {
+ auto node = nodes_[nid];
+
+ // Setup attributes.
+ auto data_entry = node.GetInputs()[0];
+ auto weight_entry = node.GetInputs()[1];
+ dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+ dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
+
+ dnnl::memory::dim B = input_shape[0], // batch size
+ IC = input_shape[1], // input channels
+ OC = weight_shape[0]; // output channels
+
+ // Memory shapes.
+ dnnl::memory::dims data_dims = {B, IC};
+ dnnl::memory::dims weight_dims = {OC, IC};
+ dnnl::memory::dims bias_dims = {OC};
+ dnnl::memory::dims out_dims = {B, OC};
+
+ // Memory descriptions.
+ auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc});
+ auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc});
+ auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x});
+ auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc});
+
+ // Dense description.
+ auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md,
+ weight_md, bias_md, dst_md);
+ auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, engine_);
+
+ auto dense = dnnl::inner_product_forward(dense_prim_desc);
+ net_.push_back(dense);
+
+ // Memories.
+ auto data_memory = BindDNNLMemory(data_entry, data_md);
+ auto weight_memory = BindDNNLMemory(weight_entry, weight_md);
+ auto bias_memory = dnnl::memory(bias_md, engine_);
+ float bias[OC] = {0};
+ write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
+ JSONGraphNodeEntry out_entry(nid, 0);
+ auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());
+
+ net_args_.push_back({{DNNL_ARG_SRC, data_memory},
+ {DNNL_ARG_WEIGHTS, weight_memory},
+ {DNNL_ARG_BIAS, bias_memory},
+ {DNNL_ARG_DST, dst_memory}});
+ }
+
+ void BatchNorm(const size_t& nid) {
+ auto node = nodes_[nid];
+
+ auto data_entry = node.GetInputs()[0];
+ auto gamma_entry = node.GetInputs()[1];
+ auto beta_entry = node.GetInputs()[2];
+ auto mean_entry = node.GetInputs()[3];
+ auto variance_entry = node.GetInputs()[4];
+ dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+ dnnl::memory::dim IC = data_shape[1];
+ float epsilon = std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
+
+ // Memory description.
+ dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
+
+ // BN description.
+ auto bn_desc = dnnl::batch_normalization_forward::desc(
+ dnnl::prop_kind::forward_inference, data_md, epsilon,
+ dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift);
+ auto bn_prim_desc = dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_);
+ auto bn = dnnl::batch_normalization_forward(bn_prim_desc);
+ net_.push_back(bn);
+
+ // Memories.
+ auto data_memory = BindDNNLMemory(data_entry, data_md);
+ JSONGraphNodeEntry out_entry(nid, 0);
+ auto out_memory = BindDNNLMemory(out_entry, data_md);
+ auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc());
+ auto variance_memory = BindDNNLMemory(variance_entry, bn_prim_desc.variance_desc());
+
+ // In DNNL, weight is composed of gamma+beta, so we point them to the same DNNL memory but
+ // assign an offset to beta data for runtime serialization.
+ auto weight_memory = BindDNNLMemory(gamma_entry, bn_prim_desc.weights_desc(), 0);
+ BindDNNLMemory(beta_entry, weight_memory, IC);
+
+ net_args_.push_back({{DNNL_ARG_SRC, data_memory},
+ {DNNL_ARG_DST, out_memory},
+ {DNNL_ARG_SCALE_SHIFT, weight_memory},
+ {DNNL_ARG_MEAN, mean_memory},
+ {DNNL_ARG_VARIANCE, variance_memory}});
+ }
+
+ void Relu(const size_t& nid) {
+ auto node = nodes_[nid];
+
+ auto data_entry = node.GetInputs()[0];
+ dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+ auto data_md = dnnl::memory::desc{{shape}, dt::f32, tag::abcd};
+
+ auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference,
+ dnnl::algorithm::eltwise_relu, data_md, 0);
+ auto relu_prim_desc = dnnl::eltwise_forward::primitive_desc(relu_desc, engine_);
+ CHECK(data_md == relu_prim_desc.dst_desc());
+
+ auto relu = dnnl::eltwise_forward(relu_prim_desc);
+ net_.push_back(relu);
+
+ auto data_memory = BindDNNLMemory(data_entry, data_md);
+ auto out_md = dnnl::memory::desc(shape, dt::f32, tag::abcd);
+ JSONGraphNodeEntry out_entry(nid, 0);
+ auto out_memory = BindDNNLMemory(out_entry, out_md);
+
+ net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}});
+ }
+
+ void Add(const size_t& nid) {
+ auto node = nodes_[nid];
+
+ // Memory and compute description.
+ std::vector<dnnl::memory::dims> data_dims;
+ std::vector<dnnl::memory::desc> data_mds;
+ std::vector<dnnl::memory> data_memories;
+
+ CHECK_EQ(node.GetInputs().size(), 2U);
+ for (auto entry : node.GetInputs()) {
+ auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_];
+ dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
+
+ data_dims.push_back(data_shape);
+ data_mds.push_back(data_md);
+ data_memories.push_back(BindDNNLMemory(entry, data_md));
+ }
+ CHECK(data_dims[0] == data_dims[1]);
+ auto out_md = data_mds[0];
+ JSONGraphNodeEntry out_entry(nid, 0);
+ auto out_memory = BindDNNLMemory(out_entry, out_md);
+
+ auto add_desc =
+ dnnl::binary::desc(dnnl::algorithm::binary_add, data_mds[0], data_mds[1], out_md);
+ auto add_prim_desc = dnnl::binary::primitive_desc(add_desc, engine_);
+ auto add = dnnl::binary(add_prim_desc);
+ net_.push_back(add);
+
+ net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]},
+ {DNNL_ARG_SRC_1, data_memories[1]},
+ {DNNL_ARG_DST, out_memory}});
+ }
+
+ // Read from DNNL memory (+offset) and write to the handle.
+ inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size,
+ size_t offset = 0) {
+ uint8_t* src = static_cast<uint8_t*>(mem.get_data_handle());
+ std::copy(src + offset, src + offset + size, static_cast<uint8_t*>(handle));
+ }
+
+ // Read from the handle and write to DNNL memory (+offset).
+ inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size,
+ size_t offset = 0) {
+ uint8_t* dst = static_cast<uint8_t*>(mem.get_data_handle());
+ std::copy(reinterpret_cast<uint8_t*>(handle), reinterpret_cast<uint8_t*>(handle) + size,
+ dst + offset);
+ }
+
+ // Generate DNNL memory description and infer the data layout by the given shape.
+ inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, dt dtype) {
+ dnnl::memory::desc data_md;
+ switch (shape.size()) {
+ case 2:
+ data_md = dnnl::memory::desc({shape, dtype, tag::ab});
+ break;
+ case 3:
+ data_md = dnnl::memory::desc({shape, dtype, tag::abc});
+ break;
+ case 4:
+ data_md = dnnl::memory::desc({shape, dtype, tag::abcd});
+ break;
+ case 5:
+ data_md = dnnl::memory::desc({shape, dtype, tag::abcde});
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data shape dimension: " << shape.size();
+ break;
+ }
+ return data_md;
+ }
+
+ /* The dnnl engine. */
+ dnnl::engine engine_;
+ /* The dnnl stream. */
+ dnnl::stream stream_;
+ /* The network layers that are represented in dnnl primitives. */
+ std::vector<dnnl::primitive> net_;
+ /* The memory that is consumed by arguments. */
+ std::vector<std::unordered_map<int, dnnl::memory>> net_args_;
+ /* The entry ID to its corresponding output memory. */
+ std::unordered_map<uint32_t, std::pair<dnnl::memory, size_t>> entry_out_mem_;
+};
+
+runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json,
+ const Array<String>& const_names) {
+ auto n = make_object<DNNLJSONRuntime>(symbol_name, graph_json, const_names);
+ return runtime::Module(n);
+}
+
+TVM_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate);
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json")
+ .set_body_typed(JSONRuntimeBase::LoadFromBinary<DNNLJSONRuntime>);
+
+} // namespace contrib
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/json/json_node.h
+ * \brief The graph nodes used by JSON runtime.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_NODE_H_
+#define TVM_RUNTIME_CONTRIB_JSON_JSON_NODE_H_
+
+#include <dlpack/dlpack.h>
+#include <dmlc/json.h>
+#include <dmlc/memory_io.h>
+#include <tvm/runtime/container.h>
+
+#include <cstdint>
+#include <cstdio>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace json {
+
+using namespace tvm::runtime;
+using JSONGraphAttrs = std::unordered_map<std::string, dmlc::any>;
+
+/*!
+ * \brief The node entry in the serialized json graph.
+ */
+class JSONGraphNodeEntry {
+ public:
+ // Constructors.
+ JSONGraphNodeEntry() = default;
+ JSONGraphNodeEntry(int id, int index, int version = 0)
+ : id_(id), index_(index), version_(version) {}
+
+ /*!
+ * \brief Serialize a node entry.
+ * \param writer The json writer.
+ */
+ void Save(dmlc::JSONWriter* writer) const {
+ writer->BeginArray();
+ writer->WriteArrayItem(id_);
+ writer->WriteArrayItem(index_);
+ writer->WriteArrayItem(version_);
+ writer->EndArray();
+ }
+
+ /*!
+ * \brief Deserialize the json string into a node entry.
+ * \param reader The json reader.
+ */
+ void Load(dmlc::JSONReader* reader) {
+ reader->BeginArray();
+ CHECK(reader->NextArrayItem()) << "invalid json format";
+ reader->Read(&id_);
+ CHECK(reader->NextArrayItem()) << "invalid json format";
+ reader->Read(&index_);
+ if (reader->NextArrayItem()) {
+ reader->Read(&version_);
+ CHECK(!reader->NextArrayItem()) << "invalid json format";
+ } else {
+ version_ = 0;
+ }
+ }
+
+ /*! \brief The json graph node ID. */
+ uint32_t id_;
+ /*! \brief The entry index. */
+ uint32_t index_;
+ uint32_t version_;
+};
+
+/*!
+ * \brief The node of the serialized json graph. It includes an array of
+ * entries.
+ */
+class JSONGraphNode {
+ public:
+ // Constructors.
+ JSONGraphNode() = default;
+ JSONGraphNode(const std::string& name, const std::string& op_type,
+ const std::vector<JSONGraphNodeEntry>& inputs = {}, size_t num_outputs = 1) {
+ name_ = name;
+ op_type_ = op_type;
+ num_inputs_ = inputs.size();
+ inputs_ = inputs;
+ num_outputs_ = num_outputs;
+ }
+
+ /*!
+ * \brief Serialize a node so that it can be saved to disk.
+ * \param writer The json writer.
+ */
+ void Save(dmlc::JSONWriter* writer) {
+ writer->BeginObject();
+ writer->WriteObjectKeyValue("op", op_type_);
+ writer->WriteObjectKeyValue("name", name_);
+ if (!inputs_.empty()) {
+ SetAttr("num_inputs", std::to_string(inputs_.size()));
+ SetAttr("num_outputs", std::to_string(num_outputs_));
+ writer->WriteObjectKeyValue("inputs", this->inputs_);
+ }
+ if (!attrs_.empty()) {
+ writer->WriteObjectKeyValue("attrs", attrs_);
+ }
+ writer->EndObject();
+ }
+
+ /*!
+ * \brief Load the attribute of a node in the json string.
+ * \param reader The json reader.
+ */
+ void LoadAttrs(dmlc::JSONReader* reader) {
+ std::string key, value;
+ reader->BeginObject();
+ while (reader->NextObjectItem(&key)) {
+ if (key == "num_inputs") {
+ reader->Read(&value);
+ num_inputs_ = strtoul(value.c_str(), nullptr, 10);
+ } else if (key == "num_outputs") {
+ reader->Read(&value);
+ num_outputs_ = strtoul(value.c_str(), nullptr, 10);
+ } else if (key == "dtype") {
+ std::vector<std::string> tmp;
+ reader->BeginArray();
+ CHECK(reader->NextArrayItem());
+ reader->Read(&tmp);
+ CHECK(!reader->NextArrayItem());
+ for (const auto& it : tmp) {
+ dtype_.push_back(tvm::runtime::String2DLDataType(it));
+ }
+ } else if (key == "shape") {
+ reader->BeginArray();
+ CHECK(reader->NextArrayItem());
+ reader->Read(&shape_);
+ CHECK(!reader->NextArrayItem());
+ } else {
+ reader->BeginArray();
+ CHECK(reader->NextArrayItem());
+ std::vector<std::string> tmp;
+ reader->Read(&tmp);
+ attrs_[key] = tmp;
+ CHECK(!reader->NextArrayItem());
+ }
+ }
+ CHECK_EQ(shape_.size(), dtype_.size());
+ }
+
+ /*!
+ * \brief Load a node in the json string.
+ * \param reader The json reader.
+ */
+ void Load(dmlc::JSONReader* reader) {
+ reader->BeginObject();
+ std::string key;
+ while (reader->NextObjectItem(&key)) {
+ if (key == "op") {
+ reader->Read(&op_type_);
+ } else if (key == "name") {
+ reader->Read(&name_);
+ } else if (key == "inputs") {
+ reader->Read(&inputs_);
+ } else if (key == "attr" || key == "attrs") {
+ this->LoadAttrs(reader);
+ } else {
+ LOG(FATAL) << "Unknown key: " << key;
+ }
+ }
+ }
+
+ /*!
+ * \brief Check if a node is a leaf node, i.e. input to the graph.
+ *
+ * \return True if the node has no input, otherwise, false.
+ */
+ bool IsLeaf() const { return inputs_.empty(); }
+
+ /*!
+ * \brief Return the number of outputs of the node.
+ *
+ * \return The number of the output.
+ */
+ uint32_t GetNumOutput() const { return num_outputs_; }
+
+ /*!
+ * \brief Return the input entries.
+ *
+ * \return The input entries.
+ */
+ std::vector<JSONGraphNodeEntry> GetInputs() const { return inputs_; }
+
+ /*!
+ * \brief Return the op type.
+ *
+ * \return The op type.
+ */
+ std::string GetOpType() const { return op_type_; }
+
+ /*!
+ * \brief Return the op name.
+ *
+ * \return The op name.
+ */
+ std::string GetOpName() const { return name_; }
+
+ /*!
+ * \brief Return the op output shapes.
+ *
+ * \return The shapes.
+ */
+ std::vector<std::vector<int64_t>> GetOpShape() const { return shape_; }
+
+ /*!
+ * \brief Return the op types.
+ *
+ * \return The types.
+ */
+ std::vector<DLDataType> GetOpDataType() const { return dtype_; }
+
+ /*!
+ * \brief Set the number of outputs of the node.
+ *
+ * \param num_outputs The number of output.
+ */
+ void SetNumOutput(uint32_t num_outputs) { num_outputs_ = num_outputs; }
+
+ /*!
+ * \brief Get the value of an attribute in the node.
+ *
+ * \tparam T The return type.
+ * \param key The key for lookup.
+ *
+ * \return The value.
+ */
+ template <typename T>
+ T GetAttr(const std::string& key) const {
+ CHECK_GT(attrs_.count(key), 0U) << "Key: " << key << "is not found";
+ return dmlc::get<T>(attrs_.at(key));
+ }
+
+ /*!
+ * \brief Set an attribute for the node.
+ *
+ * \tparam ValueT The type of the value being stored.
+ * \param key The key of the attribute.
+ * \param value The value of the attribute.
+ */
+ template <typename ValueT>
+ void SetAttr(const std::string& key, const ValueT& value) {
+ attrs_[key] = value;
+ }
+
+ virtual ~JSONGraphNode() {}
+
+ private:
+ /*! \brief The number of input. */
+ uint32_t num_inputs_{0};
+ /*! \brief The number of output. */
+ uint32_t num_outputs_{1};
+ /*! \brief The name of the op. It is the symbol that used for runtime lookup. */
+ std::string name_;
+ /*! \brief The operator type, i.e. input is "null". */
+ std::string op_type_;
+ /*! \brief The shape of the node. */
+ std::vector<std::vector<int64_t>> shape_;
+ /*! \brief The type of the node. */
+ std::vector<DLDataType> dtype_;
+ /*! \brief The inputs of the node. */
+ std::vector<JSONGraphNodeEntry> inputs_;
+ /*!
+ * \brief Attribute of the node. For simplicity, we store all attribute as
+ * a list of std::string. It's the developer's resposibility to check the
+ * required attribute of a certain op and convert it into the needed type.
+ *
+ * For example, for conv2d, this map could contain:
+ * attrs_["strides"] = ["1", "1"]
+ * attrs_["padding"] = ["0", "0", "0", "0"]
+ * attrs_["data_layout"] = ["NCHW"]
+ *
+ * when creating an execution engine, developers may need to use these
+ * attributes and they can convert it into the needed type, i.e. padding to
+ * int
+ */
+ JSONGraphAttrs attrs_;
+
+ friend class JSONRuntimeBase;
+};
+
+} // namespace json
+} // namespace runtime
+} // namespace tvm
+
+namespace dmlc {
+namespace json {
+template <typename T>
+inline bool SameType(const dmlc::any& data) {
+ return std::type_index(data.type()) == std::type_index(typeid(T));
+}
+
+template <>
+struct Handler<std::unordered_map<std::string, dmlc::any>> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const std::unordered_map<std::string, dmlc::any>& data) {
+ for (const auto& kv : data) {
+ auto k = kv.first;
+ const dmlc::any& v = kv.second;
+ if (SameType<std::vector<dmlc::any>>(v)) {
+ writer->WriteObjectKeyValue(k, dmlc::get<std::vector<dmlc::any>>(v));
+ } else {
+ LOG(FATAL) << "Not supported";
+ }
+ }
+ writer->EndObject();
+ }
+ inline static void Read(dmlc::JSONReader* reader,
+ std::unordered_map<std::string, dmlc::any>* data) {
+ LOG(FATAL) << "Not implemented";
+ }
+};
+
+template <>
+struct Handler<std::shared_ptr<tvm::runtime::json::JSONGraphNode>> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const std::shared_ptr<tvm::runtime::json::JSONGraphNode>& data) {
+ data->Save(writer);
+ }
+
+ inline static void Read(dmlc::JSONReader* reader,
+ std::shared_ptr<tvm::runtime::json::JSONGraphNode>* data) {
+ (*data)->Load(reader);
+ }
+};
+} // namespace json
+} // namespace dmlc
+
+#endif // TVM_RUNTIME_CONTRIB_JSON_JSON_NODE_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/json/json_runtime.h
+ * \brief Utilities for json runtime.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
+#define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
+
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+
+#include <cstddef>
+#include <string>
+#include <tuple>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "json_node.h"
+
+namespace tvm {
+namespace runtime {
+namespace json {
+
+/*!
+ * \brief A json runtime that executes the serialized JSON format. This runtime
+ * can be extended by user defined runtime for execution.
+ */
+class JSONRuntimeBase : public ModuleNode {
+ public:
+ JSONRuntimeBase(const std::string& symbol_name, const std::string& graph_json,
+ const Array<String> const_names)
+ : symbol_name_(symbol_name), graph_json_(graph_json), const_names_(const_names) {
+ LoadGraph(graph_json_);
+ }
+
+ const char* type_key() const { return "json"; }
+
+ /*! \brief Initialize a specific json runtime. */
+ virtual void Init(const Array<NDArray>& consts) = 0;
+
+ /*! \brief Invoke the execution engine to inteprete a specific json runtime. */
+ virtual void Run() = 0;
+
+ /*!
+ * \brief Get a packed function.
+ * \param name The name/symbol of the function.
+ * \param sptr_to_self The pointer to the module node.
+ * \return The packed function.
+ */
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
+ if (name == "get_symbol") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; });
+ } else if (name == "get_const_vars") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_names_; });
+ } else if (this->symbol_name_ == name) {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK(this->initialized_) << "The module has not been initialized";
+
+ // Bind argument tensors to data entries.
+ this->SetInputOutputBuffers(args);
+ // Execute the subgraph.
+ this->Run();
+ });
+ } else if ("__init_" + this->symbol_name_ == name) {
+ // The function to initialize constant tensors.
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK_EQ(args.size(), 1U);
+ this->Init(args[0]);
+ this->initialized_ = true;
+ *rv = 0;
+ });
+ } else {
+ return PackedFunc(nullptr);
+ }
+ }
+
+ virtual void SaveToBinary(dmlc::Stream* stream) {
+ // Save the symbol
+ stream->Write(symbol_name_);
+ // Save the graph
+ stream->Write(graph_json_);
+ // Save the required const names
+ std::vector<std::string> consts;
+ for (const auto& it : const_names_) {
+ consts.push_back(it);
+ }
+ stream->Write(consts);
+ }
+
+ template <typename T,
+ typename = typename std::enable_if<std::is_base_of<JSONRuntimeBase, T>::value>::type>
+ static Module LoadFromBinary(void* strm) {
+ dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+ std::string symbol;
+ std::string graph_json;
+ std::vector<std::string> consts;
+ // Load the symbol
+ CHECK(stream->Read(&symbol)) << "Loading symbol name failed";
+ CHECK(stream->Read(&graph_json)) << "Loading graph json failed";
+ CHECK(stream->Read(&consts)) << "Loading the const name list failed";
+ Array<String> const_names;
+ for (const auto& it : consts) {
+ const_names.push_back(it);
+ }
+ auto n = make_object<T>(symbol, graph_json, const_names);
+ return Module(n);
+ }
+
+ protected:
+ /*!
+ * \brief Set up the input and output buffers by binding their DLTensor pointers to the
+ * corresponding data entry.
+ *
+ * \param args The packed args.
+ */
+ void SetInputOutputBuffers(const TVMArgs& args) {
+ CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size())
+ << "Found mismatch in the number of provided data entryies and required.";
+
+ for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
+ auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0)
+ : EntryID(outputs_[i - input_var_idx_.size()]);
+ CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle)
+ << "Expect NDArray or DLTensor as inputs";
+
+ const DLTensor* arg;
+ if (args[i].IsObjectRef<NDArray>()) {
+ NDArray arr = args[i];
+ arg = arr.operator->();
+ } else {
+ arg = args[i].operator DLTensor*();
+ }
+
+ // Assign input/output the NDArray pointers to data entry so that we can directly
+ // read/write host buffers.
+ data_entry_[eid] = arg;
+ }
+ }
+
+ /*!
+ * \brief Load the graph and record the entries for inputs and constants.
+ *
+ * \param graph_json The graph in the json format.
+ */
+ void LoadGraph(const std::string& graph_json) {
+ std::istringstream is(graph_json);
+ dmlc::JSONReader reader(&is);
+ this->Load(&reader);
+ std::vector<std::string> consts;
+ for (size_t i = 0; i < input_nodes_.size(); i++) {
+ uint32_t nid = input_nodes_[i];
+ std::string name = nodes_[nid].name_;
+ if (nodes_[nid].op_type_ == "input") {
+ input_var_idx_.push_back(nid);
+ } else {
+ CHECK_EQ(nodes_[nid].op_type_, "const");
+ auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
+ CHECK(pos != std::end(const_names_)) << "Found non-existent constant: " << name;
+ const_idx_.push_back(nid);
+ consts.push_back(name);
+ }
+ }
+ CHECK_EQ(consts.size(), const_names_.size())
+ << "Found mismatch for the number of constants in the graph and required.";
+
+ for (size_t i = 0; i < consts.size(); i++) {
+ CHECK_EQ(consts[i], const_names_[i])
+ << "The position of constant in the graph must be the same as the required.";
+ }
+
+ // Reserve data entries.
+ data_entry_.resize(NumEntries());
+ }
+
+ /*!
+ * \brief Set up the constants/weights for inference by binding their DLTensor pointer to
+ * the corresponding data entry.
+ *
+ * \param consts A list of constant NDArray to be used.
+ */
+ void SetupConstants(const Array<NDArray>& consts) {
+ for (size_t i = 0; i < consts.size(); ++i) {
+ data_entry_[EntryID(const_idx_[i], 0)] = consts[i].operator->();
+ }
+ }
+
+ // Load the graph.
+ void Load(dmlc::JSONReader* reader) {
+ reader->BeginObject();
+ std::string key;
+ while (reader->NextObjectItem(&key)) {
+ if (key == "nodes") {
+ reader->Read(&nodes_);
+ } else if (key == "arg_nodes") {
+ reader->Read(&input_nodes_);
+ } else if (key == "node_row_ptr") {
+ reader->Read(&node_row_ptr_);
+ } else if (key == "heads") {
+ reader->Read(&outputs_);
+ } else {
+ LOG(FATAL) << "Unknown key: " << key;
+ }
+ }
+ }
+
+ // Get the node entry index.
+ uint32_t EntryID(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; }
+
+ // Get the node entry index.
+ uint32_t EntryID(const JSONGraphNodeEntry& e) const { return EntryID(e.id_, e.index_); }
+
+ // Number of node entries.
+ uint32_t NumEntries() const { return node_row_ptr_.back(); }
+
+ protected:
+ /*! \brief The only subgraph name for this module. */
+ std::string symbol_name_;
+ /*! \brief The graph. */
+ std::string graph_json_;
+ /*! \brief The required constant names. */
+ Array<String> const_names_;
+ /*! \brief The json graph nodes. */
+ std::vector<JSONGraphNode> nodes_;
+ /*! \brief The input nodes, including variables and constants. */
+ std::vector<uint32_t> input_nodes_;
+ /*! \brief Used for quick entry indexing. */
+ std::vector<uint32_t> node_row_ptr_;
+ /*! \brief Output entries. */
+ std::vector<JSONGraphNodeEntry> outputs_;
+ /*! \brief Data of that entry. */
+ std::vector<const DLTensor*> data_entry_;
+ /*! \brief Map the input name to node index. */
+ std::vector<uint32_t> input_var_idx_;
+ /*! \brief input const node index. */
+ std::vector<uint32_t> const_idx_;
+ /*! \brief Indicate if the engine has been initialized. */
+ bool initialized_{false};
+};
+
+} // namespace json
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
options=["-O2", "-std=c++14", "-I" + tmp_path.relpath("")])
def test_json_extern():
- if not tvm.get_global_func("module.loadfile_examplejson", True):
+ if not tvm.get_global_func("runtime.module.loadfile_examplejson", True):
print("Skip because JSON example runtime is not enabled.")
return
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for JSON codegen and runtime."""
+import os
+import sys
+
+import numpy as np
+
+import tvm
+import tvm.relay.op as reg
+import tvm.relay.testing
+from tvm import relay, runtime
+from tvm.contrib import util
+from tvm.relay import transform
+from tvm.relay.backend import compile_engine
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.op.contrib.register import get_pattern_table
+
+
+def set_func_attr(func, compile_name, symbol_name):
+ func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+ func = func.with_attr("Compiler", compile_name)
+ func = func.with_attr("global_symbol", symbol_name)
+ return func
+
+
+def check_result(mod,
+ ref_mod,
+ map_inputs,
+ out_shape,
+ tol=1e-5,
+ target="llvm",
+ ctx=tvm.cpu(),
+ params=None):
+ if sys.platform == "win32":
+ print("Skip test on Windows for now")
+ return
+
+ # Run the reference result
+ compile_engine.get().clear()
+ with tvm.transform.PassContext(opt_level=3):
+ json, lib, param = relay.build(ref_mod, target=target, params=params)
+ rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+ for name, data in map_inputs.items():
+ rt_mod.set_input(name, data)
+ rt_mod.set_input(**param)
+ rt_mod.run()
+ out = tvm.nd.empty(out_shape, ctx=ctx)
+ out = rt_mod.get_output(0, out)
+ ref_result = out.asnumpy()
+
+ def check_vm_result():
+ compile_engine.get().clear()
+ with relay.build_config(opt_level=3):
+ exe = relay.vm.compile(mod, target=target, params=params)
+ code, lib = exe.save()
+ exe = runtime.vm.Executable.load_exec(code, lib)
+ vm = runtime.vm.VirtualMachine(exe)
+ vm.init(ctx)
+ out = vm.run(**map_inputs)
+ tvm.testing.assert_allclose(out.asnumpy(), ref_result, rtol=tol, atol=tol)
+
+ def check_graph_runtime_result():
+ compile_engine.get().clear()
+ with relay.build_config(opt_level=3):
+ json, lib, param = relay.build(mod, target=target, params=params)
+ rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+ for name, data in map_inputs.items():
+ rt_mod.set_input(name, data)
+ rt_mod.set_input(**param)
+ rt_mod.run()
+ out = tvm.nd.empty(out_shape, ctx=ctx)
+ out = rt_mod.get_output(0, out)
+ tvm.testing.assert_allclose(out.asnumpy(), ref_result, rtol=tol, atol=tol)
+
+ check_vm_result()
+ check_graph_runtime_result()
+
+
+def test_conv2d():
+ """Test a subgraph with a single conv2d operator."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ def conv2d_direct():
+ dtype = 'float32'
+ ishape = (1, 32, 14, 14)
+ w1shape = (32, 32, 3, 3)
+
+ data0 = relay.var("data", shape=ishape, dtype=dtype)
+ weight0 = relay.var("weight", shape=w1shape, dtype=dtype)
+ out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1))
+
+ func = relay.Function([data0, weight0], out)
+ func = set_func_attr(func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = func
+
+ data = relay.var("data", shape=(ishape), dtype=dtype)
+ weight = relay.var("weight", shape=(w1shape), dtype=dtype)
+ main_f = relay.Function([data, weight], glb_var(data, weight))
+ mod["main"] = main_f
+
+ data0 = relay.var("data", shape=ishape, dtype=dtype)
+ weight0 = relay.var("weight", shape=w1shape, dtype=dtype)
+ out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1))
+ main_f = relay.Function([data0, weight0], out)
+ ref_mod = tvm.IRModule()
+ ref_mod['main'] = main_f
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+
+ return mod, ref_mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14)
+
+ def group_conv2d():
+ dtype = 'float32'
+ ishape = (1, 32, 14, 14)
+ w2shape = (32, 1, 3, 3)
+
+ data0 = relay.var("data", shape=(ishape), dtype=dtype)
+ weight0 = relay.var("weight", shape=(w2shape), dtype=dtype)
+ out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=32)
+
+ func = relay.Function([data0, weight0], out)
+ func = set_func_attr(func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = func
+
+ data = relay.var("data", shape=(ishape), dtype=dtype)
+ weight = relay.var("weight", shape=(w2shape), dtype=dtype)
+ main_f = relay.Function([data, weight], glb_var(data, weight))
+ mod["main"] = main_f
+
+ data0 = relay.var("data", shape=(ishape), dtype=dtype)
+ weight0 = relay.var("weight", shape=(w2shape), dtype=dtype)
+ out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=32)
+ main_f = relay.Function([data0, weight0], out)
+ ref_mod = tvm.IRModule()
+ ref_mod['main'] = main_f
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w_data = np.random.uniform(0, 1, w2shape).astype(dtype)
+
+ return mod, ref_mod, {"data": i_data, "weight": w_data}, (1, 32, 14, 14)
+
+ for mod, ref_mod, map_inputs, out_shape in [conv2d_direct(), group_conv2d()]:
+ check_result(mod, ref_mod, map_inputs, out_shape, tol=1e-5)
+
+
+def test_add():
+ """Test a subgraph with a single add operator."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ shape = (10, 10)
+
+ def gen_add():
+ data0 = relay.var("data0", shape=shape, dtype=dtype)
+ data1 = relay.var("data1", shape=shape, dtype=dtype)
+ out = relay.add(data0, data1)
+
+ func = relay.Function([data0, data1], out)
+ func = set_func_attr(func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = func
+
+ data0 = relay.var("data0", shape=shape, dtype=dtype)
+ data1 = relay.var("data1", shape=shape, dtype=dtype)
+ main_f = relay.Function([data0, data1], glb_var(data0, data1))
+ mod["main"] = main_f
+
+ data0 = relay.var("data0", shape=shape, dtype=dtype)
+ data1 = relay.var("data1", shape=shape, dtype=dtype)
+ out = relay.add(data0, data1)
+ main_f = relay.Function([data0, data1], out)
+ ref_mod = tvm.IRModule()
+ ref_mod["main"] = main_f
+
+ return mod, ref_mod
+
+ mod, ref_mod = gen_add()
+
+ data0 = np.random.uniform(0, 1, shape).astype(dtype)
+ data1 = np.random.uniform(0, 1, shape).astype(dtype)
+ check_result(mod, ref_mod, {"data0": data0, "data1": data1}, shape, tol=1e-5)
+
+
+def test_relu():
+ """Test a subgraph with a single ReLU operator."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ shape = (1, 32, 14, 14)
+
+ def gen_relu():
+ data0 = relay.var("data0", shape=shape, dtype=dtype)
+ out = relay.nn.relu(data0)
+
+ func = relay.Function([data0], out)
+ func = set_func_attr(func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = func
+
+ data0 = relay.var("data0", shape=shape, dtype=dtype)
+ main_f = relay.Function([data0], glb_var(data0))
+ mod["main"] = main_f
+
+ data0 = relay.var("data0", shape=shape, dtype=dtype)
+ out = relay.nn.relu(data0)
+ main_f = relay.Function([data0], out)
+ ref_mod = tvm.IRModule()
+ ref_mod["main"] = main_f
+
+ return mod, ref_mod
+
+ mod, ref_mod = gen_relu()
+
+ data0 = np.random.uniform(-1, 1, shape).astype(dtype)
+ check_result(mod, ref_mod, {"data0": data0,}, (1, 32, 14, 14), tol=1e-5)
+
+
+def test_dense():
+ """Test a subgraph with a single dense operator."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ a_shape = (1, 512)
+ b_shape = (1024, 512)
+
+ def gen_dense():
+ a = relay.var("A", shape=a_shape, dtype=dtype)
+ b = relay.var("B", shape=b_shape, dtype=dtype)
+ out = relay.nn.dense(a, b)
+
+ func = relay.Function([a, b], out)
+ func = set_func_attr(func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = func
+
+ a = relay.var("A", shape=a_shape, dtype=dtype)
+ b = relay.var("B", shape=b_shape, dtype=dtype)
+ main_f = relay.Function([a, b], glb_var(a, b))
+ mod["main"] = main_f
+
+ a = relay.var("A", shape=a_shape, dtype=dtype)
+ b = relay.var("B", shape=b_shape, dtype=dtype)
+ out = relay.nn.dense(a, b)
+ main_f = relay.Function([a, b], out)
+ ref_mod = tvm.IRModule()
+ ref_mod["main"] = main_f
+
+ return mod, ref_mod
+
+ mod, ref_mod = gen_dense()
+
+ data_a = np.random.uniform(0, 1, a_shape).astype(dtype)
+ data_b = np.random.uniform(0, 1, b_shape).astype(dtype)
+ check_result(mod, ref_mod, {"A": data_a, "B": data_b}, (1, 1024), tol=1e-5)
+
+
+def test_bn():
+ """Test a subgraph with a single batch_norm operator."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ d_shape = (1, 8)
+ c_shape = (8, )
+
+ def gen_bn():
+ data = relay.var('data', shape=d_shape)
+ gamma = relay.var("gamma", shape=c_shape)
+ beta = relay.var("beta", shape=c_shape)
+ moving_mean = relay.var("moving_mean", shape=c_shape)
+ moving_var = relay.var("moving_var", shape=c_shape)
+ bn = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var)
+ out = bn[0]
+
+ func = relay.Function([data, gamma, beta, moving_mean, moving_var], out)
+ func = set_func_attr(func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = func
+
+ data = relay.var('data', shape=d_shape)
+ gamma = relay.var("gamma", shape=c_shape)
+ beta = relay.var("beta", shape=c_shape)
+ moving_mean = relay.var("moving_mean", shape=c_shape)
+ moving_var = relay.var("moving_var", shape=c_shape)
+ main_f = relay.Function([data, gamma, beta, moving_mean, moving_var],
+ glb_var(data, gamma, beta, moving_mean, moving_var))
+ mod["main"] = main_f
+
+ data = relay.var('data', shape=d_shape)
+ gamma = relay.var("gamma", shape=c_shape)
+ beta = relay.var("beta", shape=c_shape)
+ moving_mean = relay.var("moving_mean", shape=c_shape)
+ moving_var = relay.var("moving_var", shape=c_shape)
+ bn = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var)
+ out = bn[0]
+ main_f = relay.Function([data, gamma, beta, moving_mean, moving_var], out)
+ ref_mod = tvm.IRModule()
+ ref_mod["main"] = main_f
+
+ return mod, ref_mod
+
+ mod, ref_mod = gen_bn()
+
+ data = np.random.uniform(-1, 1, d_shape).astype(dtype)
+ gamma = np.random.uniform(-1, 1, c_shape).astype(dtype)
+ beta = np.random.uniform(-1, 1, c_shape).astype(dtype)
+ moving_mean = np.random.uniform(-1, 1, c_shape).astype(dtype)
+ moving_var = np.random.uniform(-1, 1, c_shape).astype(dtype)
+ check_result(mod,
+ ref_mod, {
+ "data": data,
+ "gamma": gamma,
+ "beta": beta,
+ "moving_mean": moving_mean,
+ "moving_var": moving_var
+ },
+ d_shape,
+ tol=1e-5)
+
+
+def test_multiple_ops():
+ """Test a subgraph with multiple operators."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ ishape = (1, 32, 14, 14)
+ w1shape = (32, 32, 3, 3)
+ w2shape = (64, 32, 5, 5)
+
+ def get_net():
+ data = relay.var("data", relay.TensorType(ishape, dtype))
+ w1 = relay.var("w1", relay.TensorType(w1shape, dtype))
+ w2 = relay.var("w2", relay.TensorType(w2shape, dtype))
+
+ layer = relay.nn.conv2d(data=data, weight=w1, kernel_size=(3, 3), padding=(1, 1))
+ layer = relay.nn.relu(layer)
+ layer = relay.nn.conv2d(data=layer, weight=w2, kernel_size=(5, 5), padding=(2, 2))
+ layer = relay.nn.relu(layer)
+
+ main_f = relay.Function([data, w1, w2], layer)
+ mod = tvm.IRModule()
+ mod["main"] = main_f
+ return mod
+
+ def get_partitoned_mod(mod):
+ remove_bn_pass = tvm.transform.Sequential([
+ transform.InferType(),
+ transform.SimplifyInference(),
+ transform.FoldConstant(),
+ transform.FoldScaleAxis(),
+ ])
+ byoc_pass = tvm.transform.Sequential([
+ remove_bn_pass,
+ transform.AnnotateTarget("dnnl"),
+ transform.MergeCompilerRegions(),
+ transform.PartitionGraph()
+ ])
+
+ with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ return byoc_pass(mod)
+
+ ref_mod = get_net()
+ mod = get_partitoned_mod(ref_mod)
+
+ data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w1 = np.random.uniform(0, 1, w1shape).astype(dtype)
+ w2 = np.random.uniform(0, 1, w2shape).astype(dtype)
+ check_result(mod, ref_mod, {
+ "data": data,
+ "w1": w1,
+ "w2": w2,
+ }, (1, 64, 14, 14), tol=1e-5)
+
+
+def test_composite():
+ """Test DNNL patterns and there composite functions."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+
+ def conv2d_relu():
+ ishape = (1, 32, 14, 14)
+ w1shape = (32, 32, 3, 3)
+
+ # Composite function
+ in_1 = relay.var("in_1", shape=ishape, dtype=dtype)
+ in_2 = relay.var("in_2", shape=w1shape, dtype=dtype)
+ conv2d = relay.nn.conv2d(in_1, in_2, kernel_size=(3, 3), padding=(1, 1))
+ relu = relay.nn.relu(conv2d)
+ func = relay.Function([in_1, in_2], relu)
+ func = func.with_attr('Composite', 'dnnl.conv2d_relu')
+ func = func.with_attr('PartitionedFromPattern', 'nn.conv2d_nn.relu_')
+
+ # Partition function
+ arg_1 = relay.var("arg_1", shape=ishape, dtype=dtype)
+ arg_2 = relay.var("arg_2", shape=w1shape, dtype=dtype)
+ call = relay.Call(func, [arg_1, arg_2])
+ p_func = relay.Function([arg_1, arg_2], call)
+ p_func = set_func_attr(p_func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = p_func
+
+ # Main function
+ data = relay.var("data", shape=ishape, dtype=dtype)
+ weight = relay.var("weight", shape=w1shape, dtype=dtype)
+ main_func = relay.Function([data, weight], glb_var(data, weight))
+ mod["main"] = main_func
+
+ # Reference module
+ data = relay.var("data", shape=ishape, dtype=dtype)
+ weight = relay.var("weight", shape=w1shape, dtype=dtype)
+ conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1))
+ relu = relay.nn.relu(conv2d)
+ main_func = relay.Function([data, weight], relu)
+ ref_mod = tvm.IRModule()
+ ref_mod["main"] = main_func
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+
+ return mod, ref_mod, {'data': i_data, 'weight': w1_data}, (1, 32, 14, 14)
+
+ def conv2d_bias_relu():
+ ishape = (1, 32, 14, 14)
+ w1shape = (32, 32, 3, 3)
+ bshape = (32, 1, 1)
+
+ # Composite function
+ in_1 = relay.var("in_1", shape=ishape, dtype=dtype)
+ in_2 = relay.var("in_2", shape=w1shape, dtype=dtype)
+ in_3 = relay.var("in_3", shape=bshape, dtype=dtype)
+ conv2d = relay.nn.conv2d(in_1, in_2, kernel_size=(3, 3), padding=(1, 1))
+ add = relay.add(conv2d, in_3)
+ relu = relay.nn.relu(add)
+ func = relay.Function([in_1, in_2, in_3], relu)
+ func = func.with_attr('Composite', 'dnnl.conv2d_bias_relu')
+ func = func.with_attr('PartitionedFromPattern', 'nn.conv2d_add_nn.relu_')
+
+ # Partition function
+ arg_1 = relay.var("arg_1", shape=ishape, dtype=dtype)
+ arg_2 = relay.var("arg_2", shape=w1shape, dtype=dtype)
+ arg_3 = relay.var("arg_3", shape=bshape, dtype=dtype)
+ call = relay.Call(func, [arg_1, arg_2, arg_3])
+ p_func = relay.Function([arg_1, arg_2, arg_3], call)
+ p_func = set_func_attr(p_func, "dnnl", "dnnl_0")
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = p_func
+
+ # Main function
+ data = relay.var("data", shape=ishape, dtype=dtype)
+ weight = relay.var("weight", shape=w1shape, dtype=dtype)
+ bias = relay.var('bias', shape=bshape, dtype=dtype)
+ main_func = relay.Function([data, weight, bias], glb_var(data, weight, bias))
+ mod["main"] = main_func
+
+ # Reference module
+ data = relay.var("data", shape=ishape, dtype=dtype)
+ weight = relay.var("weight", shape=w1shape, dtype=dtype)
+ bias = relay.var('bias', shape=bshape, dtype=dtype)
+ conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1))
+ add = relay.add(conv2d, bias)
+ relu = relay.nn.relu(add)
+ main_func = relay.Function([data, weight, bias], relu)
+ ref_mod = tvm.IRModule()
+ ref_mod["main"] = main_func
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+ b_data = np.random.uniform(0, 1, bshape).astype(dtype)
+
+ return mod, ref_mod, {'data': i_data, 'weight': w1_data, 'bias': b_data}, (1, 32, 14, 14)
+
+ for mod, ref_mod, input_maps, out_shape in [conv2d_relu(), conv2d_bias_relu()]:
+ check_result(mod, ref_mod, input_maps, out_shape, tol=1e-5)
+
+
+def test_constant():
+ """Test the subgraph with (var, const, ...) arguments."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ ishape = (1, 32, 14, 14)
+ wshape = (32, 32, 3, 3)
+
+ data = relay.var("data", shape=ishape, dtype=dtype)
+ weight = relay.var("weight", shape=wshape, dtype=dtype)
+ bn_gamma = relay.var("bn_gamma")
+ bn_beta = relay.var("bn_beta")
+ bn_mmean = relay.var("bn_mean")
+ bn_mvar = relay.var("bn_var")
+
+ layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), padding=(1, 1))
+ bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+ out = bn_output[0]
+ out = relay.nn.relu(out)
+
+ func = relay.Function(relay.analysis.free_vars(out), out)
+ ref_mod, params = tvm.relay.testing.create_workload(func)
+ ref_mod["main"] = bind_params_by_name(ref_mod["main"], params)
+
+ remove_bn_pass = tvm.transform.Sequential([
+ transform.InferType(),
+ transform.SimplifyInference(),
+ transform.FoldConstant(),
+ transform.FoldScaleAxis(),
+ ])
+
+ dnnl_patterns = get_pattern_table("dnnl")
+ composite_partition = tvm.transform.Sequential([
+ transform.MergeComposite(dnnl_patterns),
+ transform.AnnotateTarget("dnnl"),
+ transform.PartitionGraph()
+ ])
+
+ with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ ref_mod = remove_bn_pass(ref_mod)
+ mod = composite_partition(ref_mod)
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ check_result(mod, ref_mod, {'data': i_data}, (1, 32, 14, 14), tol=1e-5)
+
+def test_partial_constant():
+ """Test the subgraph with (const, var, const, var) arguments."""
+ if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ ishape = (10, 10)
+
+ in_1 = relay.var("in_1", shape=ishape, dtype=dtype)
+ in_2 = relay.var("in_2", shape=ishape, dtype=dtype)
+ in_3 = relay.var("in_3", shape=ishape, dtype=dtype)
+ in_4 = relay.var("in_4", shape=ishape, dtype=dtype)
+
+ add1 = relay.add(in_1, in_2)
+ add2 = relay.add(add1, in_3)
+ add3 = relay.add(add2, in_3)
+ add4 = relay.add(add3, in_3)
+
+ func = relay.Function([in_1, in_2, in_3, in_4], add4)
+ ref_mod = tvm.IRModule.from_expr(func)
+ ref_mod = relay.transform.InferType()(ref_mod)
+
+ data1 = np.random.uniform(0, 1, ishape).astype(dtype)
+ data3 = np.random.uniform(0, 1, ishape).astype(dtype)
+
+ params = {
+ 'in_1': tvm.nd.array(data1, ctx=tvm.cpu(0)),
+ 'in_3': tvm.nd.array(data3, ctx=tvm.cpu(0))
+ }
+ ref_mod["main"] = bind_params_by_name(ref_mod["main"], params)
+
+ opt_pass = tvm.transform.Sequential([
+ transform.InferType(),
+ transform.SimplifyInference(),
+ transform.FoldConstant(),
+ transform.FoldScaleAxis(),
+ transform.AnnotateTarget("dnnl"),
+ transform.MergeCompilerRegions(),
+ transform.PartitionGraph()
+ ])
+
+ with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ mod = opt_pass(ref_mod)
+
+ data2 = np.random.uniform(0, 1, ishape).astype(dtype)
+ data4 = np.random.uniform(0, 1, ishape).astype(dtype)
+ check_result(mod, ref_mod, {'in_2': data2, 'in_4': data4}, (10, 10), tol=1e-5)
+
+
+if __name__ == "__main__":
+ test_conv2d()
+ test_add()
+ test_relu()
+ test_dense()
+ test_bn()
+ test_multiple_ops()
+ test_composite()
+ test_constant()
+ test_partial_constant()
dtype = 'float32'
ishape = (1, 3, 224, 224)
- mod, params = relay.testing.mobilenet.get_workload(
- batch_size=1, dtype='float32')
-
- mod = transform.AnnotateTarget(["dnnl"])(mod)
+ ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype='float32')
+ mod = transform.AnnotateTarget(["dnnl"])(ref_mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
- ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
- dtype='float32')
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
ref_res = ref_ex.evaluate()(i_data, **params)
+ compile_engine.get().clear()
- check_result(mod, {"data": i_data},
- (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
+ check_result(mod, {"data": i_data}, (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
def test_function_lifting():
echo set\(USE_GRAPH_RUNTIME_DEBUG ON\) >> config.cmake
echo set\(USE_VM_PROFILER ON\) >> config.cmake
echo set\(USE_EXAMPLE_EXT_RUNTIME ON\) >> config.cmake
+echo set\(USE_DNNL_CODEGEN ON\) >> config.cmake
echo set\(USE_LLVM llvm-config-10\) >> config.cmake
echo set\(USE_NNPACK ON\) >> config.cmake
echo set\(NNPACK_PATH /NNPACK/build/\) >> config.cmake