From: lhutton1 <35535092+lhutton1@users.noreply.github.com> Date: Tue, 21 Jul 2020 15:30:26 +0000 (+0100) Subject: [BYOC][Contrib] Arm Compute Library integration (#5915) X-Git-Tag: upstream/0.7.0~374 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d8c9bb18907ed57916d94325f47f1c3dcfd34784;p=platform%2Fupstream%2Ftvm.git [BYOC][Contrib] Arm Compute Library integration (#5915) * [BYOC][Contrib] Arm Compute Library integration Arm Compute Library (ACL) integration using the BYOC infrastructure. This will enable offloading select operators from a relay graph to ACL so we can achieve faster inference times on Arm CPU's due to hand crafted optimized routines. The PR adds initial support for offloading FP32 conv2d, maxpool2d and reshape to ACL. ACL codegen is used to generate a JSON representation of an operator or 'ACL layer', the ACL runtime then uses this representation to construct a layer, cache it and create a packed function to for the graph runtime to call into. RFC here: https://discuss.tvm.ai/t/rfc-byoc-arm-compute-library-integration/7082 Change-Id: If756dcea787ea346b1508e9a191b7eed7bd02b7f * Refactor ACL integration to support JSON runtime * Now uses JSON runtime * Addresses tutorial comments * Rename acl to arm_compute_lib in user facing api Change-Id: I3b5ef80607f713e898363e82ab4398fbc2cf267a * Address comments Change-Id: I041fda14f3bf9975f3518ba8a4e3ab43ba98403d * Address comments * correct mistakes in tutorial * reshuffle runtime to use fewer macro blocks * preprocess module using "optimize" functionality * use new module api Change-Id: I219488e617e5767edd7489b43b8bfce876cd24b8 * Enable ACL codegen tests in CI * Skips runtime tests as these are not supported on x86. Change-Id: I6843c003a2604afe95cfdccf2323d2a336b56fe5 * Fix check for runtime Change-Id: I3f9eec15c599f01b1105d624fb053b73bfb6ed41 * Address comments * Add warning to ACL engine creation * Correct runtime check so it doesn't fail when codegen not present * Improve testing to check acl partitions is what is expected * Check results of multiple runs test Change-Id: I9522950930805b9b601dad03269adcf8ed3138cc * Address comments * Multiple style improvements * Use base class for creating json node for single op * Move GetSource to base class * Improve annotation checks Change-Id: I8219659c4b99e86df887cd914720157cb94c61a0 * Improve tutorial Change-Id: I8f610bd37af1e3740fd48c2d502bcc4727d9d712 * Initialize conv with nullptr Change-Id: I6c37f0d75a064001c74e171ff83b9f7a7c3f1918 --- diff --git a/CMakeLists.txt b/CMakeLists.txt index e87f75e..19d582a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,8 @@ tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_COREML "Build with coreml support" OFF) tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) +tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) +tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) if(USE_CPP_RPC AND UNIX) message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.") @@ -332,6 +334,7 @@ include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/ONNX.cmake) +include(cmake/modules/contrib/ArmComputeLib.cmake) include(CheckCXXCompilerFlag) if(NOT MSVC) diff --git a/cmake/config.cmake b/cmake/config.cmake index 3f12d7c..4eae607 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -184,6 +184,20 @@ set(USE_SORT ON) # Whether use MKL-DNN (DNNL) codegen set(USE_DNNL_CODEGEN OFF) +# Whether to use Arm Compute Library (ACL) codegen +# We provide 2 separate flags since we cannot build the ACL runtime on x86. +# This is useful for cases where you want to cross-compile a relay graph +# on x86 then run on AArch. +# +# An example of how to use this can be found here: docs/deploy/arm_compute_lib.rst. +# +# USE_ARM_COMPUTE_LIB - Support for compiling a relay graph offloading supported +# operators to Arm Compute Library. OFF/ON +# USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME - Run Arm Compute Library annotated functions via the ACL +# runtime. OFF/ON/"path/to/ACL" +set(USE_ARM_COMPUTE_LIB OFF) +set(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME OFF) + # Build ANTLR parser for Relay text format # Possible values: # - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar) diff --git a/cmake/modules/contrib/ArmComputeLib.cmake b/cmake/modules/contrib/ArmComputeLib.cmake new file mode 100644 index 0000000..ff9c8f7 --- /dev/null +++ b/cmake/modules/contrib/ArmComputeLib.cmake @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# We separate the codegen and runtime build since ACL can only be built +# for AArch. In the world where we take the cross compilation approach, +# which is common with arm devices, we need to be able to cross-compile +# a relay graph on x86 for AArch and then run the graph on AArch. +if(USE_ARM_COMPUTE_LIB) + file(GLOB ACL_RELAY_CONTRIB_SRC src/relay/backend/contrib/arm_compute_lib/*.cc) + file(GLOB ACL_RUNTIME_MODULE src/runtime/contrib/arm_compute_lib/acl_runtime.cc) + list(APPEND COMPILER_SRCS ${ACL_RELAY_CONTRIB_SRC}) + list(APPEND COMPILER_SRCS ${ACL_RUNTIME_MODULE}) + message(STATUS "Build with Arm Compute Library support...") +endif() + +if(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME) + set(ACL_PATH ${CMAKE_CURRENT_SOURCE_DIR}/acl) + # Detect custom ACL path. + if (NOT USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME STREQUAL "ON") + set(ACL_PATH ${USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}) + endif() + + file(GLOB ACL_CONTRIB_SRC src/runtime/contrib/arm_compute_lib/*) + + set(ACL_INCLUDE_DIRS ${ACL_PATH}/include ${ACL_PATH}) + include_directories(${ACL_INCLUDE_DIRS}) + + find_library(EXTERN_ACL_COMPUTE_LIB + NAMES arm_compute libarm_compute + HINTS "${ACL_PATH}" "${ACL_PATH}/lib" "${ACL_PATH}/build" + ) + find_library(EXTERN_ACL_COMPUTE_CORE_LIB + NAMES arm_compute_core libarm_compute_core + HINTS "${ACL_PATH}" "${ACL_PATH}/lib" "${ACL_PATH}/build" + ) + find_library(EXTERN_ACL_COMPUTE_GRAPH_LIB + NAMES arm_compute_graph libarm_compute_graph + HINTS "${ACL_PATH}" "${ACL_PATH}/lib" "${ACL_PATH}/build" + ) + + list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_ACL_COMPUTE_LIB}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_ACL_COMPUTE_CORE_LIB}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_ACL_COMPUTE_GRAPH_LIB}) + list(APPEND RUNTIME_SRCS ${ACL_CONTRIB_SRC}) + message(STATUS "Build with Arm Compute Library graph runtime support: " + ${EXTERN_ACL_COMPUTE_LIB} ", \n" + ${EXTERN_ACL_COMPUTE_CORE_LIB} ", \n" + ${EXTERN_ACL_COMPUTE_GRAPH_LIB}) + + # Set flag to detect ACL graph runtime support. + add_definitions(-DTVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB) +endif() diff --git a/docs/deploy/arm_compute_lib.rst b/docs/deploy/arm_compute_lib.rst new file mode 100644 index 0000000..28abc9c --- /dev/null +++ b/docs/deploy/arm_compute_lib.rst @@ -0,0 +1,139 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Relay Arm|reg| Compute Library Integration +========================================== + +Introduction +------------ + +Arm Compute Library (ACL) is an open source project that provides accelerated kernels for Arm CPU's +and GPU's. Currently the integration offloads operators to ACL to use hand-crafted assembler +routines in the library. By offloading select operators from a relay graph to ACL we can achieve +a performance boost on such devices. + +Building with ACL support +------------------------- + +The current implementation has two separate build options in cmake. The reason for this split is +because ACL cannot be used on an x86 machine. However, we still want to be able compile an ACL +runtime module on an x86 machine. + +* USE_ARM_COMPUTE_LIB=ON/OFF - Enabling this flag will add support for compiling an ACL runtime module. +* USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME=ON/OFF/path-to-acl - Enabling this flag will allow the graph runtime to + compute the ACL offloaded functions. + +These flags can be used in different scenarios depending on your setup. For example, if you want +to compile an ACL module on an x86 machine and then run the module on a remote Arm device via RPC, you will +need to use USE_ARM_COMPUTE_LIB=ON on the x86 machine and USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME=ON on the remote +AArch64 device. + +Usage +----- + +.. note:: + + This section may not stay up-to-date with changes to the API. + +Create a relay graph. This may be a single operator or a whole graph. The intention is that any +relay graph can be input. The ACL integration will only pick supported operators to be offloaded +whilst the rest will be computed via TVM. (For this example we will use a single +max_pool2d operator). + +.. code:: python + + import tvm + from tvm import relay + + data_type = "float32" + data_shape = (1, 14, 14, 512) + strides = (2, 2) + padding = (0, 0, 0, 0) + pool_size = (2, 2) + layout = "NHWC" + output_shape = (1, 7, 7, 512) + + data = relay.var('data', shape=data_shape, dtype=data_type) + out = relay.nn.max_pool2d(data, pool_size=pool_size, strides=strides, layout=layout, padding=padding) + module = tvm.IRModule.from_expr(out) + + +Annotate and partition the graph for ACL. + +..code:: python + + from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib + module = partition_for_arm_compute_lib(module) + + +Build the Relay graph. + +.. code:: python + + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon" + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + lib = relay.build(module, target=target) + + +Export the module. + +.. code:: python + + lib_path = '~/lib_acl.so' + cross_compile = 'aarch64-linux-gnu-c++' + lib.export_library(lib_path, cc=cross_compile) + + +Run Inference. This must be on an Arm device. If compiling on x86 device and running on AArch64, +consider using the RPC mechanism. Tutorials for using the RPC mechanism: +https://tvm.apache.org/docs/tutorials/cross_compilation_and_rpc.html#sphx-glr-tutorials-cross-compilation-and-rpc-py + +.. code:: python + + ctx = tvm.cpu(0) + loaded_lib = tvm.runtime.load_module('lib_acl.so') + gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib['default'](ctx)) + d_data = np.random.uniform(0, 1, data_shape).astype(data_type) + map_inputs = {'data': d_data} + gen_module.set_input(**map_inputs) + gen_module.run() + + +More examples +------------- +The example above only shows a basic example of how ACL can be used for offloading a single +Maxpool2D. If you would like to see more examples for each implemented operator and for +networks refer to the tests: `tests/python/contrib/test_arm_compute_lib`. Here you can modify +`infrastructure.py` to use the remote device you have setup. + + +Adding a new operator +--------------------- +Adding a new operator requires changes to a series of places. This section will give a hint on +what needs to be changed and where, it will not however dive into the complexities for an +individual operator. This is left to the developer. + +There are a series of files we need to make changes to: +* `python/relay/op/contrib/arm_compute_lib.py` In this file we define the operators we wish to offload using the +`op.register` decorator. This will mean the annotation pass recognizes this operator as ACL +offloadable. +* `src/relay/backend/contrib/arm_compute_lib/codegen.cc` Implement `Create[OpName]JSONNode` method. This is where we +declare how the operator should be represented by JSON. This will be used to create the ACL module. +* `src/runtime/contrib/arm_compute_lib/acl_kernel.h` Implement `Create[OpName]Layer` method. This is where we +define how the JSON representation can be used to create an ACL function. We simply define how to +translate from the JSON representation to ACL API. +* `tests/python/contrib/test_arm_compute_lib` Add unit tests for the given operator. diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index 53455ed..b38a7f5 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -68,3 +68,4 @@ target device without relying on RPC. see the following resources on how to do s android integrate hls + arm_compute_lib diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index f7ed122..03170ea 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -64,6 +64,9 @@ class GraphRuntimeFactoryModule(object): def get_json(self): return self.graph_json + def get_lib(self): + return self.lib + def __getitem__(self, item): return self.module.__getitem__(item) diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 0e1b4b0..26ca78c 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -18,5 +18,6 @@ """Contrib modules.""" from .register import get_pattern_table, register_pattern_table +from .arm_compute_lib import * from .dnnl import * from .coreml import * diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py new file mode 100644 index 0000000..e5b2af5 --- /dev/null +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Arm Compute Library supported operators.""" +import tvm +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name + +from ...dataflow_pattern import wildcard, is_op, is_constant +from .register import register_pattern_table + + +def is_arm_compute_runtime_enabled(): + """Check if the ACL graph runtime is present. + + Returns + ------- + ret: bool + True if present, False if not. + """ + check_enabled = tvm.get_global_func("relay.op.is_arm_compute_runtime_enabled", True) + if check_enabled: + return check_enabled() + return False + + +def partition_for_arm_compute_lib(mod, params=None): + """Partition the graph greedily offloading supported + operators to Arm Compute Library. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + ret : annotated and partitioned module. + """ + if params: + mod['main'] = bind_params_by_name(mod['main'], params) + + seq = tvm.transform.Sequential([transform.MergeComposite(arm_compute_lib_pattern_table()), + transform.AnnotateTarget('arm_compute_lib'), + transform.PartitionGraph()]) + + return seq(mod) + + +@register_pattern_table("arm_compute_lib") +def arm_compute_lib_pattern_table(): + """Get the ACL pattern table.""" + + def conv_pattern(): + """Create a convolution pattern. + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the convolution pattern. + """ + pattern = is_op('nn.pad')(wildcard()) | wildcard() + pattern = is_op('nn.conv2d')(pattern, is_constant()) + pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant())) + pattern = pattern.optional(is_op('nn.relu')) + return pattern + + def check_conv(extract): + """Check conv pattern is supported by ACL.""" + call = extract + while call.op.name != "nn.conv2d": + call = call.args[0] + return conv2d(call.attrs, call.args) + + return [('arm_compute_lib.conv2d', conv_pattern(), check_conv)] + + +def _register_external_op_helper(op_name, supported=True): + @tvm.ir.register_op_attr(op_name, "target.arm_compute_lib") + def _func_wrapper(attrs, args): + return supported + + return _func_wrapper + + +_register_external_op_helper("reshape") + + +@tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib") +def conv2d(attrs, args): + """Check if the external ACL codegen for conv2d should be used.""" + if attrs.groups != 1: + return False + if attrs.data_layout != "NHWC": + return False + if attrs.out_dtype != "float32" and attrs.out_dtype != "": + return False + data_typ = args[0].checked_type + if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "float32": + return False + kernel_typ = args[1].checked_type + if kernel_typ.dtype != "float32": + return False + return True + + +@tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib") +def max_pool2d(attrs, args): + """Check if the external ACL codegen for maxpool2d should be used.""" + if attrs.layout != "NHWC": + return False + typ = args[0].checked_type + if typ.dtype != "float32": + return False + return True diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc new file mode 100644 index 0000000..8edbc15 --- /dev/null +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/arm_compute_lib/codegen.cc + * \brief Implementation of the Relay -> ACL JSON serializer. + */ +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" +#include "../codegen_json/codegen_json.h" + +namespace tvm { +namespace relay { +namespace contrib { + +/*! + * \brief Generates an ACLModule from a relay expression. This "compilation" + * does not require ACL since the actual conversion using ACL APIs is + * deferred until creation of the runtime. This step simply serializes the + * relay program into a JSON string. + */ +class ACLJSONSerializer : public backend::contrib::JSONSerializer { + using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + + public: + ACLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} + + /*! + * \brief Visit call nodes and generate appropriate JSON node. + * + * \param cn The current call node. + * \return A list of graph entry nodes. + */ + std::vector VisitExpr_(const CallNode* cn) override { + if (cn->op.as()) { + return JSONSerializer::VisitExpr_(cn); + } + if (!cn->op.as()) { + LOG(FATAL) << "Arm Compute Library JSON runtime does not support calls to " + << cn->op->GetTypeKey(); + } + auto fn = cn->op.as(); + auto comp = fn->GetAttr(attr::kComposite); + CHECK(comp.defined()) << "Arm Compute Library JSON runtime only supports composite functions."; + const std::string name = comp.value(); + std::shared_ptr json_node; + if (name == "arm_compute_lib.conv2d") { + json_node = CreateCompositeConvJSONNode(cn); + } else { + LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name; + } + return AddNode(json_node, GetRef(cn)); + } + + private: + /*! + * \brief Create a JSON representation of a composite convolution. + * + * \param call The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeConvJSONNode(const CallNode* cn) { + const std::string name = "nn.conv2d"; + const CallNode* pad = nullptr; + const CallNode* conv = nullptr; + const CallNode* bias = nullptr; + bool has_activation = false; + + // Unpack composite function + const auto* fn = cn->op.as(); + CHECK(fn); + const auto* current_call = fn->body.as(); + if (backend::IsOp(current_call, "nn.relu")) { + has_activation = true; + current_call = current_call->args[0].as(); + } + if (backend::IsOp(current_call, "nn.bias_add")) { + bias = current_call; + current_call = current_call->args[0].as(); + } + CHECK(backend::IsOp(current_call, "nn.conv2d")); + conv = current_call; + if (!current_call->args.empty() && current_call->args[0]->IsInstance()) { + current_call = current_call->args[0].as(); + if (backend::IsOp(current_call, "nn.pad")) { + pad = current_call; + } + } + + const auto* conv_attr = conv->attrs.as(); + CHECK(conv_attr); + CHECK(conv_attr->kernel_layout == "OHWI") + << "Kernel layout must be OHWI, has the module been pre-processed correctly?"; + + std::vector inputs; + inputs.push_back(VisitExpr(cn->args[0])[0]); + inputs.push_back(VisitExpr(conv->args[1])[0]); + if (bias) { + inputs.push_back(VisitExpr(bias->args[1])[0]); + } + + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, conv); + + // Override attributes + if (pad) { + const auto* pad_attr = pad->attrs.as(); + CHECK(pad_attr); + auto p = pad_attr->pad_width; + // Convert to TVM layout for now, conversion to ACL layout takes place in runtime. + // Standard convolution pad layout for TVM: top, left, bottom, right. + std::vector padding = {std::to_string(p[1][0].as()->value), + std::to_string(p[2][0].as()->value), + std::to_string(p[1][1].as()->value), + std::to_string(p[2][1].as()->value)}; + std::vector padding_attr; + padding_attr.emplace_back(padding); + json_node->SetAttr("padding", padding_attr); + } + if (has_activation) { + std::vector activation_type = {"relu"}; + std::vector act_attr; + act_attr.emplace_back(activation_type); + json_node->SetAttr("activation_type", act_attr); + } + return json_node; + } +}; + +/*! + * \brief Pre-process a module containing functions ready for ACL codegen. + * + * For now we enforce OHWI kernel layout and fold the transforms away. + * + * \param mod The module to be pre-processed. + * \return The processed module. + */ +IRModule PreProcessModule(const IRModule& mod) { + IRModule preprocessed_module; + tvm::Map> desired_layouts = {{"nn.conv2d", {"NHWC", "OHWI"}}}; + preprocessed_module = transform::ConvertLayout(desired_layouts)(mod); + preprocessed_module = transform::FoldConstant()(preprocessed_module); + return preprocessed_module; +} + +TVM_REGISTER_GLOBAL("relay.ext.arm_compute_lib.optimize").set_body_typed(PreProcessModule); + +/*! + * \brief Create a runtime module for ACL. + * + * This consists of a series of "serialized functions" which each represent a + * sub-graph to be computed by ACL and will each be executed independently from + * one another. Each function consists of serialized JSON describing the sub-graph + * and serialized constant tensors. + * + * \note The ACL runtime module only supports a single operator per + * sub-graph currently. + * + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module ACLCompiler(const ObjectRef& ref) { + CHECK(ref->IsInstance()) << "The input ref is expected to be a Relay function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + ACLJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto param_names = serializer.GetParams(); + const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create"); + CHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; + runtime::Module lib = (*pf)(func_name, graph_json, param_names); + return lib; +} + +TVM_REGISTER_GLOBAL("relay.ext.arm_compute_lib").set_body_typed(ACLCompiler); + +/*! + * \brief Check whether ACL graph runtime is used. + * + * \return True if ACL graph runtime is enabled, False if not. + */ +inline constexpr bool IsACLRuntimeEnabled() { +#if TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB + return true; +#else + return false; +#endif +} + +TVM_REGISTER_GLOBAL("relay.op.is_arm_compute_runtime_enabled").set_body_typed(IsACLRuntimeEnabled); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 32ab150..0d395b7 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -61,19 +61,6 @@ class CSourceModuleCodegenBase { * \return A runtime module. */ virtual runtime::Module CreateCSourceModule(const ObjectRef& ref) = 0; - - /*! - * \brief Get the external symbol of the Relay function name. - * - * \param func The provided function. - * - * \return An external symbol. - */ - std::string GetExtSymbol(const Function& func) const { - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(name_node.defined()) << "Fail to retrieve external symbol."; - return std::string(name_node.value()); - } }; // The base class to generate the declaration functions in C. diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index d5a483d..bec9af0 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -468,19 +468,6 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { return AddNode(node, GetRef(cn)); } }; - -/*! - * \brief Get the external symbol of the Relay function name. - * - * \param func The provided function. - * - * \return An external symbol. - */ -std::string GetExtSymbol(const Function& func) { - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(name_node.defined()) << "Fail to retrieve external symbol."; - return std::string(name_node.value()); -} #endif /*! diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 1fe14b8..d6edd10 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -248,6 +248,18 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, return GetRootCall(next_call, depth - 1, expected_op_names); } +/*! + * \brief Get the external symbol of the Relay function name. + * + * \param func The provided function. + * \return An external symbol. + */ +inline std::string GetExtSymbol(const Function& func) { + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(name_node.defined()) << "Fail to retrieve external symbol."; + return std::string(name_node.value()); +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/runtime/contrib/arm_compute_lib/acl_allocator.cc b/src/runtime/contrib/arm_compute_lib/acl_allocator.cc new file mode 100644 index 0000000..2feb5b0 --- /dev/null +++ b/src/runtime/contrib/arm_compute_lib/acl_allocator.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/arm_compute_lib/acl_allocator.cc + * \brief ACL Allocator implementation that requests memory from TVM. + */ + +#include "acl_allocator.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +void* ACLAllocator::allocate(size_t size, size_t alignment) { + CHECK_GT(size, 0) << "Cannot allocate size less than or equal to zero"; + return this->device_api_->AllocWorkspace(this->ctx_, size, {}); +} + +void ACLAllocator::free(void* ptr) { this->device_api_->FreeWorkspace(this->ctx_, ptr); } + +std::unique_ptr ACLAllocator::make_region(size_t size, + size_t alignment) { + return std::make_unique(size, alignment); +} + +ACLMemoryRegion::ACLMemoryRegion(size_t size, size_t alignment) + : IMemoryRegion(size), ptr_(nullptr) { + if (size != 0) { + this->ptr_ = this->device_api_->AllocDataSpace(this->ctx_, size, alignment, {}); + } +} + +ACLMemoryRegion::ACLMemoryRegion(void* ptr, size_t size) + : IMemoryRegion(size), ptr_(nullptr), is_subregion_(true) { + if (size != 0) { + this->ptr_ = ptr; + } +} + +ACLMemoryRegion::~ACLMemoryRegion() { + if (this->ptr_ != nullptr && !is_subregion_) { + this->device_api_->FreeDataSpace(this->ctx_, this->ptr_); + } +} + +std::unique_ptr ACLMemoryRegion::extract_subregion(size_t offset, + size_t size) { + if (this->ptr_ != nullptr && (offset < _size) && (_size - offset >= size)) { + return std::make_unique(static_cast(this->ptr_) + offset, size); + } else { + return nullptr; + } +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/arm_compute_lib/acl_allocator.h b/src/runtime/contrib/arm_compute_lib/acl_allocator.h new file mode 100644 index 0000000..49d0d0c --- /dev/null +++ b/src/runtime/contrib/arm_compute_lib/acl_allocator.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/arm_compute_lib/acl_allocator.h + * \brief ACL Allocator implementation that requests memory from TVM. + */ + +#ifndef TVM_RUNTIME_CONTRIB_ARM_COMPUTE_LIB_ACL_ALLOCATOR_H_ +#define TVM_RUNTIME_CONTRIB_ARM_COMPUTE_LIB_ACL_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +/*! + * \brief Override ACL memory allocator and replace with TVM workspace based allocation. + */ +class ACLAllocator : public arm_compute::IAllocator { + public: + ACLAllocator() = default; + + /*! + * \brief Allocate bytes to ACL runtime. + * + * Specific implementation requests memory from TVM using their device api. + * + * \param size Size to allocate. + * \param alignment Alignment that the returned pointer should comply with. + * \return A pointer to the allocated memory. + */ + void* allocate(size_t size, size_t alignment) override; + + /*! + * \brief Free memory from ACL runtime. + * + * \param ptr Pointer to workspace to free. + */ + void free(void* ptr) override; + + /*! + * \brief Create self-managed memory region. + * + * \param size Size of the memory region. + * \param alignment Alignment of the memory region. + * \return The memory region object. + */ + std::unique_ptr make_region(size_t size, size_t alignment) override; + + private: + /*! \brief Always allocate data in the context of the current CPU. */ + const TVMContext ctx_{kDLCPU, 0}; + /*! \brief Device API which allows requests for memory from TVM. */ + runtime::DeviceAPI* device_api_ = runtime::DeviceAPI::Get(ctx_); +}; + +/*! + * \brief Memory region that can request TVM memory for ACL to use. + */ +class ACLMemoryRegion : public arm_compute::IMemoryRegion { + public: + ACLMemoryRegion(size_t size, size_t alignment); + ACLMemoryRegion(void* ptr, size_t size); + + ~ACLMemoryRegion() override; + + /*! \brief Prevent instances of this class from being copied (As this class contains + * pointers). */ + ACLMemoryRegion(const ACLMemoryRegion&) = delete; + /*! \brief Default move constructor. */ + ACLMemoryRegion(ACLMemoryRegion&&) = default; + /*! \brief Prevent instances of this class from being copied (As this class + * contains pointers) */ + ACLMemoryRegion& operator=(const ACLMemoryRegion&) = delete; + /*! Default move assignment operator. */ + ACLMemoryRegion& operator=(ACLMemoryRegion&&) = default; + + void* buffer() override { return this->ptr_; } + + const void* buffer() const override { return this->ptr_; } + + /*! + * \brief Extract a sub-region from the memory. + * + * \warning Ownership is maintained by the parent memory, + * while a wrapped raw memory region is returned by this function. + * Thus parent memory should not be released before this. + * + * \param offset Offset to the region. + * \param size Size of the region. + * \return A wrapped memory sub-region with no ownership of the + * underlying memory. + */ + std::unique_ptr extract_subregion(size_t offset, + size_t size) override; + + private: + /*! \brief Points to a region of memory allocated by TVM. */ + void* ptr_; + /*! \brief A subregion doesn't manage TVM memory so we don't need to free it. */ + bool is_subregion_ = false; + /*! \brief Always allocate data in the context of the current CPU. */ + const TVMContext ctx_{kDLCPU, 0}; + /*! \brief Device API which allows requests for memory from TVM. */ + runtime::DeviceAPI* device_api_ = runtime::DeviceAPI::Get(ctx_); +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_ARM_COMPUTE_LIB_ACL_ALLOCATOR_H_ diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc new file mode 100644 index 0000000..e8cdef7 --- /dev/null +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/arm_compute_lib/acl_runtime.cc + * \brief A simple JSON runtime for Arm Compute Library. + */ + +#include +#include + +#include "../../file_util.h" +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB +#include +#include +#include +#include + +#include "acl_allocator.h" +#include "acl_utils.h" +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; + +class ACLRuntime : public JSONRuntimeBase { + public: + /*! + * \brief The ACL runtime module. Deserialize the provided functions + * on creation and store in the layer cache. + * + * \param symbol_name The name of the function. + * \param graph_json serialized JSON representation of a sub-graph. + * \param const_names The names of each constant in the sub-graph. + */ + explicit ACLRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + /*! + * \brief The type key of the module. + * + * \return module type key. + */ + const char* type_key() const override { return "arm_compute_lib"; } + + /*! + * \brief Initialize runtime. Create ACL layer from JSON + * representation. + * + * \param consts The constant params from compiled model. + */ + void Init(const Array& consts) override { + CHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + SetupConstants(consts); + BuildEngine(); + } + +#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB + /*! + * \brief Unpack inputs and outputs and run inference on a given layer. + * + * \param args Access inputs and outputs. + * \param function The layer to execute inference on. + * \return Status of inference. + */ + void Run() override { + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + uint32_t eid = EntryID(nid, 0); + if (nodes_[nid].GetOpType() == "input") { + void* data = data_entry_[eid]->data; + CheckACLError(layer_.inputs[i].allocator()->import_memory(data)); + } + } + + for (size_t i = 0; i < outputs_.size(); ++i) { + uint32_t eid = EntryID(outputs_[i]); + void* data = data_entry_[eid]->data; + CheckACLError(layer_.outputs[i].allocator()->import_memory(data)); + } + + this->layer_.function->run(); + } + + private: + /*! + * \brief Build ACL layer from JSON representation and cache. + * + * \note For the time being only one layer or operator is supported + * per engine. + */ + void BuildEngine() { + std::shared_ptr mm = MakeMemoryManager(); + int num_pools = 0; + + for (size_t i = 0; i < input_nodes_.size(); ++i) { + uint32_t nid = input_nodes_[i]; + const auto& node = nodes_[nid]; + if (node.GetOpType() == "input") { + layer_.inputs.push_back(MakeTensor(node)); + } else if (node.GetOpType() == "const") { + uint32_t eid = EntryID(nid, 0); + void* data = data_entry_[eid]->data; + layer_.const_inputs.push_back(MakeTensor(node, data)); + } + } + + bool found_kernel_node = false; + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (found_kernel_node) { + LOG(FATAL) + << "Arm Compute Library runtime module only supports one kernel node per function."; + } + if (node.GetOpType() == "kernel") { + found_kernel_node = true; + auto op_name = node.GetOpName(); + if ("nn.conv2d" == op_name) { + CreateConvolution2DLayer(&layer_, node, mm); + num_pools++; + } else if ("nn.max_pool2d" == op_name) { + CreatePoolingLayer(&layer_, node); + } else if ("reshape" == op_name) { + CreateReshapeLayer(&layer_, node); + } else { + LOG(FATAL) << "Unsupported op: " << op_name; + } + } + } + + this->layer_.function->prepare(); + if (num_pools > 0) mm->populate(this->allocator_, num_pools); + } + + /*! + * \brief ACL objects we cache in order to avoid needing to construct + * a new layer each time. + */ + struct CachedLayer { + std::shared_ptr function; + std::vector inputs; + std::vector const_inputs; + std::vector outputs; + }; + + /*! + * \brief Create a 2D convolution layer. + * + * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function. + * \param node The JSON representation of the operator. + * \param mm The ACL conv2d layer can request auxiliary memory from TVM. + */ + static void CreateConvolution2DLayer( + CachedLayer* layer, const JSONGraphNode& node, + const std::shared_ptr& mm) { + std::vector padding = node.GetAttr>("padding"); + std::vector strides = node.GetAttr>("strides"); + std::vector dilation = node.GetAttr>("dilation"); + arm_compute::PadStrideInfo pad_stride_info = ToACLPadStride(padding, strides); + + int groups = std::stoi(node.GetAttr>("groups")[0]); + CHECK(groups == 1) << "Arm Compute Library NEON convolution only supports group size of 1."; + + arm_compute::ActivationLayerInfo act_info; + if (node.HasAttr("activation_type")) { + std::string activation_type = node.GetAttr>("activation_type")[0]; + if (activation_type == "relu") { + act_info = arm_compute::ActivationLayerInfo( + arm_compute::ActivationLayerInfo::ActivationFunction::RELU); + } else { + LOG(FATAL) << "Unsupported activation function"; + } + } + + arm_compute::Size2D dilation_2d(std::stoi(dilation[0]), std::stoi(dilation[1])); + + layer->outputs.push_back(MakeOutputTensor(node.GetOpShape()[0])); + + auto function = std::make_shared(mm); + function->configure(&layer->inputs[0], &layer->const_inputs[0], + layer->const_inputs.size() > 1 ? &layer->const_inputs[1] : nullptr, + &layer->outputs[0], pad_stride_info, arm_compute::WeightsInfo(), + dilation_2d, act_info); + layer->function = function; + } + + /*! + * \brief Create a pooling layer. + * + * \note Currently only maxpool is supported. + * + * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function. + * \param node The JSON representation of the operator. + */ + static void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) { + std::vector padding = node.GetAttr>("padding"); + std::vector strides = node.GetAttr>("strides"); + arm_compute::PadStrideInfo pad_stride_info = ToACLPadStride(padding, strides); + + auto attr_pool_size = node.GetAttr>("pool_size"); + int pool_size_h = std::stoi(attr_pool_size[0]); + int pool_size_w = std::stoi(attr_pool_size[1]); + + arm_compute::PoolingType pool_type; + if (node.GetOpName() == "nn.max_pool2d") { + pool_type = arm_compute::PoolingType::MAX; + } else { + LOG(FATAL) << "Pooling type not supported"; + } + + arm_compute::PoolingLayerInfo pool_info = + arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w), + arm_compute::DataLayout::NHWC, pad_stride_info); + + layer->outputs.push_back(MakeOutputTensor(node.GetOpShape()[0])); + + auto function = std::make_shared(); + function->configure(&layer->inputs[0], &layer->outputs[0], pool_info); + layer->function = function; + } + + /*! + * \brief Create a reshape layer. + * + * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function. + * \param node The JSON representation of the operator. + */ + static void CreateReshapeLayer(CachedLayer* layer, const JSONGraphNode& node) { + layer->outputs.push_back(MakeOutputTensor(node.GetOpShape()[0])); + auto function = std::make_shared(); + function->configure(&layer->inputs[0], &layer->outputs[0]); + layer->function = function; + } + + /*! \brief Allow ACL functions to request auxiliary memory from TVM. */ + ACLAllocator allocator_; + /*! + * \brief The network layers represented by acl functions. + * \note Currently only supports a single layer. + */ + CachedLayer layer_; +#else + void Run() override { + LOG(FATAL) << "Cannot call run on Arm Compute Library module without runtime enabled. " + << "Please build with USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME."; + } + + void BuildEngine() { + LOG(WARNING) << "Arm Compute Library engine is not initialized. " + << "Please build with USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME."; + } +#endif +}; + +runtime::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.arm_compute_lib_runtime_create").set_body_typed(ACLRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_arm_compute_lib") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc new file mode 100644 index 0000000..ad278ba --- /dev/null +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/arm_compute_lib/acl_utils.cc + * \brief Utils and common functions for the interface. + */ + +#include "acl_utils.h" + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +void CheckACLError(const arm_compute::Status& status) { + CHECK(status.error_code() == arm_compute::ErrorCode::OK) << "ACL: " << status.error_description(); +} + +arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data) { + CHECK(tensor_rep.GetOpType() == "input" || tensor_rep.GetOpType() == "const"); + arm_compute::Tensor tensor; + arm_compute::TensorInfo info = MakeTensorInfo(tensor_rep.GetOpShape()[0]); + tensor.allocator()->init(info); + if (data != nullptr) { + CheckACLError(tensor.allocator()->import_memory(data)); + } + return tensor; +} + +arm_compute::Tensor MakeOutputTensor(const std::vector& shape) { + arm_compute::Tensor tensor; + tensor.allocator()->init(MakeTensorInfo(shape)); + return tensor; +} + +arm_compute::TensorInfo MakeTensorInfo(const std::vector& shape) { + arm_compute::TensorShape acl_shape = MakeTensorShape(shape); + return arm_compute::TensorInfo(acl_shape, 1, arm_compute::DataType::F32, + arm_compute::DataLayout::NHWC); +} + +arm_compute::TensorShape MakeTensorShape(const std::vector& shape) { + arm_compute::TensorShape acl_shape; + for (unsigned int i = shape.size(); i > 0; --i) { + acl_shape.set(shape.size() - i, shape[i - 1]); + } + return acl_shape; +} + +std::shared_ptr MakeMemoryManager() { + auto lifetime_mgr = std::make_shared(); + auto pool_mgr = std::make_shared(); + return std::make_shared(lifetime_mgr, pool_mgr); +} + +arm_compute::PadStrideInfo ToACLPadStride(const std::vector& pad, + const std::vector& stride) { + int pad_0 = 0, pad_1 = 0, pad_2 = 0, pad_3 = 0; + int stride_0 = std::stoi(stride[0]), stride_1 = std::stoi(stride[1]); + size_t size = pad.size(); + if (size == 1) { + int pad_v = std::stoi(pad[0]); + pad_0 = pad_v; + pad_1 = pad_v; + pad_2 = pad_v; + pad_3 = pad_v; + } else if (size == 2) { + // TVM: height, width -> ACL: left, right, top, bottom + int pad_h = std::stoi(pad[0]); + int pad_w = std::stoi(pad[1]); + pad_0 = pad_w; + pad_1 = pad_w; + pad_2 = pad_h; + pad_3 = pad_h; + } else if (size == 4) { + // TVM: top, left, bottom, right -> ACL: left, right, top, bottom + pad_0 = std::stoi(pad[1]); + pad_1 = std::stoi(pad[3]); + pad_2 = std::stoi(pad[0]); + pad_3 = std::stoi(pad[2]); + } else { + LOG(FATAL) << "Unsupported padding dimensions"; + } + + return arm_compute::PadStrideInfo(stride_0, stride_1, pad_0, pad_1, pad_2, pad_3, + arm_compute::DimensionRoundingType::FLOOR); +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.h b/src/runtime/contrib/arm_compute_lib/acl_utils.h new file mode 100644 index 0000000..6a92780 --- /dev/null +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.h @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/arm_compute_lib/acl_utils.h + * \brief Utils and common functions for the interface. + */ + +#ifndef TVM_RUNTIME_CONTRIB_ARM_COMPUTE_LIB_ACL_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_ARM_COMPUTE_LIB_ACL_UTILS_H_ + +#include +#include +#include + +#include +#include +#include + +#include "../json/json_node.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +/*! + * \brief Check if there are any errors from acl and forward them to TVM. + * + * Status values: + * - 0 => OK + * - 1 => RUNTIME_ERROR + * - 2 => UNSUPPORTED_EXTENSION_USE + * + * \param status status of called function. + */ +void CheckACLError(const arm_compute::Status& status); + +/*! + * \brief Make an acl tensor from JSON tensor representation. + * + * \param tensor_rep A JSON tensor representation. + * \param data (optional) Initialize the tensor with memory. + * \return arm_compute::Tensor. + */ +arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data = nullptr); + +/*! + * \brief Make an acl tensor from type and shape, without having a JSON representation. + * + * \param shape The shape of the tensor to create. + * \return arm_compute::Tensor. + */ +arm_compute::Tensor MakeOutputTensor(const std::vector& shape); + +/*! + * \brief Make an acl tensor info object from JSON tensor + * representation. + * + * \param shape The shape of the tensor to create. + * \return arm_compute::TensorInfo. + */ +arm_compute::TensorInfo MakeTensorInfo(const std::vector& shape); + +/*! + * \brief Convert vector object to acl TensorShape. + * \note This requires reversing the given vector. + * + * \param shape The shape of the tensor as a vector. + * \return arm_compute::TensorShape. + */ +arm_compute::TensorShape MakeTensorShape(const std::vector& shape); + +/*! + * \brief Create a memory manager for use with a layer that + * requires working memory. + * + * \return reference counted memory manager. + */ +std::shared_ptr MakeMemoryManager(); + +/*! + * \brief Convert TVM padding and stride format to acl PadStrideInfo. + * + * \param pad The pad vector. + * \param stride The stride vector. + * \return arm_compute::PadStrideInfo + */ +arm_compute::PadStrideInfo ToACLPadStride(const std::vector& pad, + const std::vector& stride); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_ARM_COMPUTE_LIB_ACL_UTILS_H_ diff --git a/src/runtime/contrib/json/json_node.h b/src/runtime/contrib/json/json_node.h index 7468feb..7cb17de 100644 --- a/src/runtime/contrib/json/json_node.h +++ b/src/runtime/contrib/json/json_node.h @@ -272,6 +272,15 @@ class JSONGraphNode { attrs_[key] = value; } + /*! + * \brief Check if node has attribute. + * + * \param key The key of the attribute. + * + * \return True if attribute exists, false otherwise. + */ + bool HasAttr(const std::string& key) const { return attrs_.find(key) != attrs_.end(); } + virtual ~JSONGraphNode() {} private: diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index c4f126e..92830e6 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -130,6 +130,14 @@ class JSONRuntimeBase : public ModuleNode { return Module(n); } + /*! + * \brief Get the JSON generated by codegen. + * + * \param format the format to return. + * \return A string of JSON. + */ + std::string GetSource(const std::string& format = "json") override { return graph_json_; } + protected: /*! * \brief Set up the input and output buffers by binding their DLTensor pointers to the diff --git a/tests/python/contrib/test_arm_compute_lib/__init__.py b/tests/python/contrib/test_arm_compute_lib/__init__.py new file mode 100644 index 0000000..fd14be1 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Infrastructure and tests for Arm Compute Library""" diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py new file mode 100644 index 0000000..ea486b0 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from itertools import zip_longest, combinations +import json + +import tvm +from tvm import relay +from tvm import rpc +from tvm.contrib import graph_runtime +from tvm.relay.op.contrib import arm_compute_lib +from tvm.contrib import util + + +class Device: + """Adjust the following settings to connect to and use a remote device for tests.""" + use_remote = False + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon" + # Enable cross compilation when connecting a remote device from a non-arm platform. + cross_compile = None + # cross_compile = "aarch64-linux-gnu-g++" + + def __init__(self): + """Keep remote device for lifetime of object.""" + self.device = self._get_remote() + + @classmethod + def _get_remote(cls): + """Get a remote (or local) device to use for testing.""" + if cls.use_remote: + # Here you may adjust settings to run the ACL unit tests via a remote + # device using the RPC mechanism. Use this in the case you want to compile + # an ACL module on a different machine to what you run the module on i.e. + # x86 -> AArch64. + # + # Use the following to connect directly to a remote device: + # device = rpc.connect( + # hostname="0.0.0.0", + # port=9090) + # + # Or connect via a tracker: + # device = tvm.autotvm.measure.request_remote( + # host="0.0.0.0", + # port=9090, + # device_key="device_key", + # timeout=1000) + # + # return device + raise NotImplementedError( + "Please adjust these settings to connect to your remote device.") + else: + device = rpc.LocalSession() + return device + + +def get_cpu_op_count(mod): + """Traverse graph counting ops offloaded to TVM.""" + class Counter(tvm.relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + c = Counter() + c.visit(mod["main"]) + return c.count + + +def skip_runtime_test(): + """Skip test if it requires the runtime and it's not present.""" + # ACL codegen not present. + if not tvm.get_global_func("relay.ext.arm_compute_lib", True): + print("Skip because Arm Compute Library codegen is not available.") + return True + + # Remote device is in use or ACL runtime not present + if not Device.use_remote and not arm_compute_lib.is_arm_compute_runtime_enabled(): + print("Skip because runtime isn't present or a remote device isn't being used.") + return True + + +def skip_codegen_test(): + """Skip test if it requires the ACL codegen and it's not present.""" + if not tvm.get_global_func("relay.ext.arm_compute_lib", True): + print("Skip because Arm Compute Library codegen is not available.") + return True + + +def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_partitions=1): + """Build module with option to build for ACL.""" + if isinstance(mod, tvm.relay.expr.Call): + mod = tvm.IRModule.from_expr(mod) + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + if enable_acl: + mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params) + tvm_op_count = get_cpu_op_count(mod) + assert tvm_op_count == tvm_ops, \ + "Got {} TVM operators, expected {}".format(tvm_op_count, tvm_ops) + partition_count = 0 + for global_var in mod.get_global_vars(): + if "arm_compute_lib" in global_var.name_hint: + partition_count += 1 + + assert acl_partitions == partition_count, \ + "Got {} Arm Compute Library partitions, expected {}".format( + partition_count, acl_partitions) + relay.backend.compile_engine.get().clear() + return relay.build(mod, target=target, params=params) + + +def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, no_runs=1, + tvm_ops=0, acl_partitions=1): + """Build and run the relay module.""" + lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions) + lib = update_lib(lib, device.device, device.cross_compile) + gen_module = graph_runtime.GraphModule(lib['default'](device.device.cpu(0))) + gen_module.set_input(**inputs) + out = [] + for _ in range(no_runs): + gen_module.run() + out.append([gen_module.get_output(i) for i in range(outputs)]) + return out + + +def update_lib(lib, device, cross_compile): + """Export the library to the remote/local device.""" + lib_name = "mod.so" + temp = util.tempdir() + lib_path = temp.relpath(lib_name) + if cross_compile: + lib.export_library(lib_path, cc=cross_compile) + else: + lib.export_library(lib_path) + device.upload(lib_path) + lib = device.load_module(lib_name) + return lib + + +def verify(answers, atol, rtol): + """Compare the array of answers. Each entry is a list of outputs.""" + if len(answers) < 2: + raise RuntimeError( + f"No results to compare: expected at least two, found {len(answers)}") + for answer in zip_longest(*answers): + for outs in combinations(answer, 2): + tvm.testing.assert_allclose( + outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol) + + +def extract_acl_modules(module): + """Get the ACL module(s) from llvm module.""" + return list(filter(lambda mod: mod.type_key == "arm_compute_lib", + module.get_lib().imported_modules)) + + +def verify_codegen(module, known_good_codegen, num_acl_modules, + target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon"): + """Check acl codegen against a known good output.""" + module = build_module(module, target) + acl_modules = extract_acl_modules(module) + + assert len(acl_modules) == num_acl_modules, \ + f"The number of Arm Compute Library modules produced ({len(acl_modules)}) does not " \ + f"match the expected value ({num_acl_modules})." + + for mod in acl_modules: + source = mod.get_source("json") + codegen = json.loads(source)["nodes"] + # remove input and const names as these cannot be predetermined + for node in range(len(codegen)): + if codegen[node]["op"] == "input" or codegen[node]["op"] == "const": + codegen[node]["name"] = "" + codegen_str = json.dumps(codegen, sort_keys=True, indent=2) + known_good_codegen_str = json.dumps(known_good_codegen, sort_keys=True, indent=2) + + assert codegen_str == known_good_codegen_str, \ + f"The JSON produced by codegen does not match the expected result. \n" \ + f"Actual={codegen_str} \n" \ + f"Expected={known_good_codegen_str}" diff --git a/tests/python/contrib/test_arm_compute_lib/test_conv2d.py b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py new file mode 100644 index 0000000..8765878 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm Compute Library integration conv2d tests.""" + +import numpy as np + +import tvm +from tvm import relay + +from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run, \ + verify, verify_codegen +from .infrastructure import Device + + +def _get_model(shape, kernel_size, padding, strides, + dilation, groups, dtype, channels, + var_names, has_bias=False, has_activation=False, has_pad=False): + """Return a model and any parameters it may have""" + a = relay.var(next(var_names), shape=shape, dtype=dtype) + if has_pad: + p = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0)) + a = relay.nn.pad(a, pad_width=p) + padding = (0, 0, 0, 0) + else: + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + shape = (shape[0], shape[1] + padding[0] * 2, + shape[2] + padding[1] * 2, shape[3]) + weight_shape = (kernel_size, kernel_size, shape[3] // groups, channels) + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + out = relay.nn.conv2d( + a, + weights, + kernel_size=(kernel_size, kernel_size), + data_layout="NHWC", + kernel_layout="HWIO", + dilation=(1, 1), + strides=strides, + padding=padding, + groups=groups, + channels=channels + ) + params = {"w": w} + if has_bias: + b = tvm.nd.array(np.random.uniform(-128, 127, weight_shape[3]).astype(dtype)) + biasc = relay.const(b, dtype) + out = relay.nn.bias_add(out, biasc, axis=3) + params["b"] = b + if has_activation: + out = relay.nn.relu(out) + return out, params + + +def _get_expected_codegen(shape, kernel_size, padding, strides, + dilation, groups, dtype, channels, + has_bias=False, has_activation=False): + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + weight_shape = (channels, kernel_size, kernel_size, shape[3] // groups) + output_height = ((shape[1] - kernel_size + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[2] - kernel_size + padding[1] + padding[3]) / strides[1]) + 1 + output_shape = (1, int(output_height), int(output_width), channels) + + node = { + "op": "kernel", + "name": "nn.conv2d", + "inputs": [[0, 0, 0], [1, 0, 0]], + "attrs": { + "groups": [["1"]], + "num_inputs": str(3 if has_bias else 2), + "num_outputs": "1", + "data_layout": [["NHWC"]], + "kernel_layout": [["OHWI"]], + "channels": [["1"]], + "dilation": [["1", "1"]], + "out_layout": [[""]], + "out_dtype": [[""]], + "kernel_size": [[str(kernel_size), str(kernel_size)]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "padding": [[str(p) for p in padding]], + "strides": [[str(s) for s in strides]] + }, + } + + if has_activation: + node["attrs"]["activation_type"] = [["relu"]] + + input = { + "op": "input", + "name": "", + "attrs": {"shape": [[list(shape)]], "dtype": [["float32"]]}} + kernel = { + "op": "const", + "name": "", + "attrs": {"shape": [[list(weight_shape)]], "dtype": [["float32"]]}} + + if has_bias: + bias = { + "op": "const", + "name": "", + "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [["float32"]]}} + node["inputs"].append([2, 0, 0]) + return [input, kernel, bias, node] + else: + return [input, kernel, node] + + +def test_conv2d(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + shape = (1, 14, 14, 32) + dtype = "float32" + + inputs = { + "a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype)), + } + + for kernel_size in [1, 2, 3]: + outputs = [] + func, params = _get_model(shape, kernel_size, + (0, 0), (1, 1), 1, 1, + dtype, 1, iter(inputs)) + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, + params, device, + enable_acl=acl)[0]) + verify(outputs, atol=0.002, rtol=0.01) + + for pad_ksize in [((1, 1), 3), ((2, 2), 5), ((2, 1), 3)]: + outputs = [] + func, params = _get_model(shape, pad_ksize[1], pad_ksize[0], + (1, 1), 1, 1, dtype, 1, iter(inputs)) + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, + params, device, + enable_acl=acl)[0]) + verify(outputs, atol=0.002, rtol=0.01) + + for strides in [(1, 1), (2, 2)]: + outputs = [] + func, params = _get_model(shape, 2, (0, 0), strides, + 1, 1, dtype, 1, iter(inputs)) + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, + params, device, + enable_acl=acl)[0]) + verify(outputs, atol=0.002, rtol=0.01) + + # Test composite convolution: (has_pad, has_bias, has_activation). + for composite in [(False, True, False), (False, False, True), (False, True, True), + (True, False, False)]: + outputs = [] + func, params = _get_model(shape, 2, (1, 1), (1, 1), + 1, 1, dtype, 1, iter(inputs), + has_pad=composite[0], + has_bias=composite[1], + has_activation=composite[2]) + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, + params, device, + enable_acl=acl)[0]) + verify(outputs, atol=0.002, rtol=0.01) + + +def test_codegen_conv2d(): + if skip_codegen_test(): + return + + shape = (1, 25, 25, 1) + dtype = "float32" + inputs = {"a"} + + for pad_ksize in [((1, 1), 3), ((2, 1), 3)]: + args = (shape, pad_ksize[1], pad_ksize[0], (1, 1), 1, 1, dtype, 1) + func, params = _get_model(*args, var_names=iter(inputs)) + exp_codegen = _get_expected_codegen(*args) + verify_codegen(func, exp_codegen, 1) + # Test composite convolution: (has_pad, has_bias, has_activation). + for composite in [(False, True, False), (False, False, True), (False, True, True), + (True, False, False)]: + args = (shape, 2, (1, 1), (1, 1), 1, 1, dtype, 1) + func, params = _get_model(*args, var_names=iter(inputs), + has_pad=composite[0], + has_bias=composite[1], + has_activation=composite[2]) + exp_codegen = _get_expected_codegen(*args, + has_bias=composite[1], + has_activation=composite[2]) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_conv2d() + test_codegen_conv2d() diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py new file mode 100644 index 0000000..8648a01 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm Compute Library network tests.""" + +import numpy as np + +from tvm import relay + +from .infrastructure import skip_runtime_test, build_and_run, verify +from .infrastructure import Device + + +def _build_and_run_keras_network(mod, params, inputs, device, tvm_ops, acl_partitions): + """Helper function to build and run a network from the Keras frontend.""" + data = {} + np.random.seed(0) + for name, shape in inputs.items(): + data[name] = np.random.uniform(-128, 127, shape).astype("float32") + + outputs = [] + for acl in [False, True]: + outputs.append(build_and_run(mod, data, 1, params, + device, enable_acl=acl, + tvm_ops=tvm_ops, + acl_partitions=acl_partitions)[0]) + verify(outputs, atol=0.002, rtol=0.01) + + +def test_vgg16(): + if skip_runtime_test(): + return + + device = Device() + + def get_model(): + from keras.applications import VGG16 + vgg16 = VGG16(include_top=True, weights='imagenet', + input_shape=(224, 224, 3), classes=1000) + inputs = {vgg16.input_names[0]: (1, 224, 224, 3)} + mod, params = relay.frontend.from_keras(vgg16, inputs, layout="NHWC") + return mod, params, inputs + + _build_and_run_keras_network(*get_model(), device=device, + tvm_ops=10, acl_partitions=18) + + +def test_mobilenet(): + if skip_runtime_test(): + return + + device = Device() + + def get_model(): + from keras.applications import MobileNet + mobilenet = MobileNet(include_top=True, weights='imagenet', + input_shape=(224, 224, 3), classes=1000) + inputs = {mobilenet.input_names[0]: (1, 224, 224, 3)} + mod, params = relay.frontend.from_keras(mobilenet, inputs, layout="NHWC") + return mod, params, inputs + + _build_and_run_keras_network(*get_model(), device=device, + tvm_ops=74, acl_partitions=17) + + +if __name__ == "__main__": + test_vgg16() + test_mobilenet() diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py new file mode 100644 index 0000000..aac7795 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm Compute Library integration pooling tests.""" + +import numpy as np + +import tvm +from tvm import relay + +from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run, \ + verify, verify_codegen +from .infrastructure import Device + + +def _get_model(shape, typef, sizes, strides, padding, + ceil_mode, var_names): + """Return a model and any parameters it may have.""" + var = relay.var(next(var_names), shape=shape, dtype="float32") + pool = typef(var, pool_size=sizes, strides=strides, padding=padding, + ceil_mode=ceil_mode, layout="NHWC") + return pool + + +def _get_expected_codegen(shape, typef, sizes, strides, padding, + ceil_mode): + if len(padding) == 2: + padding = (padding[1], padding[1], padding[0], padding[0]) + output_height = ((shape[1] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[2] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1 + output_shape = (1, int(output_height), int(output_width), shape[3]) + + node = { + "op": "kernel", + "name": "nn.max_pool2d", + "inputs": [[0, 0, 0]], + "attrs": { + "num_inputs": "1", + "num_outputs": "1", + "layout": [["NHWC"]], + "shape": [[list(output_shape)]], + "dtype": [["float32"]], + "padding": [[str(p) for p in padding]], + "strides": [[str(s) for s in strides]], + "pool_size": [[str(s) for s in sizes]], + "ceil_mode": [[str(1 if ceil_mode else 0)]] + }, + } + + input = { + "op": "input", + "name": "", + "attrs": {"shape": [[list(shape)]], "dtype": [["float32"]]}} + return [input, node] + + +def test_pooling(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + for size in [(2, 2), (3, 3)]: + for stride in [(2, 2)]: + shape = (1, size[0] + stride[0] * 5, + size[1] + stride[1] * 5, 16) + + inputs = { + "a": tvm.nd.array(np.random.uniform(-1, 1, shape).astype("float32")), + } + + outputs = [] + func = _get_model(shape, relay.nn.max_pool2d, size, + stride, (0, 0), True, iter(inputs)) + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, None, device, + enable_acl=acl)[0]) + verify(outputs, atol=0.001, rtol=0.001) + + +def test_codegen_pooling(): + if skip_codegen_test(): + return + + inputs = {"a"} + + for size in [(2, 2), (3, 3)]: + for stride in [(2, 2)]: + shape = (1, size[0] + stride[0] * 5, + size[1] + stride[1] * 5, 16) + args = (shape, relay.nn.max_pool2d, size, + stride, (0, 0), True) + func = _get_model(*args, iter(inputs)) + exp_codegen = _get_expected_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_pooling() + test_codegen_pooling() diff --git a/tests/python/contrib/test_arm_compute_lib/test_reshape.py b/tests/python/contrib/test_arm_compute_lib/test_reshape.py new file mode 100644 index 0000000..cb9f295 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/test_reshape.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm Compute Library integration reshape tests.""" + +import numpy as np + +import tvm +from tvm import relay + +from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run, \ + verify, verify_codegen +from .infrastructure import Device + + +def _get_model(input_shape, output_shape, var_names): + """Return a model and any parameters it may have.""" + a = relay.var(next(var_names), shape=input_shape, dtype="float32") + reshape = relay.reshape(a, output_shape) + return reshape + + +def _get_expected_codegen(input_shape, output_shape): + node = { + "op": "kernel", + "name": "reshape", + "inputs": [[0, 0, 0]], + "attrs": { + "num_inputs": "1", + "num_outputs": "1", + "newshape": [[str(s) for s in output_shape]], + "shape": [[list(output_shape)]], + "dtype": [["float32"]], + "reverse": [["0"]] + }, + } + + input = { + "op": "input", + "name": "", + "attrs": {"shape": [[list(input_shape)]], "dtype": [["float32"]]}} + + return [input, node] + + +def test_reshape(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + inputs = { + "a": tvm.nd.array( + np.random.uniform(-128, 127, (1, 1, 1, 1000)).astype("float32")) + } + + for shape in [(1, 1000), (10, 10, 10)]: + outputs = [] + func = _get_model(inputs["a"].shape, shape, iter(inputs)) + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, None, device, + enable_acl=acl)[0]) + verify(outputs, atol=1e-7, rtol=1e-7) + + +def test_codegen_reshape(): + if skip_codegen_test(): + return + + shape = (1, 1, 1, 1000) + inputs = {"a"} + + for new_shape in [(1, 1000), (10, 10, 10)]: + args = (shape, new_shape) + func = _get_model(*args, iter(inputs)) + exp_codegen = _get_expected_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_reshape() + test_codegen_reshape() diff --git a/tests/python/contrib/test_arm_compute_lib/test_runtime.py b/tests/python/contrib/test_arm_compute_lib/test_runtime.py new file mode 100644 index 0000000..2bb17ad --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/test_runtime.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm Compute Library runtime tests.""" + +import numpy as np + +import tvm +from tvm import relay + +from .infrastructure import skip_runtime_test, build_and_run, verify +from .infrastructure import Device + + +def test_multiple_ops(): + """ + Test multiple operators destined for ACL. + The ACL runtime will expect these ops as 2 separate functions for + the time being. + """ + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + def get_model(input_shape, var_names): + """Return a model and any parameters it may have.""" + a = relay.var(next(var_names), shape=input_shape, dtype="float32") + out = relay.reshape(a, (1, 1, 1000)) + out = relay.reshape(out, (1, 1000)) + return out + + inputs = { + "a": tvm.nd.array(np.random.uniform(0, 1, (1, 1, 1, 1000)).astype("float32")) + } + + outputs = [] + for acl in [False, True]: + func = get_model(inputs["a"].shape, iter(inputs)) + outputs.append(build_and_run(func, inputs, 1, None, device, + enable_acl=acl, acl_partitions=2)[0]) + verify(outputs, atol=0.002, rtol=0.01) + + +def test_heterogeneous(): + """ + Test to check if offloading only supported operators works, + while leaving unsupported operators computed via tvm. + """ + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + def get_model(input_shape, var_names): + """Return a model and any parameters it may have.""" + a = relay.var(next(var_names), shape=input_shape, dtype="float32") + out = relay.reshape(a, (1, 1, 1000)) + out = relay.sigmoid(out) + out = relay.reshape(out, (1, 1000)) + return out + + inputs = { + "a": tvm.nd.array(np.random.uniform(-127, 128, (1, 1, 1, 1000)).astype("float32")) + } + + outputs = [] + for acl in [False, True]: + func = get_model(inputs["a"].shape, iter(inputs)) + outputs.append(build_and_run(func, inputs, 1, None, device, + enable_acl=acl, tvm_ops=1, + acl_partitions=2)[0]) + verify(outputs, atol=0.002, rtol=0.01) + + +def test_multiple_runs(): + """ + Test that multiple runs of an operator work. + """ + if skip_runtime_test(): + return + + device = Device() + + def get_model(): + a = relay.var("a", shape=(1, 28, 28, 512), dtype="float32") + w = tvm.nd.array(np.ones((256, 1, 1, 512), dtype="float32")) + weights = relay.const(w, "float32") + conv = relay.nn.conv2d( + a, + weights, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1) + ) + params = {"w": w} + return conv, params + + inputs = { + "a": tvm.nd.array(np.random.uniform(-127, 128, (1, 28, 28, 512)).astype("float32")), + } + + func, params = get_model() + outputs = build_and_run(func, inputs, 1, + params, device, + enable_acl=True, + no_runs=3) + verify(outputs, atol=0.002, rtol=0.01) + + +if __name__ == "__main__": + test_multiple_ops() + test_heterogeneous() + test_multiple_runs() diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 529b996..d1c076d 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -31,6 +31,7 @@ echo set\(USE_GRAPH_RUNTIME_DEBUG ON\) >> config.cmake echo set\(USE_VM_PROFILER ON\) >> config.cmake echo set\(USE_EXAMPLE_EXT_RUNTIME ON\) >> config.cmake echo set\(USE_DNNL_CODEGEN ON\) >> config.cmake +echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake echo set\(USE_LLVM llvm-config-10\) >> config.cmake echo set\(USE_NNPACK ON\) >> config.cmake echo set\(NNPACK_PATH /NNPACK/build/\) >> config.cmake