--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import rpc, relay
+from tvm.contrib.download import download_testdata
+from tvm.relay.expr_functor import ExprMutator
+from tvm.relay import transform
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+from tvm.relay.quantize.quantize import prerequisite_optimize
+from tvm.contrib import util, xcode, graph_runtime, coreml_runtime
+from tvm.contrib.target import coreml as _coreml
+
+import os
+import re
+import sys
+import numpy as np
+from mxnet import gluon
+from PIL import Image
+import coremltools
+
+# Set to be address of tvm proxy.
+proxy_host = os.environ["TVM_IOS_RPC_PROXY_HOST"]
+# Set your desination via env variable.
+# Should in format "platform=iOS,id=<the test device uuid>"
+destination = os.environ["TVM_IOS_RPC_DESTINATION"]
+
+if not re.match(r"^platform=.*,id=.*$", destination):
+ print("Bad format: {}".format(destination))
+ print("Example of expected string: platform=iOS,id=1234567890abcabcabcabc1234567890abcabcab")
+ sys.exit(1)
+
+proxy_port = 9090
+key = "iphone"
+
+# Change target configuration, this is setting for iphone6s
+#arch = "x86_64"
+#sdk = "iphonesimulator"
+arch = "arm64"
+sdk = "iphoneos"
+target_host = "llvm -target=%s-apple-darwin" % arch
+
+# override metal compiler to compile to iphone
+@tvm.register_func("tvm_callback_metal_compile")
+def compile_metal(src):
+ return xcode.compile_metal(src, sdk=sdk)
+
+def prepare_input():
+ img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
+ img_name = 'cat.png'
+ synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
+ '4d0b62f3d01426887599d4f7ede23ee5/raw/',
+ '596b27d23537e5a1b5751d2b0481ef172f58b539/',
+ 'imagenet1000_clsid_to_human.txt'])
+ synset_name = 'imagenet1000_clsid_to_human.txt'
+ img_path = download_testdata(img_url, 'cat.png', module='data')
+ synset_path = download_testdata(synset_url, synset_name, module='data')
+ with open(synset_path) as f:
+ synset = eval(f.read())
+ image = Image.open(img_path).resize((224, 224))
+
+ image = np.array(image) - np.array([123., 117., 104.])
+ image /= np.array([58.395, 57.12, 57.375])
+ image = image.transpose((2, 0, 1))
+ image = image[np.newaxis, :]
+ return image.astype('float32'), synset
+
+
+def get_model(model_name, data_shape):
+ gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
+ mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
+ # we want a probability so add a softmax operator
+ func = mod["main"]
+ func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
+
+ return func, params
+
+
+def test_mobilenet():
+ temp = util.tempdir()
+ image, synset = prepare_input()
+ model, params = get_model('mobilenetv2_1.0', image.shape)
+
+ def run(mod, target):
+ with relay.build_config(opt_level=3):
+ graph, lib, _params = relay.build(mod, target=target,
+ target_host=target_host, params=params)
+ path_dso = temp.relpath("deploy.dylib")
+ lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk)
+ xcode.codesign(path_dso)
+
+ # Start RPC test server that contains the compiled library.
+ xcode.popen_test_rpc(proxy_host, proxy_port, key,
+ destination=destination, libs=[path_dso])
+
+ # connect to the proxy
+ remote = rpc.connect(proxy_host, proxy_port, key=key)
+
+ if target == "metal":
+ ctx = remote.metal(0)
+ else:
+ ctx = remote.cpu(0)
+ lib = remote.load_module("deploy.dylib")
+ m = graph_runtime.create(graph, lib, ctx)
+
+ m.set_input('data', tvm.nd.array(image, ctx))
+ m.set_input(**_params)
+ m.run()
+ tvm_output = m.get_output(0)
+ top1 = np.argmax(tvm_output.asnumpy()[0])
+ print('TVM prediction top-1:', top1, synset[top1])
+
+ # evaluate
+ ftimer = m.module.time_evaluator("run", ctx, number=3, repeat=10)
+ prof_res = np.array(ftimer().results) * 1000
+ print("%-19s (%s)" % ("%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
+
+ def annotate(func, compiler):
+ """
+ An annotator for Core ML.
+ """
+ # Bind free variables to the constant values.
+ bind_dict = {}
+ for arg in func.params:
+ name = arg.name_hint
+ if name in params:
+ bind_dict[arg] = relay.const(params[name])
+
+ func = relay.bind(func, bind_dict)
+
+ # Annotate the entire graph for Core ML
+ mod = tvm.IRModule()
+ mod["main"] = func
+
+ seq = tvm.transform.Sequential([
+ transform.SimplifyInference(),
+ transform.FoldConstant(),
+ transform.FoldScaleAxis(),
+ transform.AnnotateTarget(compiler),
+ transform.MergeCompilerRegions(),
+ transform.PartitionGraph()
+ ])
+
+ with relay.build_config(opt_level=3):
+ mod = seq(mod)
+
+ return mod
+
+ # CPU
+ run(model, target_host)
+ # Metal
+ run(model, "metal")
+ # CoreML
+ run(annotate(model, "coremlcompiler"), target_host)
+
+if __name__ == "__main__":
+ test_mobilenet()
import tvm._ffi
from ..rpc import base as rpc_base
-def create(compiled_model_path, output_names, ctx):
+def create(model_dir, ctx):
"""Create a runtime executor module given a coreml model and context.
Parameters
----------
- compiled_model_path : str
- The path of the compiled model to be deployed.
- output_names : list of str
- The output names of the model.
+ model_dir : str
+ The directory where the compiled models are located.
ctx : TVMContext
The context to deploy the module. It can be local or remote when there
is only one TVMContext.
else:
fcreate = tvm._ffi.get_global_func(runtime_func)
- return CoreMLModule(fcreate(compiled_model_path, ctx, *output_names))
+ return CoreMLModule(fcreate(model_dir))
class CoreMLModule(object):
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Codegen and runtime APIs for targets.
+"""
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, import-outside-toplevel
+"""Utility to compile CoreML models"""
+
+import os
+import shutil
+
+import tvm._ffi
+from ...relay.expr_functor import ExprVisitor
+from .. import xcode, coreml_runtime
+
+def _convert_add(builder, name, inputs, outputs, args, attrs):
+ builder.add_elementwise(
+ name=name,
+ input_names=inputs,
+ output_name=outputs[0],
+ mode='ADD'
+ )
+
+def _convert_multiply(builder, name, inputs, outputs, args, attrs):
+ builder.add_elementwise(
+ name=name,
+ input_names=inputs,
+ output_name=outputs[0],
+ mode='MULTIPLY'
+ )
+
+def _convert_clip(builder, name, inputs, outputs, args, attrs):
+ builder.add_clip(
+ name=name,
+ input_name=inputs[0],
+ output_name=outputs[0],
+ min_value=attrs.a_min,
+ max_value=attrs.a_max
+ )
+
+def _convert_batch_flatten(builder, name, inputs, outputs, args, attrs):
+ builder.add_flatten_to_2d(
+ name=name,
+ input_name=inputs[0],
+ output_name=outputs[0]
+ )
+
+def _convert_softmax(builder, name, inputs, outputs, args, attrs):
+ builder.add_softmax_nd(
+ name=name,
+ input_name=inputs[0],
+ output_name=outputs[0],
+ axis=int(attrs['axis'])
+ )
+
+def _convert_conv2d(builder, name, inputs, outputs, args, attrs):
+ weight = args[1].data.asnumpy()
+ if attrs['kernel_layout'] == 'OIHW':
+ # convert to 'HWIO'
+ weight = weight.transpose([2, 3, 1, 0])
+ kh, kw, kc, oc = weight.shape
+
+ builder.add_convolution(
+ name=name,
+ kernel_channels=kc,
+ output_channels=oc,
+ height=kh,
+ width=kw,
+ stride_height=int(attrs['strides'][0]),
+ stride_width=int(attrs['strides'][0]),
+ border_mode="valid",
+ groups=int(attrs['groups']),
+ W=weight,
+ b=None,
+ has_bias=False,
+ input_name=inputs[0],
+ output_name=outputs[0],
+ dilation_factors=[int(v) for v in attrs['dilation']],
+ padding_top=int(attrs['padding'][0]),
+ padding_bottom=int(attrs['padding'][2]),
+ padding_left=int(attrs['padding'][1]),
+ padding_right=int(attrs['padding'][3])
+ )
+
+def _convert_global_avg_pool2d(builder, name, inputs, outputs, args, attrs):
+ builder.add_pooling(
+ name=name,
+ height=1,
+ width=1,
+ stride_height=1,
+ stride_width=1,
+ layer_type='AVERAGE',
+ padding_type='VALID',
+ input_name=inputs[0],
+ output_name=outputs[0],
+ is_global=True
+ )
+
+_convert_map = {
+ 'add' : _convert_add,
+ 'multiply' : _convert_multiply,
+ 'clip' : _convert_clip,
+ 'nn.batch_flatten' : _convert_batch_flatten,
+ 'nn.softmax' : _convert_softmax,
+ 'nn.conv2d' : _convert_conv2d,
+ 'nn.global_avg_pool2d' : _convert_global_avg_pool2d,
+}
+
+class CodegenCoreML(ExprVisitor):
+ """
+ A visitor to traverse subgraphs and build Core ML models.
+ """
+ def __init__(self, model_name, function):
+ import coremltools
+ from coremltools.models.neural_network import NeuralNetworkBuilder
+
+ ExprVisitor.__init__(self)
+ self.model_name = model_name
+ self.function = function
+ self.out_map = {}
+ self.model_inputs_ = []
+ self.buf_idx_ = 0
+
+ # Update inputs and outputs after we visit all the nodes.
+ # Set dummy values for now.
+ # TODO: support multiple outputs
+ inputs = [('', coremltools.models.datatypes.Array(1,)) for _ in self.function.params]
+ outputs = [('', coremltools.models.datatypes.Array(1,))]
+ self.builder = NeuralNetworkBuilder(inputs, outputs,
+ disable_rank5_shape_mapping=True)
+
+ def visit_constant(self, const):
+ output = "buf_" + str(self.buf_idx_)
+ self.builder.add_load_constant_nd(
+ name=output,
+ output_name=output,
+ constant_value=const.data.asnumpy(),
+ shape=const.data.shape
+ )
+ self.buf_idx_ = self.buf_idx_ + 1
+ self.out_map[const] = [output]
+
+ def visit_var(self, var):
+ name = var.name_hint
+ shape = [int(n) for n in var.type_annotation.shape]
+ dtype = var.type_annotation.dtype
+ self.model_inputs_.append((name, shape, dtype))
+ self.out_map[var] = [name]
+
+ def visit_call(self, call):
+ inputs = []
+ for arg in call.args:
+ super().visit(arg)
+ for out in self.out_map[arg]:
+ inputs.append(out)
+ outputs = ["buf_" + str(self.buf_idx_)]
+ op_name = call.op.name
+ layer_name = op_name + "_" + str(self.buf_idx_)
+
+ assert op_name in _convert_map, "{} is not supported".format(op_name)
+ _convert_map[op_name](self.builder, layer_name, inputs, outputs,
+ call.args, call.attrs)
+
+ self.buf_idx_ = self.buf_idx_ + 1
+ self.out_map[call] = outputs
+
+ def compile(self, out_dir):
+ """
+ Build a Core ML model and compile it with Xcode toolchain.
+ """
+ import coremltools
+ from coremltools.proto.Model_pb2 import ArrayFeatureType
+
+ FEATURE_TYPE_MAP = {
+ "float32": ArrayFeatureType.FLOAT32,
+ "float64": ArrayFeatureType.DOUBLE,
+ "int32": ArrayFeatureType.INT32,
+ }
+
+ input_names, input_dims, input_dtypes = zip(*self.model_inputs_)
+ self.builder.set_input(input_names, input_dims)
+ for i, dtype in enumerate(input_dtypes):
+ assert dtype in FEATURE_TYPE_MAP
+ input_desc = self.builder.spec.description.input
+ input_desc[i].type.multiArrayType.dataType = FEATURE_TYPE_MAP[dtype]
+
+ output_dim = [int(n) for n in self.function.ret_type.shape]
+ self.builder.set_output(self.out_map[self.function.body], [output_dim])
+ for i, dtype in enumerate([self.function.ret_type.dtype]):
+ assert dtype in FEATURE_TYPE_MAP
+ output_desc = self.builder.spec.description.output
+ output_desc[i].type.multiArrayType.dataType = FEATURE_TYPE_MAP[dtype]
+
+ model = coremltools.models.MLModel(self.builder.spec)
+ xcode.compile_coreml(model, self.model_name, out_dir)
+
+
+@tvm._ffi.register_func("relay.ext.coremlcompiler")
+def coreml_compiler(ref):
+ """
+ Create a CoreML runtime from a Relay module.
+ """
+ model_dir = os.getcwd()
+ if isinstance(ref, tvm.ir.module.IRModule):
+ for var, func in ref.functions.items():
+ name = var.name_hint
+ builder = CodegenCoreML(name, func)
+ builder.visit(func.body)
+ mlmodelc_path = "{}/{}.mlmodelc".format(model_dir, name)
+ if os.path.exists(mlmodelc_path):
+ shutil.rmtree(mlmodelc_path)
+ builder.compile(model_dir)
+
+ ctx = tvm.cpu(0)
+ return coreml_runtime.create(model_dir, ctx).module
import os
import sys
import subprocess
+import json
from .._ffi.base import py_str
from . import util
return libbin
-def compile_coreml(model, out_dir="."):
+def compile_coreml(model, model_name="main", out_dir="."):
"""Compile coreml model and return the compiled model path.
"""
- mlmodel_path = os.path.join(out_dir, "tmp.mlmodel")
+ mlmodel_path = os.path.join(out_dir, model_name + ".mlmodel")
+ mlmodelc_path = os.path.join(out_dir, model_name + ".mlmodelc")
+ metadata = {
+ "inputs": list(model.input_description),
+ "outputs": list(model.output_description)
+ }
+ # Use the description field to send info to CoreML runtime
+ model.short_description = json.dumps(metadata)
model.save(mlmodel_path)
- xcrun(["coremlcompiler", "compile", mlmodel_path, out_dir])
+ res = xcrun(["coremlcompiler", "compile", mlmodel_path, out_dir])
+ if not os.path.isdir(mlmodelc_path):
+ raise RuntimeError("Compile failed: %s" % res)
- return os.path.join(out_dir, "tmp.mlmodelc")
+ return mlmodelc_path
class XCodeRPCServer(object):
from .register import get_pattern_table, register_pattern_table
from .dnnl import *
+from .coreml import *
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""CoreML codegen supported operators."""
+import tvm.ir
+from tvm.contrib.target.coreml import _convert_map
+from ...expr import Constant
+
+
+def _register_coreml_op(op_name):
+ """Register a function to check the given operator is supported by Core ML.
+
+ Paramters
+ ---------
+ op_name : Str
+ The name of operator that will be registered.
+
+ """
+ def _check_supported(attrs, args):
+ if op_name == 'nn.conv2d':
+ if not isinstance(args[1], Constant):
+ return False
+ if attrs['kernel_layout'] not in ['HWIO', 'OIHW']:
+ return False
+ return True
+
+ tvm.ir.register_op_attr(op_name, "target.coremlcompiler", _check_supported)
+
+
+for op in _convert_map:
+ _register_coreml_op(op)
#include <memory>
#include <string>
+#include <unordered_map>
#include <vector>
namespace tvm {
namespace runtime {
/*!
- * \brief CoreML runtime.
- *
- * This runtime can be accessed in various language via
- * TVM runtime PackedFunc API.
+ * \brief CoreML model.
*/
-class CoreMLRuntime : public ModuleNode {
+class CoreMLModel {
public:
/*!
- * \brief Get member function to front-end.
- * \param name The name of the function.
- * \param sptr_to_self The pointer to the module node.
- * \return The corresponding member function.
- */
- virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
-
- /*!
- * \return The type key of the executor.
+ * \brief constructor
+ * \param url The directory where compiled models are located.
*/
- const char* type_key() const { return "CoreMLRuntime"; }
-
+ explicit CoreMLModel(NSURL* url) {
+ url_ = url;
+ model_ = [MLModel modelWithContentsOfURL:url error:nil];
+ input_dict_ = [NSMutableDictionary dictionary];
+ output_ = nil;
+ }
/*!
* \brief Invoke the coreml prediction.
*/
void Invoke();
-
- /*!
- * \brief Initialize the coreml runtime with coreml model and context.
- * \param model_path The compiled model path.
- * \param ctx The context where the coreml model will be executed on.
- * \param output_names The output names of the model.
- */
- void Init(const std::string& model_path, TVMContext ctx,
- const std::vector<NSString*>& output_names);
-
/*!
* \brief set input to the model.
* \param key The input name.
*/
int GetNumOutputs() const;
+ // CoreML model url
+ NSURL* url_;
// CoreML model
MLModel* model_;
// CoreML model input dictionary
NSMutableDictionary<NSString*, id>* input_dict_;
// CoreML model output
id<MLFeatureProvider> output_;
- // List of output names
- std::vector<NSString*> output_names_;
- // TVM context
- TVMContext ctx_;
+};
+
+/*!
+ * \brief CoreML runtime.
+ *
+ * This runtime can be accessed in various language via
+ * TVM runtime PackedFunc API.
+ */
+class CoreMLRuntime : public ModuleNode {
+ public:
+ /*!
+ * \brief Get member function to front-end.
+ * \param name The name of the function.
+ * \param sptr_to_self The pointer to the module node.
+ * \return The corresponding member function.
+ */
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
+
+ /*!
+ * \brief Serialize the content of the mlmodelc directory and save it to
+ * binary stream.
+ * \param stream The binary stream to save to.
+ */
+ void SaveToBinary(dmlc::Stream* stream) final;
+
+ /*!
+ * \return The type key of the executor.
+ */
+ const char* type_key() const { return "coreml"; }
+
+ /*!
+ * \brief Initialize the coreml runtime with coreml model and context.
+ * \param model_dir The directory where compiled models are located.
+ */
+ void Init(const std::string& model_dir);
+
+ /*!
+ * \brief Get coreml model.
+ * \param model_name The name of the model.
+ */
+ CoreMLModel& GetModel(const std::string& model_name);
+
+ // Map of the avaiable CoreML models
+ std::unordered_map<std::string, std::unique_ptr<CoreMLModel>> model_map_;
};
} // namespace runtime
namespace tvm {
namespace runtime {
-MLModel* load_coreml_model(const std::string& model_path) {
- NSBundle* bundle = [NSBundle mainBundle];
- NSString* base = [bundle privateFrameworksPath];
- NSString* fname = [NSString stringWithUTF8String:("tvm/" + model_path).c_str()];
- NSString* assetPath = [base stringByAppendingPathComponent:fname];
-
- if (![[NSFileManager defaultManager] fileExistsAtPath:assetPath]) {
- assetPath = [NSString stringWithCString:model_path.c_str() encoding:NSUTF8StringEncoding];
- }
-
- NSURL* url = [NSURL fileURLWithPath:assetPath];
-
- MLModel* model = [MLModel modelWithContentsOfURL:url error:nil];
- if (model == nil) {
- NSLog(@"modelc %@ not found", url);
- }
- return model;
-}
-
-void CoreMLRuntime::Init(const std::string& model_path, TVMContext ctx,
- const std::vector<NSString*>& output_names) {
- model_ = load_coreml_model(model_path);
- ctx_ = ctx;
- input_dict_ = [NSMutableDictionary dictionary];
- output_names_ = output_names;
-}
-
-void CoreMLRuntime::Invoke() {
+void CoreMLModel::Invoke() {
id<MLFeatureProvider> input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_
error:nil];
output_ = [model_ predictionFromFeatures:input error:nil];
}
-void CoreMLRuntime::SetInput(const std::string& key, DLTensor* data_in) {
+void CoreMLModel::SetInput(const std::string& key, DLTensor* data_in) {
int64_t size = 1;
NSMutableArray* shape = [[NSMutableArray alloc] init];
for (int64_t i = 0; i < data_in->ndim; ++i) {
[input_dict_ setObject:dest forKey:nsKey];
}
-NDArray CoreMLRuntime::GetOutput(int index) const {
- NSString* name = output_names_[index];
+NDArray CoreMLModel::GetOutput(int index) const {
MLModelDescription* model_desc = model_.modelDescription;
+ NSString* metadata = [model_desc metadata][MLModelDescriptionKey];
+ NSData* data = [metadata dataUsingEncoding:NSUTF8StringEncoding];
+ NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data
+ options:NSJSONReadingAllowFragments
+ error:nil];
+ NSString* name = json[@"outputs"][index];
MLFeatureDescription* output_desc = model_desc.outputDescriptionsByName[name];
MLMultiArrayConstraint* data_desc = output_desc.multiArrayConstraint;
std::vector<int64_t> shape;
LOG(FATAL) << "unexpected data type " << data_desc.dataType;
}
MLMultiArray* src = [output_ featureValueForName:name].multiArrayValue;
- NDArray ret = NDArray::Empty(shape, dtype, ctx_);
+ TVMContext cpu_ctx = {
+ .device_type = kDLCPU,
+ .device_id = 0,
+ };
+ NDArray ret = NDArray::Empty(shape, dtype, cpu_ctx);
ret.CopyFromBytes(src.dataPointer, size);
return ret;
}
-int CoreMLRuntime::GetNumOutputs() const { return output_names_.size(); }
+int CoreMLModel::GetNumOutputs() const {
+ MLModelDescription* model_desc = model_.modelDescription;
+ return [[model_desc outputDescriptionsByName] count];
+}
+
+void CoreMLRuntime::Init(const std::string& _model_dir) {
+ NSString* model_dir = [NSString stringWithUTF8String:(_model_dir).c_str()];
+ if (![model_dir hasPrefix:@"/"]) {
+ // find models in the bundle's framework
+ NSBundle* bundle = [NSBundle mainBundle];
+ NSString* base = [bundle privateFrameworksPath];
+ model_dir = [base stringByAppendingPathComponent:model_dir];
+ }
+ NSFileManager* fileMamager = [NSFileManager defaultManager];
+ NSArray<NSString*>* files = [fileMamager contentsOfDirectoryAtPath:model_dir error:nil];
+ for (NSString* file in files) {
+ if ([[file pathExtension] isEqualToString:@"mlmodelc"]) {
+ NSString* model_path = [model_dir stringByAppendingPathComponent:file];
+ NSURL* url = [NSURL fileURLWithPath:model_path];
+ const std::string& model_name = [[file stringByDeletingPathExtension] UTF8String];
+ model_map_[model_name] = std::unique_ptr<CoreMLModel>(new CoreMLModel(url));
+ }
+ }
+}
+
+CoreMLModel& CoreMLRuntime::GetModel(const std::string& model_name) {
+ CHECK(model_map_.count(model_name) > 0) << "No such model in this module: " << model_name;
+ return *model_map_[model_name];
+}
PackedFunc CoreMLRuntime::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "invoke") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); });
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { GetModel("main").Invoke(); });
} else if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
const auto& input_name = args[0].operator std::string();
- this->SetInput(input_name, args[1]);
+ GetModel("main").SetInput(input_name, args[1]);
});
} else if (name == "get_output") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ *rv = GetModel("main").GetOutput(args[0]);
+ });
} else if (name == "get_num_outputs") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetNumOutputs(); });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ *rv = GetModel("main").GetNumOutputs();
+ });
} else {
- return PackedFunc();
+ // Return the packedfunc which executes the subgraph.
+ return PackedFunc([sptr_to_self, name, this](TVMArgs args, TVMRetValue* rv) {
+ CoreMLModel& model = GetModel(name);
+ MLModelDescription* model_desc = [model.model_ modelDescription];
+ NSString* metadata = [model_desc metadata][MLModelDescriptionKey];
+ NSData* data = [metadata dataUsingEncoding:NSUTF8StringEncoding];
+ NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data
+ options:NSJSONReadingAllowFragments
+ error:nil];
+ NSArray<NSString*>* input_names = json[@"inputs"];
+
+ // Copy input tensors to corresponding data entries.
+ for (auto i = 0; i < args.size() - 1; ++i) {
+ CHECK(args[i].type_code() == kTVMDLTensorHandle || args[i].type_code() == kTVMNDArrayHandle)
+ << "Expect NDArray or DLTensor as inputs\n";
+ if (args[i].type_code() == kTVMDLTensorHandle) {
+ model.SetInput([input_names[i] UTF8String], args[i]);
+ } else {
+ LOG(FATAL) << "Not implemented";
+ }
+ }
+
+ // Execute the subgraph.
+ model.Invoke();
+
+ // TODO: Support multiple outputs.
+ NDArray out = model.GetOutput(0);
+ if (args[args.size() - 1].type_code() == kTVMDLTensorHandle) {
+ DLTensor* arg = args[args.size() - 1];
+ out.CopyTo(arg);
+ } else {
+ NDArray arg = args[args.size() - 1];
+ out.CopyTo(arg);
+ }
+ *rv = out;
+ });
}
}
-Module CoreMLRuntimeCreate(const std::string& model_path, TVMContext ctx,
- const std::vector<NSString*>& output_names) {
+Module CoreMLRuntimeCreate(const std::string& model_dir) {
auto exec = make_object<CoreMLRuntime>();
- exec->Init(model_path, ctx, output_names);
+ exec->Init(model_dir);
return Module(exec);
}
TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
- std::vector<NSString*> output_names;
- for (size_t i = 2; i < args.size(); i++) {
- const std::string& name = args[i];
- output_names.push_back([NSString stringWithUTF8String:name.c_str()]);
- }
- *rv = CoreMLRuntimeCreate(args[0], args[1], output_names);
+ *rv = CoreMLRuntimeCreate(args[0]);
});
+
+void CoreMLRuntime::SaveToBinary(dmlc::Stream* stream) {
+ stream->Write((uint32_t)model_map_.size());
+ for (const auto& kv : model_map_) {
+ const std::string& model_name = kv.first;
+ NSURL* url = kv.second->url_;
+ NSFileWrapper* dirWrapper = [[[NSFileWrapper alloc] initWithURL:url options:0
+ error:nil] autorelease];
+ NSData* dirData = [dirWrapper serializedRepresentation];
+ stream->Write(model_name);
+ stream->Write((uint64_t)[dirData length]);
+ stream->Write([dirData bytes], [dirData length]);
+ LOG(INFO) << "Save " << model_name << " (" << [dirData length] << " bytes)";
+ }
+}
+
+/*!
+ * \brief Load a CoreML module from stream.
+ *
+ * \param strm The binary stream to load json.
+ *
+ * \return The created CoreML module.
+ */
+Module CoreMLRuntimeLoadFromBinary(void* strm) {
+ dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+
+ uint32_t nr_models;
+ stream->Read(&nr_models);
+
+ NSString* tempBaseDir = NSTemporaryDirectory();
+ if (tempBaseDir == nil) tempBaseDir = @"/tmp";
+
+ NSString* templateStr = [tempBaseDir stringByAppendingPathComponent:@"tvm.XXXXXX"];
+ const char* fsTemplate = [templateStr fileSystemRepresentation];
+ NSMutableData* bufferData = [NSMutableData dataWithBytes:fsTemplate
+ length:strlen(fsTemplate) + 1];
+ char* buffer = (char*)[bufferData mutableBytes];
+ char* result = mkdtemp(buffer);
+ NSString* tempDir = [NSString stringWithUTF8String:result];
+
+ for (int i = 0; i < nr_models; i++) {
+ std::string model_name;
+ stream->Read(&model_name);
+ uint64_t length;
+ stream->Read(&length);
+ void* ptr = new char[length];
+ stream->Read(ptr, length);
+ NSData* data = [[NSData alloc] initWithBytesNoCopy:ptr length:length];
+ NSFileWrapper* dirWrapper =
+ [[[NSFileWrapper alloc] initWithSerializedRepresentation:data] autorelease];
+ NSString* model_dir = [tempDir
+ stringByAppendingPathComponent:[NSString stringWithUTF8String:(model_name + ".mlmodelc")
+ .c_str()]];
+ NSURL* url = [NSURL fileURLWithPath:model_dir];
+ BOOL res = [dirWrapper writeToURL:url options:0 originalContentsURL:nil error:nil];
+ CHECK(res) << "Failed to create model directory " << [model_dir UTF8String];
+ }
+
+ auto exec = make_object<CoreMLRuntime>();
+ exec->Init([tempDir UTF8String]);
+ return Module(exec);
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_coreml").set_body_typed(CoreMLRuntimeLoadFromBinary);
+
} // namespace runtime
} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+from unittest import mock
+
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.contrib.target import coreml as _coreml
+
+pytest.importorskip("coremltools")
+
+
+def _has_xcode():
+ try:
+ tvm.contrib.xcode.xcrun([])
+ return True
+ except FileNotFoundError:
+ pass
+
+ return False
+
+
+def _create_graph():
+ shape = (10, 10)
+ mod = tvm.IRModule()
+
+ x = relay.var('x', shape=shape)
+ y = relay.var('y', shape=shape)
+ z = x + x
+ p = y * y
+ func = relay.Function([x, y], p - z)
+ mod["main"] = func
+
+ return mod
+
+
+def _create_graph_annotated():
+ shape = (10, 10)
+ target = "coremlcompiler"
+ mod = tvm.IRModule()
+
+ # function 0
+ f0_i0 = relay.var(target + "_0_i0", shape=shape)
+ func0 = relay.Function([f0_i0], f0_i0 * f0_i0)
+
+ func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+ func0 = func0.with_attr("Compiler", target)
+ func0 = func0.with_attr("global_symbol", target + "_0")
+ gv0 = relay.GlobalVar(target + "_0")
+ mod[gv0] = func0
+
+ # function 2
+ f2_i0 = relay.var(target + "_2_i0", shape=shape)
+ func2 = relay.Function([f2_i0], f2_i0 + f2_i0)
+
+ func2 = func2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ func2 = func2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+ func2 = func2.with_attr("Compiler", target)
+ func2 = func2.with_attr("global_symbol", target + "_2")
+ gv2 = relay.GlobalVar(target + "_2")
+ mod[gv2] = func2
+
+ # body
+ x = relay.var('x', shape=shape)
+ y = relay.var('y', shape=shape)
+ func = relay.Function([x, y], gv0(y) - gv2(x))
+ mod["main"] = func
+
+ return mod
+
+
+def test_annotate():
+ mod = _create_graph()
+ mod = transform.AnnotateTarget("coremlcompiler")(mod)
+ mod = transform.PartitionGraph()(mod)
+
+ expected = _create_graph_annotated()
+ assert tvm.ir.structural_equal(mod, expected, map_free_vars=True)
+
+
+@mock.patch('tvm.contrib.coreml_runtime.create')
+@mock.patch('tvm.contrib.xcode.compile_coreml')
+def test_construct_model(m1, m2):
+ mod = _create_graph_annotated()
+
+ fcompile = tvm._ffi.get_global_func("relay.ext.coremlcompiler")
+
+ for var, func in mod.functions.items():
+ if func.attrs and 'Compiler' in func.attrs and \
+ func.attrs['Compiler'] == 'coremlcompiler':
+ fcompile(tvm.IRModule.from_expr(func.body))
+
+
+@pytest.mark.skipif(not _has_xcode(), reason="Xcode is not available")
+def test_compile_and_run():
+ ctx=tvm.cpu()
+ target="llvm"
+ tol=1e-3
+
+ with relay.build_config(opt_level=3):
+ json, lib, params = relay.build(_create_graph_annotated(), target=target)
+ m = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+ shape = (10, 10)
+ x_data = np.random.rand(*shape).astype('float32')
+ y_data = np.random.rand(*shape).astype('float32')
+
+ m.set_input("x", x_data)
+ m.set_input("y", y_data)
+ m.set_input(**params)
+ m.run()
+ out = tvm.nd.empty(shape, ctx=ctx)
+ out = m.get_output(0, out)
+
+ expected = (y_data * y_data) - (x_data + x_data)
+ tvm.testing.assert_allclose(out.asnumpy(), expected, rtol=tol, atol=tol)
+
+
+if __name__ == "__main__":
+ test_annotate()
+ test_construct_model()
+ test_compile_and_run()
from tvm import rpc
from tvm.contrib import util, xcode, coreml_runtime
+import pytest
import os
proxy_host = os.environ.get("TVM_IOS_RPC_PROXY_HOST", "localhost")
destination = os.environ.get("TVM_IOS_RPC_DESTINATION", "")
key = "iphone"
-def skipped_test_coreml_runtime():
+@pytest.mark.skip('skip because coremltools is not available in CI')
+def test_coreml_runtime():
import coremltools
from coremltools.models.neural_network import NeuralNetworkBuilder
mode='MULTIPLY')
return coremltools.models.MLModel(builder.spec)
- def verify(coreml_model, compiled_model_path, ctx):
+ def verify(coreml_model, model_dir, ctx):
coreml_model = create_coreml_model()
out_spec = coreml_model.output_description._fd_spec
coreml_outputs = [coreml_model.predict(inputs)[name] for name in out_names]
# inference via tvm coreml runtime
- runtime = coreml_runtime.create(compiled_model_path, out_names, ctx)
+ runtime = coreml_runtime.create(model_dir, ctx)
for name in inputs:
runtime.set_input(name, tvm.nd.array(inputs[name], ctx))
runtime.invoke()
compiled_model = xcode.compile_coreml(coreml_model, out_dir=temp.temp_dir)
xcode.popen_test_rpc(proxy_host, proxy_port, key, destination=destination,
libs=[compiled_model])
- compiled_model = os.path.basename(compiled_model)
remote = rpc.connect(proxy_host, proxy_port, key=key)
ctx = remote.cpu(0)
- verify(coreml_model, compiled_model, ctx)
+ verify(coreml_model, "tvm", ctx)
def check_local(coreml_model):
temp = util.tempdir()
- compiled_model = xcode.compile_coreml(coreml_model, out_dir=temp.temp_dir)
+ xcode.compile_coreml(coreml_model, out_dir=temp.temp_dir)
ctx = tvm.cpu(0)
- verify(coreml_model, compiled_model, ctx)
+ verify(coreml_model, temp.temp_dir, ctx)
coreml_model = create_coreml_model()
check_remote(coreml_model)
if __name__ == "__main__":
- # skipped_test_coreml_runtime()
- pass
+ test_coreml_runtime()