From: Mahesh Ambule <15611578+maheshambule@users.noreply.github.com> Date: Wed, 15 Jul 2020 20:24:24 +0000 (+0530) Subject: [TARGET] ONNX codegen (#5052) X-Git-Tag: upstream/0.7.0~402 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5c73efe2dabf751d830a46d6ced64e1de39f0757;p=platform%2Fupstream%2Ftvm.git [TARGET] ONNX codegen (#5052) * Relay to ONNX converter * Relay to ONNX op test cases * Relay to ONNX end to end model test cases * Add test cases to jenkins * CI CD fixes * ONNX codegen * ONNX codegen * ONNX codegen * onnx testcases * ONNX codegen * test onnx * ONNX codegen * shape calculation * move onnx codegen to contrib/target * review comments * ONNX target use visitor * onnx fixes * lint fixes * doc string changes * review comments * review comment fixes * review comment * pytest skip * rename type to node type * test * Fix for constantshpae, add exp, fix for metadatamodule * Fix cpplint * change error tol values --- diff --git a/CMakeLists.txt b/CMakeLists.txt index a6e84b8..16389c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_COREML "Build with coreml support" OFF) +tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) if(USE_CPP_RPC AND UNIX) message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.") @@ -329,6 +330,7 @@ include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) +include(cmake/modules/contrib/ONNX.cmake) include(CheckCXXCompilerFlag) if(NOT MSVC) diff --git a/cmake/config.cmake b/cmake/config.cmake index 81864a0..3f12d7c 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -219,5 +219,9 @@ set(USE_FALLBACK_STL_MAP OFF) set(USE_HEXAGON_DEVICE OFF) set(USE_HEXAGON_SDK /path/to/sdk) +# Whether to use ONNX codegen +set(USE_TARGET_ONNX OFF) + # Whether to compile the standalone C runtime. set(USE_STANDALONE_CRT ON) + diff --git a/cmake/modules/contrib/ONNX.cmake b/cmake/modules/contrib/ONNX.cmake new file mode 100644 index 0000000..2462980 --- /dev/null +++ b/cmake/modules/contrib/ONNX.cmake @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_TARGET_ONNX) + message(STATUS "Build with contrib.codegen_onnx") + file(GLOB ONNX_CONTRIB_SRC src/runtime/contrib/onnx/onnx_module.cc) + list(APPEND RUNTIME_SRCS ${ONNX_CONTRIB_SRC}) +endif(USE_TARGET_ONNX) diff --git a/python/tvm/contrib/target/__init__.py b/python/tvm/contrib/target/__init__.py index 7d81541..13a8339 100644 --- a/python/tvm/contrib/target/__init__.py +++ b/python/tvm/contrib/target/__init__.py @@ -14,5 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Codegen and runtime APIs for targets. -""" diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py new file mode 100644 index 0000000..7f6945a --- /dev/null +++ b/python/tvm/contrib/target/onnx.py @@ -0,0 +1,898 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines, redefined-builtin +"""Relay to ONNX codegen """ + +import os +import struct +import copy +import numpy +import onnx +import onnx.utils +from onnx import numpy_helper, OperatorSetIdProto, defs +import tvm +from tvm import relay +import tvm._ffi +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.ty import TupleType, TensorType + +ONNX_OPSET_VERSONS_SUPPORTED = [11] + + +def tvm_array_to_list(arr): + return tuple(x.value for x in arr) + + +def get_onnx_version(): + return onnx.__version__ + + +def infer_type(node): + """A method to infer the type of a relay expression.""" + mod = tvm.IRModule.from_expr(node) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(node, relay.Function) else entry.body + + +def call_node_infer_type(node): + """infer the output types of call node""" + infer_out = infer_type(node) + out_type = infer_out._checked_type_ + if isinstance(out_type, TensorType): + types = [out_type] + elif isinstance(out_type, TupleType): + types = list(out_type.fields) + else: + raise RuntimeError("Unsupported output type %s in operator %s" + % (type(out_type), node.op.nae)) + + return types + + +def add_input(data, name, model_container): + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[data.dtype] + tensor_value_info = onnx.helper.make_tensor_value_info(name, dtype, shape=data.shape) + model_container.add_inputs([tensor_value_info]) + data_tensor = numpy_helper.from_array(data, name) + model_container.add_initializers([data_tensor]) + + +class OpConverter(object): + """ Operator converter Base Class. + """ + + @classmethod + def convert_attributes(cls, attrs): + """convert Relay attributes to ONNX attributes. + The derived classes should implement this method + if attributes are required by the operator + otherwise by default no attributes are passed + """ + return {} + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + onnx_node = onnx.helper.make_node(cls.__name__, + node_entry['input_names'], + node_entry['output_names'], + **attrs) + model_container.add_nodes([onnx_node]) + + +def rename(op_name): + """ This method creates dynamic operator of name op_name with empty attributes + """ + return type(op_name, (OpConverter,), {}) + + +class Reshape(object): + """ Operator converter for Reshape. + """ + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + """Converts Relay operator Reshape to ONNX operator. + Relay operator accepts shape as attribute but ONNX operator + accepts it as a input. + """ + + shape = numpy.asarray([a.value for a in node_entry['relay_node'].attrs.newshape], + dtype=numpy.int64) + input_name = 'shape{}'.format(node_entry['name']) + node = onnx.helper.make_node(cls.__name__, [node_entry['input_names'][0], input_name], + node_entry['output_names']) + model_container.add_nodes([node]) + add_input(shape, input_name, model_container) + + +class Conv(OpConverter): + """ Operator converter for Conv. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'group': attrs.get_int("groups"), + 'pads': attrs.get_int_tuple("padding"), + 'strides': attrs.get_int_tuple("strides"), + 'dilations': attrs.get_int_tuple("dilation"), + 'kernel_shape': attrs.get_int_tuple("kernel_size"), + } + + +class MaxPool(OpConverter): + """ Operator converter for MaxPool. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'pads': attrs.get_int_tuple("padding"), + 'strides': attrs.get_int_tuple("strides"), + 'kernel_shape': attrs.get_int_tuple("pool_size"), + } + + +class Transpose(OpConverter): + """ Operator converter for Transpose. + """ + + @classmethod + def convert_attributes(cls, attrs): + return {'perm': attrs.get_int_tuple("axes")} if attrs["axes"] else {} + + +class MatMul(OpConverter): + """ Operator converter for MatMul. + """ + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + inter_output_name = 'inter{}'.format(node_entry['name']) + transpose_node = onnx.helper.make_node(Transpose.__name__, + [node_entry['input_names'][1]], + [inter_output_name], + perm=(1, 0)) + model_container.add_nodes([transpose_node]) + + inputs = [node_entry['input_names'][0], inter_output_name] + matmul_node = onnx.helper.make_node(cls.__name__, inputs, node_entry['output_names']) + model_container.add_nodes([matmul_node]) + + +class Flatten(OpConverter): + """ Operator converter for Flatten. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axis': 1, + } + + +class BatchNormalization(OpConverter): + """ Operator converter for BatchNormalization. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'epsilon': float(attrs.get_str('epsilon')), + 'axis': float(attrs.get_int('axis')), + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + """Converts Relay operator batch_norm to ONNX operator. + Relay operator has property axis to handle data in NHWC format. + """ + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + transpose_out_name = node_entry['input_names'][0] + inter_output_names = [node_entry['output_names'][0]] + # axis==3 means channel is specified along the 3rd axis + if attrs['axis'] == 3: + transpose_out_name = 'transpose_{}'.format(node_entry['name']) + node_transposed = onnx.helper.make_node(Transpose.__name__, + [node_entry['input_names'][0]], + [transpose_out_name], + perm=[0, 3, 1, 2]) + model_container.add_nodes([node_transposed]) + inter_output_names = ['batch_norm_{}'.format(node_entry['name'])] + + input_names = [transpose_out_name] + node_entry['input_names'][1:] + batch_norm_node = onnx.helper.make_node(cls.__name__, + input_names, + inter_output_names, + epsilon=attrs['epsilon']) + model_container.add_nodes([batch_norm_node]) + + if attrs['axis'] == 3: + node_transposed = onnx.helper.make_node(Transpose.__name__, + inter_output_names, + [node_entry['output_names'][0]], + perm=[0, 2, 3, 1]) + model_container.add_nodes([node_transposed]) + + +class Dropout(OpConverter): + """ Operator converter for Dropout. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'ratio': float(attrs.get_str('rate')), + } + + +class AveragePool(MaxPool): + """ Operator converter for AveragePool. + """ + + +class Concat(OpConverter): + """ Operator converter for Concat. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axis': attrs.get_int("axis"), + } + + +class BiasAdd(OpConverter): + """ Operator converter for BiasAdd. + """ + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node_entry can not be a Tuple" + input_node = input_node[0] + data_ndim = len(input_node['types'][0].shape) + axis = node_entry['relay_node'].attrs.get_int("axis") + if axis < 0: + axis = axis + data_ndim + new_axes = data_ndim - axis - 1 + if new_axes: + inter_output_name = 'inter{}'.format(node_entry['name']) + unsqueeze_node = onnx.helper.make_node('Unsqueeze', + [node_entry['input_names'][1]], + [inter_output_name], + axes=tuple(range(1, new_axes + 1))) + model_container.add_nodes([unsqueeze_node]) + else: + inter_output_name = node_entry['input_names'][1] + + inputs = [node_entry['input_names'][0], inter_output_name] + matmul_node = onnx.helper.make_node('Add', inputs, node_entry['output_names']) + model_container.add_nodes([matmul_node]) + + +class ReduceMean(OpConverter): + """ Operator converter for ReduceMean. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axes': attrs.axis, + 'keepdims': 0 if bool(attrs.get_int("keepdims", 0)) is False else 1 + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + shape = input_node['types'][0].shape + axis = node_entry['relay_node'].attrs.axis + axis = list(range(shape.size())) if not axis else tvm_array_to_list(axis) + exclude = 0 if not bool(node_entry['relay_node'].attrs.exclude) else 1 + keepdims = 0 if not bool(node_entry['relay_node'].attrs.keepdims) else 1 + if exclude: + all_axis = list(range(len(shape))) + axis = set(all_axis) - set(axis) + + node = onnx.helper.make_node(cls.__name__, + node_entry['input_names'], + node_entry['output_names'], + axes=axis, + keepdims=keepdims) + model_container.add_nodes([node]) + + +class Pad(OpConverter): + """ Operator converter for Pad. + """ + + @classmethod + def convert_attributes(cls, attrs): + before = [] + after = [] + for axis_pads in attrs.pad_width: + before.append(axis_pads[0]) + after.append(axis_pads[1]) + pads = before + after + pads = numpy.asarray(pads, dtype=pads[0].dtype) + return { + 'pads': pads, + 'mode': attrs.get_str('pad_mode'), + 'constant_value': attrs.pad_value + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + """Converts Relay operator Pad to ONNX operator. + Relay operator accepts pads as attribute but ONNX operator + accepts it as a input. + """ + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + + name = node_entry['name'] + data = numpy.asarray(attrs['pads'], dtype=attrs['pads'][0].dtype).astype(numpy.int64) + input_name = 'pads_{}'.format(name) + value = numpy.dtype(node_entry['types'][0].dtype).type(attrs['constant_value']) + input_value_name = 'value_{}'.format(name) + add_input(data, input_name, model_container) + add_input(value, input_value_name, model_container) + + input_names = [node_entry['input_names'][0], input_name, input_value_name] + node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names']) + model_container.add_nodes([node]) + + +class Softmax(OpConverter): + """ Operator converter for SoftMax. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axis': attrs.axis, + } + + +class Squeeze(OpConverter): + """ Operator converter for Squeeze. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axes': attrs.axis, + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + shape = input_node['types'][0].shape + axis = node_entry['relay_node'].attrs.get_int("axis") + if not axis: + axis = [] + for axis_idx, val in enumerate(shape): + if val.value == 1: + axis.append(axis_idx) + else: + axis = node_entry['relay_node'].attrs.get_int_tuple("axis") + + node = onnx.helper.make_node(cls.__name__, + node_entry['input_names'], + node_entry['output_names'], + axes=axis) + model_container.add_nodes([node]) + + +class Slice(OpConverter): + """ Operator converter for Slice. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'starts': attrs.get_int_tuple('begin'), + 'ends': attrs.get_int_tuple('end'), + 'steps': attrs.get_int_tuple('strides'), + 'slice_mode': attrs.get_str('slice_mode') + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + + name = node_entry['name'] + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + shape = input_node['types'][0].shape + + starts = list(attrs['starts']) + ends = list(attrs['ends']) + steps = list(attrs['steps']) + starts += [0] * (len(shape) - len(starts)) + ends += [shape[i] + 1 for i in range(len(ends), len(shape))] + axes = list(range(len(shape))) + + if attrs['slice_mode'] == 'size': + ends = [starts[i] + (shape[i] + 1 if ends[i] < 0 else ends[i]) + for i in range(len(shape))] + steps = [1] * len(shape) + else: + steps += [1] * (len(shape) - len(steps)) + + def _add_input(val, input_name): + val_arr = numpy.asarray(val).astype(numpy.int64) + input_name = '{}_{}'.format(name, input_name) + add_input(val_arr, input_name, model_container) + return input_name + + input_names = [] + input_names.append(_add_input(starts, 'starts')) + input_names.append(_add_input(ends, 'ends')) + input_names.append(_add_input(axes, 'axes')) + input_names.append(_add_input(steps, 'steps')) + + input_names = [node_entry['input_names'][0]] + input_names + + slice_node = onnx.helper.make_node(cls.__name__, + input_names, + node_entry['output_names']) + model_container.add_nodes([slice_node]) + + +class Split(OpConverter): + """ Operator converter for Split. + """ + + @classmethod + def convert_attributes(cls, attrs): + indices_or_sections = attrs['indices_or_sections'] + + if isinstance(indices_or_sections, (list, tvm.ir.container.Array)): + indices_or_sections = attrs.get_int_tuple('indices_or_sections') + if isinstance(indices_or_sections, tvm.ir.PrimExpr): + indices_or_sections = indices_or_sections.value + + return { + 'indices_or_section': indices_or_sections, + 'axis': attrs.get_int('axis'), + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + shape = input_node['types'][0].concrete_shape + + indices_or_sect = attrs["indices_or_section"] + axis = attrs["axis"] + axis_length = shape[axis] + + if isinstance(indices_or_sect, int): + split = [axis_length // indices_or_sect] * indices_or_sect + else: + split = [] + for i in range(len(indices_or_sect) + 1): + if i == 0: + split.append(indices_or_sect[0]) + elif i == len(indices_or_sect): + split.append(axis_length - indices_or_sect[-1]) + else: + split.append(indices_or_sect[i] - indices_or_sect[i - 1]) + + slice_node = onnx.helper.make_node(cls.__name__, + node_entry['input_names'], + node_entry['output_names'], + split=split, + axis=axis) + model_container.add_nodes([slice_node]) + + +class ConstantOfShapeZeros(OpConverter): + """ Operator converter for ConstantOfShape. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'value': 0 + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + dtype = input_node['types'][0].dtype + input_shape_name = 'shape_{}'.format(node_entry['name']) + shape = [val.value for val in input_node['types'][0].shape] + shape = numpy.asarray(shape).astype(numpy.int64) + add_input(shape, input_shape_name, model_container) + + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)] + tensor_value = onnx.helper.make_tensor("value", dtype, + [1], [attrs['value']]) + + node = onnx.helper.make_node('ConstantOfShape', + [input_shape_name], + node_entry['output_names'], + value=tensor_value) + model_container.add_nodes([node]) + + +class ConstantOfShapeOnes(ConstantOfShapeZeros): + """ Operator converter for ConstantOfShape. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'value': 1 + } + + +relay_to_onnx_op_mapping = { + 'reshape': Reshape, + 'nn.conv2d': Conv, + 'add': rename('Add'), + 'nn.relu': rename('Relu'), + 'transpose': Transpose, + 'nn.dense': MatMul, + 'nn.max_pool2d': MaxPool, + 'nn.batch_flatten': Flatten, + 'multiply': rename('Mul'), + 'nn.bias_add': BiasAdd, + 'nn.batch_norm': BatchNormalization, + 'nn.global_avg_pool2d': rename('GlobalAveragePool'), + 'concatenate': Concat, + 'nn.dropout': Dropout, + 'nn.avg_pool2d': AveragePool, + 'divide': rename('Div'), + 'mean': ReduceMean, + 'nn.pad': Pad, + 'nn.softmax': Softmax, + 'squeeze': Squeeze, + 'strided_slice': Slice, + 'greater': rename('Greater'), + 'less': rename('Less'), + 'equal': rename('Equal'), + 'zeros_like': ConstantOfShapeZeros, + 'ones_like': ConstantOfShapeOnes, + 'subtract': rename('Sub'), + 'split': Split, + 'exp': rename('Exp') +} + + +class ModelContainer(object): + """ A container class to hold different attributes of ONNX model graph + """ + + def __init__(self, name, opset_version): + self._name = name + self._opset_version = opset_version + self._inputs = [] + self._outputs = [] + self._nodes = [] + self._initializers = [] + + def add_inputs(self, inputs): + self._inputs.extend(inputs) + + def add_outputs(self, outputs): + self._outputs.extend(outputs) + + def add_nodes(self, nodes): + self._nodes.extend(nodes) + + def add_initializers(self, initializers): + self._initializers.extend(initializers) + + def _get_opsets(self): + opsets = [] + imp = OperatorSetIdProto() + imp.version = self._opset_version + opsets.append(imp) + return opsets + + def make_model(self): + """ Creates the onnx model from the graph """ + onnx_graph = onnx.helper.make_graph( + self._nodes, + self._name, + self._inputs, + self._outputs, + self._initializers + ) + kwargs = {} + kwargs["opset_imports"] = self._get_opsets() + kwargs["producer_name"] = 'TVM Relay' + kwargs["producer_version"] = tvm.__version__ + + return onnx.helper.make_model(onnx_graph, **kwargs) + + +class RelayToONNXConverter(ExprVisitor): + """A helper class to traverse the Relay graph and convert Relay nodes to ONNX model + + Parameters + ---------- + name : str + name of the model + + params : dict + dict of the parameter names and NDarray values + + opset_version : int + target onnx opset version + + """ + + def __init__(self, name, params, opset_version): + super().__init__() + self._name = name + self._mc = ModelContainer(name, opset_version) + self._params = params + self._node_dict = {} + self._node_count = 0 + self.last_node = None + + @classmethod + def _get_node_entry(cls, relay_node, name): + return {"relay_node": relay_node, + "inputs": [relay_node], # inputs in the form of relay nodes + "types": [], # output types in case of call nodes else self type + "name": name, # name of the node + "input_names": [name], # input names in case of call nodes else self name + "output_names": [name], # output names in case of call nodes else self name + "op": None, # op name in case of call node else None + } + + def convert_to_onnx(self, func): + """ Traverse Relay graph and generate a ONNX model""" + + self.visit(func) + self._add_output(self._node_dict[self.last_node]) + model = self._mc.make_model() + polished_model = onnx.utils.polish_model(model) + return polished_model + + def visit(self, expr): + self._node_count += 1 + super().visit(expr) + + def visit_constant(self, const): + node_index = self._node_count + name = self._name + "_const_" + str(node_index) + node_entry = self._get_node_entry(const, name) + node_entry["types"] = [const.checked_type] + + self._add_constant_input(node_entry, node_index) + self._node_dict[const] = [node_entry] + + def visit_var(self, var): + node_index = self._node_count + node_entry = self._get_node_entry(var, var.name_hint) + node_entry["types"] = [var.type_annotation] + + self._add_input(node_entry, node_index) + self._node_dict[var] = [node_entry] + + def visit_tuple(self, tup): + self._node_dict[tup] = [] + for f in tup.fields: + self.visit(f) + self._node_dict[tup].extend(self._node_dict[f]) + + self.last_node = tup + + def visit_tuple_getitem(self, t): + self.visit(t.tuple_value) + tup_node = self._node_dict[t.tuple_value] + if len(tup_node) > 1: + self._node_dict[t] = tup_node[t.index] + else: + node_entry = copy.deepcopy(tup_node[0]) + output_names = [node_entry["output_names"][t.index]] + node_entry["output_names"] = output_names + self._node_dict[t] = [node_entry] + self.last_node = t + + def visit_call(self, call): + node_index = self._node_count + op = call.op + name = "{}_{}".format(op, node_index) + node_entry = self._get_node_entry(call, name) + + node_entry["op"] = op + node_entry["input_names"] = [] + node_entry["inputs"] = [] + node_entry["output_names"] = None + for input_arg in call.args: + self.visit(input_arg) + input_names = [] + for arg_node_entry in self._node_dict[input_arg]: + input_names.extend(arg_node_entry["output_names"]) + node_entry["input_names"].extend(input_names) + node_entry["inputs"].extend([input_arg]) + + node_entry['types'] = call_node_infer_type(call) + node_entry["output_names"] = [] + for i in range(len(node_entry['types'])): + node_entry["output_names"].append(name + str(i)) + self.last_node = call + self._add_node(node_entry, node_index) + self._node_dict[call] = [node_entry] + + def _add_node(self, node_entry, idx): + """Convert Relay operator node to ONNX operator and add it to container nodes list""" + if node_entry['op'].name not in relay_to_onnx_op_mapping: + raise NotImplementedError("Currently the operator '{0}' is " + "not supported.".format(node_entry['op'].name)) + converter = relay_to_onnx_op_mapping[node_entry['op'].name]() + + return converter.convert(node_entry, self._mc, self._node_dict) + + def _add_params(self, node_entry, idx): + """Add param value to initializer and name to inputs""" + param_name = node_entry['name'] + assert param_name in self._params, "The parameter {0} is not present" \ + "in params dict provided.".format(param_name) + value = self._params[param_name] + numpy_array = value.asnumpy() + tensor = numpy_helper.from_array(numpy_array, param_name) + self._mc.add_initializers([tensor]) + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy_array.dtype] + input = onnx.helper.make_tensor_value_info(param_name, + dtype, + shape=numpy_array.shape) + self._mc.add_inputs([input]) + + def _add_constant_input(self, node_entry, idx): + """Create named input for constant and add it to container inputs. + If input is a parameter then add to param + """ + node = node_entry['relay_node'] + param_name = node_entry['name'] + self._params[param_name] = node.data + self._add_params(node_entry, idx) + + def _add_input(self, node_entry, idx): + """Add input node to container inputs. If input is a parameter then add to param""" + if node_entry['name'] in self._params: + self._add_params(node_entry, idx) + else: + node_type = node_entry['types'][0] + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)] + input = onnx.helper.make_tensor_value_info(node_entry['name'], + dtype, + shape=node_type.concrete_shape) + self._mc.add_inputs([input]) + + def _add_output(self, node_entries): + """Add output node to container outputs.""" + + for node_entry in node_entries: + for node_type, output_name in zip(node_entry['types'], node_entry['output_names']): + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)] + output = onnx.helper.make_tensor_value_info(output_name, + dtype, + shape=node_type.concrete_shape) + self._mc.add_outputs([output]) + + +def to_onnx(relay_ir, params, name, opset_version=11, path=None): + """Convert a Relay Function Module into an equivalent ONNX and serialize it to the path + + Parameters + ---------- + relay_ir : tvm.ir.IRModule or tvm.relay.Function + The relay module object + + params : dict + dict of the parameter names and NDarray values + + name : str + name of the output ONNX graph + + opset_version : int + target onnx opset version + + path : str + The path where ONNX model will be saved + + Returns + ------- + onnx_model : onnx.ModelProto + converted ONNX model as a ModelProto. + + """ + + if opset_version not in ONNX_OPSET_VERSONS_SUPPORTED: + raise NotImplementedError("Currently only opset version 11 is supported.") + + if opset_version > defs.onnx_opset_version(): + raise Exception("The ONNX package installed of version {} does not support the opset " + "version {}. Upgrade the ONNX package to latest version.".format( + get_onnx_version(), opset_version)) + + func = relay_ir["main"] if isinstance(relay_ir, tvm.ir.IRModule) else relay_ir + converter = RelayToONNXConverter(name, params, opset_version) + onnx_model = converter.convert_to_onnx(func) + + if path: + onnx.save(onnx_model, path) + return onnx_model + + +@tvm._ffi.register_func("relay.ext.onnx") +def onnx_compiler(func): + """Create a runtime module for ONNX from Relay Function + + :param func: Relay function + :return: runtime module for ONNX + """ + + assert isinstance(func, tvm.relay.function.Function) + name = str(func.attrs.global_symbol) + model = to_onnx(func, {}, name) + const_vars = [const.name for const in model.graph.initializer] + name_bytes = bytes(name, 'utf-8') + name_size = struct.pack('I', len(name_bytes)) + model_serialized = model.SerializeToString() + model_size = struct.pack('I', model.ByteSize()) + data = b'' + name_size + name_bytes + model_size + model_serialized + + runtime_func = "runtime.ONNXModuleCreate" + fcreate = tvm._ffi.get_global_func(runtime_func) + return fcreate(data.hex(), name, const_vars) + + +@tvm._ffi.register_func("relay.ext.onnx.save_to_file") +def save_to_file(hex_str, path=None, fmt="onnx"): + """ Store the ONNX subgraphs in the path folder + + :param hex_str: Subgrah names and corresponding serialized onnx hex string + :param path: path to which ONNX files to be stored + It is assumed that path exists + :param fmt: extension of the files to be stored + """ + onnx_ir = bytes.fromhex(hex_str) + + offset = 0 + while offset < len(onnx_ir): + stop = offset + 4 + (name_size,) = struct.unpack('I', onnx_ir[offset:stop]) + name = onnx_ir[stop:stop + name_size].decode("utf-8") + stop = stop + name_size + (model_size,) = struct.unpack('I', onnx_ir[stop:stop + 4]) + stop = stop + 4 + model_serialized = onnx_ir[stop:stop + model_size] + offset = stop + model_size + + model_onnx = onnx.load_model_from_string(model_serialized) + onnx.save(model_onnx, "{}{}{}.{}".format(path, os.path.sep, name, fmt)) diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc new file mode 100644 index 0000000..9574b86 --- /dev/null +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file onnx_module.cc + * \brief ONNX Module without runtime support + */ +#include +#include +#include + +namespace tvm { +namespace codegen { +using namespace tvm::runtime; + +class ONNXSourceModuleNode : public runtime::ModuleNode { + public: + explicit ONNXSourceModuleNode(const std::string& code, const std::string& symbol, + const Array& const_vars) + : code_(code), symbol_(symbol), const_vars_(const_vars) {} + const char* type_key() const { return "onnx"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_; }); + } else if (name == "get_const_vars") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); + } else { + LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" + << " build TVM with 'onnx' runtime support"; + return PackedFunc(nullptr); + } + } + + std::string GetSource(const std::string& format) final { return code_; } + + void SaveToFile(const std::string& path, const std::string& format) final { + CHECK_EQ(format, "onnx") << "Can only save to onnx format"; + CHECK_NE(code_.length(), 0); + const PackedFunc* to_onnx_ = runtime::Registry::Get("relay.ext.onnx.save_to_file"); + (*to_onnx_)(code_, path, format); + } + + protected: + String code_; + std::string symbol_; + Array const_vars_; +}; + +Module ONNXSourceModuleNodeCreate(const String& code, const String& symbol, + const Array& const_vars) { + auto n = make_object(code.operator std::string(), + symbol.operator std::string(), const_vars); + return Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.ONNXModuleCreate").set_body_typed(ONNXSourceModuleNodeCreate); + +} // namespace codegen +} // namespace tvm diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py new file mode 100644 index 0000000..76b6bab --- /dev/null +++ b/tests/python/contrib/test_onnx.py @@ -0,0 +1,471 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relay to ONNX serialization test cases""" +import pytest +pytest.importorskip('onnx') +pytest.importorskip('onnxruntime') + +import numpy as np +import onnxruntime as rt + +import tvm +from tvm import relay +from tvm.contrib.target.onnx import to_onnx + + +def func_to_onnx(func, name): + mod = tvm.IRModule() + mod['main'] = func + onnx_model = to_onnx(mod, {}, name, path=None) + return onnx_model.SerializeToString() + + +def run_onnx(onnx_model, input_data): + sess = rt.InferenceSession(onnx_model) + input_names = {} + for input, data in zip(sess.get_inputs(), input_data): + input_names[input.name] = data + output_names = [out.name for out in sess.get_outputs()] + res = sess.run(output_names, input_names) + return res + + +def run_relay(func, data_tuple): + target = 'llvm' + ctx = tvm.context('llvm', 0) + intrp = relay.create_executor("graph", ctx=ctx, target=target) + relay_res = intrp.evaluate(func)(*data_tuple) + + result = [] + relay_res = relay_res if isinstance(relay_res, list) else [relay_res] + for res in relay_res: + result.append(res.asnumpy()) + + return result + + +def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0): + relay_results = run_relay(relay_func, indata) + onnx_results = run_onnx(func_to_onnx(relay_func, test_name), indata) + + for relay_res, onnx_res in zip(relay_results, onnx_results): + np.testing.assert_allclose(relay_res, onnx_res, rtol=rtol, atol=atol) + + +def test_add(): + dtype = 'float32' + t1 = relay.TensorType((5, 10, 5)) + t2 = relay.TensorType((5, 10, 5)) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_data = np.random.rand(5, 10, 5).astype(dtype) + y_data = np.random.rand(5, 10, 5).astype(dtype) + + verify_results(func, [x_data, y_data], 'test_add') + + +def test_bias_add(): + for dtype in ['float16', 'float32']: + xshape = (10, 2, 3, 4) + bshape = (2,) + rtol = 1e-2 if dtype == 'float16' else 1e-5 + x = relay.var("x", shape=xshape, dtype=dtype) + bias = relay.var("bias", dtype=dtype) + z = relay.nn.bias_add(x, bias) + func = relay.Function([x, bias], z) + + x_data = np.random.uniform(size=xshape).astype(dtype) + y_data = np.random.uniform(size=bshape).astype(dtype) + + verify_results(func, [x_data, y_data], 'test_bias_add', rtol=rtol) + + +def test_conv2d(): + def verify_conv2d(dtype, scale, dshape, kshape, + padding=(1, 1), + groups=1, + dilation=(1, 1), + **attrs): + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", shape=kshape, dtype=dtype) + y = relay.nn.conv2d(x, w, + padding=padding, + dilation=dilation, + groups=groups, + **attrs) + func = relay.Function([x, w], y) + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + verify_results(func, [data, kernel], 'test_conv2d', rtol=1e-5, atol=1e-5) + + dshape = (1, 32, 18, 18) + kshape = (32, 1, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=32, groups=32, kernel_size=(3, 3)) + + dshape = (1, 32, 18, 18) + kshape = (32, 4, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=32, groups=8, kernel_size=(3, 3)) + + # also group conv2d + dshape = (1, 32, 18, 18) + kshape = (64, 1, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=64, groups=32, kernel_size=(3, 3)) + + # normal conv2d + dshape = (1, 3, 224, 224) + kshape = (10, 3, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(3, 3)) + + dshape = (1, 3, 224, 224) + kshape = (10, 3, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(2, 2), channels=10, kernel_size=(3, 3)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(3, 3), dilation=(3, 3)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 2, 2) + verify_conv2d("float32", 1, dshape, kshape, + padding=(2, 2), channels=10, kernel_size=(2, 2), dilation=(1, 1)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 4, 4) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(4, 4)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 4, 4) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(4, 4)) + + +def test_reshape(): + def verify_reshape(shape, newshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.reshape(x, newshape=newshape) + + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + verify_results(func, [x_data], 'test_reshape', rtol=1e-5, atol=1e-5) + + verify_reshape((2, 3, 4), tuple(np.array([4, 2, 3], dtype=np.int64))) + verify_reshape((2, 3, 4), tuple(np.array([2, 0, 0], dtype=np.int64))) + verify_reshape((2, 3, 4), tuple(np.array([0, -1], dtype=np.int64))) + verify_reshape((2, 3, 4), tuple(np.array([-1, 0], dtype=np.int64))) + + +def test_transpose(): + def verify_reshape(shape, newshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.transpose(x, newshape) + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + verify_results(func, [x_data], 'test_transpose', rtol=1e-5, atol=1e-5) + + verify_reshape((1, 2, 3, 4), (0, 2, 3, 1)) + verify_reshape((1, 2, 3, 4), (0, 3, 2, 1)) + + +def test_dense(): + def verify_dense(d_shape, w_shape): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + weight = relay.var("weight", relay.TensorType(w_shape, "float32")) + func = relay.Function([data, weight], relay.nn.dense(data, weight)) + x_data = np.random.uniform(size=d_shape).astype("float32") + w_data = np.random.uniform(size=w_shape).astype("float32") + verify_results(func, [x_data, w_data], 'test_dense', rtol=1e-5, atol=1e-5) + + verify_dense((1, 8), (16, 8)) + verify_dense((1, 4), (3, 4)) + + +def test_max_pool(): + def verify_max_pool(x_shape, pool_size, strides, padding, ceil_mode): + x = relay.var("x", relay.TensorType(x_shape, "float32")) + y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding, + ceil_mode=ceil_mode) + func = relay.Function([x], y) + x_data = np.random.uniform(size=x_shape).astype("float32") + verify_results(func, [x_data], 'test_max_pool', rtol=1e-5, atol=1e-5) + + verify_max_pool((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False) + + +def test_batch_flatten(): + def verify_test_batch_flatten(d_shape): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + func = relay.Function([data], relay.nn.batch_flatten(data)) + x_data = np.random.uniform(size=d_shape).astype("float32") + verify_results(func, [x_data], 'test_batch_flatten', rtol=1e-5, atol=1e-5) + + verify_test_batch_flatten((1, 2, 3, 4)) + verify_test_batch_flatten((1, 8)) + + +def test_batch_norm(): + def verify_batch_norm(axis=1): + for dtype in ['float16', 'float32']: + data = relay.var("data", relay.TensorType((2, 4, 4, 1), dtype)) + gamma_shape = (data.type_annotation.shape[axis].value,) + beta = relay.var("beta", relay.TensorType(gamma_shape, dtype)) + gamma = relay.var("gamma", relay.TensorType(gamma_shape, dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType(gamma_shape, dtype)) + moving_var = relay.var("moving_var", relay.TensorType(gamma_shape, dtype)) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=axis) + func = relay.Function([data, gamma, beta, moving_mean, moving_var], y[0]) + + x_data = np.random.uniform(size=(2, 4, 4, 1)).astype(dtype) + beta = np.random.uniform(size=gamma_shape).astype(dtype) + gamma = np.random.uniform(size=gamma_shape).astype(dtype) + moving_mean = np.random.uniform(size=gamma_shape).astype(dtype) + moving_var = np.random.uniform(size=gamma_shape).astype(dtype) + verify_results(func, [x_data, gamma, beta, moving_mean, moving_var], 'test_batch_norm', rtol=1e-1, + atol=1e-1) + + verify_batch_norm(axis=1) + verify_batch_norm(axis=3) + + +def test_pad(): + def verify_pad(): + for dtype in ['float16', 'float32']: + dshape = (4, 10, 7, 7) + x = relay.var("x", shape=dshape, dtype=dtype) + y = relay.nn.pad(x, ((1, 1), (2, 2), (3, 3), (4, 4))) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], 'test_pad', rtol=1e-5, atol=1e-5) + + verify_pad() + + +def test_sofmax(): + def verify_sofmax(): + for dtype in ['float32']: + shape = (10, 4) + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.softmax(x, axis=1) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + verify_results(func, [x_data], 'test_softmax', rtol=1e-5, atol=1e-5) + + verify_sofmax() + + +def test_squeeze(): + def verify_squeeze(shape, dtype, axis): + x = relay.var("x", relay.TensorType(shape, dtype)) + z = relay.squeeze(x, axis=axis) + func = relay.Function([x], z) + x_data = np.random.random_sample(shape).astype(dtype) + verify_results(func, [x_data], 'test_squeeze', rtol=1e-5, atol=1e-5) + + verify_squeeze((1, 3, 2, 5), "float32", None) + verify_squeeze((1, 3, 1), "float32", [2, ]) + verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2]) + + +def test_mean(): + def verify_mean(data_shape, axis, exclude, keepdims): + dtype = "float32" + x = relay.var('x', shape=data_shape, dtype=dtype) + y = relay.mean(x, axis, keepdims, exclude) + func = relay.Function([x], y) + x_data = np.random.uniform(size=data_shape).astype(dtype) + verify_results(func, [x_data], 'test_mean', rtol=1e-5, atol=1e-5) + + verify_mean((1, 2), 0, False, False) + verify_mean((1, 2), 0, True, False) + verify_mean((1, 2), 0, True, True) + verify_mean((1, 2), 1, True, True) + verify_mean((3, 2, 1), 1, False, True) + + +def test_split(): + def verify_split(dshape, indices_or_sections, axis=None): + dtype = "float32" + x = relay.var("x", relay.ty.TensorType(dshape, "float32")) + y = relay.split(x, indices_or_sections, axis=axis) + func = relay.Function([x], y.astuple()) + x_data = np.random.uniform(size=dshape).astype(dtype) + + verify_results(func, [x_data], 'test_split', rtol=1e-5, atol=1e-5) + + verify_split((5, 5, 2, 2), 5, axis=1) + verify_split((5, 5, 2, 2), 5, axis=0) + verify_split((5, 5, 2, 2), [1, 3, 4], axis=0) + verify_split((5, 5, 2, 2), [1, 3, 4], axis=1) + + +def test_concatenate(): + def verify_concatenate(shapes, axis, dtype="float32"): + in_vars = [] + in_data = [] + for i, shape in enumerate(shapes): + in_vars.append(relay.var("x" + str(i), relay.ty.TensorType(shape, dtype))) + in_data.append(np.random.uniform(size=shape).astype(dtype)) + + out_tensor = relay.concatenate(in_vars, axis) + func = relay.Function(in_vars, out_tensor) + verify_results(func, in_data, 'test_concatenate', rtol=1e-5, atol=1e-5) + + verify_concatenate([(2,), (2,), (2,)], -1) + verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1) + verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1) + verify_concatenate([(5, 6, 7, 3), + (16, 6, 7, 3), + (12, 6, 7, 3), + (8, 6, 7, 3), + (2, 6, 7, 3)], 0) + verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1) + + +def test_strided_slice(): + def verify_strided_slice(dshape, begin, end, strides, mode): + x = relay.var("x", relay.TensorType(dshape, "float32")) + if mode == 'size': + strides = None + z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=mode) + func = relay.Function([x], z) + x_data = np.random.uniform(size=dshape).astype("float32") + verify_results(func, [x_data], 'test_strided_slice', rtol=1e-5, atol=1e-5) + + for mode in ['end', 'size']: + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 2, 3], None, mode) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -1, 3], [1, 2], mode) + verify_strided_slice((3, 4, 3), [1, ], [4, -3], None, mode) + verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], mode) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, -3], [2, 1, 1], mode) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], mode) + verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], mode) + verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], mode) + + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, mode) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None, mode) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None, mode) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2], mode) + + +def test_cmp_type(): + for op, ref in ((relay.greater, np.greater), + (relay.less, np.less), + (relay.equal, np.equal) + ): + x_shape = (10, 4) + y_shape = (5, 10, 1) + t1 = relay.TensorType(x_shape) + t2 = relay.TensorType(y_shape) + x = relay.var("x", t1) + y = relay.var("y", t2) + z = op(x, y) + x_data = np.random.rand(*x_shape).astype(t1.dtype) + y_data = np.random.rand(*y_shape).astype(t2.dtype) + func = relay.Function([x, y], z) + verify_results(func, [x_data, y_data], 'test_cmp_type', rtol=1e-5, atol=1e-5) + + +def test_unary_identity(): + for dtype in ["int16", "float32", "float64"]: + for op, ref in [(relay.zeros_like, np.zeros_like), + (relay.ones_like, np.ones_like)]: + shape = (8, 9, 4) + x = relay.var("x", relay.TensorType(shape, dtype)) + y = op(x) + func = relay.Function([x, ], y) + x_data = np.random.rand(*shape).astype(dtype) + verify_results(func, [x_data], 'test_cmp_type', rtol=1e-5, atol=1e-5) + + +def test_binary_op(): + def check_binary_op(opfunc, dtype): + t1 = relay.TensorType((5, 10, 5)) + t2 = relay.TensorType((5, 10, 5)) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) + z = opfunc(x, y) + x_data = np.random.rand(5, 10, 5).astype(dtype) + y_data = np.random.rand(5, 10, 5).astype(dtype) + func = relay.Function([x, y], z) + verify_results(func, [x_data, y_data], 'test_binary_op', rtol=1e-5, atol=1e-5) + + for opfunc, ref in [(relay.add, np.add), + (relay.subtract, np.subtract), + (relay.multiply, np.multiply), + (relay.divide, np.divide), + ]: + for dtype in ['float32']: + check_binary_op(opfunc, dtype) + + +def test_tuple_types(): + def verify_tuple_types(dshape, indices_or_sections, axis=None, dtype = "float32"): + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + y = relay.split(x, indices_or_sections, axis=axis) + z = relay.concatenate(y, axis=axis) + func = relay.Function([x], z) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + split_z = relay.split(z, indices_or_sections, axis=axis) + func = relay.Function([x], split_z.astuple()) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + out = relay.Tuple([y[0] + y[1], y[0] - y[1]]) + func = relay.Function([x], out) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + z = relay.concatenate(out, axis=axis) + func = relay.Function([x], z) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + verify_tuple_types((5, 5, 2, 2), 5, axis=1) + verify_tuple_types((5, 5, 2, 2), 5, axis=0) + verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=0) + verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=1) + + +if __name__ == '__main__': + test_add() + test_bias_add() + test_conv2d() + test_reshape() + test_transpose() + test_dense() + test_max_pool() + test_batch_flatten() + test_batch_norm() + test_pad() + test_mean() + test_split() + test_concatenate() + test_sofmax() + test_squeeze() + test_strided_slice() + test_cmp_type() + test_binary_op() + test_tuple_types() diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py new file mode 100644 index 0000000..8766c0d --- /dev/null +++ b/tests/python/contrib/test_onnx_model.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relay to ONNX target test cases""" +import pytest +pytest.importorskip('onnx') +pytest.importorskip('onnxruntime') + +from collections import OrderedDict +import numpy as np +import onnxruntime as rt +import tvm +from tvm import relay +from tvm.contrib.target.onnx import to_onnx +import tvm.relay.testing +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.ir import IRModule +from tvm.relay import transform + + +def func_to_onnx(mod, params, name): + onnx_model = to_onnx(mod, params, name, path=None) + return onnx_model.SerializeToString() + + +def run_onnx(mod, params, name, input_data): + onnx_model = func_to_onnx(mod, params, name) + sess = rt.InferenceSession(onnx_model) + input_names = {} + for input, data in zip(sess.get_inputs(), input_data): + input_names[input.name] = data + output_names = [output.name for output in sess.get_outputs()] + res = sess.run(output_names, input_names) + return res[0] + + +def get_data(in_data_shapes, dtype='float32'): + in_data = OrderedDict() + for name, shape in in_data_shapes.items(): + in_data[name] = np.random.uniform(size=shape).astype(dtype) + return in_data + + +def run_relay(mod, params, in_data): + target = 'llvm' + ctx = tvm.context('llvm', 0) + intrp = relay.create_executor("graph", mod, ctx=ctx, target=target) + in_data = [tvm.nd.array(value) for value in in_data.values()] + return intrp.evaluate()(*in_data, **params).asnumpy() + + +def _verify_results(mod, params, in_data): + a = run_relay(mod, params, in_data) + b = run_onnx(mod, params, 'test_resent', in_data.values()) + np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7) + + +def test_resnet(): + num_class = 1000 + in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) + in_data = get_data(in_data_shapes, dtype="float32") + for n in [18, 34, 50, 101]: + mod, params = tvm.relay.testing.resnet.get_workload( + 1, num_class, num_layers=n) + _verify_results(mod, params, in_data) + + +def test_squeezenet(): + in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) + in_data = get_data(in_data_shapes, dtype="float32") + for version in ['1.0', '1.1']: + mod, params = tvm.relay.testing.squeezenet.get_workload(1, version=version) + _verify_results(mod, params, in_data) + + +@pytest.mark.skip("USE_TARGET_ONNX should be ON") +def test_partition(): + in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') + in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') + in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') + in_4 = relay.var('in_4', shape=(10, 10), dtype='float32') + in_5 = relay.var('in_5', shape=(10, 10), dtype='float32') + in_6 = relay.var('in_6', shape=(10, 10), dtype='float32') + in_7 = relay.var('in_7', shape=(10, 10), dtype='float32') + in_8 = relay.var('in_8', shape=(10, 10), dtype='float32') + in_9 = relay.var('in_9', shape=(10, 10), dtype='float32') + in_10 = relay.var('in_10', shape=(10, 10), dtype='float32') + + begin0 = compiler_begin(in_1, "onnx") + begin1 = compiler_begin(in_2, "onnx") + begin2 = compiler_begin(in_3, "onnx") + begin3 = compiler_begin(in_4, "onnx") + node0 = relay.add(begin0, begin1) + node1 = relay.add(begin2, begin3) + end0 = compiler_end(node0, "onnx") + end1 = compiler_end(node1, "onnx") + begin4 = compiler_begin(end0, "onnx") + begin5 = compiler_begin(end1, "onnx") + node2 = relay.add(begin4, begin5) + end2 = compiler_end(node2, "onnx") + + dbegin0 = compiler_begin(in_5, "default") + dbegin1 = compiler_begin(in_6, "default") + node3 = relay.subtract(dbegin0, dbegin1) + dbegin2 = compiler_begin(in_7, "default") + dend1 = compiler_end(node3, "default") + dbegin3 = compiler_begin(dend1, "default") + node4 = relay.subtract(dbegin2, dbegin3) + dend2 = compiler_end(node4, "default") + + begin6 = compiler_begin(end2, "onnx") + begin7 = compiler_begin(dend2, "onnx") + node5 = relay.add(begin6, begin7) + end3 = compiler_end(node5, "onnx") + end4 = compiler_end(node5, "onnx") + dbegin4 = compiler_begin(in_8, "default") + dbegin5 = compiler_begin(end3, "default") + node6 = relay.subtract(dbegin4, dbegin5) + begin8 = compiler_begin(in_9, "onnx") + begin9 = compiler_begin(end4, "onnx") + node7 = relay.multiply(begin8, begin9) + end5 = compiler_end(node7, "onnx") + + dend3 = compiler_end(node6, "default") + begin10 = compiler_begin(dend3, "onnx") + begin11 = compiler_begin(end5, "onnx") + node8 = relay.add(begin10, begin11) + end6 = compiler_end(node8, "onnx") + begin12 = compiler_begin(in_10, "onnx") + begin13 = compiler_begin(end6, "onnx") + node9 = relay.add(begin12, begin13) + end7 = compiler_end(node9, "onnx") + + func = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7) + + target = 'llvm' + mod = IRModule.from_expr(func) + mod = transform.PartitionGraph()(mod) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=['FuseOps']): + graph_json, mod1, params = relay.build(mod, target) + + assert mod1.type_key == "metadata" + assert mod1.imported_modules[0].type_key == "llvm" + assert mod1.imported_modules[0].get_source() + assert mod1.imported_modules[1].type_key == "onnx" + assert mod1.imported_modules[1].get_source() + + +if __name__ == '__main__': + test_resnet() + test_squeezenet() + # test_partition needs USE_TARGET_ONNX to be ON + test_partition() +