import tvm
from ... import nd as _nd
from .. import ir_pass
+from .. import transform as _transform
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape)
else:
- # Try to infer shape by precompute prune if possible.
- # TODO: good to check inputs to be in params.
- # to be enhanced when relay support list_input_names API of NNVM
- logging.warning("Infering Reshape argument by precompute")
- func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
+ data, shape = inputs
+ logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
+ shape_params = ir_pass.free_vars(shape)
+ func = _expr.Function(shape_params, shape)
+ mod = _module.Module.from_expr(func)
+ seq = _transform.Sequential([_transform.InferType(),
+ _transform.FoldConstant(),
+ _transform.FuseOps(0),
+ _transform.InferType()])
+ with tvm.relay.PassContext(opt_level=2):
+ mod = seq(mod)
with tvm.relay.build_config(opt_level=0):
- graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
- ctx = tvm.context("llvm", 0)
- from tvm.contrib import graph_runtime
- m = graph_runtime.create(graph, lib, ctx)
- m.set_input(**params)
- m.run()
- params_new = m.get_output(0)
- inputs.pop(1)
- out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))
+ ex = tvm.relay.create_executor("debug", mod=mod)
+ inputs = []
+ for sp in shape_params:
+ if not sp.name_hint in params:
+ sh = [int(i) for i in sp.type_annotation.shape]
+ inputs.append(
+ tvm.nd.array(np.random.rand(*sh).astype('float32')))
+ static_shape = ex.evaluate()(*inputs, **params)
+ out = _op.reshape(data, newshape=tuple(static_shape.asnumpy()))
return out
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ # TODO(@jroesch): use shape_of once it has been fixed)
return _op.shape_of(inputs[0])
class Cast(OnnxOpConverter):
if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1
- self._params[node.output[0]] = self._parse_array(t_proto)
- self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims))
+ # We should convert scalar integers to int32, to normalize.
+ array = self._parse_array(t_proto)
+ if len(array.shape) == 0 and array.dtype == 'int64':
+ array = _nd.array(array.asnumpy().astype('int32'))
+ self._params[node.output[0]] = array
+ self._nodes[node.output[0]] = new_var(
+ node.output[0],
+ shape=list(t_proto.dims),
+ dtype=array.dtype)
else:
if op_name == "ConstantFill":
fill_value = attr.get('value', 0.0)
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import attr
import numpy as np
import math
+import torch
+import torchvision
import topi
import topi.testing
import tvm
'LogSoftmax',
{'axis': 1})
+
+def check_torch_conversion(model, input_size):
+ dummy_input = torch.randn(*input_size)
+ file_name = '{}.onnx'.format(model.__name__)
+ # Set verbose=True for more output
+ torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
+ onnx_model = onnx.load(file_name)
+ shapes = { '0' : input_size }
+ expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes)
+
+def test_resnet():
+ check_torch_conversion(torchvision.models.resnet18, (1,3,224,224))
+ # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
+
+# def test_alexnet():
+ # Torch's ONNX export does not support the adaptive pooling used by AlexNet?
+ # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
+
+# Torch's ONNX export does not support the adaptive pooling used by vgg16?
+# def test_vgg16():
+# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))
+
+# TODO(@jroesch): Update Torch + ONNX to support this import.
+# def test_squeezenet():
+# # Torch's ONNX export does not support the max pooling used by Squezenet
+# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))
+
+def test_densenet():
+ check_torch_conversion(torchvision.models.densenet161, (1,3,224,224))
+
+def test_inception():
+ check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224))
+
+# TODO(@jroesch): Update Torch + ONNX to support this import.
+# def test_googlenet():
+# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))
+
+# TODO(@jroesch): Update Torch + ONNX to support this import.
+# def test_shufflenetv2():
+# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
+
+
if __name__ == '__main__':
test_flatten()
test_reshape()
test_ParametricSoftplus()
test_Scale()
test_LogSoftmax()
+ test_resnet()
+ test_inception()
+ test_densenet()