From: Hyeongseok Oh Date: Wed, 28 Aug 2024 08:36:59 +0000 (+0900) Subject: Imported Upstream version 1.28.0 X-Git-Tag: upstream/1.28.0^0 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=refs%2Fheads%2Fupstream;p=platform%2Fcore%2Fml%2Fnnfw.git Imported Upstream version 1.28.0 --- diff --git a/.ahub/sam/exclude.txt b/.ahub/sam/exclude.txt index 3c2b71f06..cd808b047 100644 --- a/.ahub/sam/exclude.txt +++ b/.ahub/sam/exclude.txt @@ -1,52 +1,49 @@ # External code: Android NN API -/ONE/compiler/ann-api/include/NeuralNetworks.h -/ONE/compiler/ann-ref +compiler/ann-api/include/NeuralNetworks.h +compiler/ann-ref/ # Eigen -/ONE/compiler/nnc/backends/soft_backend/code_snippets/eigen.def +compiler/nnc/backends/soft_backend/code_snippets/eigen.def # Frontend test tools that are needed for release package build -/ONE/compiler/circlechef -/ONE/compiler/circle-verify -/ONE/compiler/luci/tester +compiler/circlechef/ +compiler/circle-verify/ +compiler/luci/tester/ # Exclude IR headers which have lots of similar patterns # TODO remove this when refactoring is possible -/ONE/compiler/luci/lang/include/luci/IR/Nodes -/ONE/compiler/luci/import/include/luci/Import/Nodes -/ONE/compiler/loco/include/loco/IR -/ONE/compiler/tflchef/tflite/src/Op/include +compiler/luci/lang/include/luci/IR/Nodes/ +compiler/luci/import/include/luci/Import/Nodes/ +compiler/loco/include/loco/IR/ +compiler/tflchef/tflite/src/Op/include/ # Exclude interpreter kernels which have similar patterns -/ONE/compiler/luci-interpreter/src/kernels -/ONE/compiler/locomotiv/src/Node - -# Test codes -/ONE/tests +compiler/luci-interpreter/src/kernels/ +compiler/locomotiv/src/Node/ # Flatbuffers generated -/ONE/runtime/onert/frontend/circle_schema/include/circle_schema_generated.h -/ONE/runtime/onert/frontend/tflite/src/tflite_schema_generated.h +runtime/libs/circle-schema/include/circle_schema_generated.h +runtime/libs/circle-schema/include/circle_traininfo_generated.h +runtime/onert/core/src/loader/tflite_schema_generated.h # External code: Android NN API -/ONE/runtime/nnapi-header/include/NeuralNetworks.h -/ONE/runtime/nnapi-header/include/NeuralNetworksExtensions.h +runtime/nnapi-header/include/NeuralNetworks.h +runtime/nnapi-header/include/NeuralNetworksExtensions.h # External code: Tensorflow lite -/ONE/runtime/libs/nnapi -/ONE/runtime/libs/profiling +runtime/libs/profiling/ # External code: 3rd party -/ONE/runtime/3rdparty +runtime/3rdparty/ # External code: compute libraries -/ONE/compute +compute/ # Experimental subprojects not for release -/ONE/runtime/contrib +runtime/contrib/ # Downloaded externals -/ONE/externals +externals/ # Intermediate code for runtime build (refer nnfw.spec file's nncc_workspace) -/ONE/build/nncc/ +build/nncc/ diff --git a/.ahub/tcchecker-tca/config.yaml b/.ahub/tcchecker-tca/config.yaml index 12fbabefd..ecae5f5a3 100644 --- a/.ahub/tcchecker-tca/config.yaml +++ b/.ahub/tcchecker-tca/config.yaml @@ -12,10 +12,12 @@ test: - /tests/nnfw_api testFile: - - extension: test.cpp - any: true - - extension: test.cc - any: true + - extension: cpp + ends: + - .test + - extension: cc + ends: + - .test testCase: - condition: - functionName: @@ -136,9 +138,9 @@ test: - /compiler/vconone testFile: - - extension: .test.cpp - any: true - + - extension: cpp + ends: + - .test testCase: - condition: - functionName: diff --git a/.clang-format b/.clang-format index 9243c9a2b..0f20ca442 100644 --- a/.clang-format +++ b/.clang-format @@ -21,6 +21,7 @@ BinPackArguments: true BinPackParameters: true BraceWrapping: AfterClass: true + AfterCaseLabel: true AfterControlStatement: true AfterEnum: true AfterFunction: true diff --git a/.github/workflows/check-format.yml b/.github/workflows/check-format.yml index dcfb8d5e8..ba8d4bff9 100644 --- a/.github/workflows/check-format.yml +++ b/.github/workflows/check-format.yml @@ -21,18 +21,21 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.x' - # C format: clang-format-8 + # C format: clang-format-16 # Python format: yapf==0.22.0 - name: Install packages run: | - sudo apt-get install -y clang-format-8 + sudo apt-get install -y gnupg2 software-properties-common + wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - + sudo add-apt-repository "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" + sudo apt-get update && sudo apt-get install -qqy clang-format-16 python -m pip install --upgrade pip pip install yapf==0.22.0 @@ -41,7 +44,7 @@ jobs: # Upload patch file if failed - name: Store archive - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: name: format-patch @@ -54,7 +57,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: # Fetch all history and branch (default: 1) # Require all history to get file creation date diff --git a/.github/workflows/check-pr-commit.yml b/.github/workflows/check-pr-commit.yml index 7fa84b166..5dc2de1c3 100644 --- a/.github/workflows/check-pr-commit.yml +++ b/.github/workflows/check-pr-commit.yml @@ -24,7 +24,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: # Checkout PR head commit # Checkout Action use merge commit as default diff --git a/.github/workflows/deploy-github-pages.yml b/.github/workflows/deploy-github-pages.yml index d474a2754..5f7a04bd9 100644 --- a/.github/workflows/deploy-github-pages.yml +++ b/.github/workflows/deploy-github-pages.yml @@ -20,7 +20,7 @@ jobs: steps: - name: 'Checkout' - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: 'Generate HTML' uses: mattnotmitt/doxygen-action@v1.9 with: @@ -28,7 +28,7 @@ jobs: - name: 'Tar artifact' run: tar -zcf doxygen.tar.gz -C doxygen/html ./ - name: 'Generate artifact' - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: doxygen-html path: doxygen.tar.gz diff --git a/.github/workflows/run-onert-micro-unit-tests.yml b/.github/workflows/run-onert-micro-unit-tests.yml index 8b27e638b..3256c758c 100644 --- a/.github/workflows/run-onert-micro-unit-tests.yml +++ b/.github/workflows/run-onert-micro-unit-tests.yml @@ -31,7 +31,7 @@ jobs: with: release: '12.2.Rel1' # <-- The compiler release to use - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: # Checkout PR head commit # Checkout Action use merge commit as default @@ -43,5 +43,5 @@ jobs: mkdir build cd build cmake ../infra/onert-micro/ -DENABLE_ONERT_MICRO_TEST=1 -DENABLE_TEST=1 - make -j$(nproc) luci_interpreter_kernels_micro_test - ./onert-micro/eval-driver/luci-interpreter/src/kernels/luci_interpreter_kernels_micro_test + make -j$(nproc) onert_micro_execute_kernels_test + ./onert-micro/eval-driver/onert-micro/src/execute/onert_micro_execute_kernels_test diff --git a/.readthedocs.yml b/.readthedocs.yml index 701a526bd..206831ec0 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,7 +3,13 @@ # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required -version: 1.4 +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-lts-latest + tools: + python: "3" # Build documentation in the docs/ directory with Sphinx sphinx: @@ -17,6 +23,5 @@ formats: # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 install: - requirements: docs/requirements.txt diff --git a/Makefile.template b/Makefile.template index 7621a2f7a..cecafca2b 100644 --- a/Makefile.template +++ b/Makefile.template @@ -8,6 +8,8 @@ HOST_OS?=linux TARGET_OS?=linux COVERAGE_BUILD?=0 OPTIONS?= +OPTIONS_NNCC?= +INSTALL_OPTIONS?= # make TARGET and TYPE to lowercase HOST_ARCH_LC=$(shell echo $(HOST_ARCH) | tr A-Z a-z) @@ -36,6 +38,7 @@ endif ifeq ($(CROSS_BUILD),1) TOOLCHAIN_FILE=cmake/buildtool/cross/toolchain_$(TARGET_ARCH_LC)-$(TARGET_OS).cmake OPTIONS+= -DCMAKE_TOOLCHAIN_FILE=$(TOOLCHAIN_FILE) + OPTIONS_NNCC+= -DCMAKE_TOOLCHAIN_FILE=$(TOOLCHAIN_FILE) endif ifneq ($(filter create-covsuite,$(MAKECMDGOALS)),) @@ -50,10 +53,12 @@ endif ifneq ($(EXTERNAL_VOLUME),) OPTIONS+= -DNNAS_EXTERNALS_DIR=$(EXTERNAL_VOLUME) + OPTIONS_NNCC+= -DNNAS_EXTERNALS_DIR=$(EXTERNAL_VOLUME) endif ifeq ($(TARGET_OS),android) OPTIONS+= -DNDK_DIR=$(NDK_DIR) + OPTIONS_NNCC+= -DNDK_DIR=$(NDK_DIR) endif ifneq ($(ANDROID_BUILD_TOOLS_DIR),) @@ -78,13 +83,19 @@ else NPROCS?=1 endif +ifeq ($(BUILD_TYPE_LC),release) + INSTALL_OPTIONS+= --strip +endif + WORKHOME=$(CURDIR)/Product WORKFOLDER=$(TARGET_ARCH_LC)-$(TARGET_OS).$(BUILD_TYPE_LC) WORKSPACE=$(WORKHOME)/$(WORKFOLDER) +BUILDTOOL_WORKSPACE=$(WORKHOME)/buildtool INSTALL_PATH?=$(WORKSPACE)/out OVERLAY_FOLDER?=$(WORKSPACE)/overlay INSTALL_ALIAS=$(WORKHOME)/out +BUILDTOOL_PATH?=$(BUILDTOOL_WORKSPACE)/out TIMESTAMP_CONFIGURE=$(WORKSPACE)/CONFIGURE TIMESTAMP_BUILD=$(WORKSPACE)/BUILD @@ -104,11 +115,13 @@ export NNCC_WORKSPACE=$(NNCC_FOLDER) ### ### Default target ### -all: install +all: prepare-buildtool prepare-nncc configure build install ### -### Command (public) +### Command (build step) ### +prepare-buildtool: prepare_buildtool_internal + prepare-nncc: prepare_nncc_internal configure: configure_internal @@ -117,13 +130,16 @@ build: build_internal install: install_all_internal -create-package: runtime_tar_internal +### +### Command (public) +### +create-package: all runtime_tar_internal -create-aclpack: acl_tar_internal +create-aclpack: configure acl_tar_internal -create-testsuite: test_suite_internal +create-testsuite: all test_suite_internal -create-covsuite: coverage_suite_internal +create-covsuite: all coverage_suite_internal clean: rm -rf $(WORKSPACE) @@ -133,57 +149,59 @@ distclean: rm -rf externals rm -rf tests/nnapi/src/generated/ -# create_package, create_acl_tar: to be removed -create_package: runtime_tar_internal -create_acl_tar: acl_tar_internal - ### ### Command (internal) ### $(WORKSPACE): mkdir -p $@ +prepare_buildtool_internal: $(WORKSPACE) + cmake -S infra/buildtool -B $(BUILDTOOL_WORKSPACE)/obj -DBUILDTOOL_PATH=$(BUILDTOOL_PATH) + cmake --build $(BUILDTOOL_WORKSPACE)/obj/ -j$(NPROCS) prepare_nncc_internal: $(WORKSPACE) -ifneq ($(CROSS_BUILD),1) - ./nncc configure -DBUILD_GTEST=OFF -DENABLE_TEST=OFF -DEXTERNALS_BUILD_THREADS=$(NPROCS) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ +ifeq (,$(findstring android,$(TARGET_OS))) + EXTERNAL_FLATC=$(BUILDTOOL_PATH)/bin/flatc ./nncc configure -DBUILD_GTEST=OFF -DENABLE_TEST=OFF -DEXTERNALS_BUILD_THREADS=$(NPROCS) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DCMAKE_INSTALL_PREFIX=$(OVERLAY_FOLDER) \ - -DBUILD_WHITELIST="luci;foder;pepper-csv2vec;loco;locop;logo;logo-core;mio-circle06;luci-compute;oops;hermes;hermes-std;angkor;pp;pepper-strcast;pepper-str" + -DBUILD_WHITELIST="luci;foder;pepper-csv2vec;loco;locop;logo;logo-core;mio-circle08;luci-compute;oops;hermes;hermes-std;angkor;pp;pepper-strcast;pepper-str" \ + $(OPTIONS_NNCC) ./nncc build -j$(NPROCS) - cmake --install $(NNCC_FOLDER) + cmake --install $(NNCC_FOLDER) $(INSTALL_OPTIONS) # install angkor TensorIndex and oops InternalExn header (TODO: Remove this) @mkdir -p ${OVERLAY_FOLDER}/include/nncc/core/ADT/tensor @mkdir -p ${OVERLAY_FOLDER}/include/oops + @mkdir -p ${OVERLAY_FOLDER}/include/luci/IR @cp compiler/angkor/include/nncc/core/ADT/tensor/Index.h ${OVERLAY_FOLDER}/include/nncc/core/ADT/tensor @cp compiler/oops/include/oops/InternalExn.h ${OVERLAY_FOLDER}/include/oops -endif + @cp compiler/luci/lang/include/luci/IR/CircleNodes.lst ${OVERLAY_FOLDER}/include/luci/IR @echo "Done prepare-nncc" +endif configure_internal: $(WORKSPACE) ifneq ($(DEBIAN_BUILD),) test -d externals || mkdir -p externals find packaging/ -type f -name "*.tar.gz" | xargs -i tar xf {} -C externals endif - NNFW_INSTALL_PREFIX=$(INSTALL_PATH) ./nnfw configure \ + ./nnfw configure \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE_LC) \ -DNNFW_OVERLAY_DIR=$(OVERLAY_FOLDER) \ -DEXTERNALS_BUILD_THREADS=$(NPROCS) \ $(OPTIONS) -build_internal: configure_internal +build_internal: ./nnfw build -j $(NPROCS) -install_internal: build_internal - ./nnfw install +install_internal: + ./nnfw install --prefix $(INSTALL_PATH) $(INSTALL_OPTIONS) rm -rf $(INSTALL_ALIAS) ln -s $(INSTALL_PATH) $(INSTALL_ALIAS) -runtime_tar_internal: build_internal install_internal +runtime_tar_internal: tar -zcf $(WORKSPACE)/onert-package.tar.gz -C $(INSTALL_PATH) lib tar -zcf $(WORKSPACE)/onert-devel-package.tar.gz -C $(INSTALL_PATH) include/nnfw tar -zcf $(WORKSPACE)/onert-plugin-devel-package.tar.gz -C $(INSTALL_PATH) include/onert tar -zcf $(WORKSPACE)/onert-test-package.tar.gz -C $(INSTALL_PATH) $(shell ls $(INSTALL_PATH) -I lib -I include) -acl_tar_internal: configure_internal +acl_tar_internal: tar -zcf $(WORKSPACE)/onert-acl.tar.gz -C ${OVERLAY_FOLDER} lib/libarm_compute.so lib/libarm_compute_core.so lib/libarm_compute_graph.so install_acl_internal: diff --git a/README.md b/README.md index e3ed259c7..e47d42cf4 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,6 @@ as Tensorflow or PyTorch in a unified form at runtime. - [Background](docs/overview/background.md) - [Roadmap](docs/overview/roadmap.md) -- [Overall Architecture](docs/overview/overall-architecture.md) ## Getting started diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index ef13df857..445979f81 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -71,6 +71,14 @@ function(add_compiler_directory DIR) endforeach(ACCEPTED_DIR) else() set(ENABLE ${BUILD_COMPILER_${PREFIX}}) + if(ENABLE_EXCLUDE_ME) + # skip if "exclude.me" file exist + set(EXCLUDE_ME_FILE "${CMAKE_CURRENT_SOURCE_DIR}/${DIR}/exclude.me") + if(EXISTS ${EXCLUDE_ME_FILE}) + message(STATUS "Exclude ${PREFIX}") + set(ENABLE OFF) + endif() + endif() endif() # This line prevents some errors in this CMakeLists.txt diff --git a/compiler/adtidas/exclude.me b/compiler/adtidas/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/ann-api/exclude.me b/compiler/ann-api/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/ann-ref/exclude.me b/compiler/ann-ref/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/arser/include/arser/arser.h b/compiler/arser/include/arser/arser.h index 43f99dc5e..1a7dd9db1 100644 --- a/compiler/arser/include/arser/arser.h +++ b/compiler/arser/include/arser/arser.h @@ -252,7 +252,16 @@ public: Argument &help(std::string help_message) { - _help_message = help_message; + _help_message.emplace_back(help_message); + return *this; + } + + Argument &help(std::initializer_list help_messages) + { + if (help_messages.size() == 0) + throw std::runtime_error("Empty help message list"); + + _help_message = help_messages; return *this; } @@ -304,7 +313,7 @@ private: std::string _short_name; std::vector _names; std::string _type = "string"; - std::string _help_message; + std::vector _help_message; std::function _func; uint32_t _nargs{1}; bool _is_required{false}; @@ -600,11 +609,14 @@ public: { stream.width(length_of_longest_arg); stream << std::left << arser::internal::make_comma_concatenated(arg._names) << "\t"; - for (size_t i = 0; i < arg._help_message.length(); i += message_width) + for (size_t i = 0; i < arg._help_message.size(); i++) { - if (i) - stream << std::string(length_of_longest_arg, ' ') << "\t"; - stream << arg._help_message.substr(i, message_width) << std::endl; + for (size_t j = 0; j < arg._help_message[i].length(); j += message_width) + { + if (i || j) + stream << std::string(length_of_longest_arg, ' ') << "\t"; + stream << arg._help_message[i].substr(j, message_width) << std::endl; + } } } std::cout << std::endl; @@ -638,19 +650,22 @@ template T Arser::get_impl(const std::string &arg_name, T *) throw std::runtime_error( "Type mismatch. " "You called get using a type different from the one you specified." - "Accumulated argument is returned as std::vector of the specified type"); + "Accumulated argument is returned as std::vector of the specified type: " + + arg_name); if (arg->second->_type != TypeName::Get()) throw std::runtime_error("Type mismatch. " "You called get() method with a type different " "from the one you specified. " "Please check the type of what you specified in " - "add_argument() method."); + "add_argument() method: " + + arg_name); if (arg->second->_values.size() == 0) throw std::runtime_error("Wrong access. " "You must make sure that the argument is given before accessing it. " - "You can do it by calling arser[\"argument\"]."); + "You can do it by calling arser[\"" + + arg_name + "\"]."); return internal::lexical_cast(arg->second->_values[0]); } @@ -667,8 +682,10 @@ template std::vector Arser::get_impl(const std::string &arg_name if (arg->second->_is_accumulated) { if (arg->second->_type != TypeName::Get()) - throw std::runtime_error("Type mismatch. " - "You called get using a type different from the one you specified."); + throw std::runtime_error( + "Type mismatch. " + "You called get using a type different from the one you specified: " + + arg_name); std::vector data; for (auto values : arg->second->_accum_values) @@ -680,8 +697,10 @@ template std::vector Arser::get_impl(const std::string &arg_name } if (arg->second->_type != TypeName>::Get()) - throw std::runtime_error("Type mismatch. " - "You called get using a type different from the one you specified."); + throw std::runtime_error( + "Type mismatch. " + ". You called get using a type different from the one you specified: " + + arg_name); std::vector data; std::transform(arg->second->_values.begin(), arg->second->_values.end(), std::back_inserter(data), @@ -702,13 +721,15 @@ std::vector> Arser::get_impl(const std::string &arg_name, if (not arg->second->_is_accumulated) throw std::runtime_error("Type mismatch. " - "You called get using a type different from the one you specified."); + "You called get using a type different from the one you specified: " + + arg_name); if (arg->second->_type != TypeName>::Get()) throw std::runtime_error( "Type mismatch. " "You called get using a type different from the one you specified." - "Accumulated argument is returned as std::vector of the specified type"); + "Accumulated argument is returned as std::vector of the specified type: " + + arg_name); std::vector> result; for (auto values : arg->second->_accum_values) diff --git a/compiler/arser/tests/arser.test.cpp b/compiler/arser/tests/arser.test.cpp index 63121b845..1357b8155 100644 --- a/compiler/arser/tests/arser.test.cpp +++ b/compiler/arser/tests/arser.test.cpp @@ -478,3 +478,38 @@ TEST(BasicTest, AccumulateScalarOptions_WrongType_NEG) EXPECT_THROW(arser.get("--specify"), std::runtime_error); } + +TEST(HelpMessageTest, MultilineHelp) +{ + /* arrange */ + Arser arser; + + arser.add_argument("-v", "--verbose") + .nargs(0) + .help({"Provides additional details", "Default: No"}); + + std::ostringstream oss; + std::string expected_out = "Usage: ./arser [-h] [-v] \n" + "\n" + "[Optional argument]\n" + "-h, --help \tShow help message and exit\n" + "-v, --verbose\tProvides additional details\n" + " \tDefault: No\n"; + + test::Prompt prompt("./arser -v"); + /* act */ + arser.parse(prompt.argc(), prompt.argv()); + oss << arser; + + /* assert */ + EXPECT_EQ(expected_out, oss.str()); +} + +TEST(HelpMessageTest, MultilineHelpEmpty_NEG) +{ + /* arrange */ + Arser arser; + std::initializer_list help_msg = {}; + + EXPECT_THROW(arser.add_argument("-v", "--verbose").nargs(0).help(help_msg), std::runtime_error); +} diff --git a/compiler/bino/exclude.me b/compiler/bino/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/caffe2circle/exclude.me b/compiler/caffe2circle/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/caffegen/exclude.me b/compiler/caffegen/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/circle-eval-diff/src/InputDataLoader.cpp b/compiler/circle-eval-diff/src/InputDataLoader.cpp index 7b491a37a..231f25113 100644 --- a/compiler/circle-eval-diff/src/InputDataLoader.cpp +++ b/compiler/circle-eval-diff/src/InputDataLoader.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -55,7 +56,7 @@ std::vector getEachByteSizeOf(const std::vector &nodes) for (const auto node : nodes) { const auto input_node = loco::must_cast(node); - const auto dtype_size = loco::size(input_node->dtype()); + const auto dtype_size = luci::size(input_node->dtype()); size_t element_size = 1; for (uint32_t index = 0; index < input_node->rank(); index++) @@ -76,7 +77,7 @@ size_t getTotalByteSizeOf(const std::vector &nodes) for (const auto node : nodes) { const auto input_node = loco::must_cast(node); - size_t byte_size = loco::size(input_node->dtype()); + size_t byte_size = luci::size(input_node->dtype()); for (uint32_t index = 0; index < input_node->rank(); index++) { diff --git a/compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h b/compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h index 811aa67c3..988015ad5 100644 --- a/compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h +++ b/compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h @@ -18,6 +18,7 @@ #define CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_LINUX_H #include "IScratchpadHelper.h" +#include #include namespace circle_planner @@ -45,14 +46,14 @@ public: for (int32_t i = 0; i < lhs->rank(); ++i) scratchpad_size *= lhs->dim(i).value(); - scratchpad_sizes.push_back(scratchpad_size * loco::size(lhs->dtype())); + scratchpad_sizes.push_back(scratchpad_size * luci::size(lhs->dtype())); // Scratchpad for rhs scratchpad_size = 1; for (int32_t i = 0; i < rhs->rank(); ++i) scratchpad_size *= rhs->dim(i).value(); - scratchpad_sizes.push_back(scratchpad_size * loco::size(rhs->dtype())); + scratchpad_sizes.push_back(scratchpad_size * luci::size(rhs->dtype())); return scratchpad_sizes; } diff --git a/compiler/circle-inspect/CMakeLists.txt b/compiler/circle-inspect/CMakeLists.txt index 8edfde483..76e65ddc6 100644 --- a/compiler/circle-inspect/CMakeLists.txt +++ b/compiler/circle-inspect/CMakeLists.txt @@ -1,6 +1,6 @@ -if(NOT TARGET mio_circle06) +if(NOT TARGET mio_circle08) return() -endif(NOT TARGET mio_circle06) +endif(NOT TARGET mio_circle08) set(DRIVER "driver/Driver.cpp") @@ -10,6 +10,6 @@ add_executable(circle-inspect ${DRIVER} ${SOURCES}) target_include_directories(circle-inspect PRIVATE src) target_link_libraries(circle-inspect arser) target_link_libraries(circle-inspect foder) -target_link_libraries(circle-inspect mio_circle06) -target_link_libraries(circle-inspect mio_circle06_helper) +target_link_libraries(circle-inspect mio_circle08) +target_link_libraries(circle-inspect mio_circle08_helper) target_link_libraries(circle-inspect safemain) diff --git a/compiler/circle-inspect/requires.cmake b/compiler/circle-inspect/requires.cmake index b3a2638ef..8a57c8f11 100644 --- a/compiler/circle-inspect/requires.cmake +++ b/compiler/circle-inspect/requires.cmake @@ -1,4 +1,4 @@ require("arser") require("foder") -require("mio-circle06") +require("mio-circle08") require("safemain") diff --git a/compiler/circle-inspect/src/Dump.cpp b/compiler/circle-inspect/src/Dump.cpp index 868fc2ba8..373ac67a2 100644 --- a/compiler/circle-inspect/src/Dump.cpp +++ b/compiler/circle-inspect/src/Dump.cpp @@ -36,7 +36,7 @@ void DumpOperators::run(std::ostream &os, const circle::Model *model) auto ops = reader.operators(); // dump operators - for (uint32_t i = 0; i < ops->Length(); ++i) + for (uint32_t i = 0; i < ops->size(); ++i) { const auto op = ops->Get(i); @@ -56,7 +56,7 @@ const circle::Operator *operator_match_output(mio::circle::Reader &reader, const { auto ops = reader.operators(); - for (uint32_t i = 0; i < ops->Length(); ++i) + for (uint32_t i = 0; i < ops->size(); ++i) { const auto op = ops->Get(i); @@ -75,7 +75,7 @@ size_t tensor_buffer_size(mio::circle::Reader &reader, const int32_t tensor_id) { auto tensors = reader.tensors(); - if (tensor_id < 0 || tensor_id >= tensors->Length()) + if (tensor_id < 0 || tensor_id >= tensors->size()) { throw std::runtime_error("Invalid Tensor ID"); } @@ -105,7 +105,7 @@ void DumpConv2DWeight::run(std::ostream &os, const circle::Model *model) auto ops = reader.operators(); // dump Conv2D, DepthwiseConv2D and its weight input operator - for (uint32_t i = 0; i < ops->Length(); ++i) + for (uint32_t i = 0; i < ops->size(); ++i) { const auto op = ops->Get(i); auto bc = reader.builtin_code(op); @@ -158,7 +158,7 @@ void DumpOperatorVersion::run(std::ostream &os, const circle::Model *model) auto ops = reader.operators(); // Dump operators' version - for (uint32_t i = 0; i < ops->Length(); ++i) + for (uint32_t i = 0; i < ops->size(); ++i) { const auto op = ops->Get(i); @@ -192,7 +192,7 @@ void DumpTensorDType::run(std::ostream &os, const circle::Model *model) reader.select_subgraph(g); auto tensors = reader.tensors(); - for (uint32_t i = 0; i < tensors->Length(); ++i) + for (uint32_t i = 0; i < tensors->size(); ++i) { const auto tensor = tensors->Get(i); @@ -217,7 +217,7 @@ void DumpConstants::run(std::ostream &os, const circle::Model *model) reader.select_subgraph(g); auto tensors = reader.tensors(); - for (uint32_t i = 0; i < tensors->Length(); ++i) + for (uint32_t i = 0; i < tensors->size(); ++i) { const auto tensor = tensors->Get(i); if (tensor->is_variable()) diff --git a/compiler/circle-interpreter-test/requires.cmake b/compiler/circle-interpreter-test/requires.cmake index 5ca5749ca..da5d3b688 100644 --- a/compiler/circle-interpreter-test/requires.cmake +++ b/compiler/circle-interpreter-test/requires.cmake @@ -1,3 +1,3 @@ require("common-artifacts") require("circle-interpreter") -require("luci-value-test") +require("luci-value-py-test") diff --git a/compiler/circle-interpreter/src/CircleInterpreter.cpp b/compiler/circle-interpreter/src/CircleInterpreter.cpp index 48c29a581..9c4c87982 100644 --- a/compiler/circle-interpreter/src/CircleInterpreter.cpp +++ b/compiler/circle-interpreter/src/CircleInterpreter.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -51,7 +52,7 @@ void writeDataToFile(const std::string &filename, const char *data, size_t data_ template size_t getTensorSize(const NodeT *node) { - uint32_t tensor_size = loco::size(node->dtype()); + uint32_t tensor_size = luci::size(node->dtype()); for (uint32_t i = 0; i < node->rank(); ++i) tensor_size *= node->dim(i).value(); return tensor_size; @@ -135,7 +136,8 @@ int entry(int argc, char **argv) for (int i = 0; i < module->graph()->outputs()->size(); i++) { const auto *output_node = loco::must_cast(output_nodes[i]); - std::vector output_data(getTensorSize(output_node)); + size_t output_size = interpreter.getOutputTensorSize(output_node); + std::vector output_data(output_size); interpreter.readOutputTensor(output_node, output_data.data(), output_data.size()); // Output data is written in ${output_file}n diff --git a/compiler/circle-mpqsolver/CMakeLists.txt b/compiler/circle-mpqsolver/CMakeLists.txt index 9af9fc2a3..ffc7f3731 100644 --- a/compiler/circle-mpqsolver/CMakeLists.txt +++ b/compiler/circle-mpqsolver/CMakeLists.txt @@ -37,6 +37,15 @@ set(TEST_SOURCES "src/core/Quantizer.cpp" "src/bisection/VISQErrorApproximator.cpp" "src/core/ErrorMetric.cpp" + "src/pattern/PatternResolver.cpp" + "src/pattern/PatternSolver.cpp" + "src/core/Dumper.cpp" + "src/core/DumpingHooks.cpp" + "src/core/Evaluator.cpp" + "src/MPQSolver.cpp" + "src/core/SolverOutput.cpp" + "src/bisection/BisectionSolver.cpp" + "src/core/DataProvider.cpp" ) nnas_find_package(GTest REQUIRED) @@ -46,3 +55,8 @@ target_include_directories(circle_mpqsolver_test PRIVATE ${Jsoncpp_INCLUDE_DIRS} target_link_libraries(circle_mpqsolver_test ${Jsoncpp_STATIC_LIB}) target_link_libraries(circle_mpqsolver_test luci_service) target_link_libraries(circle_mpqsolver_test luci_pass) +target_link_libraries(circle_mpqsolver_test luci_testhelper) +target_link_libraries(circle_mpqsolver_test luci_import) +target_link_libraries(circle_mpqsolver_test luci_export) +target_link_libraries(circle_mpqsolver_test luci_interpreter) +target_link_libraries(circle_mpqsolver_test dio_hdf5) diff --git a/compiler/circle-mpqsolver/src/CircleMPQSolver.cpp b/compiler/circle-mpqsolver/src/CircleMPQSolver.cpp index 12981be40..0e0ff7e14 100644 --- a/compiler/circle-mpqsolver/src/CircleMPQSolver.cpp +++ b/compiler/circle-mpqsolver/src/CircleMPQSolver.cpp @@ -20,7 +20,11 @@ #include #include "bisection/BisectionSolver.h" -#include +#include "core/DataProvider.h" +#include "core/SolverOutput.h" +#include "pattern/PatternSolver.h" +#include "core/SolverOutput.h" +#include "core/Quantizer.h" #include #include @@ -47,6 +51,9 @@ int handleAutoAlgorithm(arser::Arser &arser, mpqsolver::bisection::BisectionSolv int entry(int argc, char **argv) { const std::string bisection_str = "--bisection"; + const std::string patterns_str = "--patterns"; + const std::string layernorm_str = "--u8_layernorm_with_s16_variance"; + const std::string softmax_str = "--u8_softmax_with_s16_sub_exp"; const std::string save_intermediate_str = "--save_intermediate"; arser::Arser arser("circle-mpqsolver provides light-weight methods for finding a high-quality " @@ -55,7 +62,8 @@ int entry(int argc, char **argv) arser::Helper::add_version(arser, print_version); arser::Helper::add_verbose(arser); - arser.add_argument("--data").required(true).help("Path to the test data"); + // if patterns are set we don't need data + arser.add_argument("--data").required(false).default_value("").help("Path to the test data"); arser.add_argument("--data_format").required(false).help("Test data format (default: h5)"); arser.add_argument("--qerror_ratio") @@ -65,10 +73,26 @@ int entry(int argc, char **argv) arser.add_argument(bisection_str) .nargs(1) + .required(false) .type(arser::DataType::STR) .help("Single optional argument for bisection method. " "Whether input node should be quantized to Q16: 'auto', 'true', 'false'."); + arser.add_argument(patterns_str) + .nargs(0) + .required(false) + .help("Argument to define patterns applied (LayerNorm is the only supported) "); + + arser.add_argument(layernorm_str) + .nargs(0) + .required(false) + .help("Use int16 for computing variance in uint8 layer normalization"); + + arser.add_argument(softmax_str) + .nargs(0) + .required(false) + .help("Use int16 for computing sub and exp in uint8 softmax"); + arser.add_argument("--input_model") .required(true) .help("Input float model with min max initialized"); @@ -83,6 +107,30 @@ int entry(int argc, char **argv) .default_value("uint8") .help("Data type of quantized model's outputs (default: uint8)"); + arser.add_argument("--quantized_dtype") + .type(arser::DataType::STR) + .default_value("uint8") + .help("Data type of quantized model (supported: uint8 (default), int16)"); + + arser.add_argument("--granularity") + .type(arser::DataType::STR) + .default_value("channel") + .help("Granularity of quantization scheme on weight (supported: layer, channel (default)). " + "Activation is quantized per layer."); + + arser.add_argument("--TF-style_maxpool") + .nargs(0) + .default_value(false) + .help("Force MaxPool Op to have the same input/output quantparams. NOTE: This feature can " + "degrade accuracy of some models"); + + arser.add_argument("--save_min_max") + .nargs(0) + .default_value(false) + .help("Save recorded min/max values."); + + // TODO Support --config + arser.add_argument("--output_model").required(true).help("Output quantized model"); arser.add_argument("--visq_file") @@ -119,6 +167,10 @@ int entry(int argc, char **argv) auto output_model_path = arser.get("--output_model"); auto input_dtype = arser.get("--input_dtype"); auto output_dtype = arser.get("--output_dtype"); + auto quantized_dtype = arser.get("--quantized_dtype"); + auto granularity = arser.get("--granularity"); + auto TF_style_maxpool = arser["--TF-style_maxpool"] and arser.get("--TF-style_maxpool"); + auto save_min_max = arser["--save_min_max"] and arser.get("--save_min_max"); float qerror_ratio = arser.get("--qerror_ratio"); if (qerror_ratio < 0.0 || qerror_ratio > 1.f) @@ -127,24 +179,54 @@ int entry(int argc, char **argv) return EXIT_FAILURE; } + if (arser[bisection_str] && arser[patterns_str]) + { + // only one solver can be used for now + std::cerr << "ERROR: only one method is allowed to use" << std::endl; + return EXIT_FAILURE; + } + SolverOutput::get() << ">> Searching mixed precision configuration \n" << "model:" << input_model_path << "\n" << "dataset: " << data_path << "\n" + << "quantized dtype: " << quantized_dtype << "\n" + << "granularity: " << granularity << "\n" << "input dtype: " << input_dtype << "\n" - << "output dtype: " << output_dtype << "\n"; + << "output dtype: " << output_dtype << "\n" + << "TF_style_maxpool: " << (TF_style_maxpool ? "True" : "False") << "\n" + << "save_min_max: " << (save_min_max ? "True" : "False") << "\n"; + + std::unique_ptr solver; + + // Create quantizer parameters + mpqsolver::core::Quantizer::Context ctx; + { + ctx.output_model_dtype = quantized_dtype; + ctx.granularity = granularity; + ctx.input_type = input_dtype; + ctx.output_type = output_dtype; + ctx.save_min_max = save_min_max; + ctx.TF_style_maxpool = TF_style_maxpool; + } if (arser[bisection_str]) { // optimize + SolverOutput::get() << "Automatic mixed quantization using bisection\n"; + using namespace mpqsolver::bisection; - BisectionSolver solver(data_path, qerror_ratio, input_dtype, output_dtype); + auto bi_solver = std::make_unique(ctx, qerror_ratio); + auto input_data = + std::make_unique(data_path, input_model_path); + bi_solver->setInputData(std::move(input_data)); + { auto value = arser.get(bisection_str); if (value == "auto") { SolverOutput::get() << "algorithm: bisection (auto)\n"; - if (!handleAutoAlgorithm(arser, solver)) + if (!handleAutoAlgorithm(arser, *bi_solver)) { return EXIT_FAILURE; } @@ -152,12 +234,12 @@ int entry(int argc, char **argv) else if (value == "true") { SolverOutput::get() << "algorithm: bisection (Q16AtFront)"; - solver.algorithm(BisectionSolver::Algorithm::ForceQ16Front); + bi_solver->algorithm(BisectionSolver::Algorithm::ForceQ16Front); } else if (value == "false") { SolverOutput::get() << "algorithm: bisection (Q8AtFront)"; - solver.algorithm(BisectionSolver::Algorithm::ForceQ16Back); + bi_solver->algorithm(BisectionSolver::Algorithm::ForceQ16Back); } else { @@ -167,37 +249,26 @@ int entry(int argc, char **argv) } } - if (arser[save_intermediate_str]) - { - auto data_path = arser.get(save_intermediate_str); - if (!data_path.empty()) - { - solver.set_save_intermediate(data_path); - } - } - SolverOutput::get() << "qerror metric: MAE\n" << "target qerror ratio: " << qerror_ratio << "\n"; - auto optimized = solver.run(input_model_path); - if (optimized == nullptr) + solver.reset(bi_solver.release()); + } + else if (arser[patterns_str]) + { + SolverOutput::get() << "Automatic mixed quantization using patterns\n"; + + std::vector patterns; + if (arser[layernorm_str]) { - std::cerr << "ERROR: Failed to build mixed precision model" << input_model_path << std::endl; - return EXIT_FAILURE; + patterns.push_back(mpqsolver::pattern::QuantizationPattern::Q8LayerNormWithQ16Variance); } - - // save optimized + if (arser[softmax_str]) { - SolverOutput::get() << "Saving output model to " << output_model_path << "\n"; - luci::CircleExporter exporter; - luci::CircleFileExpContract contract(optimized.get(), output_model_path); - if (!exporter.invoke(&contract)) - { - std::cerr << "ERROR: Failed to export mixed precision model" << input_model_path - << std::endl; - return EXIT_FAILURE; - } + patterns.push_back(mpqsolver::pattern::QuantizationPattern::Q8SoftmaxWithQ16SubExp); } + + solver = std::make_unique(ctx, patterns); } else { @@ -205,5 +276,33 @@ int entry(int argc, char **argv) return EXIT_FAILURE; } + if (arser[save_intermediate_str]) + { + auto data_path = arser.get(save_intermediate_str); + if (!data_path.empty()) + { + solver->setSaveIntermediate(data_path); + } + } + + auto optimized = solver->run(input_model_path); + if (optimized == nullptr) + { + std::cerr << "ERROR: Failed to build mixed precision model" << input_model_path << std::endl; + return EXIT_FAILURE; + } + + // save optimized + { + SolverOutput::get() << "Saving output model to " << output_model_path << "\n"; + luci::CircleExporter exporter; + luci::CircleFileExpContract contract(optimized.get(), output_model_path); + if (!exporter.invoke(&contract)) + { + std::cerr << "ERROR: Failed to export mixed precision model" << input_model_path << std::endl; + return EXIT_FAILURE; + } + } + return EXIT_SUCCESS; } diff --git a/compiler/circle-mpqsolver/src/MPQSolver.cpp b/compiler/circle-mpqsolver/src/MPQSolver.cpp index 10cfbb65f..7ed749307 100644 --- a/compiler/circle-mpqsolver/src/MPQSolver.cpp +++ b/compiler/circle-mpqsolver/src/MPQSolver.cpp @@ -16,16 +16,29 @@ #include "MPQSolver.h" +#include +#include + using namespace mpqsolver; -MPQSolver::MPQSolver(const std::string &input_data_path, float qerror_ratio, - const std::string &input_quantization, const std::string &output_quantization) - : _input_data_path(input_data_path), _qerror_ratio(qerror_ratio), - _input_quantization(input_quantization), _output_quantization(output_quantization) +MPQSolver::MPQSolver(const core::Quantizer::Context &ctx) { + _quantizer = std::make_unique(ctx); } -void MPQSolver::set_save_intermediate(const std::string &save_path) +void MPQSolver::setSaveIntermediate(const std::string &save_path) { - _hooks = std::make_unique(save_path); + _hooks = std::make_unique(save_path, _quantizer->getContext()); +} + +std::unique_ptr MPQSolver::readModule(const std::string &path) +{ + luci::ImporterEx importerex; + auto module = importerex.importVerifyModule(path); + if (module.get() == nullptr) + { + throw std::runtime_error("Failed to load model"); + } + + return module; } diff --git a/compiler/circle-mpqsolver/src/MPQSolver.h b/compiler/circle-mpqsolver/src/MPQSolver.h index 6c5d25dad..09038064a 100644 --- a/compiler/circle-mpqsolver/src/MPQSolver.h +++ b/compiler/circle-mpqsolver/src/MPQSolver.h @@ -14,10 +14,13 @@ * limitations under the License. */ -#ifndef __MPQSOLVER_MPQSOLEVR_SOLVER_H__ -#define __MPQSOLVER_MPQSOLEVR_SOLVER_H__ +#ifndef __MPQSOLVER_MPQSOLVER_SOLVER_H__ +#define __MPQSOLVER_MPQSOLVER_SOLVER_H__ -#include +#include "core/Quantizer.h" +#include "core/DumpingHooks.h" + +#include #include #include @@ -28,13 +31,8 @@ namespace mpqsolver class MPQSolver { public: - /** - * @brief construct Solver using input_data_path for .h5 file, - * qerror_ratio to set target qerror, and input_quantization/output_quantization to set - * quantization type at input/output respectively - */ - MPQSolver(const std::string &input_data_path, float qerror_ratio, - const std::string &input_quantization, const std::string &output_quantization); + MPQSolver(const core::Quantizer::Context &ctx); + virtual ~MPQSolver() = default; /** @@ -45,16 +43,18 @@ public: /** * @brief set all intermediate artifacts to be saved */ - void set_save_intermediate(const std::string &save_path); + void setSaveIntermediate(const std::string &save_path); + +protected: + std::unique_ptr readModule(const std::string &path); protected: - std::string _input_data_path; std::string _input_quantization; std::string _output_quantization; - float _qerror_ratio = 0.f; // quantization error ratio + std::unique_ptr _quantizer; std::unique_ptr _hooks; }; } // namespace mpqsolver -#endif //__MPQSOLVER_MPQSOLEVR_SOLVER_H__ +#endif //__MPQSOLVER_MPQSOLVER_SOLVER_H__ diff --git a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp index 976dac550..d36ccfe8f 100644 --- a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp +++ b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp @@ -18,8 +18,9 @@ #include "DepthParameterizer.h" #include "VISQErrorApproximator.h" -#include -#include +#include "core/DataProvider.h" +#include "core/ErrorMetric.h" +#include "core/SolverOutput.h" #include @@ -72,36 +73,22 @@ bool front_has_higher_error(const NodeDepthType &nodes_depth, const std::string return error_at_input > error_at_output; } -std::unique_ptr read_module(const std::string &path) -{ - luci::ImporterEx importerex; - auto module = importerex.importVerifyModule(path); - if (module.get() == nullptr) - { - std::cerr << "ERROR: Failed to load " << path << std::endl; - return nullptr; - } - - return module; -} - } // namespace -BisectionSolver::BisectionSolver(const std::string &input_data_path, float qerror_ratio, - const std::string &input_quantization, - const std::string &output_quantization) - : MPQSolver(input_data_path, qerror_ratio, input_quantization, output_quantization) +BisectionSolver::BisectionSolver(const mpqsolver::core::Quantizer::Context &ctx, float qerror_ratio) + : MPQSolver(ctx), _qerror_ratio(qerror_ratio) { - _quantizer = std::make_unique(_input_quantization, _output_quantization); } float BisectionSolver::evaluate(const core::DatasetEvaluator &evaluator, const std::string &flt_path, const std::string &def_quant, core::LayerParams &layers) { - auto model = read_module(flt_path); + auto model = readModule(flt_path); + assert(model != nullptr); + // get fake quantized model for evaluation - if (!_quantizer->fake_quantize(model.get(), def_quant, layers)) + if (!_quantizer->fakeQuantize(model.get(), def_quant, layers)) { throw std::runtime_error("Failed to produce fake-quantized model."); } @@ -113,9 +100,15 @@ void BisectionSolver::algorithm(Algorithm algorithm) { _algorithm = algorithm; } void BisectionSolver::setVisqPath(const std::string &visq_path) { _visq_data_path = visq_path; } +void BisectionSolver::setInputData(std::unique_ptr &&data) +{ + _input_data = std::move(data); +} + std::unique_ptr BisectionSolver::run(const std::string &module_path) { - auto module = read_module(module_path); + auto module = readModule(module_path); + assert(module != nullptr); float min_depth = 0.f; float max_depth = 0.f; @@ -130,7 +123,11 @@ std::unique_ptr BisectionSolver::run(const std::string &module_pat SolverOutput::get() << "\n>> Computing baseline qerrors\n"; std::unique_ptr metric = std::make_unique(); - core::DatasetEvaluator evaluator(module.get(), _input_data_path, *metric.get()); + if (!_input_data) + { + throw std::runtime_error("no input data"); + } + core::DatasetEvaluator evaluator(module.get(), *_input_data.get(), *metric.get()); core::LayerParams layer_params; float int16_qerror = @@ -140,10 +137,10 @@ std::unique_ptr BisectionSolver::run(const std::string &module_pat float uint8_qerror = evaluate(evaluator, module_path, "uint8" /* default quant_dtype */, layer_params); SolverOutput::get() << "Full uint8 model qerror: " << uint8_qerror << "\n"; - _quantizer->set_hook(_hooks.get()); + _quantizer->setHook(_hooks.get()); if (_hooks) { - _hooks->on_begin_solver(module_path, uint8_qerror, int16_qerror); + _hooks->onBeginSolver(module_path, uint8_qerror, int16_qerror); } if (int16_qerror > uint8_qerror) @@ -154,19 +151,46 @@ std::unique_ptr BisectionSolver::run(const std::string &module_pat _qerror = int16_qerror + _qerror_ratio * std::fabs(uint8_qerror - int16_qerror); SolverOutput::get() << "Target qerror: " << _qerror << "\n"; - if (uint8_qerror <= _qerror) + // it'is assumed that int16_qerror <= _qerror <= uint8_qerror, + if (int16_qerror >= _qerror) + { + // return Q16 model (we can not make it more accurate) + if (!_quantizer->quantize(module.get(), "int16", layer_params)) + { + std::cerr << "ERROR: Failed to quantize model" << std::endl; + return nullptr; + } + + if (_hooks) + { + _hooks->onEndSolver(layer_params, "int16", int16_qerror); + } + + SolverOutput::get() << "The best configuration is int16 configuration\n"; + return module; + } + else if (uint8_qerror <= _qerror) { - // no need for bisectioning just return Q8 model + // return Q8 model (we can not make it less accurate) if (!_quantizer->quantize(module.get(), "uint8", layer_params)) { std::cerr << "ERROR: Failed to quantize model" << std::endl; return nullptr; } + + if (_hooks) + { + _hooks->onEndSolver(layer_params, "uint8", uint8_qerror); + } + + SolverOutput::get() << "The best configuration is uint8 configuration\n"; + return module; } + // search for optimal mixed precision quantization configuration int last_depth = -1; float best_depth = -1; - float best_accuracy = -1; + float best_error = -1; // minimal error core::LayerParams best_params; if (module->size() != 1) { @@ -207,11 +231,6 @@ std::unique_ptr BisectionSolver::run(const std::string &module_pat while (true) { - if (_hooks) - { - _hooks->on_begin_iteration(); - } - int cut_depth = static_cast(std::floor(0.5f * (min_depth + max_depth))); if (last_depth == cut_depth) @@ -219,6 +238,11 @@ std::unique_ptr BisectionSolver::run(const std::string &module_pat break; } + if (_hooks) + { + _hooks->onBeginIteration(); + } + SolverOutput::get() << "Looking for the optimal configuration in [" << min_depth << " , " << max_depth << "] depth segment\n"; @@ -249,34 +273,34 @@ std::unique_ptr BisectionSolver::run(const std::string &module_pat } } - float cur_accuracy = evaluate(evaluator, module_path, "uint8", layer_params); + float cur_error = evaluate(evaluator, module_path, "uint8", layer_params); if (_hooks) { - _hooks->on_end_iteration(layer_params, "uint8", cur_accuracy); + _hooks->onEndIteration(layer_params, "uint8", cur_error); } - if (cur_accuracy < _qerror) + if (cur_error < _qerror) { - SolverOutput::get() << "Qerror at depth " << cut_depth << " is " << cur_accuracy + SolverOutput::get() << "Qerror at depth " << cut_depth << " is " << cur_error << " < target qerror (" << _qerror << ")\n"; int16_front ? (max_depth = cut_depth) : (min_depth = cut_depth); best_params = layer_params; best_depth = cut_depth; - best_accuracy = cur_accuracy; + best_error = cur_error; } else { - SolverOutput::get() << "Qerror at depth " << cut_depth << " is " << cur_accuracy - << (cur_accuracy > _qerror ? " > " : " == ") << "target qerror (" - << _qerror << ")\n"; + SolverOutput::get() << "Qerror at depth " << cut_depth << " is " << cur_error + << (cur_error > _qerror ? " > " : " == ") << "target qerror (" << _qerror + << ")\n"; int16_front ? (min_depth = cut_depth) : (max_depth = cut_depth); } } if (_hooks) { - _hooks->on_end_solver(best_params, "uint8", best_accuracy); + _hooks->onEndSolver(best_params, "uint8", best_error); } SolverOutput::get() << "Found the best configuration at depth " << best_depth << "\n"; diff --git a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h index 83851c0c8..a3191c584 100644 --- a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h +++ b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h @@ -17,9 +17,8 @@ #ifndef __MPQSOLVER_BISECTION_SOLVER_H__ #define __MPQSOLVER_BISECTION_SOLVER_H__ -#include -#include -#include +#include "core/Evaluator.h" +#include "MPQSolver.h" #include @@ -46,12 +45,13 @@ public: public: /** - * @brief construct Solver using input_data_path for .h5 file, - * qerror_ratio to set target qerror, and input_quantization/output_quantization to set - * quantization type at input/output respectively + * @brief Construct a new Bisection Solver object + * + * @param ctx - quantizer context + * @param qerror_ratio - target error ratio */ - BisectionSolver(const std::string &input_data_path, float qerror_ratio, - const std::string &input_quantization, const std::string &output_quantization); + BisectionSolver(const mpqsolver::core::Quantizer::Context &ctx, float qerror_ratio); + BisectionSolver() = delete; /** @@ -59,6 +59,11 @@ public: */ std::unique_ptr run(const std::string &module_path) override; + /** + * @brief set data provider + */ + void setInputData(std::unique_ptr &&data); + /** * @brief set used algorithm */ @@ -76,10 +81,11 @@ private: const std::string &def_quant, core::LayerParams &layers); private: - float _qerror = 0.f; // quantization error + const float _qerror_ratio = 0.f; // quantization error ratio + float _qerror = 0.f; // quantization error Algorithm _algorithm = Algorithm::ForceQ16Front; - std::unique_ptr _quantizer; std::string _visq_data_path; + std::unique_ptr _input_data; }; } // namespace bisection diff --git a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.test.cpp b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.test.cpp new file mode 100644 index 000000000..a7a7d2473 --- /dev/null +++ b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.test.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "BisectionSolver.h" + +#include "core/SolverOutput.h" +#include "core/TestHelper.h" + +#include +#include + +namespace +{ + +class CircleMPQSolverBisectionSolverTestF : public ::testing::Test +{ +public: + CircleMPQSolverBisectionSolverTestF() + { + char module_template[] = "CircleMPQSolverBisectionSolverTest-CIRCLE-XXXXXX"; + mpqsolver::test::io_utils::makeTemporaryFile(module_template); + _module_path = module_template; + } + + ~CircleMPQSolverBisectionSolverTestF() { unlink(_module_path.c_str()); } + +protected: + mpqsolver::test::models::AddGraph _g; + std::string _module_path; +}; + +} // namespace + +TEST_F(CircleMPQSolverBisectionSolverTestF, verifyResultsTest) +{ + // create network + auto m = luci::make_module(); + _g.init(); + _g.transfer_to(m.get()); + + // export to _module_path + luci::CircleExporter exporter; + luci::CircleFileExpContract contract(m.get(), _module_path); + EXPECT_TRUE(exporter.invoke(&contract)); + + // create solver + mpqsolver::core::Quantizer::Context ctx; + mpqsolver::bisection::BisectionSolver solver(ctx, 0.5); + auto data = mpqsolver::test::data_utils::getAllZeroSingleDataProvider(); + solver.setInputData(std::move(data)); + solver.algorithm(mpqsolver::bisection::BisectionSolver::Algorithm::ForceQ16Back); + SolverOutput::get().TurnOn(false); + + // run solver + auto res = solver.run(_module_path); + EXPECT_TRUE(res.get() != nullptr); +} + +TEST(CircleMPQSolverBisectionSolverTest, empty_path_NEG) +{ + mpqsolver::core::Quantizer::Context ctx; + mpqsolver::bisection::BisectionSolver solver(ctx, 0.0); + solver.algorithm(mpqsolver::bisection::BisectionSolver::Algorithm::ForceQ16Back); + EXPECT_ANY_THROW(solver.run("")); +} diff --git a/compiler/circle-mpqsolver/src/bisection/DepthParameterizer.test.cpp b/compiler/circle-mpqsolver/src/bisection/DepthParameterizer.test.cpp index 504032d6b..c87424702 100644 --- a/compiler/circle-mpqsolver/src/bisection/DepthParameterizer.test.cpp +++ b/compiler/circle-mpqsolver/src/bisection/DepthParameterizer.test.cpp @@ -17,14 +17,14 @@ #include #include "DepthParameterizer.h" -#include +#include "core/TestHelper.h" #include namespace { -class NConvGraph final : public SimpleGraph +class NConvGraph final : public mpqsolver::test::models::SimpleGraph { protected: loco::Node *insertGraphBody(loco::Node *input) override @@ -43,7 +43,7 @@ protected: _conv->padding(luci::Padding::SAME); _conv->fusedActivationFunction(luci::FusedActFunc::NONE); _conv->dtype(loco::DataType::FLOAT32); - _conv->shape({1, _width, _height, _channel_size}); + _conv->shape({1, _height, _width, _channel_size}); _conv->name("conv"); _conv->filter(_filter); _conv->bias(_bias); diff --git a/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.cpp b/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.cpp index ee6376a48..c810497f6 100644 --- a/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.cpp +++ b/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.cpp @@ -44,7 +44,7 @@ void VISQErrorApproximator::init(std::istream &visq_data) auto layers = completeJsonData["error"][0]; auto names = layers.getMemberNames(); - for (auto name : names) + for (const auto &name : names) { auto value = layers[name].asFloat(); _layer_errors[name] = value; diff --git a/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.test.cpp b/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.test.cpp index ccacb1ab7..5542abe03 100644 --- a/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.test.cpp +++ b/compiler/circle-mpqsolver/src/bisection/VISQErrorApproximator.test.cpp @@ -16,32 +16,11 @@ #include "VISQErrorApproximator.h" +#include "core/TestHelper.h" + #include -#include #include -namespace -{ - -void writeDataToFile(const std::string &path, const std::string &data) -{ - std::ofstream file; - file.open(path); - file << data; - file.close(); -} - -void makeTemporaryFile(char *name_template) -{ - int fd = mkstemp(name_template); - if (fd == -1) - { - throw std::runtime_error{"mkstemp failed"}; - } -} - -} // namespace - TEST(CircleMPQSolverVISQErrorApproximatorTest, verifyResultsTest) { static std::string errors_key = "error"; @@ -57,8 +36,8 @@ TEST(CircleMPQSolverVISQErrorApproximatorTest, verifyResultsTest) auto data = Json::writeString(builder, error_data); char path[] = "VISQErrorApproximator-TEST-XXXXXX"; - makeTemporaryFile(path); - writeDataToFile(path, data); + mpqsolver::test::io_utils::makeTemporaryFile(path); + mpqsolver::test::io_utils::writeDataToFile(path, data); mpqsolver::bisection::VISQErrorApproximator approximator; EXPECT_NO_THROW(approximator.init(path)); @@ -74,8 +53,8 @@ TEST(CircleMPQSolverVISQErrorApproximatorTest, verifyResultsTest_NEG) auto data = Json::writeString(builder, error_data); char path[] = "VISQErrorApproximator-TEST-NEG-XXXXXX"; - makeTemporaryFile(path); - writeDataToFile(path, data); + mpqsolver::test::io_utils::makeTemporaryFile(path); + mpqsolver::test::io_utils::writeDataToFile(path, data); mpqsolver::bisection::VISQErrorApproximator approximator; EXPECT_THROW(approximator.init(path), std::exception); diff --git a/compiler/circle-mpqsolver/src/core/DataProvider.cpp b/compiler/circle-mpqsolver/src/core/DataProvider.cpp new file mode 100644 index 000000000..889e27896 --- /dev/null +++ b/compiler/circle-mpqsolver/src/core/DataProvider.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "DataProvider.h" + +#include +#include + +using namespace mpqsolver::core; + +using Shape = std::vector; + +namespace +{ + +// Check the type and the shape of input_node +// Throw an exception if type or shape does not match +void verifyTypeShape(const luci::CircleNode *input_node, const loco::DataType &dtype, + const Shape &shape) +{ + assert(input_node != nullptr); // FIX_CALLER_UNLESS + + // Type check + if (dtype != input_node->dtype()) + throw std::runtime_error("Wrong input type."); + + if (shape.size() != input_node->rank()) + throw std::runtime_error("Input rank mismatch."); + + for (uint32_t i = 0; i < shape.size(); i++) + { + if (not(shape.at(i) == input_node->dim(i))) + throw std::runtime_error("Input shape mismatch."); + } +} + +} // namespace + +H5FileDataProvider::H5FileDataProvider(const std::string &h5file, const std::string &module_path) + : _importer(h5file) +{ + _importer.importGroup("value"); + _is_raw_data = _importer.isRawData(); + + luci::ImporterEx importerex; + _module = importerex.importVerifyModule(module_path); + if (_module.get() != nullptr) + { + _input_nodes = loco::input_nodes(_module.get()->graph()); + } +} + +size_t H5FileDataProvider::numSamples() const { return _importer.numData(); } + +uint32_t H5FileDataProvider::numInputs(uint32_t sample) const +{ + return static_cast(_importer.numInputs(sample)); +} + +void H5FileDataProvider::getSampleInput(uint32_t sample, uint32_t input, InputData &data) const +{ + if (_is_raw_data) + { + _importer.readTensor(sample, input, data.data().data(), data.data().size()); + } + else + { + loco::DataType dtype; + Shape shape; + _importer.readTensor(sample, input, &dtype, &shape, data.data().data(), data.data().size()); + + // Check the type and the shape of the input data is valid + auto input_node = loco::must_cast(_input_nodes.at(input)); + verifyTypeShape(input_node, dtype, shape); + } +} diff --git a/compiler/circle-mpqsolver/src/core/DataProvider.h b/compiler/circle-mpqsolver/src/core/DataProvider.h new file mode 100644 index 000000000..d234ab2f1 --- /dev/null +++ b/compiler/circle-mpqsolver/src/core/DataProvider.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MPQSOLVER_DATA_PROVIDER_H__ +#define __MPQSOLVER_DATA_PROVIDER_H__ + +#include + +#include + +#include +#include + +namespace mpqsolver +{ +namespace core +{ + +struct InputData +{ + InputData(size_t size) : _data(size) {} + + const std::vector &data() const { return _data; } + + std::vector &data() { return _data; } + +private: + std::vector _data; +}; + +class DataProvider +{ +public: + virtual ~DataProvider() = default; + virtual size_t numSamples() const = 0; + virtual uint32_t numInputs(uint32_t sample) const = 0; + virtual void getSampleInput(uint32_t sample, uint32_t input, InputData &data) const = 0; +}; + +class H5FileDataProvider final : public DataProvider +{ +public: + H5FileDataProvider(const std::string &h5file, const std::string &module_path); + size_t numSamples() const override; + uint32_t numInputs(uint32_t sample) const override; + void getSampleInput(uint32_t sample, uint32_t input, InputData &data) const override; + +private: + std::vector _input_nodes; + std::unique_ptr _module; + dio::hdf5::HDF5Importer _importer; + bool _is_raw_data = false; +}; + +} // namespace core +} // namespace mpqsolver + +#endif //__MPQSOLVER_DATA_PROVIDER_H__ diff --git a/compiler/circle-mpqsolver/src/core/Dumper.cpp b/compiler/circle-mpqsolver/src/core/Dumper.cpp index 3a94cb3fa..29a0bd1ea 100644 --- a/compiler/circle-mpqsolver/src/core/Dumper.cpp +++ b/compiler/circle-mpqsolver/src/core/Dumper.cpp @@ -40,14 +40,14 @@ const std::string layer_granularity_key = "granularity"; Dumper::Dumper(const std::string &dir_path) : _dir_path(dir_path) {} -void Dumper::set_model_path(const std::string &model_path) { _model_path = model_path; } +void Dumper::setModelPath(const std::string &model_path) { _model_path = model_path; } -void Dumper::dump_MPQ_configuration(const LayerParams &layers, const std::string &def_dtype, - const std::string &path) const +void Dumper::dumpMPQConfiguration(const LayerParams &layers, const std::string &def_dtype, + const std::string &def_granularity, const std::string &path) const { Json::Value mpq_data; mpq_data[default_dtype_key] = def_dtype; - mpq_data[default_granularity_key] = "channel"; + mpq_data[default_granularity_key] = def_granularity; mpq_data[model_key] = _model_path; Json::Value layers_data; @@ -64,10 +64,10 @@ void Dumper::dump_MPQ_configuration(const LayerParams &layers, const std::string Json::StreamWriterBuilder builder; auto data = Json::writeString(builder, mpq_data); - write_data_to_file(path, data); + writeDataToFile(path, data); } -void Dumper::prepare_directory(const std::string &dir_path) const +void Dumper::prepareDirectory(const std::string &dir_path) const { struct stat sb; if (stat(dir_path.c_str(), &sb) != 0 || !S_ISDIR(sb.st_mode)) @@ -79,22 +79,23 @@ void Dumper::prepare_directory(const std::string &dir_path) const } } -void Dumper::dump_MPQ_configuration(const LayerParams &layers, const std::string &def_dtype, - int step) const +void Dumper::dumpMPQConfiguration(const LayerParams &layers, const std::string &def_dtype, + const std::string &def_granularity, int step) const { - prepare_directory(_dir_path); + prepareDirectory(_dir_path); std::string path = _dir_path + "/Configuration_" + std::to_string(step) + ".mpq.json"; - dump_MPQ_configuration(layers, def_dtype, path); + dumpMPQConfiguration(layers, def_dtype, def_granularity, path); } -void Dumper::dump_final_MPQ(const LayerParams &layers, const std::string &def_dtype) const +void Dumper::dumpFinalMPQ(const LayerParams &layers, const std::string &def_dtype, + const std::string &def_granularity) const { - prepare_directory(_dir_path); + prepareDirectory(_dir_path); std::string path = _dir_path + "/FinalConfiguration" + ".mpq.json"; - dump_MPQ_configuration(layers, def_dtype, path); + dumpMPQConfiguration(layers, def_dtype, def_granularity, path); } -void Dumper::write_data_to_file(const std::string &path, const std::string &data) const +void Dumper::writeDataToFile(const std::string &path, const std::string &data) const { std::ofstream file; file.open(path); @@ -102,7 +103,7 @@ void Dumper::write_data_to_file(const std::string &path, const std::string &data file.close(); } -void Dumper::save_circle(luci::Module *module, std::string &path) const +void Dumper::saveCircle(luci::Module *module, std::string &path) const { luci::CircleExporter exporter; luci::CircleFileExpContract contract(module, path); @@ -112,13 +113,13 @@ void Dumper::save_circle(luci::Module *module, std::string &path) const } } -void Dumper::dump_quantized(luci::Module *module, uint32_t step) const +void Dumper::dumpQuantized(luci::Module *module, uint32_t step) const { std::string path = _dir_path + "/quantized_" + std::to_string(step) + ".mpq.circle"; - save_circle(module, path); + saveCircle(module, path); } -void Dumper::dump_error(float error, const std::string &tag, const std::string &path) const +void Dumper::dumpError(float error, const std::string &tag, const std::string &path) const { std::ofstream file; file.open(path, std::ios_base::app); @@ -126,35 +127,35 @@ void Dumper::dump_error(float error, const std::string &tag, const std::string & file.close(); } -void Dumper::prepare_for_error_dumping() const +void Dumper::prepareForErrorDumping() const { - prepare_directory(_dir_path); - std::string path = get_error_path(); + prepareDirectory(_dir_path); + std::string path = getErrorPath(); std::ofstream file; file.open(path); // create empty file.close(); } -void Dumper::dump_Q8_error(float error) const +void Dumper::dumpQ8Error(float error) const { - std::string path = get_error_path(); - dump_error(error, "Q8", path); + std::string path = getErrorPath(); + dumpError(error, "Q8", path); } -void Dumper::dump_Q16_error(float error) const +void Dumper::dumpQ16Error(float error) const { - std::string path = get_error_path(); - dump_error(error, "Q16", path); + std::string path = getErrorPath(); + dumpError(error, "Q16", path); } -void Dumper::dump_MPQ_error(float error, uint32_t step) const +void Dumper::dumpMPQError(float error, uint32_t step) const { - std::string path = get_error_path(); - dump_error(error, std::to_string(step), path); + std::string path = getErrorPath(); + dumpError(error, std::to_string(step), path); } -void Dumper::dump_MPQ_error(float error) const +void Dumper::dumpMPQError(float error) const { - std::string path = get_error_path(); - dump_error(error, "FINAL", path); + std::string path = getErrorPath(); + dumpError(error, "FINAL", path); } diff --git a/compiler/circle-mpqsolver/src/core/Dumper.h b/compiler/circle-mpqsolver/src/core/Dumper.h index 220b54a20..48762f287 100644 --- a/compiler/circle-mpqsolver/src/core/Dumper.h +++ b/compiler/circle-mpqsolver/src/core/Dumper.h @@ -39,67 +39,70 @@ public: /** * @brief sets model path for further usage */ - void set_model_path(const std::string &model_path); + void setModelPath(const std::string &model_path); /** * @brief dumps mpq configuration * @param layers specific quantization parameters * @param def_dtype default quantization data type + * @param def_granularity default granularity * @param step id of mpq configuration */ - void dump_MPQ_configuration(const LayerParams &layers, const std::string &def_dtype, - int step) const; + void dumpMPQConfiguration(const LayerParams &layers, const std::string &def_dtype, + const std::string &def_granularity, int step) const; /** * @brief dumps final mpq configuration * @param layers specific quantization parameters * @param def_dtype default quantization data type + * @param def_granularity default granularity */ - void dump_final_MPQ(const LayerParams &layers, const std::string &def_dtype) const; + void dumpFinalMPQ(const LayerParams &layers, const std::string &def_dtype, + const std::string &def_granularity) const; /** * @brief dumps quantized module * @param layers specific quantization parameters * @param step id of quantized module */ - void dump_quantized(luci::Module *module, uint32_t step) const; + void dumpQuantized(luci::Module *module, uint32_t step) const; /** * @brief create file for error dumping */ - void prepare_for_error_dumping() const; + void prepareForErrorDumping() const; /** * @brief append error of Q8 quantization */ - void dump_Q8_error(float error) const; + void dumpQ8Error(float error) const; /** * @brief append error of Q16 quantization */ - void dump_Q16_error(float error) const; + void dumpQ16Error(float error) const; /** * @brief append error of mpq quantization * @param error error of quantization * @param step id of error */ - void dump_MPQ_error(float error, uint32_t step) const; + void dumpMPQError(float error, uint32_t step) const; /** * @brief dump final error * @param error final error of quantization */ - void dump_MPQ_error(float error) const; + void dumpMPQError(float error) const; private: - void write_data_to_file(const std::string &path, const std::string &data) const; - void dump_MPQ_configuration(const LayerParams &layers, const std::string &def_dtype, - const std::string &path) const; - void prepare_directory(const std::string &dir_path) const; - void save_circle(luci::Module *module, std::string &path) const; - void dump_error(float error, const std::string &tag, const std::string &path) const; - std::string get_error_path() const { return _dir_path + "/errors" + ".mpq.txt"; } + void writeDataToFile(const std::string &path, const std::string &data) const; + void dumpMPQConfiguration(const LayerParams &layers, const std::string &def_dtype, + const std::string &def_granularity, const std::string &path) const; + void prepareDirectory(const std::string &dir_path) const; + void saveCircle(luci::Module *module, std::string &path) const; + void dumpError(float error, const std::string &tag, const std::string &path) const; + std::string getErrorPath() const { return _dir_path + "/errors" + ".mpq.txt"; } private: std::string _dir_path; diff --git a/compiler/circle-mpqsolver/src/core/Dumper.test.cpp b/compiler/circle-mpqsolver/src/core/Dumper.test.cpp new file mode 100644 index 000000000..40c4d7fec --- /dev/null +++ b/compiler/circle-mpqsolver/src/core/Dumper.test.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "Dumper.h" +#include "core/TestHelper.h" + +#include +#include +#include + +namespace +{ + +class CircleMPQSolverDumperTest : public ::testing::Test +{ +public: + CircleMPQSolverDumperTest() + { + char folderTemplate[] = "CircleMPQSolverDumperTestXXXXXX"; + _folder = mpqsolver::test::io_utils::makeTemporaryFolder(folderTemplate); + } + + ~CircleMPQSolverDumperTest() + { + // cleanup + auto callback = [](const char *child, const struct stat *, int, struct FTW *) { + return remove(child); + }; + nftw(_folder.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS); + } + +protected: + std::string _folder; +}; + +} // namespace + +TEST_F(CircleMPQSolverDumperTest, verifyResultsTest) +{ + mpqsolver::core::Dumper dumper(_folder); + dumper.setModelPath(""); + mpqsolver::core::LayerParams params; + auto const step = 0; + dumper.dumpMPQConfiguration(params, "uint8", "channel", step); + + std::string step_path = _folder + "/Configuration_" + std::to_string(step) + ".mpq.json"; + EXPECT_TRUE(mpqsolver::test::io_utils::isFileExists(step_path)); + + dumper.dumpFinalMPQ(params, "uint8", "channel"); + std::string fin_path = _folder + "/FinalConfiguration" + ".mpq.json"; + EXPECT_TRUE(mpqsolver::test::io_utils::isFileExists(fin_path)); + + dumper.prepareForErrorDumping(); + std::string errors_path = _folder + "/errors" + ".mpq.txt"; + EXPECT_TRUE(mpqsolver::test::io_utils::isFileExists(errors_path)); +} + +TEST_F(CircleMPQSolverDumperTest, empty_path_NEG) +{ + mpqsolver::core::Dumper dumper(""); + dumper.setModelPath(""); + + mpqsolver::core::LayerParams params; + auto const step = 0; + EXPECT_THROW(dumper.dumpMPQConfiguration(params, "uint8", "channel", step), std::runtime_error); + EXPECT_THROW(dumper.dumpFinalMPQ(params, "uint8", "channel"), std::runtime_error); + EXPECT_THROW(dumper.prepareForErrorDumping(), std::runtime_error); +} diff --git a/compiler/circle-mpqsolver/src/core/DumpingHooks.cpp b/compiler/circle-mpqsolver/src/core/DumpingHooks.cpp index 4d0522bdd..d1517faf1 100644 --- a/compiler/circle-mpqsolver/src/core/DumpingHooks.cpp +++ b/compiler/circle-mpqsolver/src/core/DumpingHooks.cpp @@ -15,48 +15,61 @@ */ #include "DumpingHooks.h" +#include using namespace mpqsolver::core; -DumpingHooks::DumpingHooks(const std::string &save_path) - : _save_path(save_path), _dumper(_save_path) +DumpingHooks::DumpingHooks(const std::string &save_path, const Quantizer::Context &ctx) + : _save_path(save_path), _dumper(_save_path), _ctx(ctx) { } -void DumpingHooks::on_begin_solver(const std::string &model_path, float q8error, float q16error) +void DumpingHooks::onBeginSolver(const std::string &model_path, float q8error, float q16error) { _model_path = model_path; - _dumper.set_model_path(_model_path); - _dumper.prepare_for_error_dumping(); - _dumper.dump_Q8_error(q8error); - _dumper.dump_Q16_error(q16error); + _dumper.setModelPath(_model_path); + if (!std::isnan(q8error) || !std::isnan(q16error)) + { + _dumper.prepareForErrorDumping(); + } + if (!std::isnan(q8error)) + { + _dumper.dumpQ8Error(q8error); + } + if (!std::isnan(q16error)) + { + _dumper.dumpQ16Error(q16error); + } } -void DumpingHooks::on_begin_iteration() +void DumpingHooks::onBeginIteration() { _in_iterations = true; _num_of_iterations += 1; } -void DumpingHooks::on_end_iteration(const LayerParams &layers, const std::string &def_type, - float error) const +void DumpingHooks::onEndIteration(const LayerParams &layers, const std::string &def_type, + float error) { - _dumper.dump_MPQ_configuration(layers, def_type, _num_of_iterations); - _dumper.dump_MPQ_error(error, _num_of_iterations); + _dumper.dumpMPQConfiguration(layers, def_type, _ctx.granularity, _num_of_iterations); + _dumper.dumpMPQError(error, _num_of_iterations); + _in_iterations = false; } -void DumpingHooks::on_end_solver(const LayerParams &layers, const std::string &def_dtype, - float qerror) +void DumpingHooks::onEndSolver(const LayerParams &layers, const std::string &def_dtype, + float qerror) { - _dumper.dump_final_MPQ(layers, def_dtype); - _dumper.dump_MPQ_error(qerror); - _in_iterations = false; + _dumper.dumpFinalMPQ(layers, def_dtype, _ctx.granularity); + if (!std::isnan(qerror)) + { + _dumper.dumpMPQError(qerror); + } } -void DumpingHooks::on_quantized(luci::Module *module) const +void DumpingHooks::onQuantized(luci::Module *module) const { if (_in_iterations) { - _dumper.dump_quantized(module, _num_of_iterations); + _dumper.dumpQuantized(module, _num_of_iterations); } } diff --git a/compiler/circle-mpqsolver/src/core/DumpingHooks.h b/compiler/circle-mpqsolver/src/core/DumpingHooks.h index c432a9a40..0367af2f0 100644 --- a/compiler/circle-mpqsolver/src/core/DumpingHooks.h +++ b/compiler/circle-mpqsolver/src/core/DumpingHooks.h @@ -19,9 +19,9 @@ #include -#include -#include -#include +#include "core/Quantizer.h" +#include "core/SolverHooks.h" +#include "core/Dumper.h" #include @@ -40,42 +40,42 @@ public: * @brief DumpingHooks constructor * @param save_path directory where all intermediate data will be saved */ - DumpingHooks(const std::string &save_path); + DumpingHooks(const std::string &save_path, const Quantizer::Context &ctx); /** * @brief called on successfull quantization */ - virtual void on_quantized(luci::Module *module) const override; + virtual void onQuantized(luci::Module *module) const override; /** - * @brief called on the start of iterative search + * @brief called on the start of mpq search */ - virtual void on_begin_solver(const std::string &model_path, float q8error, - float q16error) override; + virtual void onBeginSolver(const std::string &model_path, float q8error, float q16error) override; /** * @brief called on the start of current iteration */ - virtual void on_begin_iteration() override; + virtual void onBeginIteration() override; /** * @brief called at the end of current iteration */ - virtual void on_end_iteration(const LayerParams &layers, const std::string &def_dtype, - float error) const override; + virtual void onEndIteration(const LayerParams &layers, const std::string &def_dtype, + float error) override; /** - * @brief called at the end of iterative search + * @brief called at the end of mpq search */ - virtual void on_end_solver(const LayerParams &layers, const std::string &def_dtype, - float qerror) override; + virtual void onEndSolver(const LayerParams &layers, const std::string &def_dtype, + float qerror) override; -protected: +private: std::string _model_path; std::string _save_path; Dumper _dumper; uint32_t _num_of_iterations = 0; bool _in_iterations = false; + Quantizer::Context _ctx; }; } // namespace core diff --git a/compiler/circle-mpqsolver/src/core/DumpingHooks.test.cpp b/compiler/circle-mpqsolver/src/core/DumpingHooks.test.cpp new file mode 100644 index 000000000..192862b8d --- /dev/null +++ b/compiler/circle-mpqsolver/src/core/DumpingHooks.test.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "DumpingHooks.h" +#include "core/TestHelper.h" + +#include +#include +#include + +namespace +{ + +class CircleMPQSolverDumpingHooksTest : public ::testing::Test +{ +public: + CircleMPQSolverDumpingHooksTest() + { + char folderTemplate[] = "CircleMPQSolverDumpingHooksTestXXXXXX"; + _folder = mpqsolver::test::io_utils::makeTemporaryFolder(folderTemplate); + } + + ~CircleMPQSolverDumpingHooksTest() + { + // cleanup + auto callback = [](const char *child, const struct stat *, int, struct FTW *) { + return remove(child); + }; + nftw(_folder.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS); + } + +protected: + std::string _folder; +}; + +} // namespace + +TEST_F(CircleMPQSolverDumpingHooksTest, verifyResultsTest) +{ + mpqsolver::core::Quantizer::Context ctx; + mpqsolver::core::DumpingHooks hooks(_folder, ctx); + EXPECT_NO_THROW(hooks.onBeginSolver("model_path.circle", 0.0, 1.0)); + std::string errors_path = _folder + "/errors" + ".mpq.txt"; + EXPECT_TRUE(mpqsolver::test::io_utils::isFileExists(errors_path)); + + hooks.onBeginIteration(); + + EXPECT_NO_THROW( + hooks.onEndIteration(mpqsolver::core::LayerParams(), ctx.output_model_dtype, 0.0)); + std::string current_mpq_path = _folder + "/Configuration_" + std::to_string(1) + ".mpq.json"; + EXPECT_TRUE(mpqsolver::test::io_utils::isFileExists(current_mpq_path)); + + EXPECT_NO_THROW(hooks.onEndSolver(mpqsolver::core::LayerParams(), ctx.output_model_dtype, 0.5)); + std::string final_mpq_path = _folder + "/FinalConfiguration" + ".mpq.json"; + EXPECT_TRUE(mpqsolver::test::io_utils::isFileExists(final_mpq_path)); +} + +TEST_F(CircleMPQSolverDumpingHooksTest, verify_NAN_results_test) +{ + mpqsolver::core::Quantizer::Context ctx; + mpqsolver::core::DumpingHooks hooks(_folder, ctx); + EXPECT_NO_THROW(hooks.onBeginSolver("model_path.circle", NAN, NAN)); + std::string errors_path = _folder + "/errors" + ".mpq.txt"; + EXPECT_TRUE(not mpqsolver::test::io_utils::isFileExists(errors_path)); + + EXPECT_NO_THROW(hooks.onEndSolver(mpqsolver::core::LayerParams(), ctx.output_model_dtype, NAN)); + std::string final_mpq_path = _folder + "/FinalConfiguration" + ".mpq.json"; + EXPECT_TRUE(mpqsolver::test::io_utils::isFileExists(final_mpq_path)); +} + +TEST_F(CircleMPQSolverDumpingHooksTest, empty_path_NEG) +{ + mpqsolver::core::Quantizer::Context ctx; + mpqsolver::core::DumpingHooks hooks("", ctx); + EXPECT_ANY_THROW(hooks.onBeginSolver("", -1, -1)); + hooks.onBeginIteration(); + EXPECT_ANY_THROW(hooks.onQuantized(nullptr)); + EXPECT_ANY_THROW( + hooks.onEndIteration(mpqsolver::core::LayerParams(), ctx.output_model_dtype, -1)); + EXPECT_ANY_THROW(hooks.onEndSolver(mpqsolver::core::LayerParams(), ctx.output_model_dtype, -1)); +} + +TEST_F(CircleMPQSolverDumpingHooksTest, empty_NAN_path_NEG) +{ + mpqsolver::core::Quantizer::Context ctx; + mpqsolver::core::DumpingHooks hooks("", ctx); + EXPECT_NO_THROW(hooks.onBeginSolver("", NAN, NAN)); + EXPECT_ANY_THROW(hooks.onEndSolver(mpqsolver::core::LayerParams(), ctx.output_model_dtype, NAN)); +} diff --git a/compiler/circle-mpqsolver/src/core/ErrorMetric.cpp b/compiler/circle-mpqsolver/src/core/ErrorMetric.cpp index 23ddfcb7d..23e17dcb2 100644 --- a/compiler/circle-mpqsolver/src/core/ErrorMetric.cpp +++ b/compiler/circle-mpqsolver/src/core/ErrorMetric.cpp @@ -16,6 +16,7 @@ #include "ErrorMetric.h" +#include #include #include @@ -29,10 +30,13 @@ using namespace mpqsolver::core; */ float MAEMetric::compute(const WholeOutput &first, const WholeOutput &second) const { - assert(first.size() == second.size()); + if (first.size() != second.size()) + { + throw std::runtime_error("Can not compare vectors of different sizes"); + } - float error = 0.f; - size_t output_size = 0; + double output_errors = 0.; // mean over mean outputs errors + size_t num_output_errors = 0; for (size_t sample_index = 0; sample_index < first.size(); ++sample_index) { @@ -42,24 +46,30 @@ float MAEMetric::compute(const WholeOutput &first, const WholeOutput &second) co const Buffer &first_elementary = first[sample_index][out_index]; const Buffer &second_elementary = second[sample_index][out_index]; assert(first_elementary.size() == second_elementary.size()); - size_t cur_size = first_elementary.size() / loco::size(loco::DataType::FLOAT32); + size_t cur_size = first_elementary.size() / luci::size(loco::DataType::FLOAT32); + + double output_error = 0.; // mean error oevr current output const float *first_floats = reinterpret_cast(first_elementary.data()); const float *second_floats = reinterpret_cast(second_elementary.data()); for (size_t index = 0; index < cur_size; index++) { - float ref_value = *(first_floats + index); - float cur_value = *(second_floats + index); - error += std::fabs(ref_value - cur_value); + double ref_value = static_cast(*(first_floats + index)); + double cur_value = static_cast(*(second_floats + index)); + output_error += std::fabs(ref_value - cur_value); + } + if (cur_size != 0) + { + output_errors += (output_error / cur_size); + num_output_errors += 1; } - output_size += cur_size; } } - if (output_size == 0) + if (num_output_errors == 0) { - throw std::runtime_error("nothing to compare"); + throw std::runtime_error("Nothing to compare"); } - return error / output_size; + return static_cast(output_errors / num_output_errors); } diff --git a/compiler/circle-mpqsolver/src/core/ErrorMetric.test.cpp b/compiler/circle-mpqsolver/src/core/ErrorMetric.test.cpp index 232d9bc60..a271138dc 100644 --- a/compiler/circle-mpqsolver/src/core/ErrorMetric.test.cpp +++ b/compiler/circle-mpqsolver/src/core/ErrorMetric.test.cpp @@ -48,6 +48,84 @@ TEST(CircleMPQSolverMAEMetricTest, verifyResultsTest) EXPECT_FLOAT_EQ(value, 1.f); } +TEST(CircleMPQSolverMAEMetricTest, verifyComplexResultsTest) +{ + // test for model with a pair of outputs + + size_t num_elements = 512; + mpqsolver::core::WholeOutput target, source; + // construct target + { + // let first target output be 0.0 + std::vector float_buffer(num_elements, 0.f); + auto const char_buffer = reinterpret_cast(float_buffer.data()); + auto const char_buffer_size = num_elements * sizeof(float) / sizeof(char); + std::vector buffer(char_buffer, char_buffer + char_buffer_size); + + // first target output + mpqsolver::core::Output out = mpqsolver::core::Output(1, buffer); + { + // let second target output be 0.25 + std::vector float_buffer(num_elements / 2, 0.25f); + auto const char_buffer = reinterpret_cast(float_buffer.data()); + auto const char_buffer_size = num_elements / 2 * sizeof(float) / sizeof(char); + std::vector buffer(char_buffer, char_buffer + char_buffer_size); + // second target output + out.emplace_back(buffer); + } + + // target consists of single sample (each sample consists of two outputs) + target = mpqsolver::core::WholeOutput(1, out); + } + + // construct target + { + // let first source output be 1.f + std::vector float_buffer(num_elements, 1.f); + auto const char_buffer = reinterpret_cast(float_buffer.data()); + auto const char_buffer_size = num_elements * sizeof(float) / sizeof(char); + std::vector buffer(char_buffer, char_buffer + char_buffer_size); + // first source output + mpqsolver::core::Output out = mpqsolver::core::Output(1, buffer); + { + // let second source output be 1.f + std::vector float_buffer(num_elements / 2, 1.0f); + auto const char_buffer = reinterpret_cast(float_buffer.data()); + auto const char_buffer_size = num_elements / 2 * sizeof(float) / sizeof(char); + std::vector buffer(char_buffer, char_buffer + char_buffer_size); + // second source output + out.emplace_back(buffer); + } + + // source consists of single sample (each sample consists of two outputs) + source = mpqsolver::core::WholeOutput(1, out); + } + + mpqsolver::core::MAEMetric metric; + float value = metric.compute(target, source); + EXPECT_FLOAT_EQ(value, 0.875f); // (|0 - 1| + |0.25 - 1|) / 2 = 1.75 / 2 = 0.875 +} + +TEST(CircleMPQSolverMAEMetricTest, different_samples_NEG) +{ + mpqsolver::core::MAEMetric metric; + + size_t num_elements = 512; + std::vector float_buffer(num_elements, 0.f); + auto const char_buffer = reinterpret_cast(float_buffer.data()); + auto const char_buffer_size = num_elements * sizeof(float) / sizeof(char); + std::vector buffer(char_buffer, char_buffer + char_buffer_size); + mpqsolver::core::Output out = mpqsolver::core::Output(1, buffer); + + // let target be zero of two samples + mpqsolver::core::WholeOutput target = mpqsolver::core::WholeOutput(2, out); + + // let source be zero of single sample + mpqsolver::core::WholeOutput source; + + EXPECT_ANY_THROW(metric.compute(target, source)); +} + TEST(CircleMPQSolverMAEMetricTest, verifyResultsTest_NEG) { mpqsolver::core::MAEMetric metric; diff --git a/compiler/circle-mpqsolver/src/core/Evaluator.cpp b/compiler/circle-mpqsolver/src/core/Evaluator.cpp index c7afda5c2..98decb5e5 100644 --- a/compiler/circle-mpqsolver/src/core/Evaluator.cpp +++ b/compiler/circle-mpqsolver/src/core/Evaluator.cpp @@ -16,6 +16,10 @@ #include "Evaluator.h" +#include "core/DataProvider.h" + +#include + #include #include @@ -31,20 +35,20 @@ using namespace luci; template size_t get_tensor_size(const NodeT *node) { - uint32_t tensor_size = loco::size(node->dtype()); + uint32_t tensor_size = luci::size(node->dtype()); for (uint32_t i = 0; i < node->rank(); ++i) tensor_size *= node->dim(i).value(); return tensor_size; } -WholeOutput compute_outputs(const luci::Module *module, const std::string &h5file) +WholeOutput compute_outputs(const luci::Module *module, const DataProvider *data_provider) { - dio::hdf5::HDF5Importer importer{h5file}; - importer.importGroup("value"); - - bool is_raw_data = importer.isRawData(); + if (data_provider == nullptr) + { + throw std::runtime_error("No data"); + } - const auto num_records = importer.numData(); + const auto num_records = data_provider->numSamples(); if (num_records == 0) throw std::runtime_error("The input data file does not contain any record."); const auto input_nodes = loco::input_nodes(module->graph()); @@ -54,31 +58,19 @@ WholeOutput compute_outputs(const luci::Module *module, const std::string &h5fil // Create interpreter. luci_interpreter::Interpreter interpreter(module); - for (int32_t record_idx = 0; record_idx < num_records; record_idx++) + for (uint32_t record_idx = 0; record_idx < num_records; record_idx++) { - if (num_inputs != static_cast(importer.numInputs(record_idx))) + if (num_inputs != data_provider->numInputs(record_idx)) throw std::runtime_error("Wrong number of inputs."); for (uint32_t input_idx = 0; input_idx < num_inputs; input_idx++) { const auto *input_node = loco::must_cast(input_nodes[input_idx]); assert(input_node->index() == input_idx); - std::vector input_data(get_tensor_size(input_node)); - - if (!is_raw_data) - { - loco::DataType dtype; - Shape shape; - importer.readTensor(record_idx, input_idx, &dtype, &shape, input_data.data(), - input_data.size()); - } - else - { - // Skip type/shape check for raw data - importer.readTensor(record_idx, input_idx, input_data.data(), input_data.size()); - } - - interpreter.writeInputTensor(input_node, input_data.data(), input_data.size()); + InputData input_data(get_tensor_size(input_node)); + data_provider->getSampleInput(record_idx, input_idx, input_data); + + interpreter.writeInputTensor(input_node, input_data.data().data(), input_data.data().size()); } interpreter.interpret(); @@ -103,11 +95,11 @@ WholeOutput compute_outputs(const luci::Module *module, const std::string &h5fil } // namespace -DatasetEvaluator::DatasetEvaluator(const luci::Module *ref_module, const std::string &h5file, +DatasetEvaluator::DatasetEvaluator(const luci::Module *ref_module, const DataProvider &provider, const ErrorMetric &metric) - : _ref_module(ref_module), _h5file(h5file), _metric(&metric) + : _ref_module(ref_module), _provider(&provider), _metric(&metric) { - _ref_output = compute_outputs(_ref_module, _h5file); + _ref_output = compute_outputs(_ref_module, _provider); } void DatasetEvaluator::validate(const luci::Module *trgt_fq_module) const @@ -132,7 +124,7 @@ float DatasetEvaluator::evaluate(const luci::Module *trgt_fq_module) const validate(trgt_fq_module); - const WholeOutput &cur_output = compute_outputs(trgt_fq_module, _h5file); + const WholeOutput &cur_output = compute_outputs(trgt_fq_module, _provider); float error = _metric->compute(_ref_output, cur_output); return error; } diff --git a/compiler/circle-mpqsolver/src/core/Evaluator.h b/compiler/circle-mpqsolver/src/core/Evaluator.h index 9820508bc..0f50bb55b 100644 --- a/compiler/circle-mpqsolver/src/core/Evaluator.h +++ b/compiler/circle-mpqsolver/src/core/Evaluator.h @@ -19,6 +19,8 @@ #include "ErrorMetric.h" +#include "DataProvider.h" + #include #include @@ -34,9 +36,9 @@ class DatasetEvaluator final { public: /** - * @brief create Evaluator for comparing output of ref_module on h5file + * @brief create Evaluator for comparing output of ref_module on provider */ - DatasetEvaluator(const luci::Module *ref_module, const std::string &h5file, + DatasetEvaluator(const luci::Module *ref_module, const DataProvider &provider, const ErrorMetric &metric); DatasetEvaluator() = delete; ~DatasetEvaluator() = default; @@ -55,7 +57,7 @@ private: private: const luci::Module *_ref_module = nullptr; - std::string _h5file; + const DataProvider *_provider = nullptr; WholeOutput _ref_output; const ErrorMetric *_metric = nullptr; }; diff --git a/compiler/circle-mpqsolver/src/core/Evaluator.test.cpp b/compiler/circle-mpqsolver/src/core/Evaluator.test.cpp new file mode 100644 index 000000000..757257a3e --- /dev/null +++ b/compiler/circle-mpqsolver/src/core/Evaluator.test.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "Evaluator.h" + +#include "DataProvider.h" +#include "TestHelper.h" + +TEST(CircleMPQSolverEvaluatorTest, verifyResultsTest) +{ + // create nn module + auto m = luci::make_module(); + mpqsolver::test::models::AddGraph g; + g.init(); + g.transfer_to(m.get()); + + mpqsolver::core::MAEMetric metric; + auto data = mpqsolver::test::data_utils::getAllZeroSingleDataProvider(); + mpqsolver::core::DatasetEvaluator evaluator(m.get(), *data.get(), metric); + float value = evaluator.evaluate(m.get()); + EXPECT_FLOAT_EQ(value, 0.f); +} + +TEST(CircleMPQSolverEvaluatorTest, empty_path_NEG) +{ + mpqsolver::core::MAEMetric metric; + EXPECT_ANY_THROW(mpqsolver::core::H5FileDataProvider data("", ""); + mpqsolver::core::DatasetEvaluator evaluator(nullptr, data, metric)); +} diff --git a/compiler/circle-mpqsolver/src/core/Quantizer.cpp b/compiler/circle-mpqsolver/src/core/Quantizer.cpp index 421793197..3eea87d37 100644 --- a/compiler/circle-mpqsolver/src/core/Quantizer.cpp +++ b/compiler/circle-mpqsolver/src/core/Quantizer.cpp @@ -49,12 +49,7 @@ bool make_model_fake_quantized(luci::Module *module) } // namespace -Quantizer::Quantizer(const std::string &input_dtype, const std::string &output_dtype) - : _input_dtype(input_dtype), _output_dtype(output_dtype) -{ -} - -void Quantizer::set_hook(const QuantizerHook *hook) { _hook = hook; } +void Quantizer::setHook(const QuantizerHook *hook) { _hook = hook; } /** * @brief quantize recorded module (min/max initialized) with specified parameters @@ -67,7 +62,6 @@ bool Quantizer::quantize(luci::Module *module, const std::string &quant_dtype, return false; static const std::string default_dtype = "float32"; - static const std::string granularity_type = "channel"; luci::CircleQuantizer quantizer; @@ -76,10 +70,18 @@ bool Quantizer::quantize(luci::Module *module, const std::string &quant_dtype, options->param(AlgorithmParameters::Quantize_input_model_dtype, default_dtype); options->param(AlgorithmParameters::Quantize_output_model_dtype, quant_dtype); - options->param(AlgorithmParameters::Quantize_granularity, granularity_type); - options->param(AlgorithmParameters::Quantize_input_type, _input_dtype); - options->param(AlgorithmParameters::Quantize_output_type, _output_dtype); - options->param(AlgorithmParameters::Quantize_TF_style_maxpool, "False"); + // Only channel-wise quantization is supported for int16 + // TODO Fix this if this assumption breaks + if (quant_dtype == "int16") + options->param(AlgorithmParameters::Quantize_granularity, "channel"); + else + options->param(AlgorithmParameters::Quantize_granularity, _ctx.granularity); + + options->param(AlgorithmParameters::Quantize_input_type, _ctx.input_type); + options->param(AlgorithmParameters::Quantize_output_type, _ctx.output_type); + options->param(AlgorithmParameters::Quantize_TF_style_maxpool, + _ctx.TF_style_maxpool ? "True" : "False"); + options->param(AlgorithmParameters::Quantize_save_min_max, _ctx.save_min_max ? "True" : "False"); if (!layer_params.empty()) { @@ -108,18 +110,27 @@ bool Quantizer::quantize(luci::Module *module, const std::string &quant_dtype, if (_hook) { - _hook->on_quantized(module); + _hook->onQuantized(module); } return true; } +/** + * @brief quantize recorded module (min/max initialized) with specified parameters + * returns true on success + */ +bool Quantizer::quantize(luci::Module *module, LayerParams &layer_params) +{ + return quantize(module, _ctx.output_model_dtype, layer_params); +} + /** * @brief fake_quantize recorded module (min/max initialized) with specified parameters * returns true on success */ -bool Quantizer::fake_quantize(luci::Module *module, const std::string &quant_dtype, - LayerParams &layer_params) +bool Quantizer::fakeQuantize(luci::Module *module, const std::string &quant_dtype, + LayerParams &layer_params) { if (!quantize(module, quant_dtype, layer_params)) return false; diff --git a/compiler/circle-mpqsolver/src/core/Quantizer.h b/compiler/circle-mpqsolver/src/core/Quantizer.h index 259d5c4b0..2001ce096 100644 --- a/compiler/circle-mpqsolver/src/core/Quantizer.h +++ b/compiler/circle-mpqsolver/src/core/Quantizer.h @@ -38,18 +38,30 @@ struct QuantizerHook * @brief called on successfull quantization * @param module quantized module */ - virtual void on_quantized(luci::Module *module) const = 0; + virtual void onQuantized(luci::Module *module) const = 0; }; class Quantizer { public: - Quantizer(const std::string &input_dtype, const std::string &output_type); + struct Context + { + std::string output_model_dtype = "uint8"; + std::string granularity = "channel"; + std::string input_type = "uint8"; + std::string output_type = "uint8"; + bool TF_style_maxpool = false; + bool save_min_max = false; + // TODO Support layer info + }; + +public: + Quantizer(const Context &ctx) : _ctx(ctx) {} /** * @brief set hook on the end of quantization event */ - void set_hook(const QuantizerHook *callback); + void setHook(const QuantizerHook *callback); /** * @brief quantize recorded module (min/max initialized) with specified parameters @@ -57,16 +69,23 @@ public: */ bool quantize(luci::Module *module, const std::string &quant_dtype, LayerParams &layer_params); + /** + * @brief quantize recorded module (min/max initialized) with specified parameters + * returns true on success + */ + bool quantize(luci::Module *module, LayerParams &layer_params); + /** * @brief fake_quantize recorded module (min/max initialized) with specified parameters * returns true on success */ - bool fake_quantize(luci::Module *module, const std::string &quant_dtype, - LayerParams &layer_params); + bool fakeQuantize(luci::Module *module, const std::string &quant_dtype, + LayerParams &layer_params); + + const Context &getContext() const { return _ctx; } private: - std::string _input_dtype = "uint8"; - std::string _output_dtype = "uint8"; + Context _ctx; const QuantizerHook *_hook = nullptr; }; diff --git a/compiler/circle-mpqsolver/src/core/Quantizer.test.cpp b/compiler/circle-mpqsolver/src/core/Quantizer.test.cpp index 7d7e74fdc..93c7ac00e 100644 --- a/compiler/circle-mpqsolver/src/core/Quantizer.test.cpp +++ b/compiler/circle-mpqsolver/src/core/Quantizer.test.cpp @@ -20,60 +20,7 @@ #include -#include - -namespace -{ - -class AddGraph final : public SimpleGraph -{ -protected: - void initInput(loco::Node *input) override - { - auto ci_input = loco::must_cast(input); - initMinMax(ci_input); - } - - void initMinMax(luci::CircleNode *node) - { - auto qparam = std::make_unique(); - qparam->min.assign(1, _a_min); - qparam->max.assign(1, _a_max); - node->quantparam(std::move(qparam)); - } - - loco::Node *insertGraphBody(loco::Node *input) override - { - _add = _g->nodes()->create(); - _beta = _g->nodes()->create(); - - _add->dtype(loco::DataType::FLOAT32); - _beta->dtype(loco::DataType::FLOAT32); - - uint32_t channel_size = 16; - _add->shape({1, _channel_size, _width, _height}); - _beta->shape({1, _channel_size, _width, _height}); - - _beta->size(channel_size); - _add->x(input); - _add->y(_beta); - _add->fusedActivationFunction(luci::FusedActFunc::NONE); - - _add->name("add"); - _beta->name("beta"); - initMinMax(_add); - - return _add; - } - -public: - float _a_min = -1.f; - float _a_max = 1.f; - luci::CircleAdd *_add = nullptr; - luci::CircleConst *_beta = nullptr; -}; - -} // namespace +using namespace mpqsolver::test::models; TEST(CircleMPQSolverQuantizerTest, verifyResultsTest) { @@ -84,10 +31,10 @@ TEST(CircleMPQSolverQuantizerTest, verifyResultsTest) float range = g._a_max - g._a_min; g.transfer_to(m.get()); - std::string def_quant = "uint8"; - mpqsolver::core::Quantizer quantizer(def_quant, def_quant); + mpqsolver::core::Quantizer::Context context; + mpqsolver::core::Quantizer quantizer(context); mpqsolver::core::LayerParams params; - auto res = quantizer.quantize(m.get(), def_quant, params); + auto res = quantizer.quantize(m.get(), context.output_model_dtype, params); EXPECT_TRUE(res); auto quant_param = add->quantparam(); EXPECT_TRUE(quant_param != nullptr); @@ -99,9 +46,9 @@ TEST(CircleMPQSolverQuantizerTest, verifyResultsTest) TEST(CircleMPQSolverQuantizerTest, verifyResultsTest_NEG) { - std::string def_quant = "uint8"; - mpqsolver::core::Quantizer quantizer(def_quant, def_quant); + mpqsolver::core::Quantizer::Context context; + mpqsolver::core::Quantizer quantizer(context); mpqsolver::core::LayerParams params; - auto res = quantizer.quantize(nullptr, def_quant, params); + auto res = quantizer.quantize(nullptr, context.output_model_dtype, params); EXPECT_TRUE(!res); } diff --git a/compiler/circle-mpqsolver/src/core/SolverHooks.h b/compiler/circle-mpqsolver/src/core/SolverHooks.h index 851a69993..06f8d766f 100644 --- a/compiler/circle-mpqsolver/src/core/SolverHooks.h +++ b/compiler/circle-mpqsolver/src/core/SolverHooks.h @@ -19,7 +19,7 @@ #include -#include +#include "core/Quantizer.h" #include @@ -32,17 +32,17 @@ class SolverHooks { public: /** - * @brief called on the start of iterative search + * @brief called on the start of mpq search * @param model_path path of original float model to quantize - * @param q8error error of Q8 quantization - * @param q16error error of Q16 quantization + * @param q8error error of Q8 quantization (if NAN, then not applicable) + * @param q16error error of Q16 quantization (if NAN, then not applicable) */ - virtual void on_begin_solver(const std::string &model_path, float q8error, float q16error) = 0; + virtual void onBeginSolver(const std::string &model_path, float q8error, float q16error) = 0; /** * @brief called on the start of current iteration */ - virtual void on_begin_iteration() = 0; + virtual void onBeginIteration() = 0; /** * @brief called at the end of current iteration @@ -50,17 +50,17 @@ public: * @param def_dtype default quantization dtype * @param error error of quantization for current iteration */ - virtual void on_end_iteration(const LayerParams &layers, const std::string &def_dtype, - float error) const = 0; + virtual void onEndIteration(const LayerParams &layers, const std::string &def_dtype, + float error) = 0; /** - * @brief called at the end of iterative search + * @brief called at the end of mpq search * @param layers model nodes with specific quantization parameters * @param def_dtype default quantization dtype - * @param qerror final error of quantization + * @param qerror final error of quantization (if NAN, then not applicable) */ - virtual void on_end_solver(const LayerParams &layers, const std::string &def_dtype, - float qerror) = 0; + virtual void onEndSolver(const LayerParams &layers, const std::string &def_dtype, + float qerror) = 0; }; } // namespace core diff --git a/compiler/circle-mpqsolver/src/core/TestHelper.h b/compiler/circle-mpqsolver/src/core/TestHelper.h index f930738f9..5165e405b 100644 --- a/compiler/circle-mpqsolver/src/core/TestHelper.h +++ b/compiler/circle-mpqsolver/src/core/TestHelper.h @@ -17,55 +17,36 @@ #ifndef __MPQSOLVER_TEST_HELPER_H__ #define __MPQSOLVER_TEST_HELPER_H__ +#include "DataProvider.h" + #include #include +#include + +namespace mpqsolver +{ +namespace test +{ +namespace models +{ +/** + * @brief base class of simple graphs used for testing + */ class SimpleGraph { public: SimpleGraph() : _g(loco::make_graph()) {} public: - void init() - { - _input = _g->nodes()->create(); - _output = _g->nodes()->create(); - _input->name("input"); - _output->name("output"); - - auto graph_input = _g->inputs()->create(); - _input->index(graph_input->index()); - auto graph_output = _g->outputs()->create(); - _output->index(graph_output->index()); - - graph_input->dtype(loco::DataType::FLOAT32); - _input->dtype(loco::DataType::FLOAT32); - _output->dtype(loco::DataType::FLOAT32); - graph_output->dtype(loco::DataType::FLOAT32); - - graph_input->shape({1, _channel_size, _width, _height}); - _input->shape({1, _channel_size, _width, _height}); - _output->shape({1, _channel_size, _width, _height}); - graph_output->shape({1, _channel_size, _width, _height}); - - auto graph_body = insertGraphBody(_input); - _output->from(graph_body); - - initInput(_input); - } + void init(); virtual ~SimpleGraph() = default; - void transfer_to(luci::Module *module) - { - // WARNING: after g is transfered, _graph_inputs, _inputs - // and _graph_outputs, _outputs in TestOsGraphlet will be invalid. - // arrays are not cleared as this is just helpers to unit tests - module->add(std::move(_g)); - } + void transfer_to(luci::Module *module); protected: virtual loco::Node *insertGraphBody(loco::Node *input) = 0; - virtual void initInput(loco::Node *input){}; + virtual void initInput(loco::Node *){}; public: std::unique_ptr _g; @@ -76,4 +57,90 @@ public: uint32_t _height = 4; }; +/** + * @brief simple model with just an Add of input and constant + */ +class AddGraph final : public SimpleGraph +{ +private: + void initInput(loco::Node *input) override; + void initMinMax(luci::CircleNode *node); + + loco::Node *insertGraphBody(loco::Node *input) override; + +public: + float _a_min = -1.f; + float _a_max = 1.f; + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_beta = nullptr; +}; + +class SoftmaxGraphlet +{ +public: + SoftmaxGraphlet() = default; + virtual ~SoftmaxGraphlet() = default; + + void init(loco::Graph *g); + +protected: + void initMinMax(luci::CircleNode *node, float min, float max); + +public: + luci::CircleAbs *_ifm = nullptr; + luci::CircleReduceMax *_max = nullptr; + luci::CircleSub *_sub = nullptr; + luci::CircleExp *_exp = nullptr; + luci::CircleSum *_sum = nullptr; + luci::CircleDiv *_div = nullptr; + +protected: + luci::CircleConst *_softmax_indices = nullptr; +}; + +class SoftmaxTestGraph : public luci::test::TestIOGraph, public SoftmaxGraphlet +{ +public: + SoftmaxTestGraph() = default; + + void init(void); +}; + +} // namespace models + +namespace io_utils +{ + +/** + * @brief create valid name of temporary file + */ +void makeTemporaryFile(char *name_template); + +/** + * @brief write data to file_path + */ +void writeDataToFile(const std::string &file_path, const std::string &data); + +/** + * @brief create valid name of temporary folder + */ +std::string makeTemporaryFolder(char *name_template); + +/** + * @brief checks whether file exists + */ +bool isFileExists(const std::string &file_path); + +} // namespace io_utils + +namespace data_utils +{ + +std::unique_ptr getAllZeroSingleDataProvider(); + +} // namespace data_utils + +} // namespace test +} // namespace mpqsolver + #endif //__MPQSOLVER_TEST_HELPER_H__ diff --git a/compiler/circle-mpqsolver/src/core/TestHelper.test.cpp b/compiler/circle-mpqsolver/src/core/TestHelper.test.cpp new file mode 100644 index 000000000..e5e7c8e49 --- /dev/null +++ b/compiler/circle-mpqsolver/src/core/TestHelper.test.cpp @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TestHelper.h" + +#include +#include + +namespace mpqsolver +{ +namespace test +{ +namespace models +{ + +void SimpleGraph::init() +{ + _input = _g->nodes()->create(); + _output = _g->nodes()->create(); + _input->name("input"); + _output->name("output"); + + auto graph_input = _g->inputs()->create(); + _input->index(graph_input->index()); + auto graph_output = _g->outputs()->create(); + _output->index(graph_output->index()); + + graph_input->dtype(loco::DataType::FLOAT32); + _input->dtype(loco::DataType::FLOAT32); + _output->dtype(loco::DataType::FLOAT32); + graph_output->dtype(loco::DataType::FLOAT32); + + graph_input->shape({1, _height, _width, _channel_size}); + _input->shape({1, _height, _width, _channel_size}); + _output->shape({1, _height, _width, _channel_size}); + graph_output->shape({1, _height, _width, _channel_size}); + + auto graph_body = insertGraphBody(_input); + _output->from(graph_body); + + initInput(_input); +} + +void SimpleGraph::transfer_to(luci::Module *module) +{ + // WARNING: after g is transfered, _graph_inputs, _inputs + // and _graph_outputs, _outputs in TestOsGraphlet will be invalid. + // arrays are not cleared as this is just helpers to unit tests + module->add(std::move(_g)); +} + +void AddGraph::initInput(loco::Node *input) +{ + auto ci_input = loco::must_cast(input); + initMinMax(ci_input); +} + +void AddGraph::initMinMax(luci::CircleNode *node) +{ + auto qparam = std::make_unique(); + qparam->min.assign(1, _a_min); + qparam->max.assign(1, _a_max); + node->quantparam(std::move(qparam)); +} + +loco::Node *AddGraph::insertGraphBody(loco::Node *input) +{ + _add = _g->nodes()->create(); + _beta = _g->nodes()->create(); + + _add->dtype(loco::DataType::FLOAT32); + _beta->dtype(loco::DataType::FLOAT32); + + _add->shape({1, _height, _width, _channel_size}); + _beta->shape({1, _height, _width, _channel_size}); + + _beta->size(_channel_size * _width * _height); + _add->x(input); + _add->y(_beta); + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + + _add->name("add"); + _beta->name("beta"); + initMinMax(_add); + + return _add; +} + +void SoftmaxGraphlet::initMinMax(luci::CircleNode *node, float min, float max) +{ + auto qparam = std::make_unique(); + qparam->min.assign(1, min); + qparam->max.assign(1, max); + node->quantparam(std::move(qparam)); +} + +void SoftmaxGraphlet::init(loco::Graph *g) +{ + _ifm = nullptr; + + _ifm = g->nodes()->create(); + _max = g->nodes()->create(); + _sub = g->nodes()->create(); + _exp = g->nodes()->create(); + _sum = g->nodes()->create(); + _div = g->nodes()->create(); + _softmax_indices = g->nodes()->create(); + + _ifm->name("ifm"); + _max->name("maximum_of_ifm"); + _sub->name("sub"); + _exp->name("exp"); + _sum->name("sum"); + _div->name("div"); + _softmax_indices->name("reduction_indices"); + + initMinMax(_ifm, 0, 1); + initMinMax(_max, 0, 1); + initMinMax(_sub, 0, 1); + initMinMax(_exp, 0, 1); + initMinMax(_sum, 0, 1); + initMinMax(_div, 0, 1); + + _softmax_indices->dtype(loco::DataType::S32); + _softmax_indices->size(1); + _softmax_indices->shape({1}); + _softmax_indices->at(0) = -1; + _softmax_indices->shape_status(luci::ShapeStatus::VALID); + + _max->keep_dims(true); + _sum->keep_dims(true); +} + +void SoftmaxTestGraph::init(void) +{ + TestIOGraph::init({1, 12, 11, 15}, {1, 12, 11, 15}); + SoftmaxGraphlet::init(g()); + + _ifm->x(input()); + _max->input(_ifm); + _max->reduction_indices(_softmax_indices); + + _sub->x(_ifm); + _sub->y(_max); + _sub->fusedActivationFunction(luci::FusedActFunc::NONE); + _exp->x(_sub); + _sum->input(_exp); + _sum->reduction_indices(_softmax_indices); + _div->x(_exp); + _div->y(_sum); + _div->fusedActivationFunction(luci::FusedActFunc::NONE); + + output()->from(_div); + + initMinMax(input(), 0, 1); + initMinMax(output(), 0, 1); +} + +} // namespace models + +namespace io_utils +{ + +void makeTemporaryFile(char *name_template) +{ + int fd = mkstemp(name_template); + if (fd == -1) + { + throw std::runtime_error{"mkstemp failed"}; + } +} + +void writeDataToFile(const std::string &path, const std::string &data) +{ + std::ofstream file; + file.open(path); + file << data; + file.close(); +} + +std::string makeTemporaryFolder(char *name_template) +{ + auto const res = mkdtemp(name_template); + if (res == nullptr) + { + throw std::runtime_error{"mkdtemp failed"}; + } + return res; +} + +bool isFileExists(const std::string &path) +{ + std::ifstream f(path); + return f.good(); +} + +} // namespace io_utils + +namespace data_utils +{ + +class SingleDataProvider final : public core::DataProvider +{ +public: + SingleDataProvider() = default; + size_t numSamples() const override { return 1; } + uint32_t numInputs(uint32_t) const override { return 1; } + void getSampleInput(uint32_t, uint32_t, core::InputData &data) const + { + size_t size = data.data().size() / sizeof(float); + auto floats = reinterpret_cast(data.data().data()); + for (uint32_t idx = 0; idx < size; idx++) + { + floats[idx] = 0.f; // or any other value + } + } +}; + +std::unique_ptr getAllZeroSingleDataProvider() +{ + return std::make_unique(); +} + +} // namespace data_utils + +} // namespace test +} // namespace mpqsolver diff --git a/compiler/circle-mpqsolver/src/pattern/PatternResolver.cpp b/compiler/circle-mpqsolver/src/pattern/PatternResolver.cpp new file mode 100644 index 000000000..947a8f8c9 --- /dev/null +++ b/compiler/circle-mpqsolver/src/pattern/PatternResolver.cpp @@ -0,0 +1,349 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PatternResolver.h" + +#include +#include +#include + +using namespace mpqsolver::pattern; + +using LayerParam = luci::CircleQuantizer::Options::LayerParam; + +namespace +{ + +/** + * SUBGRAPH PATTERN + * LayerNorm := (ifm - |ifm|) / sqrt(|(ifm - |ifm|)^2| + eps) + * - |x|: mean of x + * - Below diagram shows LayerNorm pattern to quantize. + * + * [In] + * | + * V + * ifm -----+ (reduction indicies) + * | | | + * | V V + * | mean_of_ifm + * V | + * +------------sub <----+ + * | | + * | V (reduction indicies) + * | sub_squared | + * | | | + * | V | + * | mean_as_variance <-------+ (const_as_eps) + * | | | + * | V | + * | add_eps<----------------------+ + * | | + * | V + * | rsqrt + * | | + * | V + * +------>mul_as_terminal + * | + * V + * [Out] + * + */ +class LayerNormPattern final +{ +public: + LayerNormPattern(luci::CircleMul *candidate) + { + assert(candidate); + mul_as_terminal = candidate; + } + +public: + bool matched(); + + std::vector get_q16_nodes() + { + return {sub_squared, mean_as_variance, add_eps, rsqrt}; + } + + std::vector get_q8_nodes() { return {mean_of_ifm, sub, mul_as_terminal}; } + +public: + loco::Node *ifm = nullptr; + luci::CircleMean *mean_of_ifm = nullptr; // = |ifm| + luci::CircleSub *sub = nullptr; // = ifm - |ifm| + luci::CircleMul *sub_squared = nullptr; // = (ifm - |ifm|)^2 + luci::CircleMean *mean_as_variance = nullptr; // = |(ifm - |ifm|)^2| + luci::CircleAdd *add_eps = nullptr; // = |(ifm - |ifm|)^2| + eps + luci::CircleRsqrt *rsqrt = nullptr; // = 1.0 / sqrt(|(ifm - |ifm|)^2| + eps) + luci::CircleMul *mul_as_terminal = nullptr; // = (ifm - |ifm|) / sqrt(|(ifm - |ifm|)^2| + eps) +}; + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +bool LayerNormPattern::matched() +{ + sub = dynamic_cast(mul_as_terminal->x()); + rsqrt = dynamic_cast(mul_as_terminal->y()); + if (!sub || !rsqrt) + { + sub = dynamic_cast(mul_as_terminal->y()); + rsqrt = dynamic_cast(mul_as_terminal->x()); + } + CHECK_OR_FALSE(rsqrt != nullptr && sub != nullptr); + + ifm = dynamic_cast(sub->x()); + mean_of_ifm = dynamic_cast(sub->y()); + CHECK_OR_FALSE(ifm != nullptr && mean_of_ifm != nullptr); + + add_eps = dynamic_cast(rsqrt->x()); + CHECK_OR_FALSE(add_eps != nullptr); + + auto const *eps = dynamic_cast(add_eps->x()); + mean_as_variance = dynamic_cast(add_eps->y()); + if (!eps || !mean_as_variance) + { + eps = dynamic_cast(add_eps->y()); + mean_as_variance = dynamic_cast(add_eps->x()); + } + + CHECK_OR_FALSE(eps != nullptr && mean_as_variance != nullptr); + + // eps should be scalar value + CHECK_OR_FALSE(eps->size() == 1); + + sub_squared = dynamic_cast(mean_as_variance->input()); + CHECK_OR_FALSE(sub_squared != nullptr); + + // check that sub_squared = sub * sub + CHECK_OR_FALSE(sub_squared->x() == sub_squared->y() && sub_squared->x() == sub); + + auto const mean_as_variance_indices = + dynamic_cast(mean_as_variance->reduction_indices()); + auto const mean_of_ifm_indices = + dynamic_cast(mean_of_ifm->reduction_indices()); + + // check validity of reduction indices + CHECK_OR_FALSE(mean_of_ifm_indices != nullptr && mean_as_variance_indices != nullptr); + + // check dtype of reduction indices + CHECK_OR_FALSE(mean_of_ifm_indices->dtype() == loco::DataType::S32 && + mean_as_variance_indices->dtype() == loco::DataType::S32); + + // reduction indices of both mean operations must be the same + CHECK_OR_FALSE(mean_as_variance_indices->size() == + mean_of_ifm_indices->size()); + + std::set set_of_mean_as_variance_indices; + std::set set_of_mean_of_ifm_indices; + for (uint32_t index = 0; index < mean_as_variance_indices->size(); index++) + { + set_of_mean_as_variance_indices.insert( + mean_as_variance_indices->at(index)); + set_of_mean_of_ifm_indices.insert(mean_of_ifm_indices->at(index)); + } + // now make sure that reduction indices of mean_as_variance are the same as mean_of_ifm + CHECK_OR_FALSE(set_of_mean_as_variance_indices == set_of_mean_of_ifm_indices); + + return true; +} + +/** + * SUBGRAPH PATTERN + * SoftmaxPattern := (exp(ifm - max(ifm))) / sum(exp(ifm - max(ifm))) + * - Below diagram shows Softmax pattern to quantize. + * + * [In] [CircleConst(=-1)] + * | \ / + * | \ / + * | [CircleReduceMax] + * | / + * | / + * | / + * [Sub] [CircleConst(=-1)] + * | | + * | | + * [Exp] | + * | \ | + * | \ | + * | [CircleSum]-----------+ + * | / + * | / + * [Div] + * | + * | + * [CircleNode] + */ +class SoftmaxPattern final +{ +public: + SoftmaxPattern(luci::CircleDiv *candidate) + { + assert(candidate); + _div_as_terminal = candidate; + } + +public: + bool matched(); + + std::vector get_q16_nodes() { return {_sub, _exp}; } + + std::vector get_q8_nodes() { return {_ifm, _max, _sum, _div_as_terminal}; } + +public: + luci::CircleNode *_ifm = nullptr; // input feature map + luci::CircleReduceMax *_max = nullptr; // = max(_ifm) + luci::CircleSub *_sub = nullptr; // = _ifm - max(_ifm) + luci::CircleExp *_exp = nullptr; // = exp(_ifm - max(_ifm)) + luci::CircleSum *_sum = nullptr; // = sum(exp(_ifm - max(_ifm))) + luci::CircleDiv *_div_as_terminal = nullptr; // = exp(_ifm - max(_ifm))/sum(exp(_ifm - max(_ifm))) +}; + +bool SoftmaxPattern::matched() +{ + _exp = dynamic_cast(_div_as_terminal->x()); + _sum = dynamic_cast(_div_as_terminal->y()); + CHECK_OR_FALSE(_exp != nullptr && _sum != nullptr); + + CHECK_OR_FALSE(_sum->input() == _exp); + CHECK_OR_FALSE(_sum->keep_dims() == true); + + auto const sum_indices = dynamic_cast(_sum->reduction_indices()); + CHECK_OR_FALSE(sum_indices != nullptr); + + _sub = dynamic_cast(_exp->x()); + CHECK_OR_FALSE(_sub != nullptr); + + _ifm = loco::must_cast(_sub->x()); + + _max = dynamic_cast(_sub->y()); + CHECK_OR_FALSE(_max != nullptr); + + CHECK_OR_FALSE(_max->input() == _ifm); + CHECK_OR_FALSE(_max->keep_dims() == true); + + auto const max_indices = dynamic_cast(_max->reduction_indices()); + CHECK_OR_FALSE(max_indices != nullptr); + + // check dtype of reduction indices + CHECK_OR_FALSE(max_indices->dtype() == loco::DataType::S32 && + sum_indices->dtype() == loco::DataType::S32); + + // reduction of max and sum must be done over the last (channel) dimension + { + CHECK_OR_FALSE(max_indices->size() == 1 && + sum_indices->size() == 1); + + auto const rank = _ifm->rank(); + int32_t last_dim = (rank == 0) ? 0 : rank - 1; + + CHECK_OR_FALSE(max_indices->at(0) == -1 || + max_indices->at(0) == last_dim); + + CHECK_OR_FALSE(sum_indices->at(0) == -1 || + sum_indices->at(0) == last_dim); + } + + return true; +} + +#undef CHECK_OR_FALSE + +} // namespace + +std::map +Q8LayerNormWithQ16VarianceResolver::resolve(const luci::Module *module) +{ + if (!module) + { + throw std::runtime_error("No module for pattern resolving"); + } + + std::map nodes_params; + for (size_t idx = 0; idx < module->size(); ++idx) + { + auto graph = module->graph(idx); + + for (auto node : loco::active_nodes(loco::output_nodes(graph))) + { + auto const mul = dynamic_cast(node); + if (!mul) + continue; + + LayerNormPattern pattern(mul); + if (!pattern.matched()) + continue; + + // set quantization parameters of recognized pattern + for (auto q16_node : pattern.get_q16_nodes()) + { + LayerParam param = {q16_node->name(), "int16", "channel"}; + nodes_params[q16_node] = param; + } + + for (auto q8_node : pattern.get_q8_nodes()) + { + LayerParam param = {q8_node->name(), "uint8", "channel"}; + nodes_params[q8_node] = param; + } + } + } + + return nodes_params; +} + +std::map +Q8SoftmaxWithQ16SubExpResolver::resolve(const luci::Module *module) +{ + if (!module) + { + throw std::runtime_error("No module for pattern resolving"); + } + + std::map nodes_params; + for (size_t idx = 0; idx < module->size(); ++idx) + { + auto graph = module->graph(idx); + + for (auto node : loco::active_nodes(loco::output_nodes(graph))) + { + auto const div = dynamic_cast(node); + if (!div) + continue; + + SoftmaxPattern pattern(div); + if (!pattern.matched()) + continue; + + // set quantization parameters of recognized pattern + for (auto q16_node : pattern.get_q16_nodes()) + { + LayerParam param = {q16_node->name(), "int16", "channel"}; + nodes_params[q16_node] = param; + } + + for (auto q8_node : pattern.get_q8_nodes()) + { + LayerParam param = {q8_node->name(), "uint8", "channel"}; + nodes_params[q8_node] = param; + } + } + } + + return nodes_params; +} diff --git a/compiler/circle-mpqsolver/src/pattern/PatternResolver.h b/compiler/circle-mpqsolver/src/pattern/PatternResolver.h new file mode 100644 index 000000000..c005bc78e --- /dev/null +++ b/compiler/circle-mpqsolver/src/pattern/PatternResolver.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MPQSOLVER_PATTERN_RESOLVER_H__ +#define __MPQSOLVER_PATTERN_RESOLVER_H__ + +#include +#include +#include + +#include + +namespace mpqsolver +{ +namespace pattern +{ + +class PatternResolver +{ +public: + virtual ~PatternResolver() = default; + virtual std::map + resolve(const luci::Module *module) = 0; +}; + +class Q8LayerNormWithQ16VarianceResolver : public PatternResolver +{ +public: + /** + * @brief resolve all nodes of LayerNorm pattern as prescribed + */ + std::map + resolve(const luci::Module *module) override; +}; + +class Q8SoftmaxWithQ16SubExpResolver : public PatternResolver +{ +public: + /** + * @brief resolve all nodes of Softmax pattern as prescribed + */ + std::map + resolve(const luci::Module *module) override; +}; + +} // namespace pattern +} // namespace mpqsolver + +#endif //__MPQSOLVER_PATTERN_RESOLVER_H__ diff --git a/compiler/circle-mpqsolver/src/pattern/PatternResolver.test.cpp b/compiler/circle-mpqsolver/src/pattern/PatternResolver.test.cpp new file mode 100644 index 000000000..1f08c9532 --- /dev/null +++ b/compiler/circle-mpqsolver/src/pattern/PatternResolver.test.cpp @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PatternResolver.h" + +#include "core/TestHelper.h" + +#include +#include + +#include + +#include +#include + +using LayerParam = luci::CircleQuantizer::Options::LayerParam; + +namespace +{ + +using namespace luci::test; + +class LayerNormGraphlet +{ +public: + LayerNormGraphlet() = default; + virtual ~LayerNormGraphlet() = default; + + void init(loco::Graph *g) + { + ifm = nullptr; + + ifm = g->nodes()->create(); + mean_of_ifm = g->nodes()->create(); + sub = g->nodes()->create(); + sub_squared = g->nodes()->create(); + mean_as_variance = g->nodes()->create(); + add_eps = g->nodes()->create(); + rsqrt = g->nodes()->create(); + mul = g->nodes()->create(); + _eps = g->nodes()->create(); + _mean_of_ifm_indices = g->nodes()->create(); + _mean_as_variance_indices = g->nodes()->create(); + + ifm->name("ifm"); + mean_of_ifm->name("mean_of_ifm"); + sub->name("sub"); + sub_squared->name("sub_squared"); + mean_as_variance->name("mean_as_variance"); + add_eps->name("add_eps"); + rsqrt->name("rsqrt"); + mul->name("mul"); + _eps->name("eps"); + _mean_of_ifm_indices->name("mean_of_ifm_indices"); + _mean_as_variance_indices->name("mean_as_variance_indices"); + + _eps->dtype(loco::DataType::FLOAT32); + _eps->size(1); + _eps->shape({1}); + _eps->at(0) = 1.e-05f; + _eps->shape_status(luci::ShapeStatus::VALID); + + _mean_of_ifm_indices->dtype(loco::DataType::S32); + _mean_of_ifm_indices->size(1); + _mean_of_ifm_indices->shape({1}); + _mean_of_ifm_indices->at(0) = -1; + _mean_of_ifm_indices->shape_status(luci::ShapeStatus::VALID); + + _mean_as_variance_indices->dtype(loco::DataType::S32); + _mean_as_variance_indices->size(1); + _mean_as_variance_indices->shape({1}); + _mean_as_variance_indices->at(0) = -1; + _mean_as_variance_indices->shape_status(luci::ShapeStatus::VALID); + } + +public: + luci::CircleAbs *ifm = nullptr; + luci::CircleMean *mean_of_ifm = nullptr; + luci::CircleSub *sub = nullptr; + luci::CircleMul *sub_squared = nullptr; + luci::CircleMean *mean_as_variance = nullptr; + luci::CircleAdd *add_eps = nullptr; + luci::CircleRsqrt *rsqrt = nullptr; + luci::CircleMul *mul = nullptr; + +protected: + luci::CircleConst *_eps = nullptr; + luci::CircleConst *_mean_of_ifm_indices = nullptr; + luci::CircleConst *_mean_as_variance_indices = nullptr; +}; + +class LayerNormTestGraph : public TestIOGraph, public LayerNormGraphlet +{ +public: + LayerNormTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 12, 11, 15}, {1, 12, 11, 15}); + LayerNormGraphlet::init(g()); + + ifm->x(input()); + mean_of_ifm->input(ifm); + mean_of_ifm->reduction_indices(_mean_of_ifm_indices); + sub->x(ifm); + sub->y(mean_of_ifm); + sub_squared->x(sub); + sub_squared->y(sub); + mean_as_variance->input(sub_squared); + mean_as_variance->reduction_indices(_mean_as_variance_indices); + add_eps->x(mean_as_variance); + add_eps->y(_eps); + rsqrt->x(add_eps); + mul->x(sub); + mul->y(rsqrt); + + output()->from(mul); + } +}; + +} // namespace + +TEST(LayerNormPatternResolverTest, resolve_pattern) +{ + auto m = luci::make_module(); + LayerNormTestGraph g; + g.init(); + g.transfer_to(m.get()); + + std::map params; + mpqsolver::pattern::Q8LayerNormWithQ16VarianceResolver resolver; + EXPECT_NO_THROW({ params = resolver.resolve(m.get()); }); + + std::set q16_nodes = {g.sub_squared, g.mean_as_variance, g.add_eps, g.rsqrt}; + std::set q8_nodes = {g.mean_of_ifm, g.sub, g.mul}; + + // params of all valid layers are set + EXPECT_EQ(params.size(), q16_nodes.size() + q8_nodes.size()); + + for (auto param : params) + { + // params of all layers are set as prescribed + if (q16_nodes.find(param.first) != q16_nodes.end()) + { + EXPECT_STREQ(param.second.dtype.c_str(), "int16"); + } + else if (q8_nodes.find(param.first) != q8_nodes.end()) + { + EXPECT_STREQ(param.second.dtype.c_str(), "uint8"); + } + } +} + +TEST(LayerNormPatternResolverTest, resolve_pattern_NEG) +{ + std::map params; + mpqsolver::pattern::Q8LayerNormWithQ16VarianceResolver resolver; + EXPECT_ANY_THROW(resolver.resolve(nullptr)); +} + +TEST(SoftmaxResolverTest, resolve_pattern) +{ + auto m = luci::make_module(); + mpqsolver::test::models::SoftmaxTestGraph g; + g.init(); + g.transfer_to(m.get()); + + std::map params; + mpqsolver::pattern::Q8SoftmaxWithQ16SubExpResolver resolver; + EXPECT_NO_THROW({ params = resolver.resolve(m.get()); }); + + std::set q16_nodes = {g._sub, g._exp}; + std::set q8_nodes = {g._ifm, g._max, g._sum, g._div}; + + // params of all valid layers are set + EXPECT_EQ(params.size(), q16_nodes.size() + q8_nodes.size()); + + for (auto param : params) + { + // params of all layers are set as prescribed + if (q16_nodes.find(param.first) != q16_nodes.end()) + { + EXPECT_STREQ(param.second.dtype.c_str(), "int16"); + } + else if (q8_nodes.find(param.first) != q8_nodes.end()) + { + EXPECT_STREQ(param.second.dtype.c_str(), "uint8"); + } + } +} + +TEST(SoftmaxPatternResolverTest, resolve_pattern_NEG) +{ + mpqsolver::pattern::Q8SoftmaxWithQ16SubExpResolver resolver; + EXPECT_ANY_THROW(resolver.resolve(nullptr)); +} diff --git a/compiler/circle-mpqsolver/src/pattern/PatternSolver.cpp b/compiler/circle-mpqsolver/src/pattern/PatternSolver.cpp new file mode 100644 index 000000000..064a72646 --- /dev/null +++ b/compiler/circle-mpqsolver/src/pattern/PatternSolver.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PatternSolver.h" + +#include "PatternResolver.h" + +#include +#include + +using namespace mpqsolver::pattern; + +using LayerParam = luci::CircleQuantizer::Options::LayerParam; +using LayerParams = luci::CircleQuantizer::Options::LayerParams; + +PatternSolver::PatternSolver(const mpqsolver::core::Quantizer::Context &ctx, + const std::vector &patterns) + : MPQSolver(ctx) +{ + MPQOptions options{patterns}; + setMPQOptions(options); +} + +std::unique_ptr PatternSolver::run(const std::string &module_path) +{ + auto module = readModule(module_path); + assert(module != nullptr); + + _quantizer->setHook(_hooks.get()); + if (_hooks) + { + _hooks->onBeginSolver(module_path, NAN, NAN); + } + + resolvePatterns(module.get()); + + auto layer_params = getFrozenParams(); + + if (_hooks) + { + _hooks->onEndSolver(layer_params, _quantizer->getContext().output_model_dtype, NAN); + } + + if (!_quantizer->quantize(module.get(), layer_params)) + { + throw std::runtime_error("Failed to quantize model"); + } + + return module; +} + +void PatternSolver::setMPQOptions(MPQOptions &options) { _options = options; } + +LayerParams PatternSolver::getFrozenParams() const +{ + LayerParams params; + for (const auto &node_to_param : _frozen._node_to_param) + { + params.push_back(std::make_shared(node_to_param.second)); + } + + return params; +} + +void PatternSolver::resolvePatterns(luci::Module *module) +{ + _frozen._node_to_param.clear(); + + for (auto pattern : _options._patterns) + { + std::unique_ptr resolver; + switch (pattern) + { + case QuantizationPattern::Q8LayerNormWithQ16Variance: + resolver = std::make_unique(); + break; + case QuantizationPattern::Q8SoftmaxWithQ16SubExp: + resolver = std::make_unique(); + break; + default: + throw std::runtime_error("Unsupported pattern to resolve"); + } + + auto const resolved = resolver->resolve(module); + for (const auto &node_param : resolved) + { + auto const frozen = _frozen._node_to_param.find(node_param.first); + if (frozen == _frozen._node_to_param.end()) + { + // node was not previously defined - just set it (no ambiguity) + _frozen._node_to_param[node_param.first] = node_param.second; + } + else if (frozen->second.dtype != node_param.second.dtype) + { + // ambiguity (incoming description conflicts with current) + throw std::runtime_error("Resolved patterns contradict each other"); + } + } + } +} diff --git a/compiler/circle-mpqsolver/src/pattern/PatternSolver.h b/compiler/circle-mpqsolver/src/pattern/PatternSolver.h new file mode 100644 index 000000000..52d740261 --- /dev/null +++ b/compiler/circle-mpqsolver/src/pattern/PatternSolver.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MPQSOLVER_PATTERN_SOLVER_H__ +#define __MPQSOLVER_PATTERN_SOLVER_H__ + +#include "MPQSolver.h" + +#include +#include +#include +#include + +namespace mpqsolver +{ +namespace pattern +{ + +enum class QuantizationPattern +{ + Q8LayerNormWithQ16Variance, + Q8SoftmaxWithQ16SubExp, +}; + +struct MPQOptions +{ + std::vector _patterns; +}; + +struct FrozenNodes +{ + std::map _node_to_param; +}; + +class PatternSolver final : public MPQSolver +{ +public: + /** + * @brief construct PatternSolver using qunatization context and patterns to apply + */ + PatternSolver(const mpqsolver::core::Quantizer::Context &ctx, + const std::vector &patterns); + + /** + * @brief run solver for recorded float module at module_path + */ + std::unique_ptr run(const std::string &module_path) override; + +private: + /** + * @brief set quantization options + */ + void setMPQOptions(MPQOptions &options); + + /** + * @brief fill _frozen with prescribed quantization parameters of resolved nodes + */ + void resolvePatterns(luci::Module *module); + + /** + * @brief transform _frozen nodes to Quantizer friendly form + */ + luci::CircleQuantizer::Options::LayerParams getFrozenParams() const; + +private: + MPQOptions _options; // options for mpq quantization + FrozenNodes _frozen; // nodes with prescribed quantization parameters +}; + +} // namespace pattern +} // namespace mpqsolver + +#endif //__MPQSOLVER_PATTERN_SOLVER_H__ diff --git a/compiler/circle-mpqsolver/src/pattern/PatternSolver.test.cpp b/compiler/circle-mpqsolver/src/pattern/PatternSolver.test.cpp new file mode 100644 index 000000000..90ff0f89e --- /dev/null +++ b/compiler/circle-mpqsolver/src/pattern/PatternSolver.test.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "PatternSolver.h" + +#include "core/TestHelper.h" + +#include +#include + +using namespace mpqsolver::pattern; + +namespace +{ + +class CircleMPQSolverPatternSolverTest : public ::testing::Test +{ +public: + CircleMPQSolverPatternSolverTest() + { + char module_template[] = "CircleMPQSolverPatternSolverTest-CIRCLE-XXXXXX"; + mpqsolver::test::io_utils::makeTemporaryFile(module_template); + _module_path = module_template; + } + + ~CircleMPQSolverPatternSolverTest() { unlink(_module_path.c_str()); } + +protected: + mpqsolver::test::models::SoftmaxTestGraph _g; + std::string _module_path; +}; + +} // namespace + +TEST_F(CircleMPQSolverPatternSolverTest, verify_results) +{ + auto m = luci::make_module(); + _g.init(); + _g.transfer_to(m.get()); + + // export to _module_path + luci::CircleExporter exporter; + luci::CircleFileExpContract contract(m.get(), _module_path); + EXPECT_TRUE(exporter.invoke(&contract)); + + // Create quantizer parameters + mpqsolver::core::Quantizer::Context ctx; + { + ctx.output_model_dtype = "uint8"; + ctx.granularity = "channel"; + ctx.input_type = "uint8"; + ctx.output_type = "uint8"; + ctx.save_min_max = false; + ctx.TF_style_maxpool = false; + } + + // create solver + mpqsolver::pattern::PatternSolver solver( + ctx, std::vector(1, QuantizationPattern::Q8SoftmaxWithQ16SubExp)); + + // run solver + auto const res = solver.run(_module_path); + EXPECT_TRUE(res.get() != nullptr); + ASSERT_EQ(1, res.get()->size()); + + auto const graph = res.get()->graph(); + ASSERT_NE(nullptr, graph); + + uint32_t exp_count = 0; + for (auto node : loco::postorder_traversal(loco::output_nodes(graph))) + { + auto const exp = dynamic_cast(node); + if (exp != nullptr) + { + exp_count += 1; + auto const dtype = exp->dtype(); + // pattern was applied + ASSERT_EQ(loco::DataType::S16, dtype); + } + } + + // the model has a single exp node + ASSERT_EQ(1, exp_count); +} + +TEST_F(CircleMPQSolverPatternSolverTest, empty_patterns_NEG) +{ + auto m = luci::make_module(); + _g.init(); + _g.transfer_to(m.get()); + + // export to _module_path + luci::CircleExporter exporter; + luci::CircleFileExpContract contract(m.get(), _module_path); + EXPECT_TRUE(exporter.invoke(&contract)); + + // Create quantizer parameters + mpqsolver::core::Quantizer::Context ctx; + { + ctx.output_model_dtype = "uint8"; + ctx.granularity = "channel"; + ctx.input_type = "uint8"; + ctx.output_type = "uint8"; + ctx.save_min_max = false; + ctx.TF_style_maxpool = false; + } + + // create solver + mpqsolver::pattern::PatternSolver solver(ctx, std::vector()); + + // run solver + auto const res = solver.run(_module_path); + EXPECT_TRUE(res.get() != nullptr); + ASSERT_EQ(1, res.get()->size()); + + auto const graph = res.get()->graph(); + ASSERT_NE(nullptr, graph); + + uint32_t exp_count = 0; + for (auto node : loco::postorder_traversal(loco::output_nodes(graph))) + { + auto const exp = dynamic_cast(node); + if (exp != nullptr) + { + exp_count += 1; + auto const dtype = exp->dtype(); + // pattern was not applied + ASSERT_EQ(loco::DataType::U8, dtype); + } + } + + // the model has a single exp node + ASSERT_EQ(1, exp_count); +} + +TEST_F(CircleMPQSolverPatternSolverTest, empty_path_NEG) +{ + // Create quantizer parameters + mpqsolver::core::Quantizer::Context ctx; + { + ctx.output_model_dtype = "uint8"; + ctx.granularity = "channel"; + ctx.input_type = "uint8"; + ctx.output_type = "uint8"; + ctx.save_min_max = false; + ctx.TF_style_maxpool = false; + } + + // create solver + mpqsolver::pattern::PatternSolver solver( + ctx, std::vector(1, QuantizationPattern::Q8LayerNormWithQ16Variance)); + + EXPECT_ANY_THROW(solver.run("")); +} diff --git a/compiler/circle-operator/CMakeLists.txt b/compiler/circle-operator/CMakeLists.txt index 33d9a96d0..a13e76eb8 100644 --- a/compiler/circle-operator/CMakeLists.txt +++ b/compiler/circle-operator/CMakeLists.txt @@ -1,6 +1,6 @@ -if(NOT TARGET mio_circle06) +if(NOT TARGET mio_circle08) return() -endif(NOT TARGET mio_circle06) +endif(NOT TARGET mio_circle08) set(DRIVER "driver/Driver.cpp") @@ -10,8 +10,8 @@ add_executable(circle-operator ${DRIVER} ${SOURCES}) target_include_directories(circle-operator PRIVATE src) target_link_libraries(circle-operator arser) target_link_libraries(circle-operator foder) -target_link_libraries(circle-operator mio_circle06) -target_link_libraries(circle-operator mio_circle06_helper) +target_link_libraries(circle-operator mio_circle08) +target_link_libraries(circle-operator mio_circle08_helper) target_link_libraries(circle-operator safemain) install(TARGETS circle-operator DESTINATION bin) diff --git a/compiler/circle-operator/requires.cmake b/compiler/circle-operator/requires.cmake index b3a2638ef..8a57c8f11 100644 --- a/compiler/circle-operator/requires.cmake +++ b/compiler/circle-operator/requires.cmake @@ -1,4 +1,4 @@ require("arser") require("foder") -require("mio-circle06") +require("mio-circle08") require("safemain") diff --git a/compiler/circle-operator/src/Dump.cpp b/compiler/circle-operator/src/Dump.cpp index 36bfe8632..dc2602238 100644 --- a/compiler/circle-operator/src/Dump.cpp +++ b/compiler/circle-operator/src/Dump.cpp @@ -27,7 +27,7 @@ namespace void dump_ops(std::ostream &os, mio::circle::Reader &reader, const cirops::DumpOption &option) { auto ops = reader.operators(); - for (uint32_t i = 0; i < ops->Length(); ++i) + for (uint32_t i = 0; i < ops->size(); ++i) { const auto op = ops->Get(i); const auto op_name = reader.opcode_name(op); diff --git a/compiler/circle-opselector/driver/Driver.cpp b/compiler/circle-opselector/driver/Driver.cpp index 5ad2b9ca3..8cfa248ad 100644 --- a/compiler/circle-opselector/driver/Driver.cpp +++ b/compiler/circle-opselector/driver/Driver.cpp @@ -80,10 +80,9 @@ int entry(int argc, char **argv) auto module = opselector::getModule(input_path); // TODO support two or more subgraphs - if (module.get()->size() != 1) + if (module.get()->size() > 1) { - std::cerr << "ERROR: Not support two or more subgraphs" << std::endl; - return EXIT_FAILURE; + std::cout << "WARNING: Only first subgraph's operators will be selected" << std::endl; } opselector::OpSelector op_selector{module.get()}; diff --git a/compiler/circle-opselector/src/OpSelector.cpp b/compiler/circle-opselector/src/OpSelector.cpp index 09a66548d..a069d8873 100644 --- a/compiler/circle-opselector/src/OpSelector.cpp +++ b/compiler/circle-opselector/src/OpSelector.cpp @@ -68,7 +68,7 @@ bool is_number(const std::vector &vec) { for (const auto &s : vec) { - if (not::is_number(s)) + if (not ::is_number(s)) { return false; } @@ -94,6 +94,15 @@ public: bool visit(const luci::CircleNode *) final { return false; } }; +class IsMultiGraphNode final : public luci::CircleNodeVisitor +{ +public: + bool visit(const luci::CircleIf *) final { return true; } + bool visit(const luci::CircleWhile *) final { return true; } + // default is false + bool visit(const luci::CircleNode *) final { return false; } +}; + std::unique_ptr make_graph(const std::vector nodes) { auto graph = loco::make_graph(); @@ -119,6 +128,14 @@ std::unique_ptr make_graph(const std::vector(arg); + if (circle_output_exclude) + { + auto clone = luci::clone_node(circle_output_exclude, graph.get()); + ctx.emplace(circle_output_exclude, clone); + continue; + } auto circle_const = dynamic_cast(arg); if (circle_const != nullptr) { @@ -247,10 +264,7 @@ namespace opselector OpSelector::OpSelector(const luci::Module *module) : _module{module} { - if (_module->size() != 1) - { - throw std::runtime_error{"ERROR: Not support two or more subgraphs"}; - } + // DO NOTHING } template <> @@ -262,7 +276,7 @@ OpSelector::select_by(const std::vector &comma_toke for (const auto &comma_token : comma_tokens) { auto dash_tokens = ::split_into_vector(comma_token, '-'); - if (not::is_number(dash_tokens)) + if (not ::is_number(dash_tokens)) { throw std::runtime_error{ "ERROR: To select operator by id, please use these args: [0-9], '-', ','"}; @@ -367,8 +381,6 @@ std::unique_ptr OpSelector::select_by(const std::string &str) throw std::runtime_error{"ERROR: Nothing was entered."}; } - assert(_module->size() == 1); - auto selected_nodes = select_by(colon_tokens); // multiout node should be considered @@ -387,6 +399,17 @@ std::unique_ptr OpSelector::select_by(const std::string &str) } selected_nodes.insert(selected_nodes.end(), output_nodes.begin(), output_nodes.end()); + // TODO support two or more subgraphs + for (const auto &n : selected_nodes) + { + IsMultiGraphNode multigraph_visitor; + bool isMultiGraph = n->accept(&multigraph_visitor); + if (isMultiGraph) + { + throw std::runtime_error{"ERROR: If or While operator can't be selected."}; + } + } + auto new_module = std::make_unique(); new_module->add(::make_graph(selected_nodes)); diff --git a/compiler/circle-part-driver/src/PModelsRunner.cpp b/compiler/circle-part-driver/src/PModelsRunner.cpp index dd2ffe22d..b7ab32061 100644 --- a/compiler/circle-part-driver/src/PModelsRunner.cpp +++ b/compiler/circle-part-driver/src/PModelsRunner.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -78,7 +79,7 @@ void save_shape(const std::string &shape_filename, const luci::CircleOutput *out template size_t tensor_size(const NodeT *node) { - uint32_t tsize = loco::size(node->dtype()); + uint32_t tsize = luci::size(node->dtype()); for (uint32_t i = 0; i < node->rank(); ++i) { assert(node->dim(i).known()); diff --git a/compiler/circle-part-value-py-test/CMakeLists.txt b/compiler/circle-part-value-py-test/CMakeLists.txt new file mode 100644 index 000000000..6a291314f --- /dev/null +++ b/compiler/circle-part-value-py-test/CMakeLists.txt @@ -0,0 +1,110 @@ +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR) + +unset(RECIPE_LIST) +unset(PARTITION_LIST) +unset(OUTPUT_COUNT_LIST) +unset(TEST_DEPS) + +macro(add RECIPE_NAME PARTITION_NAME OUTPUT_COUNT) + list(APPEND RECIPE_LIST ${RECIPE_NAME}) + list(APPEND PARTITION_LIST ${PARTITION_NAME}) + list(APPEND OUTPUT_COUNT_LIST ${OUTPUT_COUNT}) +endmacro(add) + +# Read "test.lst" +include("test.lst") + +list(LENGTH RECIPE_LIST RECIPE_LENGTH) +math(EXPR RECIPE_LENGTH_M1 "${RECIPE_LENGTH} - 1") + +foreach(IDX RANGE ${RECIPE_LENGTH_M1}) + list(GET RECIPE_LIST ${IDX} RECIPE_NAME) + list(GET PARTITION_LIST ${IDX} PARTITION_NAME) + list(GET OUTPUT_COUNT_LIST ${IDX} OUTPUT_COUNT) + + # NOTE about the name: + # Use '.recipe' name for source tflite and circle files + # Use '.part' name for actual test folder and test files + + # Output to a folder + set(PARTITIONER_OUTPUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/${PARTITION_NAME}") + + add_custom_command(OUTPUT ${PARTITIONER_OUTPUT_PATH} + COMMAND ${CMAKE_COMMAND} -E make_directory "${PARTITIONER_OUTPUT_PATH}" + COMMENT "Make directory ${PARTITIONER_OUTPUT_PATH}" + ) + + # Copy tflite + set(TFLITE_SRC_PATH "${ARTIFACTS_BIN_PATH}/${RECIPE_NAME}.tflite") + set(TFLITE_DST_PATH "${PARTITIONER_OUTPUT_PATH}/${PARTITION_NAME}.tflite") + + add_custom_command(OUTPUT ${TFLITE_DST_PATH} + COMMAND ${CMAKE_COMMAND} -E copy "${TFLITE_SRC_PATH}" "${TFLITE_DST_PATH}" + DEPENDS ${TFLITE_SRC_PATH} + COMMENT "Copy ${RECIPE_NAME}.tflite" + ) + list(APPEND TEST_DEPS ${TFLITE_DST_PATH}) + + # Copy circle + set(CIRCLE_SRC_PATH "${ARTIFACTS_BIN_PATH}/${RECIPE_NAME}.circle") + set(CIRCLE_DST_PATH "${PARTITIONER_OUTPUT_PATH}/${PARTITION_NAME}.circle") + + add_custom_command(OUTPUT ${CIRCLE_DST_PATH} + COMMAND ${CMAKE_COMMAND} -E copy "${CIRCLE_SRC_PATH}" "${CIRCLE_DST_PATH}" + DEPENDS ${CIRCLE_SRC_PATH} + COMMENT "Copy ${RECIPE_NAME}.circle" + ) + list(APPEND TEST_DEPS ${CIRCLE_DST_PATH}) + + # Copy .part + set(PART_FILE "${PARTITION_NAME}.part") + set(PART_SRC_PATH "${CMAKE_CURRENT_SOURCE_DIR}/parts/${PART_FILE}") + set(PART_DST_PATH "${PARTITIONER_OUTPUT_PATH}/${PART_FILE}") + + add_custom_command(OUTPUT ${PART_DST_PATH} + COMMAND ${CMAKE_COMMAND} -E copy "${PART_SRC_PATH}" "${PART_DST_PATH}" + DEPENDS ${PART_SRC_PATH} + COMMENT "Copy ${PART_FILE}" + ) + list(APPEND TEST_DEPS ${PART_DST_PATH}) + + # Partition connection file to generate + set(PARTITIONER_CONN_JSON "${PARTITIONER_OUTPUT_PATH}/${PARTITION_NAME}.conn.json") + + # Run partitioner + add_custom_command(OUTPUT ${PARTITIONER_CONN_JSON} + COMMAND circle-partitioner "--part_file" "${PART_FILE}" "--input_file" + "${PARTITION_NAME}.circle" "--work_path" "${PARTITIONER_OUTPUT_PATH}" + DEPENDS circle-partitioner ${PART_DST_PATH} ${CIRCLE_DST_PATH} + COMMENT "Parition ${RECIPE_NAME}.circle with ${PART_FILE}" + ) + list(APPEND TEST_DEPS ${PARTITIONER_CONN_JSON}) + + # Write .excnt file; expected count of output models + set(COUNT_FILE "${PARTITION_NAME}.excnt") + set(COUNT_FILE_PATH "${PARTITIONER_OUTPUT_PATH}/${COUNT_FILE}") + add_custom_command(OUTPUT ${COUNT_FILE_PATH} + COMMAND echo ${OUTPUT_COUNT} > ${COUNT_FILE_PATH} + DEPENDS ${PART_SRC_PATH} ${PARTITIONER_OUTPUT_PATH} + COMMENT "Write ${COUNT_FILE} with ${OUTPUT_COUNT}" + ) + list(APPEND TEST_DEPS ${COUNT_FILE_PATH}) +endforeach(IDX) + +add_custom_target(circle_part_value_py_test_prepare ALL DEPENDS ${TEST_DEPS}) +add_dependencies(circle_part_value_py_test_prepare common_artifacts_deps) + +set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_12_1") +set(TEST_LIST_FILE "test.lst") + +add_test(NAME circle_part_value_py_test + COMMAND ${VIRTUALENV}/bin/python -m pytest -sv test_circle_part_value.py + --test_list ${TEST_LIST_FILE} + --bin_dir ${CMAKE_CURRENT_BINARY_DIR} + --circle_part_driver $ + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) diff --git a/compiler/circle-part-value-py-test/README.md b/compiler/circle-part-value-py-test/README.md new file mode 100644 index 000000000..5ae38fab8 --- /dev/null +++ b/compiler/circle-part-value-py-test/README.md @@ -0,0 +1,15 @@ +# circle-part-value-py-test + +_circle-part-value-py-test_ evaluates partitioned models produced by circle-partitioner. + +### Process of evaluation + +Evaluation process is like how _luci-value-test_ does. + +1) generates random input and stores to reference input file(s) +2) executes tflite file from common-artifacts for reference output +3) partitions circle file with .part file and produces into output folder +4) executes produced partitioned circle models with reference input file(s) +5) saves output(s) of circle models to file(s) +6) compares reference output with saved output file(s) +7) fail test if values differ diff --git a/compiler/circle-part-value-py-test/conftest.py b/compiler/circle-part-value-py-test/conftest.py new file mode 100644 index 000000000..63a78f4b5 --- /dev/null +++ b/compiler/circle-part-value-py-test/conftest.py @@ -0,0 +1,32 @@ +import re + + +def extract_test_args(s): + p = re.compile('add\\((.*)\\)') + result = p.search(s) + return result.group(1) + + +def pytest_addoption(parser): + parser.addoption("--test_list", action="store", help="Path to test list") + parser.addoption("--bin_dir", action="store", help="Directory including artifacts") + parser.addoption( + "--circle_part_driver", action="store", help="Path to circle part driver") + + +def pytest_generate_tests(metafunc): + list_path = metafunc.config.getoption('test_list') + bin_dir = metafunc.config.getoption('bin_dir') + circle_part_driver = metafunc.config.getoption('circle_part_driver') + + with open(list_path) as f: + contents = [line.rstrip() for line in f] + + comment_removed = [line for line in contents if not line.startswith('#')] + newline_removed = [line for line in comment_removed if line.startswith('add(')] + test_args = [extract_test_args(line) for line in newline_removed] + # add(RECIPE_NAME PARTITION_NAME EXPECTED_OUTPUT_COUNT) + partition_list = [(arg.split()[1], bin_dir, circle_part_driver) for arg in test_args] + + if 'test_name' in metafunc.fixturenames: + metafunc.parametrize('test_name,bin_dir,part_driver_path', partition_list) diff --git a/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.001.part b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.001.part new file mode 100644 index 000000000..01b8c704e --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +ADD=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.002.part b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.002.part new file mode 100644 index 000000000..dc378a448 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.002.part @@ -0,0 +1,8 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SUB=acl_cl +DIV=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.003.part b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.003.part new file mode 100644 index 000000000..eee3fd1d1 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.003.part @@ -0,0 +1,9 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opname + +[OPNAME] +Mean_as_variance=acl_cl +Add_as_variance=acl_cl +Pow=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.part b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.part new file mode 100644 index 000000000..d4d439d27 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Net_InstanceNorm_003.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +DIV=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.001.part b/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.001.part new file mode 100644 index 000000000..496971e55 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=cpu +comply=opcode + +[OPCODE] +ADD=npu diff --git a/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.002.part b/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.002.part new file mode 100644 index 000000000..9913fea96 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.002.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=cpu +comply=opcode + +[OPCODE] +UNPACK=npu diff --git a/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.part b/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.part new file mode 100644 index 000000000..c63efc592 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Net_UnpackAdd_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=npu +comply=opcode + +[OPCODE] +UNPACK=cpu diff --git a/compiler/circle-part-value-py-test/parts/Part_Add_Sqrt_000.part b/compiler/circle-part-value-py-test/parts/Part_Add_Sqrt_000.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Add_Sqrt_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Add_Sqrt_Rsqrt_000.part b/compiler/circle-part-value-py-test/parts/Part_Add_Sqrt_Rsqrt_000.part new file mode 100644 index 000000000..c6dba9f94 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Add_Sqrt_Rsqrt_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +RSQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Add_Sub_000.001.part b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_000.001.part new file mode 100644 index 000000000..179cad191 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_000.001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opname + +[OPNAME] +add1=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Add_Sub_000.part b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_000.part new file mode 100644 index 000000000..905137ce7 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SUB=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Add_Sub_001.part b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_001.part new file mode 100644 index 000000000..41ce4b23d --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opname + +[OPNAME] +some/node/add2;and/another=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Add_Sub_002.001.part b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_002.001.part new file mode 100644 index 000000000..030653e8a --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_002.001.part @@ -0,0 +1,9 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opname + +[OPNAME] +add1=cpu +add2=acl_cl +ofm=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Add_Sub_002.002.part b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_002.002.part new file mode 100644 index 000000000..837b36269 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Add_Sub_002.002.part @@ -0,0 +1,9 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opname + +[OPNAME] +add1=acl_cl +add2=acl_cl +ofm=cpu diff --git a/compiler/circle-part-value-py-test/parts/Part_If_Add_Sub_000.001.part b/compiler/circle-part-value-py-test/parts/Part_If_Add_Sub_000.001.part new file mode 100644 index 000000000..01b8c704e --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_If_Add_Sub_000.001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +ADD=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_If_Add_Sub_001.001.part b/compiler/circle-part-value-py-test/parts/Part_If_Add_Sub_001.001.part new file mode 100644 index 000000000..01b8c704e --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_If_Add_Sub_001.001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +ADD=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_000.part b/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_000.part new file mode 100644 index 000000000..ad0842165 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=npu +comply=opcode + +[OPCODE] +MUL=npu diff --git a/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_001.part b/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_001.part new file mode 100644 index 000000000..c82b741b0 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=npu +comply=opcode + +[OPCODE] +SQRT=cpu diff --git a/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_002.part b/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_002.part new file mode 100644 index 000000000..d9d2a8e59 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Mul_Sqrt_FC_nobias_000_002.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=npu +comply=opcode + +[OPCODE] +FULLY_CONNECTED=cpu diff --git a/compiler/circle-part-value-py-test/parts/Part_Split_Add_000.part b/compiler/circle-part-value-py-test/parts/Part_Split_Add_000.part new file mode 100644 index 000000000..91af566cd --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Split_Add_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=npu +comply=opcode + +[OPCODE] +SPLIT=cpu diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_000.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_000.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_001.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_001.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_002.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_002.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_002.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_003.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_003.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_003.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_000.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_000.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_001.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_001.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_002.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_002.part new file mode 100644 index 000000000..402af87e9 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_002.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +SQRT=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_003.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_003.part new file mode 100644 index 000000000..0ec264c94 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_003.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +WWW=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_004.part b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_004.part new file mode 100644 index 000000000..febab2246 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Sqrt_Rsqrt_Add_004.part @@ -0,0 +1,6 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] diff --git a/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias.part b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias.part new file mode 100644 index 000000000..d4d439d27 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +DIV=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_001.part b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_001.part new file mode 100644 index 000000000..dbd174ee1 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=npu +comply=opcode + +[OPCODE] +TANH=cpu diff --git a/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_002.part b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_002.part new file mode 100644 index 000000000..475439a9d --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_002.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=cpu +comply=opcode + +[OPCODE] +FULLY_CONNECTED=npu diff --git a/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_003.part b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_003.part new file mode 100644 index 000000000..d9d2a8e59 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_Tanh_FC_nobias_003.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,npu +default=npu +comply=opcode + +[OPCODE] +FULLY_CONNECTED=cpu diff --git a/compiler/circle-part-value-py-test/parts/Part_While_000.part b/compiler/circle-part-value-py-test/parts/Part_While_000.part new file mode 100644 index 000000000..e469eeb26 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_While_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +MAXIMUM=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/Part_While_001.part b/compiler/circle-part-value-py-test/parts/Part_While_001.part new file mode 100644 index 000000000..e469eeb26 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/Part_While_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +MAXIMUM=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/SignatureDef_MultiOut_000.part b/compiler/circle-part-value-py-test/parts/SignatureDef_MultiOut_000.part new file mode 100644 index 000000000..e469eeb26 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/SignatureDef_MultiOut_000.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +MAXIMUM=acl_cl diff --git a/compiler/circle-part-value-py-test/parts/SignatureDef_MultiOut_001.part b/compiler/circle-part-value-py-test/parts/SignatureDef_MultiOut_001.part new file mode 100644 index 000000000..e469eeb26 --- /dev/null +++ b/compiler/circle-part-value-py-test/parts/SignatureDef_MultiOut_001.part @@ -0,0 +1,7 @@ +[partition] +backends=cpu,acl_cl +default=cpu +comply=opcode + +[OPCODE] +MAXIMUM=acl_cl diff --git a/compiler/circle-part-value-py-test/requires.cmake b/compiler/circle-part-value-py-test/requires.cmake new file mode 100644 index 000000000..a9301f947 --- /dev/null +++ b/compiler/circle-part-value-py-test/requires.cmake @@ -0,0 +1,3 @@ +require("common-artifacts") +require("circle-partitioner") +require("circle-part-driver") diff --git a/compiler/circle-part-value-py-test/test.lst b/compiler/circle-part-value-py-test/test.lst new file mode 100644 index 000000000..b7a3f403a --- /dev/null +++ b/compiler/circle-part-value-py-test/test.lst @@ -0,0 +1,58 @@ +# Add recipe names from /res/TensorFlowLiteRecipes to test. +# Only add items exist in common-artifacts test: tflite/circle files are copied as source. +# +# add(RECIPE_NAME PARTITION_NAME EXPECTED_OUTPUT_COUNT) +# EXPECTED_OUTPUT_COUNT: 0 for skip expected count test + +add(Part_Add_Sub_000 Part_Add_Sub_000 2) +add(Part_Sqrt_Rsqrt_000 Part_Sqrt_Rsqrt_000 2) +add(Part_Sqrt_Rsqrt_001 Part_Sqrt_Rsqrt_001 2) +add(Part_Sqrt_Rsqrt_002 Part_Sqrt_Rsqrt_002 4) +add(Part_Sqrt_Rsqrt_003 Part_Sqrt_Rsqrt_003 3) +add(Part_Sqrt_Rsqrt_Add_000 Part_Sqrt_Rsqrt_Add_000 3) +add(Part_Sqrt_Rsqrt_Add_001 Part_Sqrt_Rsqrt_Add_001 3) +add(Part_Sqrt_Rsqrt_Add_002 Part_Sqrt_Rsqrt_Add_002 4) +add(Part_Sqrt_Rsqrt_Add_003 Part_Sqrt_Rsqrt_Add_003 1) +add(Part_Sqrt_Rsqrt_Add_004 Part_Sqrt_Rsqrt_Add_004 1) +add(Part_Add_Sqrt_000 Part_Add_Sqrt_000 3) +add(Part_Add_Sqrt_Rsqrt_000 Part_Add_Sqrt_Rsqrt_000 3) +add(Net_InstanceNorm_003 Net_InstanceNorm_003 3) +add(Net_InstanceNorm_003 Net_InstanceNorm_003.001 5) +# skip expected count for now +add(Net_InstanceNorm_003 Net_InstanceNorm_003.002 0) + +# comply=opname +add(Part_Add_Sub_000 Part_Add_Sub_000.001 3) +add(Part_Add_Sub_001 Part_Add_Sub_001 3) +add(Part_Add_Sub_002 Part_Add_Sub_002.001 2) +add(Part_Add_Sub_002 Part_Add_Sub_002.002 2) +add(Net_InstanceNorm_003 Net_InstanceNorm_003.003 3) + +# IF with subgraphs +add(Part_If_Add_Sub_000 Part_If_Add_Sub_000.001 3) +add(Part_If_Add_Sub_001 Part_If_Add_Sub_001.001 3) + +# WHILE with subgraphs +add(Part_While_000 Part_While_000 3) +add(Part_While_001 Part_While_001 3) + +# UNPACK with multiple outputs +add(Net_UnpackAdd_001 Net_UnpackAdd_001 2) +add(Net_UnpackAdd_001 Net_UnpackAdd_001.001 2) +add(Net_UnpackAdd_001 Net_UnpackAdd_001.002 2) + +# Other multiple outputs +add(Part_Split_Add_000 Part_Split_Add_000 2) + +# test SignatureDef, with any OPCODE +add(SignatureDef_MultiOut_000 SignatureDef_MultiOut_000 0) +add(SignatureDef_MultiOut_001 SignatureDef_MultiOut_001 0) + +# FC with nobias +add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias 1) +add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias_001 2) +add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias_002 2) +add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias_003 2) +add(Part_Mul_Sqrt_FC_nobias_000 Part_Mul_Sqrt_FC_nobias_000_000 0) +add(Part_Mul_Sqrt_FC_nobias_000 Part_Mul_Sqrt_FC_nobias_000_001 0) +add(Part_Mul_Sqrt_FC_nobias_000 Part_Mul_Sqrt_FC_nobias_000_002 0) diff --git a/compiler/circle-part-value-py-test/test_circle_part_value.py b/compiler/circle-part-value-py-test/test_circle_part_value.py new file mode 100644 index 000000000..98524875f --- /dev/null +++ b/compiler/circle-part-value-py-test/test_circle_part_value.py @@ -0,0 +1,147 @@ +import numpy as np +import tensorflow as tf +import subprocess +import os +import json + + +# Compares the execution result of TFLite interpreter and partitioned model(s) from a circle model. +def part_eval(test_name, bin_dir, circle_part_driver): + artifacts_dir = os.path.join(bin_dir, test_name) + tflite_model = os.path.join(artifacts_dir, test_name + ".tflite") + circle_model = os.path.join(artifacts_dir, test_name + ".circle") + partition_conn_ini = os.path.join(artifacts_dir, test_name + ".conn.ini") + partition_conn_json = os.path.join(artifacts_dir, test_name + ".conn.json") + expected_count = os.path.join(artifacts_dir, test_name + ".excnt") + + # Check expected count of models from partitioning + try: + with open(expected_count, "r") as expected_count_file: + expected_count_line = expected_count_file.readline() + + expected_count_line = int(expected_count_line) + if expected_count_line: + with open(partition_conn_json) as json_file: + json_data = json.load(json_file) + parts_value = json_data["parts"] + if len(parts_value) != expected_count_line: + print("Partitioned model count differs from expected:", + expected_count_line) + assert False + + print("Partitioned model count expected: ", expected_count_line) + else: + print("Skip expected partitioned model count check: 0") + + except: + print("Skip expected partitioned model count check: error") + + # Build TFLite interpreter. + interpreter = tf.lite.Interpreter(tflite_model) + interpreter.allocate_tensors() + + # Read SignatureDef and get output tensor id orders for remapping + full_signatures = interpreter._get_full_signature_list() + full_signatures_outputs_remap = None + if full_signatures != None: + signature_serving_default = full_signatures.get('serving_default', None) + if signature_serving_default != None: + signature_outputs = signature_serving_default['outputs'] + + full_signatures_outputs_remap = [] + for index, (key, value) in enumerate(signature_outputs.items()): + full_signatures_outputs_remap.append(value) + + # Generate random input data. + num_inputs = len(interpreter.get_input_details()) + for i in range(num_inputs): + input_details = interpreter.get_input_details()[i] + if input_details["dtype"] == np.float32: + input_data = np.array( + np.random.random_sample(input_details["shape"]), input_details["dtype"]) + elif input_details["dtype"] == np.uint8: + input_data = np.array( + np.random.randint(0, 256, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.int16: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.int32: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.int64: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.bool_: + input_data = np.array( + np.random.choice(a=[True, False], size=input_details["shape"]), + input_details["dtype"]) + else: + assert False, "Unsupported input dtype" + + interpreter.set_tensor(input_details["index"], input_data) + input_data.tofile(circle_model + ".input" + str(i)) + + # Do inference + interpreter.invoke() + + # Execute circle-part-driver. + partition_command = [ + circle_part_driver, partition_conn_ini, + str(num_inputs), circle_model + ".input", circle_model + ".output" + ] + print("Run: ") + for arg in partition_command: + print(" ", arg, "\\") + print("", flush=True) + + # working directory into the folder as ini has relative filename of the model + subprocess.run(partition_command, check=True, cwd=artifacts_dir) + + # Compare the results. + inpt_output_details = interpreter.get_output_details() + for idx in range(len(inpt_output_details)): + output_details = inpt_output_details[idx] + output_data = np.fromfile(circle_model + ".output" + str(idx), + output_details["dtype"]) + shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r') + output_shape = [int(i) for i in shape_file.read().split(',')] + luci_output_data = np.reshape(output_data, output_shape) + output_tensor = output_details["index"] + if full_signatures_outputs_remap != None: + output_tensor = full_signatures_outputs_remap[idx] + intp_output_data = interpreter.get_tensor(output_tensor) + if output_details["dtype"] == np.uint8: + assert np.allclose( + luci_output_data, intp_output_data, rtol=0, atol=0 + ), "Execution result of " + tflite_model + " does not match with " + circle_model + elif output_details["dtype"] == np.float32: + assert np.allclose( + luci_output_data, intp_output_data, rtol=1.e-5, atol=1.e-5 + ), "Execution result of " + tflite_model + " does not match with " + circle_model + elif output_details["dtype"] == np.int64: + assert np.allclose( + luci_output_data, intp_output_data, rtol=0, atol=0 + ), "Execution result of " + tflite_model + " does not match with " + circle_model + elif output_details["dtype"] == np.int32: + assert np.allclose( + luci_output_data, intp_output_data, rtol=0, atol=0 + ), "Execution result of " + tflite_model + " does not match with " + circle_model + elif output_details["dtype"] == np.int16: + assert np.allclose( + luci_output_data, intp_output_data, rtol=0, atol=0 + ), "Execution result of " + tflite_model + " does not match with " + circle_model + elif output_details["dtype"] == np.bool_: + assert np.allclose( + luci_output_data, intp_output_data, rtol=0, atol=0 + ), "Execution result of " + tflite_model + " does not match with " + circle_model + else: + assert False, "Unsupported data type: " + output_details["dtype"] + + +# arguments must be in sync with `conftest.py` +def test_circle_part_value(test_name: str, bin_dir: str, part_driver_path: str): + part_eval(test_name, bin_dir, part_driver_path) diff --git a/compiler/circle-part-value-test/exclude.me b/compiler/circle-part-value-test/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/circle-quantizer-dredd-recipe-test/test.lst b/compiler/circle-quantizer-dredd-recipe-test/test.lst index 309069bb8..c4ad8ae32 100644 --- a/compiler/circle-quantizer-dredd-recipe-test/test.lst +++ b/compiler/circle-quantizer-dredd-recipe-test/test.lst @@ -46,6 +46,7 @@ Add(Quant_Logistic_001 DTYPE int16 GRANULARITY channel USE_QCONFIG) Add(Quant_MaxPool2D_001 DTYPE int16 GRANULARITY channel USE_QCONFIG) Add(Quant_Mean_001 DTYPE int16 GRANULARITY channel USE_QCONFIG) Add(Quant_Mul_001 DTYPE int16 GRANULARITY channel USE_QCONFIG) +Add(Quant_Mul_002 DTYPE int16 GRANULARITY channel USE_QCONFIG) Add(Quant_Neg_001 DTYPE int16 GRANULARITY channel USE_QCONFIG) Add(Quant_Pad_001 DTYPE int16 GRANULARITY channel USE_QCONFIG) Add(Quant_PRelu_001 DTYPE int16 GRANULARITY channel USE_QCONFIG) diff --git a/compiler/circle-quantizer/src/CircleQuantizer.cpp b/compiler/circle-quantizer/src/CircleQuantizer.cpp index 02b96f91e..f18642f90 100644 --- a/compiler/circle-quantizer/src/CircleQuantizer.cpp +++ b/compiler/circle-quantizer/src/CircleQuantizer.cpp @@ -26,35 +26,51 @@ #include #include -#include #include #include #include -using OptionHook = std::function; - using LayerParam = luci::CircleQuantizer::Options::LayerParam; +using LayerParams = luci::CircleQuantizer::Options::LayerParams; +using LayerParamsSet = luci::CircleQuantizer::Options::LayerParamsSet; using Algorithms = luci::CircleQuantizer::Options::Algorithm; using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters; -std::vector> read_layer_params(std::string &filename) +struct QConfReader { - Json::Value root; - std::ifstream ifs(filename); + void init(const std::string &filename) + { + std::ifstream ifs; + + ifs.open(filename); + + // Failed to open cfg file + if (not ifs.is_open()) + throw std::runtime_error("Cannot open config file. " + filename); + + JSONCPP_STRING errs; + Json::CharReaderBuilder builder; - // Failed to open cfg file - if (not ifs.is_open()) - throw std::runtime_error("Cannot open config file. " + filename); + // Failed to parse + if (not parseFromStream(builder, ifs, &_root, &errs)) + throw std::runtime_error("Cannot parse config file (json format). " + errs); - Json::CharReaderBuilder builder; - JSONCPP_STRING errs; + ifs.close(); + } + + Json::Value &root(void) { return _root; } - // Failed to parse - if (not parseFromStream(builder, ifs, &root, &errs)) - throw std::runtime_error("Cannot parse config file (json format). " + errs); +private: + Json::Value _root; +}; - auto layers = root["layers"]; - std::vector> p; +LayerParams read_layer_params(std::string &filename) +{ + QConfReader qcr; + qcr.init(filename); + + auto layers = qcr.root()["layers"]; + LayerParams p; for (auto layer : layers) { if (layer.isMember("name")) @@ -87,6 +103,46 @@ std::vector> read_layer_params(std::string &filename return p; } +LayerParamsSet read_layer_params_set(std::string &filename) +{ + LayerParamsSet lpss; + + // read default values + LayerParams lps = read_layer_params(filename); + lpss.emplace_back(lps); + + QConfReader qcr; + qcr.init(filename); + + auto layers = qcr.root()["layers"]; + // alternate names + for (const auto &layer : layers) + { + const std::string key_alt_names = "alternate"; + if (layer.isMember(key_alt_names)) + { + auto alternate = layer[key_alt_names]; + for (const auto &altkey : alternate.getMemberNames()) + { + LayerParams lps; + for (const auto &altvalue : alternate[altkey]) + { + auto l = std::make_shared(); + { + l->name = altvalue.asString(); + l->dtype = layer["dtype"].asString(); + l->granularity = layer["granularity"].asString(); + } + lps.emplace_back(l); + } + lpss.emplace_back(lps); + } + } + } + + return lpss; +} + void print_exclusive_options(void) { std::cout << "Use only one of the 3 options below." << std::endl; @@ -95,6 +151,7 @@ void print_exclusive_options(void) std::cout << " --requantize" << std::endl; std::cout << " --force_quantparam" << std::endl; std::cout << " --quantize_weights" << std::endl; + std::cout << " --quantize_onnx_fq_model" << std::endl; } void print_version(void) @@ -112,6 +169,7 @@ int entry(int argc, char **argv) const std::string qdqw = "--quantize_dequantize_weights"; const std::string qwmm = "--quantize_with_minmax"; + const std::string qofm = "--quantize_onnx_fq_model"; const std::string rq = "--requantize"; const std::string fq = "--force_quantparam"; const std::string cq = "--copy_quantparam"; @@ -123,11 +181,16 @@ int entry(int argc, char **argv) const std::string gpd = "--generate_profile_data"; + const std::string save_min_max = "--save_min_max"; + arser::Arser arser("circle-quantizer provides circle model quantization"); arser::Helper::add_version(arser, print_version); arser::Helper::add_verbose(arser); + arser.add_argument(qofm).nargs(0).default_value(false).help( + "Quantize Onnx fake-quantized (with QDQ) model"); + arser.add_argument(qdqw) .nargs(3) .type(arser::DataType::STR_VEC) @@ -148,6 +211,11 @@ int entry(int argc, char **argv) .help("Force MaxPool Op to have the same input/output quantparams. NOTE: This feature can " "degrade accuracy of some models"); + arser.add_argument(save_min_max) + .nargs(0) + .default_value(false) + .help("Save recorded min/max values."); + arser.add_argument(fake_quant) .nargs(0) .help("Convert a quantized model to a fake-quantized model. NOTE: This feature will " @@ -213,7 +281,7 @@ int entry(int argc, char **argv) } { - // only one of qdqw, qwmm, rq, fq, cq, fake_quant, qw option can be used + // only one of qdqw, qwmm, rq, fq, cq, fake_quant, qw, qofm option can be used int32_t opt_used = arser[qdqw] ? 1 : 0; opt_used += arser[qwmm] ? 1 : 0; opt_used += arser[rq] ? 1 : 0; @@ -221,6 +289,7 @@ int entry(int argc, char **argv) opt_used += arser[cq] ? 1 : 0; opt_used += arser[fake_quant] ? 1 : 0; opt_used += arser[qw] ? 1 : 0; + opt_used += arser.get(qofm) ? 1 : 0; if (opt_used != 1) { print_exclusive_options(); @@ -257,6 +326,10 @@ int entry(int argc, char **argv) auto layer_params = read_layer_params(filename); options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params); + + auto layer_params_set = read_layer_params_set(filename); + + options->layer_params_set(layer_params_set); } catch (const std::runtime_error &e) { @@ -291,6 +364,9 @@ int entry(int argc, char **argv) if (arser[tf_maxpool] and arser.get(tf_maxpool)) options->param(AlgorithmParameters::Quantize_TF_style_maxpool, "True"); + if (arser[save_min_max] and arser.get(save_min_max)) + options->param(AlgorithmParameters::Quantize_save_min_max, "True"); + if (arser[cfg]) { auto filename = arser.get(cfg); @@ -299,6 +375,10 @@ int entry(int argc, char **argv) auto layer_params = read_layer_params(filename); options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params); + + auto layer_params_set = read_layer_params_set(filename); + + options->layer_params_set(layer_params_set); } catch (const std::runtime_error &e) { @@ -308,6 +388,13 @@ int entry(int argc, char **argv) } } + if (arser.get(qofm)) + { + options->enable(Algorithms::QuantizeOnnxFakeQuantizedModel); + + options->param(AlgorithmParameters::Quantize_input_model_dtype, "onnx_fake_quant"); + } + if (arser[rq]) { auto values = arser.get>(rq); diff --git a/compiler/circle-tensordump/CMakeLists.txt b/compiler/circle-tensordump/CMakeLists.txt index ed6ddc408..c65f634e8 100644 --- a/compiler/circle-tensordump/CMakeLists.txt +++ b/compiler/circle-tensordump/CMakeLists.txt @@ -1,6 +1,6 @@ -if(NOT TARGET mio_circle06) +if(NOT TARGET mio_circle08) return() -endif(NOT TARGET mio_circle06) +endif(NOT TARGET mio_circle08) nnas_find_package(HDF5 COMPONENTS STATIC QUIET) @@ -19,8 +19,8 @@ target_include_directories(circle-tensordump PRIVATE ${HDF5_INCLUDE_DIRS}) target_link_libraries(circle-tensordump PRIVATE ${HDF5_CXX_LIBRARIES}) target_link_libraries(circle-tensordump PRIVATE arser) target_link_libraries(circle-tensordump PRIVATE foder) -target_link_libraries(circle-tensordump PRIVATE mio_circle06) -target_link_libraries(circle-tensordump PRIVATE mio_circle06_helper) +target_link_libraries(circle-tensordump PRIVATE mio_circle08) +target_link_libraries(circle-tensordump PRIVATE mio_circle08_helper) target_link_libraries(circle-tensordump PRIVATE safemain) install(TARGETS circle-tensordump DESTINATION bin) diff --git a/compiler/circle-tensordump/requires.cmake b/compiler/circle-tensordump/requires.cmake index b3a2638ef..8a57c8f11 100644 --- a/compiler/circle-tensordump/requires.cmake +++ b/compiler/circle-tensordump/requires.cmake @@ -1,4 +1,4 @@ require("arser") require("foder") -require("mio-circle06") +require("mio-circle08") require("safemain") diff --git a/compiler/circle-verify/CMakeLists.txt b/compiler/circle-verify/CMakeLists.txt index cdf74cc26..3ccdd0306 100644 --- a/compiler/circle-verify/CMakeLists.txt +++ b/compiler/circle-verify/CMakeLists.txt @@ -1,14 +1,14 @@ -if(NOT TARGET mio_circle06) - message(STATUS "Skip circle-verify: mio_circle06 not found") +if(NOT TARGET mio_circle08) + message(STATUS "Skip circle-verify: mio_circle08 not found") return() -endif(NOT TARGET mio_circle06) +endif(NOT TARGET mio_circle08) file(GLOB_RECURSE SOURCES "src/*.cpp") add_executable(circle-verify ${SOURCES}) target_include_directories(circle-verify PRIVATE src) target_link_libraries(circle-verify arser) -target_link_libraries(circle-verify mio_circle06) +target_link_libraries(circle-verify mio_circle08) target_link_libraries(circle-verify safemain) target_link_libraries(circle-verify cwrap) target_link_libraries(circle-verify foder) diff --git a/compiler/circle-verify/requires.cmake b/compiler/circle-verify/requires.cmake index 2fd44ad75..d382ef976 100644 --- a/compiler/circle-verify/requires.cmake +++ b/compiler/circle-verify/requires.cmake @@ -1,5 +1,5 @@ require("arser") -require("mio-circle06") +require("mio-circle08") require("safemain") require("cwrap") require("foder") diff --git a/compiler/circle2circle-dredd-recipe-test/test.lst b/compiler/circle2circle-dredd-recipe-test/test.lst index 2dd24af60..4bf6a80d6 100644 --- a/compiler/circle2circle-dredd-recipe-test/test.lst +++ b/compiler/circle2circle-dredd-recipe-test/test.lst @@ -10,23 +10,67 @@ ## TFLITE RECIPE -Add(Net_Preactivation_BN_000 PASS fuse_preactivation_batchnorm) +Add(BatchMatMulV2_000 PASS resolve_customop_batchmatmul) +Add(BroadcastTo_000 PASS resolve_former_customop) +Add(DepthwiseConv2D_003 PASS) +Add(FullyConnected_007 PASS replace_non_const_fc_with_batch_matmul) +Add(FullyConnected_008 PASS replace_non_const_fc_with_batch_matmul) +Add(HardSwish_001 PASS decompose_hardswish) +Add(MatMul_000 PASS resolve_customop_matmul) +Add(MaxPoolWithArgmax_000 PASS resolve_customop_max_pool_with_argmax) +Add(MaxPoolWithArgmax_001 PASS resolve_customop_max_pool_with_argmax) +Add(MaxPoolWithArgmax_002 PASS resolve_customop_max_pool_with_argmax) +Add(Net_Add_FloorMod_Gather_000 PASS remove_gather_guard) +Add(Net_Add_FullyConnected_000 PASS fuse_add_to_fullyconnected_bias) +Add(Net_Add_FullyConnected_001 PASS fuse_add_to_fullyconnected_bias) +Add(Net_Add_FullyConnected_002 PASS fuse_add_to_fullyconnected_bias) Add(Net_BroadcastTo_AddV2_000 PASS resolve_customop_add) Add(Net_BroadcastTo_AddV2_001 PASS resolve_customop_add) +Add(Net_BroadcastTo_AddV2_002 PASS resolve_customop_add) +Add(Net_Conv_Add_000 PASS fuse_add_with_conv) +Add(Net_Conv_Add_001 PASS fuse_add_with_conv) +Add(Net_Conv_Add_002 PASS fuse_add_with_conv) Add(Net_Conv_Add_Mul_000 PASS fuse_batchnorm_with_conv) Add(Net_Conv_Add_Mul_001 PASS fuse_batchnorm_with_conv) Add(Net_Conv_Add_Mul_002 PASS fuse_batchnorm_with_conv) Add(Net_Conv_FakeQuant_000 PASS remove_fakequant) -Add(Net_Conv_QuantDequant_000 PASS remove_quantdequant) Add(Net_Conv_Min_Max_000 PASS transform_min_max_to_relu6) Add(Net_Conv_Min_Relu_000 PASS transform_min_relu_to_relu6) +Add(Net_Conv_Mul_000 PASS fuse_mul_with_conv) +Add(Net_Conv_Mul_001 PASS fuse_mul_with_conv) +Add(Net_Conv_Mul_002 PASS fuse_mul_with_conv) +Add(Net_Conv_Mul_003 PASS fuse_mul_with_conv) Add(Net_Conv_PReluGraph_000 PASS fuse_prelu) +Add(Net_Conv_QuantDequant_000 PASS remove_quantdequant) Add(Net_Conv_Relu6_000 PASS fuse_activation_function) Add(Net_Duplicate_Weights_000 PASS remove_duplicate_const) Add(Net_DwConv_BN_000 PASS fuse_batchnorm_with_dwconv) Add(Net_DwConv_BN_001 PASS fuse_batchnorm_with_dwconv) +Add(Net_FC_Gelu_FC_000 PASS replace_with_fc_gelu_fc) Add(Net_FullyConnected_Add_000 PASS fold_fully_connected) +Add(Net_Gelu_000 PASS fuse_gelu) +Add(Net_Gelu_001 PASS fuse_gelu) +Add(Net_Horizontal_FullyConnected_Add_000 PASS fuse_horizontal_fc_layers) +Add(Net_InstanceNorm_001 PASS fuse_instnorm) +Add(Net_InstanceNorm_003 PASS fuse_instnorm) +Add(Net_InstanceNorm_004 PASS fuse_instnorm) +Add(Net_InstanceNorm_005 PASS fuse_instnorm) +Add(Net_InstanceNorm_006 PASS fuse_instnorm) +Add(Net_InstanceNorm_007 PASS fuse_instnorm) +Add(Net_Maximum_Minimum_000 PASS transform_min_max_to_relu6) +Add(Net_Mul_Add_000 PASS remove_unnecessary_add) +Add(Net_Mul_Add_001 PASS remove_unnecessary_add) +Add(Net_Mul_Add_002 PASS remove_unnecessary_add) +Add(Net_Mul_Add_003 PASS remove_unnecessary_add) +Add(Net_Mul_Div_000 PASS fuse_mul_with_div) +Add(Net_Mul_Div_001 PASS fuse_mul_with_div) +Add(Net_Mul_FullyConnected_000 PASS fuse_mul_to_fullyconnected_weights fold_mul) +Add(Net_Mul_FullyConnected_001 PASS fuse_mul_to_fullyconnected_weights fold_mul) +Add(Net_Mul_FullyConnected_002 PASS fuse_mul_to_fullyconnected_weights fold_mul) +Add(Net_Preactivation_BN_000 PASS fuse_preactivation_batchnorm) Add(Net_Reshape_Reshape_000 PASS remove_redundant_reshape) +Add(Net_Shape_Add_000 PASS fold_shape) +Add(Net_Sqrt_Div_000 PASS transform_sqrt_div_to_rsqrt_mul) Add(Net_Squeeze_Squeeze_000 PASS substitute_squeeze_to_reshape) Add(Net_TConv_Add_000 PASS fuse_add_with_tconv) Add(Net_TConv_Add_001 PASS fuse_add_with_tconv) @@ -37,26 +81,22 @@ Add(Net_TConv_BN_002 PASS fuse_batchnorm_with_tconv) Add(Net_TConv_BN_003 PASS fuse_batchnorm_with_tconv) Add(Net_TConv_BN_004 PASS fuse_batchnorm_with_tconv) Add(Net_TConv_BN_005 PASS fuse_batchnorm_with_tconv) -Add(Net_InstanceNorm_001 PASS fuse_instnorm) -Add(Net_InstanceNorm_003 PASS fuse_instnorm) -Add(Net_InstanceNorm_004 PASS fuse_instnorm) -Add(Net_InstanceNorm_005 PASS fuse_instnorm) -Add(Net_InstanceNorm_006 PASS fuse_instnorm) -Add(Net_InstanceNorm_007 PASS fuse_instnorm) -Add(Net_Maximum_Minimum_000 PASS transform_min_max_to_relu6) -Add(BatchMatMulV2_000 PASS resolve_customop_batchmatmul) -Add(MatMul_000 PASS resolve_customop_matmul) -Add(DepthwiseConv2D_003 PASS) +Add(Net_TConv_Slice_000 PASS fuse_slice_with_tconv) +Add(Net_TConv_Slice_001 PASS fuse_slice_with_tconv) +Add(Net_TConv_Slice_002 PASS fuse_slice_with_tconv) +Add(Net_TConv_Slice_003 PASS fuse_slice_with_tconv) +Add(Net_Trans_Reshape_Trans_000 PASS remove_unnecessary_transpose) Add(PadV2_001 PASS substitute_padv2_to_pad) +Add(Softmax_001 PASS decompose_softmax) +Add(Softmax_002 PASS decompose_softmax) Add(StridedSlice_003 PASS substitute_strided_slice_to_reshape) -Add(MaxPoolWithArgmax_000 PASS resolve_customop_max_pool_with_argmax) -Add(MaxPoolWithArgmax_001 PASS resolve_customop_max_pool_with_argmax) -Add(MaxPoolWithArgmax_002 PASS resolve_customop_max_pool_with_argmax) -Add(FullyConnected_007 PASS replace_non_const_fc_with_batch_matmul) -Add(FullyConnected_008 PASS replace_non_const_fc_with_batch_matmul) -Add(Net_Gelu_000 PASS fuse_gelu) -Add(Net_Gelu_001 PASS fuse_gelu) -Add(HardSwish_001 PASS decompose_hardswish) + +# CSE test + +Add(CSE_Quantize_000 PASS common_subexpression_elimination) +Add(CSE_Quantize_001 PASS common_subexpression_elimination) +Add(CSE_Transpose_000 PASS common_subexpression_elimination) +Add(CSE_Transpose_001 PASS common_subexpression_elimination) ## CIRCLE RECIPE @@ -74,24 +114,24 @@ Add(REGRESS_ONNX_Conv_BN_001 PASS remove_unnecessary_reshape fuse_batchnorm_with_conv) -Add(REGRESS_ONNX_Conv_BN_Relu6_001 PASS +Add(REGRESS_ONNX_Conv_BN_MeanMean_001 PASS convert_nchw_to_nhwc nchw_to_nhwc_input_shape nchw_to_nhwc_output_shape remove_redundant_transpose - transform_min_max_to_relu6 fuse_batchnorm_with_conv - fuse_activation_function) + fuse_activation_function + fuse_mean_with_mean + fuse_transpose_with_mean) -Add(REGRESS_ONNX_Conv_BN_MeanMean_001 PASS +Add(REGRESS_ONNX_Conv_BN_Relu6_001 PASS convert_nchw_to_nhwc nchw_to_nhwc_input_shape nchw_to_nhwc_output_shape remove_redundant_transpose + transform_min_max_to_relu6 fuse_batchnorm_with_conv - fuse_activation_function - fuse_mean_with_mean - fuse_transpose_with_mean) + fuse_activation_function) Add(REGRESS_ONNX_Mul_Mul_000 PASS convert_nchw_to_nhwc) diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index 6a7be2204..757c368f3 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -81,13 +81,22 @@ int entry(int argc, char **argv) add_switch(arser, "--fold_fully_connected", "This will fold FullyConnected operator with constant inputs"); add_switch(arser, "--fold_gather", "This will fold Gather operator"); + add_switch(arser, "--fold_mul", "This will fold Mul operator"); + add_switch(arser, "--fold_reshape", "This will fold Reshape operator"); + add_switch(arser, "--fold_shape", "This will fold Shape operator"); add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator"); + add_switch(arser, "--fold_squeeze", "This will fold Squeeze operator"); add_switch(arser, "--forward_reshape_to_unaryop", "This will move Reshape after UnaryOp for centain condition"); add_switch(arser, "--forward_transpose_op", "This will move Transpose Op forward if possible (for further optimization)"); add_switch(arser, "--fuse_activation_function", "This will fuse Activation function to a preceding operator"); + add_switch(arser, "--fuse_horizontal_fc_layers", + "This will fuse horizontal FullyConnected layers"); + add_switch(arser, "--fuse_add_to_fullyconnected_bias", + "This will fuse Add to following FullyConnected bias"); + add_switch(arser, "--fuse_add_with_conv", "This will fuse Add operator to Convolution operator"); add_switch(arser, "--fuse_add_with_fully_connected", "This will fuse Add operator to FullyConnected operator"); add_switch(arser, "--fuse_add_with_tconv", @@ -103,6 +112,14 @@ int entry(int argc, char **argv) add_switch(arser, "--fuse_mean_with_mean", "This will fuse two Mean operations when they follow one by one. This will fold them " "into one operation and merge reduction indices."); + add_switch(arser, "--fuse_mul_to_fullyconnected_weights", + "This will fuse Mul to following FullyConnected weights"); + add_switch(arser, "--fuse_mul_with_conv", + "This will fuse Mul operation with a preceding Conv if possible."); + add_switch(arser, "--fuse_mul_with_div", + "This will fuse Mul operation with a Div operation whose numerator is const."); + add_switch(arser, "--fuse_slice_with_tconv", + "This will fuse Slice operation with a preceding TConv if possible."); add_switch(arser, "--fuse_transpose_with_mean", "This will fuse Mean operation with a preceding Transpose under certain conditions."); add_switch(arser, "--make_batchnorm_gamma_positive", @@ -113,23 +130,35 @@ int entry(int argc, char **argv) "This will fuse BatchNorm operators of pre-activations to Convolution operator"); add_switch(arser, "--fuse_prelu", "This will fuse operators to PReLU operator"); add_switch(arser, "--fuse_gelu", "This will fuse operators to GeLU operator"); + add_switch(arser, "--fuse_rsqrt", "This will fuse operators to Rsqrt operator"); add_switch(arser, "--remove_duplicate_const", "This will remove all duplicate constant nodes"); add_switch(arser, "--remove_fakequant", "This will remove FakeQuant operators"); + add_switch(arser, "--remove_gather_guard", + "This will remove Add/FloorMod guards of Gather indices with certain conditions. " + "CAUTION: user must guarantee that indices are all non-negative values."); + add_switch(arser, "--remove_qdq_for_mpo", + "This will remove QDQ to simulate mixed-precision operator"); add_switch(arser, "--remove_quantdequant", "This will remove Quantize-Dequantize sequence"); add_switch(arser, "--remove_redundant_quantize", "This will remove redundant Quantize operators"); add_switch(arser, "--remove_redundant_reshape", "This will fuse or remove subsequent Reshape operators"); add_switch(arser, "--remove_redundant_transpose", "This will fuse or remove subsequent Transpose operators"); + add_switch(arser, "--remove_unnecessary_add", + "This will remove unnecessary add of zero constant"); add_switch(arser, "--remove_unnecessary_reshape", "This will remove unnecessary reshape operators"); add_switch(arser, "--remove_unnecessary_slice", "This will remove unnecessary slice operators"); add_switch(arser, "--remove_unnecessary_strided_slice", "This will remove unnecessary strided slice operators"); add_switch(arser, "--remove_unnecessary_split", "This will remove unnecessary split operators"); + add_switch(arser, "--remove_unnecessary_transpose", + "This will remove unnecessary transpose operators"); add_switch(arser, "--replace_cw_mul_add_with_depthwise_conv", "This will replace channel-wise mul/add with DepthwiseConv2D operator"); add_switch(arser, "--replace_sub_with_add", "This will replace sub with add operator"); + add_switch(arser, "--replace_with_fc_gelu_fc", + "This will replace a specific pattern into FC + Gelu + FC pattern."); add_switch(arser, "--resolve_customop_add", "This will convert Custom(Add) to Add operator"); add_switch(arser, "--resolve_customop_batchmatmul", "This will convert Custom(BatchMatmul) to BatchMatmul operator"); @@ -139,6 +168,8 @@ int entry(int argc, char **argv) "This will convert Custom(MaxPoolWithArgmax) to equivalent set of operators"); add_switch(arser, "--resolve_customop_splitv", "This will convert Custom(SplitV) to SplitV operator"); + add_switch(arser, "--resolve_former_customop", + "This will convert a former custom op to builtin in from schema version upgrade"); add_switch(arser, "--shuffle_weight_to_16x1float32", "This will convert weight format of FullyConnected to SHUFFLED16x1FLOAT32. Note that " "it only converts weights whose row is a multiple of 16"); @@ -169,13 +200,25 @@ int entry(int argc, char **argv) "Transform Minimum(6)-Maximum(0) pattern to Relu6 operator"); add_switch(arser, "--transform_min_relu_to_relu6", "Transform Minimum(6)-Relu pattern to Relu6 operator"); + add_switch(arser, "--transform_sqrt_div_to_rsqrt_mul", + "Transform Sqrt-Div pattern to Rsqrt-Mul operators"); add_switch(arser, "--decompose_hardswish", "Decompose HardSwish operator to Add, Mul and Relu6 operators"); + add_switch(arser, "--decompose_softmax", + "Decompose Softmax operator into multiple operators for special backends"); + add_switch(arser, "--common_subexpression_elimination", + "Perform common subexpression elimination"); add_switch(arser, "--mute_warnings", "This will turn off warning messages"); add_switch(arser, "--disable_validation", "This will turn off operator validations. May help input model investigation."); add_switch(arser, "--generate_profile_data", "This will turn on profiling data generation."); + // NOTE Experimental options; these will be removed someday + // Add experimental options here + add_switch(arser, "--exp_disable_sep_transposeconv_actfunc", + "This will turn off experimental separation of activation function from " + "TransposeConv."); + // Convert dynamic batch to single batch // Users have to use this option only when the first dimension of rank 4 input (NHWC or NCHW) // is dynamic. Remove this comment after non-rank 4 is supported. @@ -237,16 +280,30 @@ int entry(int argc, char **argv) options->enable(Algorithms::FoldFullyConnected); if (arser.get("--fold_gather")) options->enable(Algorithms::FoldGather); + if (arser.get("--fold_mul")) + options->enable(Algorithms::FoldMul); + if (arser.get("--fold_reshape")) + options->enable(Algorithms::FoldReshape); + if (arser.get("--fold_shape")) + options->enable(Algorithms::FoldShape); if (arser.get("--fold_sparse_to_dense")) options->enable(Algorithms::FoldSparseToDense); + if (arser.get("--fold_squeeze")) + options->enable(Algorithms::FoldSqueeze); if (arser.get("--forward_reshape_to_unaryop")) options->enable(Algorithms::ForwardReshapeToUnaryOp); if (arser.get("--forward_transpose_op")) options->enable(Algorithms::ForwardTransposeOp); if (arser.get("--fuse_activation_function")) options->enable(Algorithms::FuseActivationFunction); + if (arser.get("--fuse_horizontal_fc_layers")) + options->enable(Algorithms::FuseHorizontalFullyConnected); if (arser.get("--fuse_batchnorm_with_conv")) options->enable(Algorithms::FuseBatchNormWithConv); + if (arser.get("--fuse_add_to_fullyconnected_bias")) + options->enable(Algorithms::FuseAddToFullyConnectedBias); + if (arser.get("--fuse_add_with_conv")) + options->enable(Algorithms::FuseAddWithConv); if (arser.get("--fuse_add_with_fully_connected")) options->enable(Algorithms::FuseAddWithFullyConnected); if (arser.get("--fuse_add_with_tconv")) @@ -255,12 +312,20 @@ int entry(int argc, char **argv) options->enable(Algorithms::FuseBatchNormWithDwConv); if (arser.get("--fuse_batchnorm_with_tconv")) options->enable(Algorithms::FuseBatchNormWithTConv); + if (arser.get("--fuse_mul_to_fullyconnected_weights")) + options->enable(Algorithms::FuseMulToFullyConnectedWeights); + if (arser.get("--fuse_slice_with_tconv")) + options->enable(Algorithms::FuseSliceWithTConv); if (arser.get("--fuse_bcq")) options->enable(Algorithms::FuseBCQ); if (arser.get("--fuse_instnorm")) options->enable(Algorithms::FuseInstanceNorm); if (arser.get("--fuse_mean_with_mean")) options->enable(Algorithms::FuseMeanWithMean); + if (arser.get("--fuse_mul_with_conv")) + options->enable(Algorithms::FuseMulWithConv); + if (arser.get("--fuse_mul_with_div")) + options->enable(Algorithms::FuseMulWithDiv); if (arser.get("--make_batchnorm_gamma_positive")) options->enable(Algorithms::MakeBatchNormGammaPositive); if (arser.get("--fuse_preactivation_batchnorm")) @@ -269,12 +334,18 @@ int entry(int argc, char **argv) options->enable(Algorithms::FusePRelu); if (arser.get("--fuse_gelu")) options->enable(Algorithms::FuseGelu); + if (arser.get("--fuse_rsqrt")) + options->enable(Algorithms::FuseRsqrt); if (arser.get("--fuse_transpose_with_mean")) options->enable(Algorithms::FuseTransposeWithMean); if (arser.get("--remove_duplicate_const")) options->enable(Algorithms::RemoveDuplicateConst); if (arser.get("--remove_fakequant")) options->enable(Algorithms::RemoveFakeQuant); + if (arser.get("--remove_gather_guard")) + options->enable(Algorithms::RemoveGatherGuard); + if (arser.get("--remove_qdq_for_mpo")) + options->enable(Algorithms::RemoveQDQForMixedPrecisionOp); if (arser.get("--remove_quantdequant")) options->enable(Algorithms::RemoveQuantDequantSeq); if (arser.get("--remove_redundant_quantize")) @@ -283,6 +354,8 @@ int entry(int argc, char **argv) options->enable(Algorithms::RemoveRedundantReshape); if (arser.get("--remove_redundant_transpose")) options->enable(Algorithms::RemoveRedundantTranspose); + if (arser.get("--remove_unnecessary_add")) + options->enable(Algorithms::RemoveUnnecessaryAdd); if (arser.get("--remove_unnecessary_reshape")) options->enable(Algorithms::RemoveUnnecessaryReshape); if (arser.get("--remove_unnecessary_slice")) @@ -291,10 +364,14 @@ int entry(int argc, char **argv) options->enable(Algorithms::RemoveUnnecessaryStridedSlice); if (arser.get("--remove_unnecessary_split")) options->enable(Algorithms::RemoveUnnecessarySplit); + if (arser.get("--remove_unnecessary_transpose")) + options->enable(Algorithms::RemoveUnnecessaryTranspose); if (arser.get("--replace_cw_mul_add_with_depthwise_conv")) options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv); if (arser.get("--replace_sub_with_add")) options->enable(Algorithms::ReplaceSubWithAdd); + if (arser.get("--replace_with_fc_gelu_fc")) + options->enable(Algorithms::ReplaceWithFCGeluFC); if (arser.get("--resolve_customop_add")) options->enable(Algorithms::ResolveCustomOpAdd); if (arser.get("--resolve_customop_batchmatmul")) @@ -305,6 +382,8 @@ int entry(int argc, char **argv) options->enable(Algorithms::ResolveCustomOpMaxPoolWithArgmax); if (arser.get("--resolve_customop_splitv")) options->enable(Algorithms::ResolveCustomOpSplitV); + if (arser.get("--resolve_former_customop")) + options->enable(Algorithms::ResolveFormerCustomOp); if (arser.get("--shuffle_weight_to_16x1float32")) options->enable(Algorithms::ShuffleWeightTo16x1Float32); if (arser.get("--replace_non_const_fc_with_batch_matmul")) @@ -325,13 +404,27 @@ int entry(int argc, char **argv) options->enable(Algorithms::TransformMinMaxToRelu6Pass); if (arser.get("--transform_min_relu_to_relu6")) options->enable(Algorithms::TransformMinReluToRelu6Pass); + if (arser.get("--transform_sqrt_div_to_rsqrt_mul")) + options->enable(Algorithms::TransformSqrtDivToRsqrtMul); + if (arser.get("--common_subexpression_elimination")) + options->enable(Algorithms::CommonSubExpressionElimination); if (arser.get("--decompose_hardswish")) options->enable(Algorithms::DecomposeHardSwishPass); + if (arser.get("--decompose_softmax")) + options->enable(Algorithms::DecomposeSoftmaxPass); if (arser.get("--expand_broadcast_const")) options->enable(Algorithms::ExpandBroadcastConst); if (arser.get("--unroll_unidirseqlstm")) options->enable(Algorithms::UnrollUnidirSeqLSTM); + // NOTE Experimental options; these will be removed someday + // Add experimental options here + // NOTE XpSepActFromTransposeConv is enabled for default + // exp_disable_sep_act_transposeconv is to turn it off + // which will leave TransposeConv with fused activation + if (!arser.get("--exp_disable_sep_transposeconv_actfunc")) + options->enable(Algorithms::XpSepActFromTransposeConv); + if (arser.get("--mute_warnings")) settings->set(luci::UserSettings::Key::MuteWarnings, true); if (arser.get("--disable_validation")) diff --git a/compiler/circlechef/CMakeLists.txt b/compiler/circlechef/CMakeLists.txt index 56c501c24..18b58a9c1 100644 --- a/compiler/circlechef/CMakeLists.txt +++ b/compiler/circlechef/CMakeLists.txt @@ -5,10 +5,10 @@ if(NOT Protobuf_FOUND) return() endif(NOT Protobuf_FOUND) -if(NOT TARGET mio_circle06) - message(STATUS "circlechef: SKIP (missing mio-circle06)") +if(NOT TARGET mio_circle08) + message(STATUS "circlechef: SKIP (missing mio-circle08)") return() -endif(NOT TARGET mio_circle06) +endif(NOT TARGET mio_circle08) # Recipe Parser add_subdirectory(proto) diff --git a/compiler/circlechef/circle/CMakeLists.txt b/compiler/circlechef/circle/CMakeLists.txt index cdd6040b7..e50d4a64e 100644 --- a/compiler/circlechef/circle/CMakeLists.txt +++ b/compiler/circlechef/circle/CMakeLists.txt @@ -4,7 +4,7 @@ add_library(circlechef_circle STATIC ${SOURCES}) target_include_directories(circlechef_circle PUBLIC include) target_include_directories(circlechef_circle PRIVATE src) target_link_libraries(circlechef_circle circlechef_proto) -target_link_libraries(circlechef_circle mio_circle06) -target_link_libraries(circlechef_circle mio_circle06_helper) +target_link_libraries(circlechef_circle mio_circle08) +target_link_libraries(circlechef_circle mio_circle08_helper) target_link_libraries(circlechef_circle cwrap) target_link_libraries(circlechef_circle souschef) diff --git a/compiler/circlechef/circle/src/CircleImport.cpp b/compiler/circlechef/circle/src/CircleImport.cpp index f8756ef94..f983d3ebe 100644 --- a/compiler/circlechef/circle/src/CircleImport.cpp +++ b/compiler/circlechef/circle/src/CircleImport.cpp @@ -44,7 +44,7 @@ bool CircleImport::select_sub_graph(uint32_t sgindex) _inputs.clear(); _outputs.clear(); - if (_subgraphs->Length() <= sgindex) + if (_subgraphs->size() <= sgindex) { assert(false); return false; diff --git a/compiler/circlechef/circle/src/CircleImport.h b/compiler/circlechef/circle/src/CircleImport.h index 9c1d161b6..69453fdda 100644 --- a/compiler/circlechef/circle/src/CircleImport.h +++ b/compiler/circlechef/circle/src/CircleImport.h @@ -54,7 +54,7 @@ public: const std::vector &inputs() const { return _inputs; } const std::vector &outputs() const { return _outputs; } - uint32_t num_subgraph() const { return _subgraphs->Length(); } + uint32_t num_subgraph() const { return _subgraphs->size(); } circle::BuiltinOperator builtin_code(const circle::Operator *op) const; std::string opcode_name(const circle::Operator *op) const; diff --git a/compiler/circlechef/circle/src/CircleOpChefs.h b/compiler/circlechef/circle/src/CircleOpChefs.h index 6a0ce5dc3..cf8f658c1 100644 --- a/compiler/circlechef/circle/src/CircleOpChefs.h +++ b/compiler/circlechef/circle/src/CircleOpChefs.h @@ -21,6 +21,7 @@ #include "Op/BatchMatMul.h" #include "Op/BCQFullyConnected.h" #include "Op/BCQGather.h" +#include "Op/GRU.h" #include "Op/InstanceNorm.h" #endif // __CIRCLE_OP_CHEFS_H__ diff --git a/compiler/circlechef/circle/src/CircleOpRegistry.h b/compiler/circlechef/circle/src/CircleOpRegistry.h index 2bf1e19ed..a01878421 100644 --- a/compiler/circlechef/circle/src/CircleOpRegistry.h +++ b/compiler/circlechef/circle/src/CircleOpRegistry.h @@ -58,6 +58,7 @@ private: REG_TFL_OP(BATCH_MATMUL, CircleOpBatchMatMul); REG_TFL_OP(BCQ_FULLY_CONNECTED, CircleOpBCQFullyConnected); REG_TFL_OP(BCQ_GATHER, CircleOpBCQGather); + REG_TFL_OP(GRU, CircleOpGRU); REG_TFL_OP(INSTANCE_NORM, CircleOpInstanceNorm); #undef REG_TFL_OP } diff --git a/compiler/circlechef/circle/src/Convert.cpp b/compiler/circlechef/circle/src/Convert.cpp index 248687fed..8f11e00cd 100644 --- a/compiler/circlechef/circle/src/Convert.cpp +++ b/compiler/circlechef/circle/src/Convert.cpp @@ -31,10 +31,14 @@ circlechef::TensorType as_circlechef_type(const circle::TensorType type) return circlechef::INT64; case circle::TensorType_UINT8: return circlechef::UINT8; + case circle::TensorType_UINT4: + return circlechef::UINT4; case circle::TensorType_BOOL: return circlechef::BOOL; case circle::TensorType_INT16: return circlechef::INT16; + case circle::TensorType_INT4: + return circlechef::INT4; // TODO handle other types // TensorType_FLOAT16 // TensorType_STRING diff --git a/compiler/circlechef/circle/src/Convert.h b/compiler/circlechef/circle/src/Convert.h index 7842c4b01..050c24bd6 100644 --- a/compiler/circlechef/circle/src/Convert.h +++ b/compiler/circlechef/circle/src/Convert.h @@ -45,8 +45,8 @@ template std::vector as_index_vector(const flatbuffers::Vector ret(flat_array->Length()); - for (uint32_t i = 0; i < flat_array->Length(); i++) + std::vector ret(flat_array->size()); + for (uint32_t i = 0; i < flat_array->size(); i++) { ret[i] = flat_array->Get(i); } diff --git a/compiler/circlechef/circle/src/Op/GRU.cpp b/compiler/circlechef/circle/src/Op/GRU.cpp new file mode 100644 index 000000000..a45daf2f8 --- /dev/null +++ b/compiler/circlechef/circle/src/Op/GRU.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GRU.h" + +#include "Convert.h" + +namespace circlechef +{ + +void CircleOpGRU::filler(const circle::Operator *op, CircleImport *import, + circlechef::ModelRecipe *model_recipe) const +{ + // index 1, 2, 3, 4, 5 maybe constant + const std::vector &inputs = as_index_vector(op->inputs()); + assert(inputs.size() == 6); + + import->set_tensor_filler(inputs[1]); // set gaussian filler + import->set_tensor_filler(inputs[2]); + import->set_tensor_filler(inputs[3]); + import->set_tensor_filler(inputs[4]); + import->set_tensor_filler(inputs[5]); +} + +circlechef::Operation *CircleOpGRU::build(const circle::Operator *op, CircleImport *import, + circlechef::ModelRecipe *model_recipe) const +{ + auto op_params = op->builtin_options_as_GRUOptions(); + assert(op_params != nullptr); + + auto operation = model_recipe->add_operation(); + + operation->set_type("GRU"); + + auto op_options = operation->mutable_gru_options(); + + op_options->set_activation(as_circlechef_activation(op_params->fused_activation_function())); + op_options->set_return_sequences(op_params->return_sequences()); + op_options->set_time_major(op_params->time_major()); + + return operation; +} + +} // namespace circlechef diff --git a/compiler/circlechef/circle/src/Op/GRU.h b/compiler/circlechef/circle/src/Op/GRU.h new file mode 100644 index 000000000..bfd867184 --- /dev/null +++ b/compiler/circlechef/circle/src/Op/GRU.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_OP_CIRCLE_GRU_H__ +#define __CIRCLE_OP_CIRCLE_GRU_H__ + +#include "CircleOpChef.h" + +namespace circlechef +{ + +/** + * @brief circlechef operator builder for GRU + */ +class CircleOpGRU : public CircleOpChef +{ +public: + void filler(const circle::Operator *op, CircleImport *import, + circlechef::ModelRecipe *model_recipe) const override; + circlechef::Operation *build(const circle::Operator *op, CircleImport *import, + circlechef::ModelRecipe *model_recipe) const override; +}; + +} // namespace circlechef + +#endif // __CIRCLE_OP_CIRCLE_GRU_H__ diff --git a/compiler/circlechef/circle/src/RecipeChef.cpp b/compiler/circlechef/circle/src/RecipeChef.cpp index e21bca8a6..645a714b7 100644 --- a/compiler/circlechef/circle/src/RecipeChef.cpp +++ b/compiler/circlechef/circle/src/RecipeChef.cpp @@ -82,7 +82,7 @@ std::unique_ptr generate_recipe(const circle::Model *model) auto operators = circle_import.operators(); // operand fillers for adding all operators - for (uint32_t i = 0; i < operators->Length(); ++i) + for (uint32_t i = 0; i < operators->size(); ++i) { const auto *op = operators->Get(i); circle::BuiltinOperator builtincode = circle_import.builtin_code(op); @@ -99,7 +99,7 @@ std::unique_ptr generate_recipe(const circle::Model *model) } // add all operands(tensors) - for (uint32_t i = 0; i < tensors->Length(); ++i) + for (uint32_t i = 0; i < tensors->size(); ++i) { auto tensor = tensors->Get(i); @@ -198,7 +198,7 @@ std::unique_ptr generate_recipe(const circle::Model *model) } // add all operators - for (uint32_t i = 0; i < operators->Length(); ++i) + for (uint32_t i = 0; i < operators->size(); ++i) { const auto *op = operators->Get(i); circle::BuiltinOperator builtincode = circle_import.builtin_code(op); diff --git a/compiler/circlechef/core/CMakeLists.txt b/compiler/circlechef/core/CMakeLists.txt index dc1dbc4dc..073abbdfc 100644 --- a/compiler/circlechef/core/CMakeLists.txt +++ b/compiler/circlechef/core/CMakeLists.txt @@ -7,7 +7,7 @@ target_include_directories(circlechef_core PUBLIC include) target_include_directories(circlechef_core PRIVATE src) target_link_libraries(circlechef_core PUBLIC circlechef_proto) target_link_libraries(circlechef_core PUBLIC circlechef_log) -target_link_libraries(circlechef_core PUBLIC mio_circle06) +target_link_libraries(circlechef_core PUBLIC mio_circle08) target_link_libraries(circlechef_core PUBLIC souschef) target_link_libraries(circlechef_core PRIVATE nncc_coverage) diff --git a/compiler/circlechef/core/src/Convert.cpp b/compiler/circlechef/core/src/Convert.cpp index 6066324b0..d2c70de1e 100644 --- a/compiler/circlechef/core/src/Convert.cpp +++ b/compiler/circlechef/core/src/Convert.cpp @@ -56,18 +56,22 @@ circle::TensorType as_circle_tensortype(const circlechef::TensorType &value) { case circlechef::FLOAT32: return circle::TensorType_FLOAT32; + case circlechef::INT64: + return circle::TensorType_INT64; case circlechef::INT32: return circle::TensorType_INT32; + case circlechef::INT16: + return circle::TensorType_INT16; + case circlechef::INT4: + return circle::TensorType_INT4; case circlechef::UINT8: return circle::TensorType_UINT8; - case circlechef::INT64: - return circle::TensorType_INT64; + case circlechef::UINT4: + return circle::TensorType_UINT4; case circlechef::STRING: return circle::TensorType_STRING; case circlechef::BOOL: return circle::TensorType_BOOL; - case circlechef::INT16: - return circle::TensorType_INT16; default: break; } diff --git a/compiler/circlechef/core/src/Convert.test.cpp b/compiler/circlechef/core/src/Convert.test.cpp index b17f5df44..8946b6e10 100644 --- a/compiler/circlechef/core/src/Convert.test.cpp +++ b/compiler/circlechef/core/src/Convert.test.cpp @@ -44,11 +44,13 @@ TEST(ConvertTest, as_circle_activation_NEG) TEST(ConvertTest, as_circle_tensortype) { ASSERT_EQ(circle::TensorType_FLOAT32, as_circle_tensortype(circlechef::FLOAT32)); + ASSERT_EQ(circle::TensorType_INT64, as_circle_tensortype(circlechef::INT64)); ASSERT_EQ(circle::TensorType_INT32, as_circle_tensortype(circlechef::INT32)); + ASSERT_EQ(circle::TensorType_INT16, as_circle_tensortype(circlechef::INT16)); + ASSERT_EQ(circle::TensorType_INT4, as_circle_tensortype(circlechef::INT4)); ASSERT_EQ(circle::TensorType_UINT8, as_circle_tensortype(circlechef::UINT8)); - ASSERT_EQ(circle::TensorType_INT64, as_circle_tensortype(circlechef::INT64)); + ASSERT_EQ(circle::TensorType_UINT4, as_circle_tensortype(circlechef::UINT4)); ASSERT_EQ(circle::TensorType_BOOL, as_circle_tensortype(circlechef::BOOL)); - ASSERT_EQ(circle::TensorType_INT16, as_circle_tensortype(circlechef::INT16)); } TEST(ConvertTest, as_circle_tensortype_NEG) diff --git a/compiler/circlechef/core/src/DataChef.def b/compiler/circlechef/core/src/DataChef.def index c634c047e..ae3a07478 100644 --- a/compiler/circlechef/core/src/DataChef.def +++ b/compiler/circlechef/core/src/DataChef.def @@ -5,19 +5,26 @@ // DATA_CHEF(TYPE, NAME, FACTORY_CLASS) // "TYPE" SHOULD BE an enum tag of tflchef::TensorType DATA_CHEF(FLOAT32, constant, ConstantDataChefFactory) -DATA_CHEF(BOOL, constant, ConstantDataChefFactory) -DATA_CHEF(UINT8, constant, ConstantDataChefFactory) -DATA_CHEF(INT16, constant, ConstantDataChefFactory) -DATA_CHEF(INT32, constant, ConstantDataChefFactory) DATA_CHEF(INT64, constant, ConstantDataChefFactory) +DATA_CHEF(INT32, constant, ConstantDataChefFactory) +DATA_CHEF(INT16, constant, ConstantDataChefFactory) +DATA_CHEF(INT4, constant, ConstantInt4DataChefFactory) +DATA_CHEF(UINT8, constant, ConstantDataChefFactory) +DATA_CHEF(UINT4, constant, ConstantUint4DataChefFactory) +DATA_CHEF(BOOL, constant, ConstantDataChefFactory) + +DATA_CHEF(FLOAT32, explicit, ExplicitDataChefFactory) DATA_CHEF(INT64, explicit, ExplicitDataChefFactory) DATA_CHEF(INT32, explicit, ExplicitDataChefFactory) DATA_CHEF(INT16, explicit, ExplicitDataChefFactory) +DATA_CHEF(INT4, explicit, ExplicitInt4DataChefFactory) DATA_CHEF(UINT8, explicit, ExplicitDataChefFactory) -DATA_CHEF(BOOL, explicit, ExplicitDataChefFactory) -DATA_CHEF(FLOAT32, explicit, ExplicitDataChefFactory) +DATA_CHEF(UINT4, explicit, ExplicitUint4DataChefFactory) DATA_CHEF(STRING, explicit, ExplicitDataChefFactory) +DATA_CHEF(BOOL, explicit, ExplicitDataChefFactory) + DATA_CHEF(FLOAT32, gaussian, GaussianFloat32DataChefFactory) DATA_CHEF(INT32, gaussian, GaussianInt32DataChefFactory) DATA_CHEF(INT16, gaussian, GaussianInt16DataChefFactory) DATA_CHEF(UINT8, gaussian, GaussianUint8DataChefFactory) + diff --git a/compiler/circlechef/core/src/ModelChef.cpp b/compiler/circlechef/core/src/ModelChef.cpp index 6c5206dfc..62f6c8c0d 100644 --- a/compiler/circlechef/core/src/ModelChef.cpp +++ b/compiler/circlechef/core/src/ModelChef.cpp @@ -89,9 +89,11 @@ DataChefRegistry &data_chef_registry(const circlechef::TensorType &type) static DataChefRegistry s64; static DataChefRegistry fp32; static DataChefRegistry u8; + static DataChefRegistry u4; static DataChefRegistry string; static DataChefRegistry boolean; static DataChefRegistry s16; + static DataChefRegistry s4; switch (type) { @@ -103,12 +105,16 @@ DataChefRegistry &data_chef_registry(const circlechef::TensorType &type) return fp32; case circlechef::UINT8: return u8; + case circlechef::UINT4: + return u4; case circlechef::STRING: return string; case circlechef::BOOL: return boolean; case circlechef::INT16: return s16; + case circlechef::INT4: + return s4; default: break; } @@ -135,10 +141,10 @@ gather_builtincode_map(const ::circlechef::ModelRecipe &model_recipe) for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) + if (operation.type() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -151,10 +157,10 @@ gather_builtincode_map(const ::circlechef::ModelRecipe &model_recipe) const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) + if (operation.type() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -171,9 +177,11 @@ std::set gather_customcode_set(const ::circlechef::ModelRecipe &mod std::set customcode_set; for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if (operation.type() == "Custom") + { + assert(not operation.custom_code().empty()); + customcode_set.insert(operation.custom_code()); + } } // Add ops used in Graphs(subgraphs) @@ -182,9 +190,11 @@ std::set gather_customcode_set(const ::circlechef::ModelRecipe &mod const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if (operation.type() == "Custom") + { + assert(not operation.custom_code().empty()); + customcode_set.insert(operation.custom_code()); + } } } @@ -296,6 +306,34 @@ template void cook_graph(const T &graph, CookParams &cp) // Create Data int32_t count = (element_count(dims) > 0) ? element_count(dims) : filler.arg_size(); auto data_vec = chef->generate(count); + // pack for INT4 and replace data_vec + if (operand.type() == circlechef::TensorType::INT4) + { + uint32_t packed = (count + 1) / 2; + std::vector data_packed(packed); + for (uint32_t idx = 0; idx < packed; ++idx) + { + uint32_t sidx = idx * 2; + data_packed[idx] = data_vec[sidx++] & 0x0f; + if (sidx < count) + data_packed[idx] |= data_vec[sidx] << 4; + } + data_vec = data_packed; + } + // pack for UINT4 and replace data_vec + else if (operand.type() == circlechef::TensorType::UINT4) + { + uint32_t packed = (count + 1) / 2; + std::vector data_packed(packed); + for (uint32_t idx = 0; idx < packed; ++idx) + { + uint32_t sidx = idx * 2; + data_packed[idx] = data_vec[sidx++] & 0x0f; + if (sidx < count) + data_packed[idx] |= data_vec[sidx] << 4; + } + data_vec = data_packed; + } auto data = flatbuffer_builder->CreateVector(data_vec); // Create Buffer @@ -418,7 +456,11 @@ template void cook_graph(const T &graph, CookParams &cp) { assert(operation.has_type()); - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); + std::string op_type = operation.type(); + if (not operation.custom_code().empty()) + op_type = operation.custom_code(); + + auto op_chef = op_chef_registry().lookup(op_type).create(&operation); // Create 'inputs' std::vector input_vec = as_dataset(operation.input()).map(lookup).vectorize(); diff --git a/compiler/circlechef/core/src/Op/FullyConnected.cpp b/compiler/circlechef/core/src/Op/FullyConnected.cpp new file mode 100644 index 000000000..cb567b951 --- /dev/null +++ b/compiler/circlechef/core/src/Op/FullyConnected.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FullyConnected.h" +#include "Convert.h" + +#include + +flatbuffers::Offset FullyConnectedChef::value(flatbuffers::FlatBufferBuilder &fbb) const +{ + auto &operation = (*_operation); + + assert(operation.has_fullyconnected_options()); + + auto circle_activation = as_circle_activation(operation.fullyconnected_options().activation()); + + circle::FullyConnectedOptionsBuilder fc_options_builder{fbb}; + fc_options_builder.add_fused_activation_function(circle_activation); + fc_options_builder.add_keep_num_dims(operation.fullyconnected_options().keep_num_dims()); + + return fc_options_builder.Finish().Union(); +} + +std::unique_ptr +FullyConnectedChefFactory::create(const circlechef::Operation *operation) const +{ + return std::unique_ptr{new FullyConnectedChef{operation}}; +} diff --git a/compiler/circlechef/core/src/Op/FullyConnected.h b/compiler/circlechef/core/src/Op/FullyConnected.h new file mode 100644 index 000000000..56c74bb5a --- /dev/null +++ b/compiler/circlechef/core/src/Op/FullyConnected.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OP_FULLYCONNECTED_H__ +#define __OP_FULLYCONNECTED_H__ + +#include "OpChef.h" + +class FullyConnectedChef final : public OpChef +{ +public: + explicit FullyConnectedChef(const circlechef::Operation *operation) : _operation{operation} + { + // DO NOTHING + } + +public: + circle::BuiltinOperator code(void) const override + { + return circle::BuiltinOperator_FULLY_CONNECTED; + } + + circle::BuiltinOptions type(void) const override + { + return circle::BuiltinOptions_FullyConnectedOptions; + } + + flatbuffers::Offset value(flatbuffers::FlatBufferBuilder &fbb) const override; + +private: + const circlechef::Operation *_operation; +}; + +struct FullyConnectedChefFactory final : public OpChefFactory +{ + std::unique_ptr create(const circlechef::Operation *operation) const override; +}; + +#endif // __OP_FULLYCONNECTED_H__ diff --git a/compiler/circlechef/core/src/Op/GRU.cpp b/compiler/circlechef/core/src/Op/GRU.cpp new file mode 100644 index 000000000..c32e9d2fb --- /dev/null +++ b/compiler/circlechef/core/src/Op/GRU.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GRU.h" + +#include "Convert.h" + +flatbuffers::Offset GRUChef::value(flatbuffers::FlatBufferBuilder &fbb) const +{ + auto &operation = (*_operation); + + assert(operation.has_gru_options()); + auto circle_activation = as_circle_activation(operation.gru_options().activation()); + auto return_sequences = operation.gru_options().return_sequences(); + auto time_major = operation.gru_options().time_major(); + + circle::GRUOptionsBuilder options_builder{fbb}; + options_builder.add_fused_activation_function(circle_activation); + options_builder.add_return_sequences(return_sequences); + options_builder.add_time_major(time_major); + + return options_builder.Finish().Union(); +} + +std::unique_ptr GRUChefFactory::create(const circlechef::Operation *operation) const +{ + return std::unique_ptr{new GRUChef{operation}}; +} diff --git a/compiler/circlechef/core/src/Op/GRU.h b/compiler/circlechef/core/src/Op/GRU.h new file mode 100644 index 000000000..0215cb731 --- /dev/null +++ b/compiler/circlechef/core/src/Op/GRU.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OP_CIRCLE_GRU_H__ +#define __OP_CIRCLE_GRU_H__ + +#include "OpChef.h" + +class GRUChef final : public OpChef +{ +public: + explicit GRUChef(const circlechef::Operation *operation) : _operation{operation} + { + // DO NOTHING + } + +public: + circle::BuiltinOperator code(void) const override { return circle::BuiltinOperator_GRU; } + + circle::BuiltinOptions type(void) const override { return circle::BuiltinOptions_GRUOptions; } + + flatbuffers::Offset value(flatbuffers::FlatBufferBuilder &fbb) const override; + +private: + const circlechef::Operation *_operation; +}; + +struct GRUChefFactory final : public OpChefFactory +{ + std::unique_ptr create(const circlechef::Operation *operation) const override; +}; + +#endif // __OP_CIRCLE_GRU_H__ diff --git a/compiler/circlechef/core/src/OpChef.def b/compiler/circlechef/core/src/OpChef.def index 3128d3ba2..6084f6c34 100644 --- a/compiler/circlechef/core/src/OpChef.def +++ b/compiler/circlechef/core/src/OpChef.def @@ -7,4 +7,6 @@ OP_CHEF(BatchMatMul, BatchMatMulChefFactory) OP_CHEF(BCQFullyConnected, BCQFullyConnectedChefFactory) OP_CHEF(BCQGather, BCQGatherChefFactory) +OP_CHEF(FullyConnected, FullyConnectedChefFactory) +OP_CHEF(GRU, GRUChefFactory) OP_CHEF(InstanceNorm, InstanceNormChefFactory) diff --git a/compiler/circlechef/core/src/OpChefs.h b/compiler/circlechef/core/src/OpChefs.h index e13c5e0c6..10fb30c06 100644 --- a/compiler/circlechef/core/src/OpChefs.h +++ b/compiler/circlechef/core/src/OpChefs.h @@ -20,6 +20,8 @@ #include "Op/BatchMatMul.h" #include "Op/BCQFullyConnected.h" #include "Op/BCQGather.h" +#include "Op/FullyConnected.h" +#include "Op/GRU.h" #include "Op/InstanceNorm.h" #endif // __OP_CHEFS_H__ diff --git a/compiler/circlechef/proto/circlechef.proto b/compiler/circlechef/proto/circlechef.proto index d5e08576f..234d746cc 100644 --- a/compiler/circlechef/proto/circlechef.proto +++ b/compiler/circlechef/proto/circlechef.proto @@ -14,6 +14,7 @@ package circlechef; // This enum value corresponds to TensorType in TensorFlow Lite schema enum TensorType { + UINT4 = -1; FLOAT32 = 0; INT32 = 2; UINT8 = 3; @@ -21,6 +22,7 @@ enum TensorType { STRING = 5; BOOL = 6; INT16 = 7; + INT4 = 17; } message TensorShape { @@ -76,6 +78,17 @@ message InstanceNormOptions { optional Activation activation = 2 [default = NONE]; } +message FullyConnectedOptions { + optional Activation activation = 1 [default = NONE]; + optional bool keep_num_dims = 2 [ default = false ]; +} + +message GRUOptions { + optional Activation activation = 1 [default = NONE]; + optional bool return_sequences = 2 [default = false]; + optional bool time_major = 3 [default = false]; +} + message BCQFullyConnectedOptions { optional int32 weights_hidden_size = 1 [default = 0]; optional Activation activation = 2 [default = NONE]; @@ -91,11 +104,14 @@ message Operation { repeated string input = 2; repeated string output = 3; optional int32 version = 4 [default = 1]; + optional string custom_code = 5; optional BatchMatMulOptions batch_matmul_options = 100; optional InstanceNormOptions instance_norm_options = 101; optional BCQFullyConnectedOptions bcq_fully_connected_options = 102; optional BCQGatherOptions bcq_gather_options = 103; + optional GRUOptions gru_options = 104; + optional FullyConnectedOptions fullyconnected_options = 105; } // For additional subgraphs diff --git a/compiler/circlechef/requires.cmake b/compiler/circlechef/requires.cmake index 67eaa278c..77bfddc97 100644 --- a/compiler/circlechef/requires.cmake +++ b/compiler/circlechef/requires.cmake @@ -1,7 +1,6 @@ require("arser") -require("nnkit") require("cwrap") -require("mio-circle06") +require("mio-circle08") require("safemain") require("hermes") require("hermes-std") diff --git a/compiler/circlechef/tests/int4_datatype/test.recipe b/compiler/circlechef/tests/int4_datatype/test.recipe new file mode 100644 index 000000000..6bc50b0ef --- /dev/null +++ b/compiler/circlechef/tests/int4_datatype/test.recipe @@ -0,0 +1,34 @@ +operand { + name: "ifm1" + type: INT4 + shape { dim: 1 dim: 4 dim: 4 dim: 1 } +} +operand { + name: "constant" + type: INT4 + shape { dim: 1 dim: 4 dim: 4 dim: 1 } + filler { + tag: "explicit" + arg: "-8" arg: "-7" arg: "-6" arg: "-5" + arg: "-4" arg: "-3" arg: "-2" arg: "-1" + arg: "0" arg: "1" arg: "2" arg: "3" + arg: "4" arg: "5" arg: "6" arg: "7" + } +} +operand { + name: "ofm" + type: INT4 + shape { dim: 1 dim: 4 dim: 4 dim: 1 } +} +operation { + type: "BatchMatMul" + input: "ifm1" + input: "constant" + output: "ofm" + batch_matmul_options { + adjoint_lhs: false + adjoint_rhs: false + } +} +input: "ifm1" +output: "ofm" diff --git a/compiler/circlechef/tests/uint4_datatype/test.recipe b/compiler/circlechef/tests/uint4_datatype/test.recipe new file mode 100644 index 000000000..ce7e872a2 --- /dev/null +++ b/compiler/circlechef/tests/uint4_datatype/test.recipe @@ -0,0 +1,34 @@ +operand { + name: "ifm1" + type: UINT4 + shape { dim: 1 dim: 4 dim: 4 dim: 1 } +} +operand { + name: "constant" + type: UINT4 + shape { dim: 1 dim: 4 dim: 4 dim: 1 } + filler { + tag: "explicit" + arg: "0" arg: "1" arg: "2" arg: "3" + arg: "4" arg: "5" arg: "6" arg: "7" + arg: "8" arg: "9" arg: "10" arg: "11" + arg: "12" arg: "13" arg: "14" arg: "15" + } +} +operand { + name: "ofm" + type: UINT4 + shape { dim: 1 dim: 4 dim: 4 dim: 1 } +} +operation { + type: "BatchMatMul" + input: "ifm1" + input: "constant" + output: "ofm" + batch_matmul_options { + adjoint_lhs: false + adjoint_rhs: false + } +} +input: "ifm1" +output: "ofm" diff --git a/compiler/circledump/CMakeLists.txt b/compiler/circledump/CMakeLists.txt index b7326730a..9945ba0f0 100644 --- a/compiler/circledump/CMakeLists.txt +++ b/compiler/circledump/CMakeLists.txt @@ -1,7 +1,7 @@ -if(NOT TARGET mio_circle06) - message(STATUS "Skip circledump: mio_circle06 not found") +if(NOT TARGET mio_circle08) + message(STATUS "Skip circledump: mio_circle08 not found") return() -endif(NOT TARGET mio_circle06) +endif(NOT TARGET mio_circle08) set(DRIVER "driver/Driver.cpp") @@ -11,8 +11,8 @@ add_executable(circledump ${DRIVER} ${SOURCES}) target_include_directories(circledump PRIVATE include) target_link_libraries(circledump arser) target_link_libraries(circledump foder) -target_link_libraries(circledump mio_circle06) -target_link_libraries(circledump mio_circle06_helper) +target_link_libraries(circledump mio_circle08) +target_link_libraries(circledump mio_circle08_helper) target_link_libraries(circledump safemain) install(TARGETS circledump DESTINATION bin) diff --git a/compiler/circledump/README.md b/compiler/circledump/README.md index f71194b08..9fa265300 100644 --- a/compiler/circledump/README.md +++ b/compiler/circledump/README.md @@ -65,6 +65,6 @@ O T(3) ofm ### Dependency -- mio-circle06 +- mio-circle08 - safemain - FlatBuffers diff --git a/compiler/circledump/requires.cmake b/compiler/circledump/requires.cmake index b3a2638ef..8a57c8f11 100644 --- a/compiler/circledump/requires.cmake +++ b/compiler/circledump/requires.cmake @@ -1,4 +1,4 @@ require("arser") require("foder") -require("mio-circle06") +require("mio-circle08") require("safemain") diff --git a/compiler/circledump/src/Dump.cpp b/compiler/circledump/src/Dump.cpp index 69427a20e..166931648 100644 --- a/compiler/circledump/src/Dump.cpp +++ b/compiler/circledump/src/Dump.cpp @@ -126,24 +126,11 @@ void dump_sub_graph(std::ostream &os, mio::circle::Reader &reader) { auto tensors = reader.tensors(); auto operators = reader.operators(); - auto data_format = reader.data_format(); - - // dump data_format - os << "Data Format:" << std::endl; - if (data_format == circle::DataFormat::DataFormat_CHANNELS_LAST) - { - os << "CHANNEL_LAST (NHWC for 2d, NDHWC for 3d data)" << std::endl; - } - else if (data_format == circle::DataFormat::DataFormat_CHANNELS_FIRST) - { - os << "CHANNEL_FIRST (NCHW for 2d, NCDHW for 3d data)" << std::endl; - } - os << std::endl; // dump operands(tensors) os << "Operands: T(subgraph index : tensor index) TYPE (shape) (shape_signature) " << "B(buffer index) (variable) OperandName" << std::endl; - for (uint32_t i = 0; i < tensors->Length(); ++i) + for (uint32_t i = 0; i < tensors->size(); ++i) { // TODO refactor to some better structure auto tensor = tensors->Get(i); @@ -294,7 +281,7 @@ void dump_sub_graph(std::ostream &os, mio::circle::Reader &reader) os << " Option(values) ... <-- depending on OpCode" << std::endl; os << " I T(tensor index) OperandName <-- as input" << std::endl; os << " O T(tensor index) OperandName <-- as output" << std::endl; - for (uint32_t i = 0; i < operators->Length(); ++i) + for (uint32_t i = 0; i < operators->size(); ++i) { const auto op = operators->Get(i); circle::BuiltinOperator builtincode = reader.builtin_code(op); @@ -392,7 +379,7 @@ void dump_model(std::ostream &os, const circle::Model *model) // dump buffer os << "Buffers: B(index) (length) values, if any" << std::endl; - for (uint32_t i = 0; i < buffers->Length(); ++i) + for (uint32_t i = 0; i < buffers->size(); ++i) { const uint8_t *buff_data; size_t size = reader.buffer_info(i, &buff_data); @@ -410,7 +397,7 @@ void dump_model(std::ostream &os, const circle::Model *model) if (metadata != nullptr) { os << "metadata : B(index) name" << std::endl; - for (uint32_t i = 0; i < metadata->Length(); ++i) + for (uint32_t i = 0; i < metadata->size(); ++i) { const auto buff_id = metadata->Get(i)->buffer(); const auto metadata_name = metadata->Get(i)->name()->str(); @@ -430,14 +417,14 @@ void dump_model(std::ostream &os, const circle::Model *model) if (signaturedefs != nullptr) { os << "SignatureDef" << std::endl; - for (uint32_t i = 0; i < signaturedefs->Length(); ++i) + for (uint32_t i = 0; i < signaturedefs->size(); ++i) { auto sign_i = signaturedefs->Get(i); os << "S(" << i << ") signature_key(" << sign_i->signature_key()->c_str() << "), sub_graph(" << sign_i->subgraph_index() << ")" << std::endl; auto inputs_i = sign_i->inputs(); - for (uint32_t t = 0; t < inputs_i->Length(); ++t) + for (uint32_t t = 0; t < inputs_i->size(); ++t) { auto inputs_i_t = inputs_i->Get(t); os << " I(" << t << ")" @@ -446,7 +433,7 @@ void dump_model(std::ostream &os, const circle::Model *model) } auto outputs_i = sign_i->outputs(); - for (uint32_t t = 0; t < outputs_i->Length(); ++t) + for (uint32_t t = 0; t < outputs_i->size(); ++t) { auto outputs_i_t = outputs_i->Get(t); os << " O(" << t << ")" diff --git a/compiler/circledump/src/OpPrinter.cpp b/compiler/circledump/src/OpPrinter.cpp index bfcb1ec18..61a0941ea 100644 --- a/compiler/circledump/src/OpPrinter.cpp +++ b/compiler/circledump/src/OpPrinter.cpp @@ -135,7 +135,7 @@ public: if (auto conv_params = op->builtin_options_as_Conv2DOptions()) { os << " "; - os << "Padding(" << conv_params->padding() << ") "; + os << "Padding(" << EnumNamePadding(conv_params->padding()) << ") "; os << "Stride.W(" << conv_params->stride_w() << ") "; os << "Stride.H(" << conv_params->stride_h() << ") "; os << "Dilation.W(" << conv_params->dilation_w_factor() << ") "; @@ -184,7 +184,7 @@ public: if (auto pool_params = op->builtin_options_as_Pool2DOptions()) { os << " "; - os << "Padding(" << pool_params->padding() << ") "; + os << "Padding(" << EnumNamePadding(pool_params->padding()) << ") "; os << "Stride.W(" << pool_params->stride_w() << ") "; os << "Stride.H(" << pool_params->stride_h() << ") "; os << "Filter.W(" << pool_params->filter_width() << ") "; @@ -298,7 +298,7 @@ public: if (auto conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { os << " "; - os << "Padding(" << conv_params->padding() << ") "; + os << "Padding(" << EnumNamePadding(conv_params->padding()) << ") "; os << "Stride.W(" << conv_params->stride_w() << ") "; os << "Stride.H(" << conv_params->stride_h() << ") "; os << "DepthMultiplier(" << conv_params->depth_multiplier() << ") "; @@ -662,7 +662,7 @@ public: if (auto params = op->builtin_options_as_TransposeConvOptions()) { os << " "; - os << "Padding(" << params->padding() << ") "; + os << "Padding(" << EnumNamePadding(params->padding()) << ") "; os << "Stride.W(" << params->stride_w() << ") "; os << "Stride.H(" << params->stride_h() << ") "; os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function()) @@ -791,6 +791,24 @@ public: } }; +class GRUPrinter : public OpPrinter +{ +public: + void options(const circle::Operator *op, std::ostream &os) const override + { + if (auto *params = op->builtin_options_as_GRUOptions()) + { + os << " "; + os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function()) + << ") "; + os << "return_sequences(" << params->return_sequences() << ") "; + os << "time_major(" << params->time_major() << ") "; + + os << std::endl; + } + } +}; + class InstanceNormPrinter : public OpPrinter { public: @@ -853,6 +871,7 @@ OpPrinterRegistry::OpPrinterRegistry() // There is no Option for PRELU // There is no Option for RELU // There is no Option for RELU6 + // There is no Option for RELU_0_TO_1 // There is no Option for RELU_N1_TO_1 _op_map[circle::BuiltinOperator_REDUCE_ANY] = make_unique(); _op_map[circle::BuiltinOperator_REDUCE_MAX] = make_unique(); @@ -891,6 +910,7 @@ OpPrinterRegistry::OpPrinterRegistry() // Circle only _op_map[circle::BuiltinOperator_BCQ_FULLY_CONNECTED] = make_unique(); _op_map[circle::BuiltinOperator_BCQ_GATHER] = make_unique(); + _op_map[circle::BuiltinOperator_GRU] = make_unique(); _op_map[circle::BuiltinOperator_INSTANCE_NORM] = make_unique(); } diff --git a/compiler/cli/exclude.me b/compiler/cli/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/coco/exclude.me b/compiler/coco/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/common-artifacts/CMakeLists.txt b/compiler/common-artifacts/CMakeLists.txt index 2b032034a..61c2f44b3 100644 --- a/compiler/common-artifacts/CMakeLists.txt +++ b/compiler/common-artifacts/CMakeLists.txt @@ -49,6 +49,7 @@ if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "aarch64") COMMAND ${CMAKE_COMMAND} -E echo "flatbuffers==23.5.26" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} COMMAND ${CMAKE_COMMAND} -E echo "protobuf==4.23.3" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} COMMAND ${CMAKE_COMMAND} -E echo "pydot==1.4.2" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} + COMMAND ${CMAKE_COMMAND} -E echo "pytest==7.4.3" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} COMMAND ${VIRTUALENV_OVERLAY_TF_2_12_1}/bin/${PYTHON_OVERLAY} -m pip --default-timeout=1000 ${PIP_OPTION_TRUSTED_HOST} install --upgrade pip setuptools COMMAND ${VIRTUALENV_OVERLAY_TF_2_12_1}/bin/${PYTHON_OVERLAY} -m pip --default-timeout=1000 @@ -63,6 +64,7 @@ else(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "aarch64") COMMAND ${CMAKE_COMMAND} -E echo "flatbuffers==23.5.26" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} COMMAND ${CMAKE_COMMAND} -E echo "protobuf==4.23.3" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} COMMAND ${CMAKE_COMMAND} -E echo "pydot==1.4.2" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} + COMMAND ${CMAKE_COMMAND} -E echo "pytest==7.4.3" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_12_1} COMMAND ${VIRTUALENV_OVERLAY_TF_2_12_1}/bin/${PYTHON_OVERLAY} -m pip --default-timeout=1000 ${PIP_OPTION_TRUSTED_HOST} install --upgrade pip setuptools COMMAND ${VIRTUALENV_OVERLAY_TF_2_12_1}/bin/${PYTHON_OVERLAY} -m pip --default-timeout=1000 diff --git a/compiler/common-artifacts/exclude.lst b/compiler/common-artifacts/exclude.lst index 75055225b..25eb12467 100644 --- a/compiler/common-artifacts/exclude.lst +++ b/compiler/common-artifacts/exclude.lst @@ -30,7 +30,6 @@ tcgenerate(BatchToSpaceND_000) tcgenerate(BroadcastTo_000) # luci-interpreter doesn't support custom operator tcgenerate(Ceil_000) tcgenerate(Conv2D_003) # runtime doesn't support dilation -tcgenerate(Cos_000) tcgenerate(Densify_000) # luci-interpreter doesn't support tcgenerate(DepthwiseConv2D_001) # runtime doesn't support dilation tcgenerate(DepthwiseConv2D_003) # runtime doesn't support dilation @@ -42,6 +41,9 @@ tcgenerate(Fill_000) tcgenerate(Fill_001) tcgenerate(FloorMod_000) tcgenerate(FloorMod_001) +tcgenerate(FullyConnected_I4_000) +tcgenerate(FullyConnected_I4_001) +tcgenerate(FullyConnected_I4_002) tcgenerate(FullyConnected_U8_000) tcgenerate(GatherNd_000) tcgenerate(GatherNd_001) @@ -69,6 +71,7 @@ tcgenerate(Net_Conv_FakeQuant_000) # luci-interpreter doesn't support FakeQuant tcgenerate(Net_Dangle_001) tcgenerate(Net_Densify_Add_000) # luci-interpreter doesn't support Densify yet tcgenerate(Net_Densify_Dequantize_Add_000) # luci-interpreter doesn't support Densify/Dequantize yet +tcgenerate(Net_FC_Gelu_FC_000) # luci-interpreter doesn't support custom operator Erf tcgenerate(Net_Gather_SparseToDense_AddV2_000) # luci-interpreter doesn't support custom operator tcgenerate(Net_Gelu_000) # luci-interpreter doesn't support custom operator tcgenerate(Net_Gelu_001) # luci-interpreter doesn't support custom operator @@ -130,7 +133,6 @@ tcgenerate(SelectV2_000) tcgenerate(SelectV2_001) tcgenerate(SelectV2_002) tcgenerate(Shape_000) -tcgenerate(Sin_000) tcgenerate(Slice_001) # luci-interpreter doesn't support Slice with -1 tcgenerate(SpaceToBatchND_000) tcgenerate(SpaceToBatchND_001) @@ -143,7 +145,6 @@ tcgenerate(Sum_000) tcgenerate(Sum_001) tcgenerate(Sum_dynamic_000) # TestDataGenerator does not support unknown dimension tcgenerate(Sum_dynamic_001) # TestDataGenerator does not support unknown dimension -tcgenerate(Tile_000) tcgenerate(Tile_U8_000) tcgenerate(TopKV2_000) tcgenerate(TopKV2_001) @@ -167,5 +168,11 @@ tcgenerate(ZerosLike_000) tcgenerate(BCQFullyConnected_000) tcgenerate(BCQFullyConnected_001) tcgenerate(BCQGather_000) +tcgenerate(CircleBatchMatMul_I4_000) +tcgenerate(CircleBatchMatMul_U4_000) +tcgenerate(CircleFullyConnected_U4_000) +tcgenerate(CircleFullyConnected_U4_001) +tcgenerate(CircleFullyConnected_U4_002) +tcgenerate(GRU_000) # luci-interpreter does not support custom GRU tcgenerate(InstanceNorm_000) tcgenerate(InstanceNorm_001) diff --git a/compiler/common-artifacts/src/TestDataGenerator.cpp b/compiler/common-artifacts/src/TestDataGenerator.cpp index 7481050c5..72f4fbb88 100644 --- a/compiler/common-artifacts/src/TestDataGenerator.cpp +++ b/compiler/common-artifacts/src/TestDataGenerator.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -265,7 +266,7 @@ int entry(int argc, char **argv) input_file.createDataSet("value/" + std::to_string(input_index), dtype, *dataspace)); auto data_size = ::element_num(dims); - auto dtype_size = loco::size(input_node->dtype()); + auto dtype_size = luci::size(input_node->dtype()); auto byte_size = dtype_size * data_size; std::vector data(byte_size); @@ -329,7 +330,7 @@ int entry(int argc, char **argv) auto dataset = std::make_unique( output_file.createDataSet("value/" + std::to_string(output_index), dtype, *dataspace)); - uint32_t tensor_bytesize = loco::size(output_node->dtype()); + uint32_t tensor_bytesize = luci::size(output_node->dtype()); tensor_bytesize *= ::element_num(dims); std::vector output_data(tensor_bytesize); interpreter.readOutputTensor(output_node, output_data.data(), output_data.size()); diff --git a/compiler/dalgona/src/Dalgona.cpp b/compiler/dalgona/src/Dalgona.cpp index 1a35b6d03..e662c2074 100644 --- a/compiler/dalgona/src/Dalgona.cpp +++ b/compiler/dalgona/src/Dalgona.cpp @@ -18,6 +18,7 @@ #include "PythonHooks.h" #include "RandomUtils.h" +#include #include #include #include @@ -51,7 +52,7 @@ template size_t getByteSize(const NodeT *node) { assert(node != nullptr); // FIX_CALLER_UNLESS - uint32_t dtype_size = loco::size(node->dtype()); + uint32_t dtype_size = luci::size(node->dtype()); return static_cast(dtype_size) * static_cast(numElements(node)); } diff --git a/compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h b/compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h index add441147..7ddda5574 100644 --- a/compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h +++ b/compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h @@ -61,16 +61,16 @@ public: * @param buffer_bytes : byte size of the buffer */ void readTensor(int32_t data_idx, int32_t input_idx, loco::DataType *dtype, - std::vector *shape, void *buffer, size_t buffer_bytes); + std::vector *shape, void *buffer, size_t buffer_bytes) const; // Read a raw tensor (no type/shape is specified) - void readTensor(int32_t data_idx, int32_t input_idx, void *buffer, size_t buffer_bytes); + void readTensor(int32_t data_idx, int32_t input_idx, void *buffer, size_t buffer_bytes) const; - bool isRawData() { return _group.attrExists("rawData"); } + bool isRawData() const { return _group.attrExists("rawData"); } - int32_t numData() { return _group.getNumObjs(); } + int32_t numData() const { return _group.getNumObjs(); } - int32_t numInputs(int32_t data_idx); + int32_t numInputs(int32_t data_idx) const; private: H5::H5File _file; diff --git a/compiler/dio-hdf5/src/HDF5Importer.cpp b/compiler/dio-hdf5/src/HDF5Importer.cpp index 920899058..22139611e 100644 --- a/compiler/dio-hdf5/src/HDF5Importer.cpp +++ b/compiler/dio-hdf5/src/HDF5Importer.cpp @@ -122,14 +122,14 @@ HDF5Importer::HDF5Importer(const std::string &path) _file = H5::H5File(path, H5F_ACC_RDONLY); } -int32_t HDF5Importer::numInputs(int32_t record_idx) +int32_t HDF5Importer::numInputs(int32_t record_idx) const { auto records = _group.openGroup(std::to_string(record_idx)); return records.getNumObjs(); } void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, void *buffer, - size_t buffer_bytes) + size_t buffer_bytes) const { auto record = _group.openGroup(std::to_string(record_idx)); auto tensor = record.openDataSet(std::to_string(input_idx)); @@ -141,7 +141,7 @@ void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, void *buffe } void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, DataType *dtype, Shape *shape, - void *buffer, size_t buffer_bytes) + void *buffer, size_t buffer_bytes) const { auto record = _group.openGroup(std::to_string(record_idx)); auto tensor = record.openDataSet(std::to_string(input_idx)); diff --git a/compiler/embedded-import-value-test/src/TestDriver.cpp b/compiler/embedded-import-value-test/src/TestDriver.cpp index 63fd745eb..b937ba5b8 100644 --- a/compiler/embedded-import-value-test/src/TestDriver.cpp +++ b/compiler/embedded-import-value-test/src/TestDriver.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -31,7 +32,7 @@ namespace uint32_t tensor_size_of(const luci::CircleNode *node) { - uint32_t tensor_size = loco::size(node->dtype()); + uint32_t tensor_size = luci::size(node->dtype()); for (uint32_t i = 0; i < node->rank(); ++i) tensor_size *= node->dim(i).value(); return tensor_size; @@ -45,8 +46,8 @@ std::vector random_data_for(const luci::CircleInput *node) // define size of buffer in elements const auto dtype = node->dtype(); - assert(inputs_data.size() % loco::size(dtype) == 0); // FIX ME UNLESS - const auto element_count = inputs_data.size() / loco::size(dtype); + assert(inputs_data.size() % luci::size(dtype) == 0); // FIX ME UNLESS + const auto element_count = inputs_data.size() / luci::size(dtype); // random generator engine std::random_device device; diff --git a/compiler/enco-intf/exclude.me b/compiler/enco-intf/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/enco/exclude.me b/compiler/enco/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/enco/frontend/tflite/CMakeLists.txt b/compiler/enco/frontend/tflite/CMakeLists.txt index 995e66f81..c71bde7ee 100644 --- a/compiler/enco/frontend/tflite/CMakeLists.txt +++ b/compiler/enco/frontend/tflite/CMakeLists.txt @@ -1,4 +1,4 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) return() diff --git a/compiler/enco/test/tflite/CMakeLists.txt b/compiler/enco/test/tflite/CMakeLists.txt index 81d5ed2a2..8c831af33 100644 --- a/compiler/enco/test/tflite/CMakeLists.txt +++ b/compiler/enco/test/tflite/CMakeLists.txt @@ -18,6 +18,10 @@ endfunction(get_test_configuration) ### ### Prepare test(s) ### +if(NOT TARGET nnkit_tflite_backend) + return() +endif(NOT TARGET nnkit_tflite_backend) + if(NOT TARGET tflchef-file) return() endif(NOT TARGET tflchef-file) diff --git a/compiler/encodump/exclude.me b/compiler/encodump/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/exo/CMakeLists.txt b/compiler/exo/CMakeLists.txt index 645db714c..697c39dd9 100644 --- a/compiler/exo/CMakeLists.txt +++ b/compiler/exo/CMakeLists.txt @@ -1,4 +1,4 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) message(STATUS "Build exo: FALSE (missing FlatBuffers)") diff --git a/compiler/exo/exclude.me b/compiler/exo/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/fipe/exclude.me b/compiler/fipe/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/hermes-std/include/hermes/ConsoleReporter.h b/compiler/hermes-std/include/hermes/ConsoleReporter.h index c55e46a17..6e037dc3e 100644 --- a/compiler/hermes-std/include/hermes/ConsoleReporter.h +++ b/compiler/hermes-std/include/hermes/ConsoleReporter.h @@ -32,6 +32,7 @@ struct ConsoleReporter final : public hermes::Sink private: bool _is_colored = false; + bool _env_checked = false; }; } // namespace hermes diff --git a/compiler/hermes-std/src/ConsoleReporter.cpp b/compiler/hermes-std/src/ConsoleReporter.cpp index 524ed59d8..03f60ad09 100644 --- a/compiler/hermes-std/src/ConsoleReporter.cpp +++ b/compiler/hermes-std/src/ConsoleReporter.cpp @@ -42,12 +42,16 @@ static constexpr const char *kTermColorResetAllCode = "\033[0m"; void ConsoleReporter::notify(const hermes::Message *m) { - const char *env_color_p = std::getenv("ONE_HERMES_COLOR"); - if (env_color_p) + if (not _env_checked) { - auto env_color_str = std::string(env_color_p); - if ((env_color_str == "1") or (env_color_str == "ON")) - _is_colored = true; + const char *env_color_p = std::getenv("ONE_HERMES_COLOR"); + if (env_color_p) + { + auto env_color_str = std::string(env_color_p); + if ((env_color_str == "1") or (env_color_str == "ON")) + _is_colored = true; + } + _env_checked = true; } if (_is_colored) diff --git a/compiler/kuma/exclude.me b/compiler/kuma/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/loco/include/loco/IR/DataType.h b/compiler/loco/include/loco/IR/DataType.h index b89edf29e..bb42165f2 100644 --- a/compiler/loco/include/loco/IR/DataType.h +++ b/compiler/loco/include/loco/IR/DataType.h @@ -27,11 +27,13 @@ enum class DataType { Unknown, // Unknown type (serves as a default value) + U4, // 4-bit unsigned integer U8, // 8-bit unsigned integer U16, // 16-bit unsigned integer U32, // 32-bit unsigned integer U64, // 64-bit unsigned integer + S4, // 4-bit signed integer S8, // 8-bit signed integer S16, // 16-bit signed integer S32, // 32-bit signed integer diff --git a/compiler/loco/include/loco/IR/DataTypeTraits.h b/compiler/loco/include/loco/IR/DataTypeTraits.h index 6be46c3b3..6bb3e5733 100644 --- a/compiler/loco/include/loco/IR/DataTypeTraits.h +++ b/compiler/loco/include/loco/IR/DataTypeTraits.h @@ -35,12 +35,24 @@ template struct DataTypeImpl }; // TODO Support other enum values +template <> struct DataTypeImpl +{ + // Use C++ int8_t type for 4bit integer + using Type = int8_t; +}; + template <> struct DataTypeImpl { // Use C++ int8_t type for 8bit integer using Type = int8_t; }; +template <> struct DataTypeImpl +{ + // Use C++ uint8_t type for unsigned 4bit integer + using Type = uint8_t; +}; + template <> struct DataTypeImpl { // Use C++ uint8_t type for unsigned 8bit integer @@ -123,8 +135,12 @@ inline uint32_t size(DataType data_type) { switch (data_type) { + case DataType::S4: + throw std::runtime_error("S4 type is not supported by loco::size"); case DataType::S8: return sizeof(DataTypeImpl::Type); + case DataType::U4: + throw std::runtime_error("U4 type is not supported by loco::size"); case DataType::U8: return sizeof(DataTypeImpl::Type); case DataType::S16: diff --git a/compiler/loco/include/loco/IR/NodePool.h b/compiler/loco/include/loco/IR/NodePool.h index 4db4caae3..77823c5a5 100644 --- a/compiler/loco/include/loco/IR/NodePool.h +++ b/compiler/loco/include/loco/IR/NodePool.h @@ -34,7 +34,7 @@ public: ~NodePool(); public: - template Derived *create(Args &&... args) + template Derived *create(Args &&...args) { std::unique_ptr ptr{new Derived(std::forward(args)...)}; ptr->graph(_graph); diff --git a/compiler/loco/src/IR/MockupNode.h b/compiler/loco/src/IR/MockupNode.h index 16eaccf36..45a231d30 100644 --- a/compiler/loco/src/IR/MockupNode.h +++ b/compiler/loco/src/IR/MockupNode.h @@ -46,7 +46,7 @@ public: Node *arg(uint32_t N) const final { return _arg.node(); } void drop(void) final { _arg.node(nullptr); } - Node *in(void)const { return _arg.node(); } + Node *in(void) const { return _arg.node(); } void in(Node *node) { _arg.node(node); } private: diff --git a/compiler/loco/src/Service/GraphBuilder.h b/compiler/loco/src/Service/GraphBuilder.h index 74eed2af8..f3015205c 100644 --- a/compiler/loco/src/Service/GraphBuilder.h +++ b/compiler/loco/src/Service/GraphBuilder.h @@ -87,7 +87,7 @@ public: public: // "Layer" is in theory a subgraph builder. template - auto push(Args &&... args) + auto push(Args &&...args) -> decltype(static_cast(nullptr)->operator()(static_cast(nullptr))) { Layer layer{std::forward(args)...}; diff --git a/compiler/loco/src/Service/TypeInference.test.cpp b/compiler/loco/src/Service/TypeInference.test.cpp index 0d2cc8864..b3895b111 100644 --- a/compiler/loco/src/Service/TypeInference.test.cpp +++ b/compiler/loco/src/Service/TypeInference.test.cpp @@ -233,17 +233,23 @@ public: TEST(MultiDialectTypeInferenceRuleTest, test1) { - // Create a simple network : Pull - S8 - U8 - Push + // Create a simple network : Pull - S4 - S8 - U4 - U8 - Push auto g = loco::make_graph(); auto pull_node = g->nodes()->create(); pull_node->dtype(loco::DataType::FLOAT32); + auto s4_node = g->nodes()->create>(); + s4_node->input(pull_node); + auto s8_node = g->nodes()->create>(); - s8_node->input(pull_node); + s8_node->input(s4_node); + + auto u4_node = g->nodes()->create>(); + u4_node->input(s8_node); auto u8_node = g->nodes()->create>(); - u8_node->input(s8_node); + u8_node->input(u4_node); auto push_node = g->nodes()->create(); push_node->from(u8_node); @@ -257,26 +263,38 @@ TEST(MultiDialectTypeInferenceRuleTest, test1) loco::link(graph_output, push_node); // initially they don't have type info + ASSERT_FALSE(loco::dtype_known(s4_node)); ASSERT_FALSE(loco::dtype_known(s8_node)); + ASSERT_FALSE(loco::dtype_known(u4_node)); ASSERT_FALSE(loco::dtype_known(u8_node)); // Run Type Inference TestTypeInferenceRule u8_rule; TestTypeInferenceRule s8_rule; + TestTypeInferenceRule s4_rule; + TestTypeInferenceRule u4_rule; loco::CanonicalTypeInferenceRule canon_rule; loco::MultiDialectTypeInferenceRule rules; rules.bind(TestDialect::get(), &s8_rule) .bind(TestDialect::get(), &u8_rule) + .bind(TestDialect::get(), &s4_rule) + .bind(TestDialect::get(), &u4_rule) .bind(loco::CanonicalDialect::get(), &canon_rule); loco::apply(&rules).to(g.get()); // Verify! + ASSERT_TRUE(loco::dtype_known(s4_node)); + ASSERT_EQ(loco::DataType::S4, loco::dtype_get(s4_node)); + ASSERT_TRUE(loco::dtype_known(s8_node)); ASSERT_EQ(loco::DataType::S8, loco::dtype_get(s8_node)); + ASSERT_TRUE(loco::dtype_known(u4_node)); + ASSERT_EQ(loco::DataType::U4, loco::dtype_get(u4_node)); + ASSERT_TRUE(loco::dtype_known(u8_node)); ASSERT_EQ(loco::DataType::U8, loco::dtype_get(u8_node)); } diff --git a/compiler/locop/src/FormattedGraph.cpp b/compiler/locop/src/FormattedGraph.cpp index 94bfbd2f8..604037183 100644 --- a/compiler/locop/src/FormattedGraph.cpp +++ b/compiler/locop/src/FormattedGraph.cpp @@ -41,6 +41,8 @@ std::string str(const loco::DataType &dtype) case loco::DataType::Unknown: return "Unknown"; + case loco::DataType::U4: + return "U4"; case loco::DataType::U8: return "U8"; case loco::DataType::U16: @@ -50,6 +52,8 @@ std::string str(const loco::DataType &dtype) case loco::DataType::U64: return "U64"; + case loco::DataType::S4: + return "S4"; case loco::DataType::S8: return "S8"; case loco::DataType::S16: diff --git a/compiler/luci-compute/CMakeLists.txt b/compiler/luci-compute/CMakeLists.txt index b7ddb4406..fdfefa18c 100644 --- a/compiler/luci-compute/CMakeLists.txt +++ b/compiler/luci-compute/CMakeLists.txt @@ -1,23 +1,59 @@ -nnas_find_package(TensorFlowSource EXACT 2.8.0 QUIET) -nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.8.0 QUIET) -nnas_find_package(TensorFlowRuySource EXACT 2.8.0 QUIET) +nnas_find_package(TensorFlowSource EXACT 2.16.1 QUIET) +nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.16.1 QUIET) +nnas_find_package(TensorFlowRuySource EXACT 2.16.1 QUIET) +nnas_find_package(NEON2SSESource QUIET) if(NOT TensorFlowSource_FOUND) - message(STATUS "Build luci-compute: FAILED (missing TensorFlowSource 2.8.0)") + message(STATUS "Build luci-compute: FAILED (missing TensorFlowSource 2.16.1)") return() endif(NOT TensorFlowSource_FOUND) if(NOT TensorFlowGEMMLowpSource_FOUND) - message(STATUS "Build luci-compute: FAILED (missing TensorFlowGEMMLowpSource 2.8.0)") + message(STATUS "Build luci-compute: FAILED (missing TensorFlowGEMMLowpSource 2.16.1)") return() endif(NOT TensorFlowGEMMLowpSource_FOUND) if(NOT TensorFlowRuySource_FOUND) - message(STATUS "Build luci-compute: FAILED (missing TensorFlowRuySource 2.8.0)") + message(STATUS "Build luci-compute: FAILED (missing TensorFlowRuySource 2.16.1)") return() endif(NOT TensorFlowRuySource_FOUND) -add_library(luci_compute INTERFACE) -target_include_directories(luci_compute SYSTEM INTERFACE "${TensorFlowSource_DIR}") -target_include_directories(luci_compute SYSTEM INTERFACE "${TensorFlowGEMMLowpSource_DIR}") -target_include_directories(luci_compute SYSTEM INTERFACE "${TensorFlowRuySource_DIR}") +if(NOT NEON2SSESource_FOUND) + message(STATUS "Build luci-compute: FAILED (missing NEON2SSESource)") + return() +endif(NOT NEON2SSESource_FOUND) + +file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${TESTS}) + +# Require for tflite::RuntimeShape +list(APPEND SOURCES "${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/runtime_shape.cc") + +if (NOT LUCI_LIBRARY_TYPE) + set(LUCI_LIBRARY_TYPE "SHARED") +endif(NOT LUCI_LIBRARY_TYPE) + +add_library(luci_compute ${LUCI_LIBRARY_TYPE} ${SOURCES}) +target_include_directories(luci_compute PUBLIC include) +target_include_directories(luci_compute PRIVATE src) +target_include_directories(luci_compute SYSTEM PRIVATE "${TensorFlowSource_DIR}") +target_include_directories(luci_compute SYSTEM PRIVATE "${TensorFlowGEMMLowpSource_DIR}") +target_include_directories(luci_compute SYSTEM PRIVATE "${TensorFlowRuySource_DIR}") +target_include_directories(luci_compute SYSTEM PRIVATE "${NEON2SSESource_DIR}") +target_link_libraries(luci_compute PUBLIC loco) +install(TARGETS luci_compute DESTINATION lib) + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(luci_compute_test ${TESTS}) +target_include_directories(luci_compute_test PRIVATE src) +target_link_libraries(luci_compute_test luci_compute) +target_include_directories(luci_compute_test SYSTEM PRIVATE "${TensorFlowSource_DIR}") +target_include_directories(luci_compute_test SYSTEM PRIVATE "${TensorFlowGEMMLowpSource_DIR}") +target_include_directories(luci_compute_test SYSTEM PRIVATE "${TensorFlowRuySource_DIR}") +target_include_directories(luci_compute_test SYSTEM PRIVATE "${NEON2SSESource_DIR}") diff --git a/compiler/luci-compute/include/luci_compute/DepthwiseConv2D.h b/compiler/luci-compute/include/luci_compute/DepthwiseConv2D.h new file mode 100644 index 000000000..4f76485e0 --- /dev/null +++ b/compiler/luci-compute/include/luci_compute/DepthwiseConv2D.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_COMPUTE_DEPTHWISE_CONV2D_H__ +#define __LUCI_COMPUTE_DEPTHWISE_CONV2D_H__ + +#include "Types.h" + +#include + +namespace luci +{ +namespace compute +{ + +// TODO extract some common for multiple Ops +class DepthwiseConv2D +{ +public: + DepthwiseConv2D() = default; + +public: + DepthwiseParams ¶ms(void) { return _params; } + + void input(const loco::TensorShape &shape, const float *data) + { + _input_shape = shape; + _input_data = data; + } + + void filter(const loco::TensorShape &shape, const float *data) + { + _filter_shape = shape; + _filter_data = data; + } + + void bias(const loco::TensorShape &shape, const float *data) + { + _bias_shape = shape; + _bias_data = data; + } + + void fused_act_func(FusedActFunc func) { _fused_act_func = func; }; + + void output(float *data) { _output_data = data; } + +public: + bool prepare(void); + const loco::TensorShape &output_shape(void) const { return _output_shape; } + void compute(void); + +private: + // param to pass to compute kernel + DepthwiseParams _params = {}; + // shape and data for inputs + loco::TensorShape _input_shape; + loco::TensorShape _filter_shape; + loco::TensorShape _bias_shape; + const float *_input_data = nullptr; + const float *_filter_data = nullptr; + const float *_bias_data = nullptr; + FusedActFunc _fused_act_func = FusedActFunc::UNDEFINED; + + // compute results + loco::TensorShape _output_shape; + float *_output_data = nullptr; +}; + +} // namespace compute +} // namespace luci + +#endif // __LUCI_COMPUTE_DEPTHWISE_CONV2D_H__ diff --git a/compiler/luci-compute/include/luci_compute/FullyConnected.h b/compiler/luci-compute/include/luci_compute/FullyConnected.h new file mode 100644 index 000000000..e97264a05 --- /dev/null +++ b/compiler/luci-compute/include/luci_compute/FullyConnected.h @@ -0,0 +1,89 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_COMPUTE_FULLY_CONNECTED_H__ +#define __LUCI_COMPUTE_FULLY_CONNECTED_H__ + +#include "Types.h" + +#include + +namespace luci +{ +namespace compute +{ + +// TODO extract some common for multiple Ops +class FullyConnected +{ +public: + FullyConnected() = default; + +public: + FullyConnectedParams ¶ms(void) { return _params; } + + bool keep_num_dims(void) const { return _keep_num_dims; } + void keep_num_dims(bool knd) { _keep_num_dims = knd; } + + void input(const loco::TensorShape &shape, const float *data) + { + _input_shape = shape; + _input_data = data; + } + + void weights(const loco::TensorShape &shape, const float *data) + { + _weights_shape = shape; + _weights_data = data; + } + + void bias(const loco::TensorShape &shape, const float *data) + { + _bias_shape = shape; + _bias_data = data; + } + + void fused_act_func(FusedActFunc func) { _fused_act_func = func; }; + + void output(float *data) { _output_data = data; } + +public: + bool prepare(void); + const loco::TensorShape &output_shape(void) const { return _output_shape; } + void compute(void); + +private: + // param to pass to compute kernel + FullyConnectedParams _params = {}; + // new param from tflite version 5 + bool _keep_num_dims = false; + // shape and data for inputs + loco::TensorShape _input_shape; + loco::TensorShape _weights_shape; + loco::TensorShape _bias_shape; + const float *_input_data = nullptr; + const float *_weights_data = nullptr; + const float *_bias_data = nullptr; + FusedActFunc _fused_act_func = FusedActFunc::UNDEFINED; + + // compute results + loco::TensorShape _output_shape; + float *_output_data = nullptr; +}; + +} // namespace compute +} // namespace luci + +#endif // __LUCI_COMPUTE_FULLY_CONNECTED_H__ diff --git a/compiler/luci-compute/include/luci_compute/Types.h b/compiler/luci-compute/include/luci_compute/Types.h new file mode 100644 index 000000000..7f643064e --- /dev/null +++ b/compiler/luci-compute/include/luci_compute/Types.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// portion copied from TF2.8.0 tensorflow/lite/kernels/internal/types.h + +#ifndef __LUCI_COMPUTE_TYPES_H__ +#define __LUCI_COMPUTE_TYPES_H__ + +#include +#include +#include + +namespace luci +{ +namespace compute +{ + +// from tflite as-is +enum class PaddingType : uint8_t +{ + kNone, + kSame, + kValid +}; + +// from tflite as-is +struct PaddingValues +{ + int16_t width; + int16_t height; + // offset is used for calculating "remaining" padding, for example, `width` + // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is + // 1 + 1 = 2. + int16_t width_offset; + // Same as width_offset except it's over the height dimension. + int16_t height_offset; +}; + +// from tflite as-is +struct DepthwiseParams +{ + PaddingType padding_type; + PaddingValues padding_values; + int16_t stride_width; + int16_t stride_height; + int16_t dilation_width_factor; + int16_t dilation_height_factor; + int16_t depth_multiplier; + // uint8_t inference params. + // TODO(b/65838351): Use smaller types if appropriate. + int32_t input_offset; + int32_t weights_offset; + int32_t output_offset; + int32_t output_multiplier; + int output_shift; + // uint8_t, etc, activation params. + int32_t quantized_activation_min; + int32_t quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; + const int32_t *output_multiplier_per_channel; + const int32_t *output_shift_per_channel; +}; + +// from tflite, with tidy long comments +enum class FullyConnectedWeightsFormat : uint8_t +{ + kDefault, + kShuffled4x16Int8, +}; + +// from tflite as-is +struct FullyConnectedParams +{ + // uint8_t inference params. + // TODO(b/65838351): Use smaller types if appropriate. + int32_t input_offset; + int32_t weights_offset; + int32_t output_offset; + int32_t output_multiplier; + int output_shift; + // uint8_t, etc, activation params. + int32_t quantized_activation_min; + int32_t quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; + // Mark the operands as cacheable if they are unchanging, e.g. weights. + bool lhs_cacheable; + bool rhs_cacheable; + FullyConnectedWeightsFormat weights_format; +}; + +// from luci as-is +enum class FusedActFunc +{ + UNDEFINED, // This is not defined by TFLite or Circle. This was added to + // prevent programming error. + NONE, + RELU, + RELU_N1_TO_1, + RELU6, + TANH, + SIGN_BIT +}; + +} // namespace compute +} // namespace luci + +#endif // __LUCI_COMPUTE_TYPES_H__ diff --git a/compiler/luci-compute/requires.cmake b/compiler/luci-compute/requires.cmake new file mode 100644 index 000000000..44f6870da --- /dev/null +++ b/compiler/luci-compute/requires.cmake @@ -0,0 +1 @@ +require("loco") diff --git a/compiler/luci-compute/src/ConvertTypes.cpp b/compiler/luci-compute/src/ConvertTypes.cpp new file mode 100644 index 000000000..9e4ae1749 --- /dev/null +++ b/compiler/luci-compute/src/ConvertTypes.cpp @@ -0,0 +1,67 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvertTypes.h" + +#include +#include + +namespace luci +{ +namespace compute +{ + +tflite::RuntimeShape tflite_shape(const loco::TensorShape &shape) +{ + tflite::RuntimeShape runtime_shape(shape.rank()); + for (uint32_t i = 0; i < shape.rank(); ++i) + { + if (not shape.dim(i).known()) + throw std::runtime_error("luci-comp tflite_shape shape unknown."); + runtime_shape.SetDim(i, shape.dim(i).value()); + } + return runtime_shape; +} + +tflite::PaddingType tflite_padding(const PaddingType type) +{ + switch (type) + { + case PaddingType::kSame: + return tflite::PaddingType::kSame; + case PaddingType::kValid: + return tflite::PaddingType::kValid; + default: + break; + } + throw std::runtime_error("luci-comp tflite_padding unsupported type."); +} + +tflite::FullyConnectedWeightsFormat tflite_weights_format(const FullyConnectedWeightsFormat type) +{ + switch (type) + { + case FullyConnectedWeightsFormat::kDefault: + return tflite::FullyConnectedWeightsFormat::kDefault; + case FullyConnectedWeightsFormat::kShuffled4x16Int8: + return tflite::FullyConnectedWeightsFormat::kShuffled4x16Int8; + default: + break; + } + throw std::runtime_error("luci-comp tflite_weights_format unsupported type."); +} + +} // namespace compute +} // namespace luci diff --git a/compiler/luci-compute/src/ConvertTypes.h b/compiler/luci-compute/src/ConvertTypes.h new file mode 100644 index 000000000..43e4fe1aa --- /dev/null +++ b/compiler/luci-compute/src/ConvertTypes.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_COMPUTE_CONVERT_TYPES_H__ +#define __LUCI_COMPUTE_CONVERT_TYPES_H__ + +#include "luci_compute/Types.h" + +#include + +#include + +namespace luci +{ +namespace compute +{ + +tflite::RuntimeShape tflite_shape(const loco::TensorShape &shape); + +tflite::PaddingType tflite_padding(const PaddingType type); + +tflite::FullyConnectedWeightsFormat tflite_weights_format(const FullyConnectedWeightsFormat type); + +} // namespace compute +} // namespace luci + +#endif // __LUCI_COMPUTE_CONVERT_TYPES_H__ diff --git a/compiler/luci-compute/src/ConvertTypes.test.cpp b/compiler/luci-compute/src/ConvertTypes.test.cpp new file mode 100644 index 000000000..f19af5ac9 --- /dev/null +++ b/compiler/luci-compute/src/ConvertTypes.test.cpp @@ -0,0 +1,69 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvertTypes.h" + +#include + +TEST(ConvertTypes, tflite_shape) +{ + loco::TensorShape shape; + shape.rank(2); + shape.dim(0) = 1; + shape.dim(1) = 2; + + auto tflite_shape = luci::compute::tflite_shape(shape); + EXPECT_EQ(tflite_shape.DimensionsCount(), 2); + EXPECT_EQ(tflite_shape.Dims(0), 1); + EXPECT_EQ(tflite_shape.Dims(1), 2); +} + +TEST(ConvertTypes, tflite_shape_NEG) +{ + loco::TensorShape shape; + shape.rank(2); + shape.dim(0) = 1; + + ASSERT_ANY_THROW(luci::compute::tflite_shape(shape)); +} + +TEST(ConvertTypes, tflite_padding) +{ + auto pts = luci::compute::PaddingType::kSame; + ASSERT_EQ(luci::compute::tflite_padding(pts), tflite::PaddingType::kSame); + auto ptv = luci::compute::PaddingType::kValid; + ASSERT_EQ(luci::compute::tflite_padding(ptv), tflite::PaddingType::kValid); +} + +TEST(ConvertTypes, tflite_padding_NEG) +{ + auto pt = luci::compute::PaddingType::kNone; + ASSERT_ANY_THROW(luci::compute::tflite_padding(pt)); +} + +TEST(ConvertTypes, tflite_weights_format) +{ + auto fwf = luci::compute::FullyConnectedWeightsFormat::kDefault; + ASSERT_EQ(luci::compute::tflite_weights_format(fwf), + tflite::FullyConnectedWeightsFormat::kDefault); +} + +TEST(ConvertTypes, tflite_weights_format_NEG) +{ + // force convert with invalid value as future unhandled value + luci::compute::FullyConnectedWeightsFormat fwf = + static_cast(250); + ASSERT_ANY_THROW(luci::compute::tflite_weights_format(fwf)); +} diff --git a/compiler/luci-compute/src/ConvertValues.cpp b/compiler/luci-compute/src/ConvertValues.cpp new file mode 100644 index 000000000..5a187254f --- /dev/null +++ b/compiler/luci-compute/src/ConvertValues.cpp @@ -0,0 +1,53 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvertValues.h" + +#include // std::numeric_limits +#include + +namespace luci +{ +namespace compute +{ + +void get_act_minmax(const FusedActFunc act, float &act_min, float &act_max) +{ + switch (act) + { + case FusedActFunc::NONE: + case FusedActFunc::TANH: + act_min = std::numeric_limits::lowest(); + act_max = std::numeric_limits::max(); + break; + case FusedActFunc::RELU: + act_min = 0; + act_max = std::numeric_limits::max(); + break; + case FusedActFunc::RELU_N1_TO_1: + act_min = -1; + act_max = 1; + break; + case FusedActFunc::RELU6: + act_min = 0; + act_max = 6; + break; + default: + throw std::runtime_error("luci-comp get_act_minmax unsupported type."); + } +} + +} // namespace compute +} // namespace luci diff --git a/compiler/luci-compute/src/ConvertValues.h b/compiler/luci-compute/src/ConvertValues.h new file mode 100644 index 000000000..506752d2a --- /dev/null +++ b/compiler/luci-compute/src/ConvertValues.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_COMPUTE_CONVERT_VALUES_H__ +#define __LUCI_COMPUTE_CONVERT_VALUES_H__ + +#include "luci_compute/Types.h" + +namespace luci +{ +namespace compute +{ + +void get_act_minmax(const FusedActFunc act, float &act_min, float &act_max); + +} // namespace compute +} // namespace luci + +#endif // __LUCI_COMPUTE_CONVERT_VALUES_H__ diff --git a/compiler/luci-compute/src/ConvertValues.test.cpp b/compiler/luci-compute/src/ConvertValues.test.cpp new file mode 100644 index 000000000..ed68fdacc --- /dev/null +++ b/compiler/luci-compute/src/ConvertValues.test.cpp @@ -0,0 +1,35 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvertValues.h" + +#include + +TEST(ConvertValues, get_act_minmax) +{ + auto func = luci::compute::FusedActFunc::RELU6; + float act_min, act_max; + ASSERT_NO_THROW(luci::compute::get_act_minmax(func, act_min, act_max)); + EXPECT_EQ(act_min, 0); + EXPECT_EQ(act_max, 6); +} + +TEST(ConvertValues, get_act_minmax_NEG) +{ + // force convert with invalid value as future unhandled value + luci::compute::FusedActFunc func = static_cast(250); + float act_min, act_max; + ASSERT_ANY_THROW(luci::compute::get_act_minmax(func, act_min, act_max)); +} diff --git a/compiler/luci-compute/src/DepthwiseConv2D.cpp b/compiler/luci-compute/src/DepthwiseConv2D.cpp new file mode 100644 index 000000000..739fc9c1c --- /dev/null +++ b/compiler/luci-compute/src/DepthwiseConv2D.cpp @@ -0,0 +1,171 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci_compute/Types.h" +#include "luci_compute/DepthwiseConv2D.h" + +#include "ConvertTypes.h" +#include "ConvertValues.h" + +#include + +#include +#include + +namespace luci +{ +namespace compute +{ + +namespace +{ + +int32_t compute_output(PaddingType padding, int32_t in_size, int32_t filter_size, int32_t stride, + int32_t dilation_rate) +{ + assert(in_size > 0); + assert(filter_size > 0); + assert(stride > 0); + assert(dilation_rate > 0); + + auto const effective_filter_size = (filter_size - 1) * dilation_rate + 1; + switch (padding) + { + case PaddingType::kSame: + return (in_size + stride - 1) / stride; + + case PaddingType::kValid: + return (in_size + stride - effective_filter_size) / stride; + + default: + return -1; + } + return -1; +} + +int16_t compute_padding(int32_t out_size, int32_t in_size, int32_t filter_size, int32_t stride, + int32_t dilation_rate) +{ + assert(out_size > 0); + assert(in_size > 0); + assert(filter_size > 0); + assert(stride > 0); + assert(dilation_rate > 0); + + auto const effective_filter_size = (filter_size - 1) * dilation_rate + 1; + auto const padding = ((out_size - 1) * stride + effective_filter_size - in_size) / 2; + assert(padding < INT16_MAX); + return padding > 0 ? static_cast(padding) : 0; +} + +} // namespace + +bool DepthwiseConv2D::prepare(void) +{ + // TODO support other ranks if necessary + if (_input_shape.rank() != 4 || _filter_shape.rank() != 4) + return false; + // if bias exist, check if rank is 1 + if (_bias_data && _bias_shape.rank() != 1) + return false; + + auto const input_batches = _input_shape.dim(0).value(); + auto const input_height = _input_shape.dim(1).value(); + auto const input_width = _input_shape.dim(2).value(); + auto const input_depth = _input_shape.dim(3).value(); + + auto const filter_height = _filter_shape.dim(1).value(); + auto const filter_width = _filter_shape.dim(2).value(); + auto const filter_channels_out = _filter_shape.dim(3).value(); + + if (filter_channels_out % input_depth != 0) + return false; // wrong input/output depth ratio + + if (_params.depth_multiplier != static_cast(filter_channels_out / input_depth)) + return false; // wrong depth multiplier value + + if (_bias_shape.dim(0).value() != filter_channels_out) + return false; // unsupported bias value + + auto output_height = compute_output(_params.padding_type, input_height, filter_height, + _params.stride_height, _params.dilation_height_factor); + if (output_height < 0) + return false; + + auto output_width = compute_output(_params.padding_type, input_width, filter_width, + _params.stride_width, _params.dilation_width_factor); + if (output_width < 0) + return false; + + get_act_minmax(_fused_act_func, _params.float_activation_min, _params.float_activation_max); + + _output_shape.rank(4); + _output_shape.dim(0) = input_batches; + _output_shape.dim(1) = output_height; + _output_shape.dim(2) = output_width; + _output_shape.dim(3) = filter_channels_out; + + _params.padding_values.height = + compute_padding(output_height, input_height, filter_height, _params.stride_height, + _params.dilation_height_factor); + _params.padding_values.width = compute_padding( + output_width, input_width, filter_width, _params.stride_width, _params.dilation_width_factor); + + return true; +} + +void DepthwiseConv2D::compute(void) +{ + assert(_input_data != nullptr); + assert(_filter_data != nullptr); + // NOTE _bias_shape can be nullptr + assert(_output_data != nullptr); + + // NOTE if this fails, structure may have changed + static_assert(sizeof(compute::DepthwiseParams) == sizeof(tflite::DepthwiseParams)); + + tflite::DepthwiseParams params; + + // clang-format off + params.padding_type = tflite_padding(_params.padding_type); + params.padding_values.width = _params.padding_values.width; + params.padding_values.height = _params.padding_values.height; + params.padding_values.width_offset = _params.padding_values.width_offset; + params.padding_values.height_offset = _params.padding_values.height_offset; + params.stride_width = _params.stride_width; + params.stride_height = _params.stride_height; + params.dilation_width_factor = _params.dilation_width_factor; + params.dilation_height_factor = _params.dilation_height_factor; + params.depth_multiplier = _params.depth_multiplier; + params.input_offset = _params.input_offset; + params.weights_offset = _params.weights_offset; + params.output_offset = _params.output_offset; + params.output_multiplier = _params.output_multiplier; + params.output_shift = _params.output_shift; + params.quantized_activation_min = _params.quantized_activation_min; + params.quantized_activation_max = _params.quantized_activation_max; + params.float_activation_min = _params.float_activation_min; + params.float_activation_max = _params.float_activation_max; + params.output_multiplier_per_channel = _params.output_multiplier_per_channel; + params.output_shift_per_channel = _params.output_shift_per_channel; + // clang-format on + + tflite::reference_ops::DepthwiseConv( + params, tflite_shape(_input_shape), _input_data, tflite_shape(_filter_shape), _filter_data, + tflite_shape(_bias_shape), _bias_data, tflite_shape(_output_shape), _output_data); +} + +} // namespace compute +} // namespace luci diff --git a/compiler/luci-compute/src/DepthwiseConv2D.test.cpp b/compiler/luci-compute/src/DepthwiseConv2D.test.cpp new file mode 100644 index 000000000..adf2503a0 --- /dev/null +++ b/compiler/luci-compute/src/DepthwiseConv2D.test.cpp @@ -0,0 +1,143 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvertValues.h" + +#include + +#include +#include + +class DepthwiseConv2DTest : public ::testing::Test +{ +protected: + loco::TensorShape tensor_shape(const std::initializer_list shape) + { + loco::TensorShape tensor_shape; + tensor_shape.rank(shape.size()); + uint32_t i = 0; + for (auto it = shape.begin(); it != shape.end(); ++it, ++i) + tensor_shape.dim(i) = *it; + return tensor_shape; + } + + std::vector vector_shape(const loco::TensorShape &tensor_shape) + { + std::vector shape; + for (uint32_t r = 0; r < tensor_shape.rank(); ++r) + shape.push_back(tensor_shape.dim(r).value()); + return shape; + } + +protected: + luci::compute::DepthwiseConv2D _dwconv2d; +}; + +TEST_F(DepthwiseConv2DTest, prepare_compute) +{ + auto input_shape = tensor_shape({1, 4, 2, 2}); + std::vector input_data{ + 1, 2, 7, 8, // + 3, 4, 9, 10, // + 5, 6, 11, 12, // + 13, 14, 15, 16, // + }; + auto filter_shape = tensor_shape({1, 2, 2, 4}); + std::vector filter_data{ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }; + auto bias_shape = tensor_shape({4}); + std::vector bias_data{1, 2, 3, 4}; + + auto ¶ms = _dwconv2d.params(); + params.padding_type = luci::compute::PaddingType::kValid; + params.stride_height = 2; + params.stride_width = 1; + params.dilation_height_factor = 1; + params.dilation_width_factor = 1; + params.depth_multiplier = 2; + + _dwconv2d.input(input_shape, input_data.data()); + _dwconv2d.filter(filter_shape, filter_data.data()); + _dwconv2d.bias(bias_shape, bias_data.data()); + _dwconv2d.fused_act_func(luci::compute::FusedActFunc::RELU); + + EXPECT_TRUE(_dwconv2d.prepare()); + + auto output_shape = _dwconv2d.output_shape(); + auto output_count = loco::element_count(&output_shape); + std::vector output_data_vector; + output_data_vector.resize(output_count); + + _dwconv2d.output(output_data_vector.data()); + + ASSERT_NO_THROW(_dwconv2d.compute()); + + std::vector ref_output_data{ + 71, 0, 99, 0, // + 167, 0, 227, 28, // + }; + std::vector ref_output_shape{1, 2, 1, 4}; + std::vector output_shape_vector = vector_shape(output_shape); + + EXPECT_THAT(output_data_vector, ref_output_data); + EXPECT_THAT(output_shape_vector, ref_output_shape); +} + +TEST_F(DepthwiseConv2DTest, prepare_invalid_rank_NEG) +{ + auto input_shape = tensor_shape({2}); // expect rank 4 + std::vector input_data{1, 2}; + auto filter_shape = tensor_shape({2}); + std::vector filter_data{1, 2}; + auto bias_shape = tensor_shape({1}); + std::vector bias_data{1}; + + _dwconv2d.input(input_shape, input_data.data()); + _dwconv2d.filter(filter_shape, filter_data.data()); + _dwconv2d.bias(bias_shape, bias_data.data()); + _dwconv2d.fused_act_func(luci::compute::FusedActFunc::RELU); + + EXPECT_FALSE(_dwconv2d.prepare()); +} + +TEST_F(DepthwiseConv2DTest, prepare_invalid_shape_NEG) +{ + auto input_shape = tensor_shape({1, 4, 2, 2}); + std::vector input_data{ + 1, 2, 7, 8, // + 3, 4, 9, 10, // + 5, 6, 11, 12, // + 13, 14, 15, 16, // + }; + auto filter_shape = tensor_shape({1, 2, 2, 3}); // expect ,,, 4 + std::vector filter_data{ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + }; + auto bias_shape = tensor_shape({4}); + std::vector bias_data{1, 2, 3, 4}; + + _dwconv2d.input(input_shape, input_data.data()); + _dwconv2d.filter(filter_shape, filter_data.data()); + _dwconv2d.bias(bias_shape, bias_data.data()); + _dwconv2d.fused_act_func(luci::compute::FusedActFunc::RELU); + + EXPECT_FALSE(_dwconv2d.prepare()); +} diff --git a/compiler/luci-compute/src/FullyConnected.cpp b/compiler/luci-compute/src/FullyConnected.cpp new file mode 100644 index 000000000..112c943f1 --- /dev/null +++ b/compiler/luci-compute/src/FullyConnected.cpp @@ -0,0 +1,109 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci_compute/Types.h" +#include "luci_compute/FullyConnected.h" + +#include "ConvertTypes.h" +#include "ConvertValues.h" + +#include + +#include +#include + +namespace luci +{ +namespace compute +{ + +namespace +{ + +} // namespace + +bool FullyConnected::prepare(void) +{ + if (_input_shape.rank() < 1 || _weights_shape.rank() != 2) + return false; + + auto const input_elems = element_count(&_input_shape); + auto const weights_height = _weights_shape.dim(0).value(); + auto const weights_width = _weights_shape.dim(1).value(); + if (weights_height == 0 || weights_width == 0) + return false; + if (input_elems % weights_width != 0) + return false; + auto const batch_size = input_elems / weights_width; + auto const num_units = weights_height; + if (_bias_data) + { + if (element_count(&_bias_shape) != num_units) + return false; + } + + get_act_minmax(_fused_act_func, _params.float_activation_min, _params.float_activation_max); + + if (_keep_num_dims) + { + _output_shape.rank(_input_shape.rank()); + for (uint32_t i = 0; i < _input_shape.rank(); i++) + _output_shape.dim(i) = _input_shape.dim(i); + _output_shape.dim(_input_shape.rank() - 1) = num_units; + } + else + { + _output_shape.rank(2); + _output_shape.dim(0) = batch_size; + _output_shape.dim(1) = num_units; + } + + return true; +} + +void FullyConnected::compute(void) +{ + assert(_input_data != nullptr); + assert(_weights_data != nullptr); + // NOTE _bias_shape can be nullptr + assert(_output_data != nullptr); + + // NOTE if this fails, structure may have changed + static_assert(sizeof(compute::FullyConnectedParams) == sizeof(tflite::FullyConnectedParams)); + + tflite::FullyConnectedParams params; + + // clang-format off + params.input_offset = _params.input_offset; + params.weights_offset = _params.weights_offset; + params.output_offset = _params.output_offset; + params.output_multiplier = _params.output_multiplier; + params.output_shift = _params.output_shift; + params.quantized_activation_min = _params.quantized_activation_min; + params.quantized_activation_max = _params.quantized_activation_max; + params.float_activation_min = _params.float_activation_min; + params.float_activation_max = _params.float_activation_max; + params.lhs_cacheable = _params.lhs_cacheable; + params.rhs_cacheable = _params.rhs_cacheable; + params.weights_format = tflite_weights_format(_params.weights_format); + // clang-format on + + tflite::reference_ops::FullyConnected( + params, tflite_shape(_input_shape), _input_data, tflite_shape(_weights_shape), _weights_data, + tflite_shape(_bias_shape), _bias_data, tflite_shape(_output_shape), _output_data); +} + +} // namespace compute +} // namespace luci diff --git a/compiler/luci-compute/src/FullyConnected.test.cpp b/compiler/luci-compute/src/FullyConnected.test.cpp new file mode 100644 index 000000000..3e35288b8 --- /dev/null +++ b/compiler/luci-compute/src/FullyConnected.test.cpp @@ -0,0 +1,135 @@ +/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvertValues.h" + +#include + +#include +#include + +class FullyConnectedTest : public ::testing::Test +{ +protected: + loco::TensorShape tensor_shape(const std::initializer_list shape) + { + loco::TensorShape tensor_shape; + tensor_shape.rank(shape.size()); + uint32_t i = 0; + for (auto it = shape.begin(); it != shape.end(); ++it, ++i) + tensor_shape.dim(i) = *it; + return tensor_shape; + } + + std::vector vector_shape(const loco::TensorShape &tensor_shape) + { + std::vector shape; + for (uint32_t r = 0; r < tensor_shape.rank(); ++r) + shape.push_back(tensor_shape.dim(r).value()); + return shape; + } + +protected: + luci::compute::FullyConnected _fc; +}; + +TEST_F(FullyConnectedTest, prepare_compute) +{ + auto input_shape = tensor_shape({3, 2, 2, 1}); + std::vector input_data{ + -3, -5, 5, 4, // + 9, -2, -3, -2, // + -4, 9, -8, 1, // + }; + auto weights_shape = tensor_shape({3, 6}); + std::vector weights_data{ + -3, -7, 4, -4, -6, 4, // + 3, 5, 2, 3, -3, -8, // + -3, 7, 4, 9, 0, -5, // + }; + auto bias_shape = tensor_shape({3}); + std::vector bias_data{-1, -5, -8}; + + auto ¶ms = _fc.params(); + params.weights_format = luci::compute::FullyConnectedWeightsFormat::kDefault; + + _fc.input(input_shape, input_data.data()); + _fc.weights(weights_shape, weights_data.data()); + _fc.bias(bias_shape, bias_data.data()); + _fc.fused_act_func(luci::compute::FusedActFunc::RELU); + + EXPECT_TRUE(_fc.prepare()); + + auto output_shape = _fc.output_shape(); + auto output_count = loco::element_count(&output_shape); + std::vector output_data_vector; + output_data_vector.resize(output_count); + + _fc.output(output_data_vector.data()); + + ASSERT_NO_THROW(_fc.compute()); + + std::vector ref_output_data{ + 0, 0, 32, // + 22, 11, 47, // + }; + std::vector ref_output_shape{2, 3}; + std::vector output_shape_vector = vector_shape(output_shape); + + EXPECT_THAT(output_data_vector, ref_output_data); + EXPECT_THAT(output_shape_vector, ref_output_shape); +} + +TEST_F(FullyConnectedTest, prepare_invalid_rank_NEG) +{ + auto input_shape = tensor_shape({3}); + std::vector input_data{-3, -5, 5}; + auto weights_shape = tensor_shape({3}); // expect rank 2 + std::vector weights_data{-3, -7, 4}; + auto bias_shape = tensor_shape({3}); + std::vector bias_data{-1, -5, -8}; + + _fc.input(input_shape, input_data.data()); + _fc.weights(weights_shape, weights_data.data()); + _fc.bias(bias_shape, bias_data.data()); + _fc.fused_act_func(luci::compute::FusedActFunc::RELU); + + EXPECT_FALSE(_fc.prepare()); +} + +TEST_F(FullyConnectedTest, prepare_invalid_shape_NEG) +{ + auto input_shape = tensor_shape({3, 2, 2, 1}); + std::vector input_data{ + -3, -5, 5, 4, // + 9, -2, -3, -2, // + -4, 9, -8, 1, // + }; + auto weights_shape = tensor_shape({3, 5}); // expect 3, 6 + std::vector weights_data{ + -3, -7, 4, -4, -6, // + 3, 5, 2, 3, -3, // + -3, 7, 4, 9, 0, // + }; + auto bias_shape = tensor_shape({3}); + std::vector bias_data{-1, -5, -8}; + + _fc.input(input_shape, input_data.data()); + _fc.weights(weights_shape, weights_data.data()); + _fc.bias(bias_shape, bias_data.data()); + _fc.fused_act_func(luci::compute::FusedActFunc::RELU); + + EXPECT_FALSE(_fc.prepare()); +} diff --git a/compiler/luci-eval-driver/src/EvalDriver.cpp b/compiler/luci-eval-driver/src/EvalDriver.cpp index fb48f67e2..22f13f070 100644 --- a/compiler/luci-eval-driver/src/EvalDriver.cpp +++ b/compiler/luci-eval-driver/src/EvalDriver.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -48,7 +49,7 @@ void writeDataToFile(const std::string &filename, const char *data, size_t data_ template size_t getTensorSize(const NodeT *node) { - uint32_t tensor_size = loco::size(node->dtype()); + uint32_t tensor_size = luci::size(node->dtype()); for (uint32_t i = 0; i < node->rank(); ++i) tensor_size *= node->dim(i).value(); return tensor_size; diff --git a/compiler/luci-interpreter/include/luci_interpreter/BuddyMemoryManager.h b/compiler/luci-interpreter/include/luci_interpreter/BuddyMemoryManager.h index 205baa626..fec08993c 100644 --- a/compiler/luci-interpreter/include/luci_interpreter/BuddyMemoryManager.h +++ b/compiler/luci-interpreter/include/luci_interpreter/BuddyMemoryManager.h @@ -114,7 +114,7 @@ private: const int32_t l = lowerLog2(block->size + sizeof(Block)); const int64_t address = ((uint8_t *)block - (uint8_t *)_start_block); - buddy = (Block *)((address ^ (1 << l)) + (uint8_t *)_start_block); + buddy = (Block *)((address ^ (1LL << l)) + (uint8_t *)_start_block); if (!buddy->is_free || buddy->size != block->size) return nullptr; diff --git a/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h b/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h index 8e2f457a5..d64961c20 100644 --- a/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h +++ b/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h @@ -60,6 +60,8 @@ public: void readOutputTensor(const luci::CircleOutput *output_node, void *data, size_t data_size); + size_t getOutputTensorSize(const luci::CircleOutput *output_node); + void interpret(); void attachObserver(ExecutionObserver *observer); diff --git a/compiler/luci-interpreter/include/luci_interpreter/core/DataType.h b/compiler/luci-interpreter/include/luci_interpreter/core/DataType.h index 27bf719b5..57499a636 100644 --- a/compiler/luci-interpreter/include/luci_interpreter/core/DataType.h +++ b/compiler/luci-interpreter/include/luci_interpreter/core/DataType.h @@ -19,6 +19,7 @@ #include #include +#include #include @@ -29,7 +30,7 @@ using DataType = loco::DataType; template using DataTypeImpl = loco::DataTypeImpl
; -inline size_t getDataTypeSize(DataType data_type) { return loco::size(data_type); } +inline size_t getDataTypeSize(DataType data_type) { return luci::size(data_type); } } // namespace luci_interpreter diff --git a/compiler/luci-interpreter/include/luci_interpreter/core/Tensor.h b/compiler/luci-interpreter/include/luci_interpreter/core/Tensor.h index ad3388785..f118ee22c 100644 --- a/compiler/luci-interpreter/include/luci_interpreter/core/Tensor.h +++ b/compiler/luci-interpreter/include/luci_interpreter/core/Tensor.h @@ -60,6 +60,17 @@ public: return result; } + // TODO Replace num_elements + int64_t large_num_elements() const + { + int64_t result = 1; + for (const int32_t dim : _dims) + { + result *= dim; + } + return result; + } + bool operator==(const Shape &other) const { return _dims == other._dims; } bool operator!=(const Shape &other) const { return !operator==(other); } diff --git a/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst b/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst index e4d42de33..4cccfbcd0 100644 --- a/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst +++ b/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst @@ -4,9 +4,12 @@ REGISTER_KERNEL(ArgMax) REGISTER_KERNEL(AveragePool2D) REGISTER_KERNEL(BatchMatMul) REGISTER_KERNEL(BatchToSpaceND) +REGISTER_KERNEL(BroadcastTo) REGISTER_KERNEL(Cast) REGISTER_KERNEL(Concatenation) REGISTER_KERNEL(Conv2D) +REGISTER_KERNEL(Cos) +REGISTER_KERNEL(CumSum) REGISTER_KERNEL(DepthToSpace) REGISTER_KERNEL(DepthwiseConv2D) REGISTER_KERNEL(Dequantize) @@ -57,6 +60,7 @@ REGISTER_KERNEL(Quantize) REGISTER_KERNEL(ReduceMax) REGISTER_KERNEL(ReduceProd) REGISTER_KERNEL(Relu) +REGISTER_KERNEL(Relu0To1) REGISTER_KERNEL(Relu6) REGISTER_KERNEL(Reshape) REGISTER_KERNEL(ResizeBilinear) @@ -64,7 +68,9 @@ REGISTER_KERNEL(ResizeNearestNeighbor) REGISTER_KERNEL(ReverseV2) REGISTER_KERNEL(Rsqrt) REGISTER_KERNEL(Select) +REGISTER_KERNEL(SelectV2) REGISTER_KERNEL(Shape) +REGISTER_KERNEL(Sin) REGISTER_KERNEL(Slice) REGISTER_KERNEL(Softmax) REGISTER_KERNEL(SpaceToBatchND) @@ -80,6 +86,7 @@ REGISTER_KERNEL(Sub) REGISTER_KERNEL(Sum) REGISTER_KERNEL(SVDF) REGISTER_KERNEL(Tanh) +REGISTER_KERNEL(Tile) REGISTER_KERNEL(Transpose) REGISTER_KERNEL(TransposeConv) REGISTER_KERNEL(UnidirectionalSequenceLSTM) diff --git a/compiler/luci-interpreter/pal/linux/PALBroadcastTo.h b/compiler/luci-interpreter/pal/linux/PALBroadcastTo.h new file mode 100644 index 000000000..22964104d --- /dev/null +++ b/compiler/luci-interpreter/pal/linux/PALBroadcastTo.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_PAL_BROADCASTTO_H +#define LUCI_INTERPRETER_PAL_BROADCASTTO_H + +#include + +namespace luci_interpreter_pal +{ + +static inline void BroadcastTo(const tflite::RuntimeShape &input_shape, const char *input_data, + const tflite::RuntimeShape &output_shape, char *output_data, + TfLiteType data_type) +{ + // BroadcastTo op supports up to 8 kMaxDims in tensorflow. + // but, currently we support up to 5 dims because there is a compiler bug in 7.4.0 gcc version. + // https://github.com/tensorflow/tensorflow/blob/932af96ae91b4fa647dc50ad0f14c3e0b60affab/tensorflow/lite/kernels/broadcast_to.cc#L118 + constexpr int kMaxDims = 5; + tflite::reference_ops::BroadcastTo(input_shape, input_data, output_shape, output_data, + data_type); +} + +} // namespace luci_interpreter_pal + +#endif // LUCI_INTERPRETER_PAL_BROADCASTTO_H diff --git a/compiler/luci-interpreter/pal/linux/PALConv2d.h b/compiler/luci-interpreter/pal/linux/PALConv2d.h index 985a15f39..8ffcd864b 100644 --- a/compiler/luci-interpreter/pal/linux/PALConv2d.h +++ b/compiler/luci-interpreter/pal/linux/PALConv2d.h @@ -111,9 +111,11 @@ static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad, const int32_t output_width = output_shape.Dims(2); auto data_type_size = static_cast(luci_interpreter::getDataTypeSize(input_data_type)); - int32_t scratchpad_size = batches * output_width * output_height * input_depth * filter_height * - filter_width * data_type_size; - luci_interpreter::Shape scratchpad_shape{scratchpad_size}; + // im2col_shape + // data_type_size is added because we use U8 for scratchpad buffer dtype + luci_interpreter::Shape scratchpad_shape{batches, output_height, output_width, + input_depth * filter_height * filter_width, + data_type_size}; scratchpad->resize(scratchpad_shape); } else diff --git a/compiler/luci-interpreter/pal/linux/PALRelu0To1.h b/compiler/luci-interpreter/pal/linux/PALRelu0To1.h new file mode 100644 index 000000000..0960d5266 --- /dev/null +++ b/compiler/luci-interpreter/pal/linux/PALRelu0To1.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_PAL_RELU0TO1_H +#define LUCI_INTERPRETER_PAL_RELU0TO1_H + +#include + +namespace luci_interpreter_pal +{ +static inline void Relu0To1(const tflite::RuntimeShape &input_shape, const float *input_data, + const tflite::RuntimeShape &output_shape, float *output_data) +{ + tflite::optimized_ops::Relu0To1(input_shape, input_data, output_shape, output_data); +} + +template +static inline void ReluX(const tflite::ReluParams ¶ms, const tflite::RuntimeShape &input_shape, + const T *input_data, const tflite::RuntimeShape &output_shape, + T *output_data) +{ + tflite::optimized_ops::ReluX(params, input_shape, input_data, output_shape, output_data); +} +} // namespace luci_interpreter_pal + +#endif // LUCI_INTERPRETER_PAL_RELU0TO1_H diff --git a/compiler/luci-interpreter/pal/linux/pal.cmake b/compiler/luci-interpreter/pal/linux/pal.cmake index 28f6352bc..e31105fbd 100644 --- a/compiler/luci-interpreter/pal/linux/pal.cmake +++ b/compiler/luci-interpreter/pal/linux/pal.cmake @@ -25,10 +25,11 @@ elseif("${TARGET_ARCH}" STREQUAL "aarch64") endif() macro(initialize_pal) - nnas_find_package(TensorFlowSource EXACT 2.8.0 QUIET) - nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.8.0 QUIET) - nnas_find_package(TensorFlowEigenSource EXACT 2.8.0 QUIET) - nnas_find_package(TensorFlowRuySource EXACT 2.8.0 QUIET) + nnas_find_package(TensorFlowSource EXACT 2.12.1 QUIET) + nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.12.1 QUIET) + nnas_find_package(TensorFlowEigenSource EXACT 2.12.1 QUIET) + nnas_find_package(TensorFlowRuySource EXACT 2.12.1 QUIET) + nnas_find_package(TensorFlowThreadpoolSource EXACT 2.12.1 QUIET) if (NOT TensorFlowSource_FOUND) message(STATUS "Skipping luci-interpreter: TensorFlow not found") @@ -50,6 +51,11 @@ macro(initialize_pal) return() endif () + if (NOT TensorFlowThreadpoolSource_FOUND) + message(STATUS "Skipping luci-interpreter: Threadpool not found") + return() + endif () + find_package(Threads REQUIRED) set(PAL_INITIALIZED TRUE) @@ -61,6 +67,7 @@ macro(add_pal_to_target TGT) "${TensorFlowRuySource_DIR}" "${TensorFlowGEMMLowpSource_DIR}" "${TensorFlowEigenSource_DIR}" + "${TensorFlowThreadpoolSource_DIR}/include" "${TensorFlowSource_DIR}") target_include_directories(${TGT} PRIVATE ${LUCI_INTERPRETER_PAL_DIR}) @@ -68,7 +75,10 @@ macro(add_pal_to_target TGT) # instead add sources with visitors in this library set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/tensor_utils.cc ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc - ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc) + ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc + ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/portable_tensor_utils.cc + ${TensorFlowSource_DIR}/tensorflow/lite/kernels/kernel_util.cc + ${TensorFlowSource_DIR}/tensorflow/lite/core/c/common.cc) if(TARGET_ARCH_BASE STREQUAL "arm") # NOTE may need to revise this list for version upgrade @@ -126,6 +136,7 @@ macro(add_pal_to_target TGT) "${TensorFlowRuySource_DIR}" "${TensorFlowGEMMLowpSource_DIR}" "${TensorFlowEigenSource_DIR}" + "${TensorFlowThreadpoolSource_DIR}/include" "${TensorFlowSource_DIR}" ) diff --git a/compiler/luci-interpreter/src/CMakeLists.txt b/compiler/luci-interpreter/src/CMakeLists.txt index 997b75a84..401b8f303 100644 --- a/compiler/luci-interpreter/src/CMakeLists.txt +++ b/compiler/luci-interpreter/src/CMakeLists.txt @@ -39,7 +39,7 @@ else () add_library(${LUCI_INTERPRETER_BINARY} STATIC ${SOURCES}) endif () -set(TEST_SOURCES BuddyMemoryManager.test.cpp) +set(TEST_SOURCES SimpleMemoryManager.test.cpp BuddyMemoryManager.test.cpp) target_include_directories(${LUCI_INTERPRETER_BINARY} PUBLIC "${LUCI_INTERPRETER_INCLUDE_DIR}") target_include_directories(${LUCI_INTERPRETER_BINARY} PRIVATE "${LUCI_INTERPRETER_SOURCE_DIR}") @@ -57,5 +57,5 @@ endif(NOT ENABLE_TEST) nnas_find_package(GTest REQUIRED) -GTest_AddTest(buddy_manager_test ${TEST_SOURCES}) -target_link_libraries(buddy_manager_test ${LUCI_INTERPRETER_BINARY}) +GTest_AddTest(luci_interpreter_memory_manager_test ${TEST_SOURCES}) +target_link_libraries(luci_interpreter_memory_manager_test ${LUCI_INTERPRETER_BINARY}) diff --git a/compiler/luci-interpreter/src/Interpreter.cpp b/compiler/luci-interpreter/src/Interpreter.cpp index 8cf272efd..fd46ec35d 100644 --- a/compiler/luci-interpreter/src/Interpreter.cpp +++ b/compiler/luci-interpreter/src/Interpreter.cpp @@ -125,6 +125,20 @@ void Interpreter::readOutputTensor(const luci::CircleOutput *output_node, void * tensor->readData(data, data_size); } +size_t Interpreter::getOutputTensorSize(const luci::CircleOutput *output_node) +{ + Tensor *tensor = _runtime_module->getOutputTensors()[output_node->index()]; + if (tensor == nullptr) + { + const std::string &name = output_node->name(); + throw std::runtime_error("Cannot find tensor size for output node named \"" + name + "\"."); + } + + size_t tensor_size = luci_interpreter::getDataTypeSize(tensor->element_type()); + tensor_size *= tensor->shape().num_elements(); + return tensor_size; +} + void Interpreter::interpret() { _runtime_module->execute(); } void Interpreter::attachObserver(ExecutionObserver *observer) diff --git a/compiler/luci-interpreter/src/SimpleMemoryManager.cpp b/compiler/luci-interpreter/src/SimpleMemoryManager.cpp index 230e39896..a39c34a0a 100644 --- a/compiler/luci-interpreter/src/SimpleMemoryManager.cpp +++ b/compiler/luci-interpreter/src/SimpleMemoryManager.cpp @@ -30,7 +30,9 @@ void SimpleMemoryManager::allocate_memory(luci_interpreter::Tensor &tensor) release_memory(tensor); } const auto element_size = getDataTypeSize(tensor.element_type()); - const auto num_elements = tensor.shape().num_elements(); + + // Use large_num_elements to avoid overflow + const auto num_elements = tensor.shape().large_num_elements(); auto *data = new uint8_t[num_elements * element_size]; tensor.set_data_buffer(data); diff --git a/compiler/luci-interpreter/src/SimpleMemoryManager.test.cpp b/compiler/luci-interpreter/src/SimpleMemoryManager.test.cpp new file mode 100644 index 000000000..18902e3e6 --- /dev/null +++ b/compiler/luci-interpreter/src/SimpleMemoryManager.test.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci_interpreter/SimpleMemoryManager.h" +#include + +using namespace luci_interpreter; +using namespace testing; + +TEST(SimpleMemoryManager, basic) +{ + SimpleMemoryManager smm; + Tensor t(DataType::U8, Shape({1, 16, 16, 256}), AffineQuantization{}, "t"); + + EXPECT_NO_THROW(smm.allocate_memory(t)); + EXPECT_NO_THROW(smm.release_memory(t)); +} + +TEST(SimpleMemoryManager, huge) +{ + SimpleMemoryManager smm; + Tensor t(DataType::U8, Shape({1, 512, 512, 256 * 3 * 3 * 4}), AffineQuantization{}, "t"); + + EXPECT_NO_THROW(smm.allocate_memory(t)); + EXPECT_NO_THROW(smm.release_memory(t)); +} + +TEST(SimpleMemoryManager, string_dtype_NEG) +{ + SimpleMemoryManager smm; + Tensor t(DataType::STRING, Shape({1, 16, 16, 4}), AffineQuantization{}, "t"); + + EXPECT_ANY_THROW(smm.allocate_memory(t)); +} + +TEST(SimpleMemoryManager, negative_shape_NEG) +{ + SimpleMemoryManager smm; + Tensor t(DataType::U8, Shape({1, 16, 16, -4}), AffineQuantization{}, "t"); + + EXPECT_ANY_THROW(smm.allocate_memory(t)); +} diff --git a/compiler/luci-interpreter/src/core/KernelParams.h b/compiler/luci-interpreter/src/core/KernelParams.h index 4ddbcefb8..132ed5c49 100644 --- a/compiler/luci-interpreter/src/core/KernelParams.h +++ b/compiler/luci-interpreter/src/core/KernelParams.h @@ -65,6 +65,12 @@ struct Conv2DParams Activation activation; }; +struct CumSumParams +{ + bool exclusive; + bool reverse; +}; + struct DepthToSpaceParams { int block_size; diff --git a/compiler/luci-interpreter/src/kernels/Abs.cpp b/compiler/luci-interpreter/src/kernels/Abs.cpp index 5c6331501..35e1b9d54 100644 --- a/compiler/luci-interpreter/src/kernels/Abs.cpp +++ b/compiler/luci-interpreter/src/kernels/Abs.cpp @@ -42,7 +42,7 @@ void Abs::execute() const eval(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Abs Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Add.cpp b/compiler/luci-interpreter/src/kernels/Add.cpp index d7bf3084f..e954b61f1 100644 --- a/compiler/luci-interpreter/src/kernels/Add.cpp +++ b/compiler/luci-interpreter/src/kernels/Add.cpp @@ -70,7 +70,7 @@ void Add::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Add Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp b/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp index d3bade9e4..cd42bcb4f 100644 --- a/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp +++ b/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp @@ -99,7 +99,7 @@ void AveragePool2D::execute() const evalSInt8(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp AveragePool2D Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/BatchMatMul.cpp b/compiler/luci-interpreter/src/kernels/BatchMatMul.cpp index 24ca22996..c159e55af 100644 --- a/compiler/luci-interpreter/src/kernels/BatchMatMul.cpp +++ b/compiler/luci-interpreter/src/kernels/BatchMatMul.cpp @@ -58,7 +58,7 @@ void BatchMatMul::configure() // TODO Support non-float types if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32) - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp BatchMatMul(1) Unsupported type."); LUCI_INTERPRETER_CHECK(lhs->element_type() == rhs->element_type()); @@ -180,7 +180,7 @@ void BatchMatMul::execute() const getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp BatchMatMul(2) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/BatchToSpaceND.cpp b/compiler/luci-interpreter/src/kernels/BatchToSpaceND.cpp index bd315ff7b..3df0f6a40 100644 --- a/compiler/luci-interpreter/src/kernels/BatchToSpaceND.cpp +++ b/compiler/luci-interpreter/src/kernels/BatchToSpaceND.cpp @@ -96,7 +96,7 @@ void BatchToSpaceND::execute() const getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp BatchToSpaceND Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/BroadcastTo.cpp b/compiler/luci-interpreter/src/kernels/BroadcastTo.cpp new file mode 100644 index 000000000..a0ae0f831 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/BroadcastTo.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/BroadcastTo.h" +#include "kernels/Utils.h" + +#include "PALBroadcastTo.h" + +#include + +namespace luci_interpreter +{ +namespace kernels +{ + +namespace +{ + +// TODO Extract this function to Utils.h +Shape extractShapeFromTensor(const Tensor *tensor) +{ + Shape shape(tensor->shape().num_elements()); + + // Ensures the shape is 1D tensor + LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == 1); + + if (tensor->element_type() == DataType::S32) + { + const auto *shape_data = tensor->data(); + for (int i = 0; i < tensor->shape().num_elements(); ++i) + { + // Ensures the dim value of shape is positive. + LUCI_INTERPRETER_CHECK(shape_data[i] >= 0); + + shape.dim(i) = shape_data[i]; + } + } + else if (tensor->element_type() == DataType::S64) + { + const auto *shape_data = tensor->data(); + for (int i = 0; i < tensor->shape().num_elements(); ++i) + { + // Ensures the dim value of shape is positive. + LUCI_INTERPRETER_CHECK(shape_data[i] >= 0); + + shape.dim(i) = static_cast(shape_data[i]); + // Check value overflow + LUCI_INTERPRETER_CHECK(static_cast(shape.dim(i)) == shape_data[i]); + } + } + else + { + LUCI_INTERPRETER_CHECK(false); + } + return shape; +} + +} // namespace + +BroadcastTo::BroadcastTo(const Tensor *input, const Tensor *shape, Tensor *output) + : Kernel({input, shape}, {output}) +{ +} + +void BroadcastTo::configure() +{ + LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type()); + + Shape output_shape = extractShapeFromTensor(shape()); + + int input_rank = input()->shape().num_dims(); + int output_rank = output_shape.num_dims(); + + // Ensures output rank is not less than input rank + LUCI_INTERPRETER_CHECK(input_rank <= output_rank); + + // Check if output shape is broadcastable from input shape + // from https://www.tensorflow.org/api_docs/python/tf/broadcast_to + // if a tensor has fewer axes than necessary its shape is padded on the left with ones. + int extending_rank = output_rank - input_rank; + for (int idx = 0; idx < input_rank; ++idx) + { + LUCI_INTERPRETER_CHECK(input()->shape().dim(idx) == 1 || + input()->shape().dim(idx) == output_shape.dim(extending_rank + idx)); + } + + output()->resize(output_shape); +} + +void BroadcastTo::execute() const +{ + switch (input()->element_type()) + { + case DataType::FLOAT32: + evalFloat(); + break; + case DataType::BOOL: + evalBool(); + break; + default: + throw std::runtime_error("luci-intp BroadcastTo Unsupported type."); + } +} + +void BroadcastTo::evalFloat() const +{ + luci_interpreter_pal::BroadcastTo(getTensorShape(input()), getTensorData(input()), + getTensorShape(output()), getTensorData(output()), + TfLiteType::kTfLiteFloat32); +} + +void BroadcastTo::evalBool() const +{ + luci_interpreter_pal::BroadcastTo(getTensorShape(input()), getTensorData(input()), + getTensorShape(output()), getTensorData(output()), + TfLiteType::kTfLiteBool); +} + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/BroadcastTo.h b/compiler/luci-interpreter/src/kernels/BroadcastTo.h new file mode 100644 index 000000000..0037dfcfb --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/BroadcastTo.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_KERNELS_BROADCASTTO_H +#define LUCI_INTERPRETER_KERNELS_BROADCASTTO_H + +#include "core/Kernel.h" +#include "core/KernelParams.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +class BroadcastTo : public Kernel +{ +public: + BroadcastTo(const Tensor *input, const Tensor *shape, Tensor *output); + + const Tensor *input() const { return _inputs[0]; } + const Tensor *shape() const { return _inputs[1]; } + Tensor *output() const { return _outputs[0]; } + + void configure() override; + void execute() const override; + +private: + void evalFloat() const; + void evalBool() const; +}; + +} // namespace kernels +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_KERNELS_BROADCASTTO_H diff --git a/compiler/luci-interpreter/src/kernels/BroadcastTo.test.cpp b/compiler/luci-interpreter/src/kernels/BroadcastTo.test.cpp new file mode 100644 index 000000000..dffaaa495 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/BroadcastTo.test.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/BroadcastTo.h" +#include "kernels/TestUtils.h" +#include "luci_interpreter/TestMemoryManager.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +namespace +{ + +using namespace testing; + +template +void Check(std::initializer_list input_shape, std::initializer_list shape_shape, + std::initializer_list output_shape, std::initializer_list input_data, + std::initializer_list shape_data, std::initializer_list output_data) +{ + std::unique_ptr memory_manager = std::make_unique(); + constexpr DataType element_type = DataType::FLOAT32; + constexpr DataType shape_type = getElementType(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Tensor shape_tensor = makeInputTensor(shape_shape, shape_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(element_type); + + BroadcastTo kernel(&input_tensor, &shape_tensor, &output_tensor); + + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), FloatArrayNear(output_data)); +} + +template +void Check_bool(std::initializer_list input_shape, + std::initializer_list shape_shape, + std::initializer_list output_shape, + std::initializer_list input_data, std::initializer_list shape_data, + std::initializer_list output_data) +{ + std::unique_ptr memory_manager = std::make_unique(); + constexpr DataType element_type = DataType::BOOL; + constexpr DataType shape_type = getElementType(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Tensor shape_tensor = makeInputTensor(shape_shape, shape_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(element_type); + + BroadcastTo kernel(&input_tensor, &shape_tensor, &output_tensor); + + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), output_shape); +} + +class BroadcastToTest : public ::testing::Test +{ +}; + +TEST_F(BroadcastToTest, SimpleS32) +{ + Check(/*input_shape*/ {1, 3}, /*shape_shape*/ {2}, /*output_shape*/ {2, 3}, + /*input_data*/ + {1, 2, 3}, + /*shape_data*/ + {2, 3}, + /*output_data*/ + { + 1, 2, 3, // Row 1 + 1, 2, 3, // Row 2 + }); + SUCCEED(); +} + +TEST_F(BroadcastToTest, SimpleS64) +{ + Check(/*input_shape*/ {1, 3}, /*shape_shape*/ {2}, /*output_shape*/ {2, 3}, + /*input_data*/ + {1, 2, 3}, + /*shape_data*/ + {2, 3}, + /*output_data*/ + { + 1, 2, 3, // Row 1 + 1, 2, 3, // Row 2 + }); + SUCCEED(); +} + +TEST_F(BroadcastToTest, SimpleBool) +{ + Check_bool(/*input_shape*/ {1, 3}, /*shape_shape*/ {2}, /*output_shape*/ {2, 3}, + /*input_data*/ + {true, false, true}, + /*shape_data*/ + {2, 3}, + /*output_data*/ + { + true, false, true, // Row 1 + true, false, true, // Row 2 + }); + SUCCEED(); +} + +TEST_F(BroadcastToTest, DifferentInOutType_NEG) +{ + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = makeInputTensor({1, 3}, {1, 2, 3}, memory_manager.get()); + Tensor shape_tensor = makeInputTensor({2}, {2, 3}, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::U8); + + BroadcastTo kernel(&input_tensor, &shape_tensor, &output_tensor); + + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST_F(BroadcastToTest, BroadcastAble_NEG) +{ + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor({2, 3}, {1, 2, 3, 1, 2, 3}, memory_manager.get()); + Tensor shape_tensor = makeInputTensor({2}, {2, 6}, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + BroadcastTo kernel(&input_tensor, &shape_tensor, &output_tensor); + + EXPECT_ANY_THROW(kernel.configure()); +} + +} // namespace + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Concatenation.cpp b/compiler/luci-interpreter/src/kernels/Concatenation.cpp index 46ee5941e..f2f556bb3 100644 --- a/compiler/luci-interpreter/src/kernels/Concatenation.cpp +++ b/compiler/luci-interpreter/src/kernels/Concatenation.cpp @@ -107,7 +107,7 @@ void Concatenation::execute() const evalGeneric(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Concatenation Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Conv2D.cpp b/compiler/luci-interpreter/src/kernels/Conv2D.cpp index 234f95425..9aae9da26 100644 --- a/compiler/luci-interpreter/src/kernels/Conv2D.cpp +++ b/compiler/luci-interpreter/src/kernels/Conv2D.cpp @@ -75,7 +75,7 @@ void Conv2D::configure() } else { - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Conv2D(1) Unsupported type."); } LUCI_INTERPRETER_CHECK(output()->element_type() == input()->element_type()); @@ -143,7 +143,7 @@ void Conv2D::execute() const evalFloat(); break; } - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Conv2D(2) Unsupported type."); case DataType::U8: if (filter()->scales().size() == 1) { @@ -164,7 +164,7 @@ void Conv2D::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Conv2D(3) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Cos.cpp b/compiler/luci-interpreter/src/kernels/Cos.cpp new file mode 100644 index 000000000..c14593dee --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Cos.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Cos.h" + +#include "kernels/Utils.h" + +#include + +namespace luci_interpreter +{ +namespace kernels +{ + +namespace +{ + +template +inline void CalcCos(const T *input_data, const size_t num_elements, T *output_data) +{ + for (size_t idx = 0; idx < num_elements; ++idx) + { + output_data[idx] = std::cos(input_data[idx]); + } +} + +} // namespace + +Cos::Cos(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {} + +void Cos::configure() +{ + LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32); + LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type()); + output()->resize(input()->shape()); +} + +void Cos::execute() const +{ + switch (input()->element_type()) + { + case DataType::FLOAT32: + evalFloat(); + break; + default: + throw std::runtime_error("luci-intp Cos Unsupported type."); + } +} + +void Cos::evalFloat() const +{ + const int size = tflite::MatchingFlatSize(getTensorShape(input()), getTensorShape(output())); + CalcCos(getTensorData(input()), size, getTensorData(output())); +} + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Cos.h b/compiler/luci-interpreter/src/kernels/Cos.h new file mode 100644 index 000000000..6097c5bc2 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Cos.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_KERNELS_COS_H +#define LUCI_INTERPRETER_KERNELS_COS_H + +#include "core/Kernel.h" +#include "core/KernelParams.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +class Cos : public Kernel +{ +public: + Cos(const Tensor *input, Tensor *output); + + const Tensor *input() const { return _inputs[0]; } + Tensor *output() const { return _outputs[0]; } + + void configure() override; + void execute() const override; + +private: + void evalFloat() const; +}; + +} // namespace kernels +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_KERNELS_COS_H diff --git a/compiler/luci-interpreter/src/kernels/Cos.test.cpp b/compiler/luci-interpreter/src/kernels/Cos.test.cpp new file mode 100644 index 000000000..ad7ea6fb9 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Cos.test.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Cos.h" +#include "kernels/TestUtils.h" +#include "luci_interpreter/TestMemoryManager.h" + +#include + +namespace luci_interpreter +{ +namespace kernels +{ +namespace +{ + +using namespace testing; + +#define PI 3.14159265358979323846 + +TEST(CosTest, Float) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 1, 3}; + std::vector input_data{0.0f, PI / 3.0f, -PI / 3.0f}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Cos kernel(&input_tensor, &output_tensor); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + std::vector ref_output_shape{1, 1, 3}; + std::vector ref_output_data{std::cos(0.0f), std::cos(PI / 3.0f), std::cos(-PI / 3.0f)}; + EXPECT_THAT(extractTensorData(output_tensor), FloatArrayNear(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape)); +} + +TEST(SinTest, InvalidDType_NEG) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 1, 3}; + std::vector input_data{1l, 2l, 3l}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::S64); + + Cos kernel(&input_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +} // namespace +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/CumSum.cpp b/compiler/luci-interpreter/src/kernels/CumSum.cpp new file mode 100644 index 000000000..4c54172f1 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/CumSum.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/CumSum.h" + +#include + +#include "kernels/Utils.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +CumSum::CumSum(const Tensor *input, const Tensor *axis, Tensor *output, const CumSumParams ¶ms) + : KernelWithParams({input, axis}, {output}, params) +{ +} + +void CumSum::configure() +{ + LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type()); + LUCI_INTERPRETER_CHECK(input()->shape().num_dims() >= 1); + + LUCI_INTERPRETER_CHECK(axis()->element_type() == DataType::S32); + LUCI_INTERPRETER_CHECK(axis()->shape().num_dims() == 0); + + output()->resize(input()->shape()); +} + +void CumSum::execute() const +{ + switch (input()->element_type()) + { + case DataType::FLOAT32: + tflite::optimized_ops::CumSum(getTensorData(input()), getTensorShape(input()), + *getTensorData(axis()), params().exclusive, + params().reverse, getTensorData(output())); + break; + case DataType::S32: + tflite::optimized_ops::CumSum(getTensorData(input()), getTensorShape(input()), + *getTensorData(axis()), params().exclusive, + params().reverse, getTensorData(output())); + break; + case DataType::S64: + tflite::optimized_ops::CumSum(getTensorData(input()), getTensorShape(input()), + *getTensorData(axis()), params().exclusive, + params().reverse, getTensorData(output())); + break; + default: + throw std::runtime_error("luci-intp CumSum Unsupported type."); + } +} + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/CumSum.h b/compiler/luci-interpreter/src/kernels/CumSum.h new file mode 100644 index 000000000..bc08480a6 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/CumSum.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_KERNELS_CUMSUM_H +#define LUCI_INTERPRETER_KERNELS_CUMSUM_H + +#include + +#include "core/Kernel.h" +#include "core/KernelParams.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +class CumSum : public KernelWithParams +{ +public: + CumSum(const Tensor *input, const Tensor *axis, Tensor *output, const CumSumParams ¶ms); + + const Tensor *input() const { return _inputs[0]; } + const Tensor *axis() const { return _inputs[1]; } + Tensor *output() const { return _outputs[0]; } + + void configure() override; + void execute() const override; +}; + +} // namespace kernels +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_KERNELS_CUMSUM_H diff --git a/compiler/luci-interpreter/src/kernels/CumSum.test.cpp b/compiler/luci-interpreter/src/kernels/CumSum.test.cpp new file mode 100644 index 000000000..2a8ea98f8 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/CumSum.test.cpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/CumSum.h" +#include "kernels/TestUtils.h" +#include "luci_interpreter/TestMemoryManager.h" + +namespace luci_interpreter +{ +namespace kernels +{ +namespace +{ +using namespace testing; + +class CumSumTest : public ::testing::Test +{ +protected: + void SetUp() override { _memory_manager = std::make_unique(); } + + std::unique_ptr _memory_manager; +}; + +TEST_F(CumSumTest, Float) +{ + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector axis_data{2}; + Shape input_shape{1, 1, 2, 4}; + Shape axis_shape(0); + std::vector output_data{1, 2, 3, 4, 6, 8, 10, 12}; + std::vector output_shape{1, 1, 2, 4}; + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + CumSumParams params{false, false}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); +} + +TEST_F(CumSumTest, Int32) +{ + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector axis_data{2}; + Shape input_shape{1, 1, 2, 4}; + Shape axis_shape(0); + std::vector output_data{1, 2, 3, 4, 6, 8, 10, 12}; + std::vector output_shape{1, 1, 2, 4}; + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::S32); + + CumSumParams params{false, false}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); +} + +TEST_F(CumSumTest, Int64) +{ + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector axis_data{2}; + Shape input_shape{1, 1, 2, 4}; + Shape axis_shape(0); + std::vector output_data{1, 2, 3, 4, 6, 8, 10, 12}; + std::vector output_shape{1, 1, 2, 4}; + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::S64); + + CumSumParams params{false, false}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); +} + +TEST_F(CumSumTest, Float_Reverse) +{ + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector axis_data{2}; + Shape input_shape{1, 1, 2, 4}; + Shape axis_shape(0); + std::vector output_data{6, 8, 10, 12, 5, 6, 7, 8}; + std::vector output_shape{1, 1, 2, 4}; + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + CumSumParams params{false, true}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); +} + +TEST_F(CumSumTest, Float_Exclusive) +{ + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector axis_data{2}; + Shape input_shape{1, 1, 4, 2}; + Shape axis_shape(0); + std::vector output_data{0, 0, 1, 2, 4, 6, 9, 12}; + std::vector output_shape{1, 1, 4, 2}; + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + CumSumParams params{true, false}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); +} + +TEST_F(CumSumTest, InputShapeInvalid_NEG) +{ + std::vector input_data{1}; + std::vector axis_data{2}; + Shape input_shape(0); + Shape axis_shape(0); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + CumSumParams params{false, false}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST_F(CumSumTest, AxisShapeInvalid_NEG) +{ + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector axis_data{2}; + Shape input_shape{1, 1, 2, 4}; + Shape axis_shape{1}; + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + CumSumParams params{false, false}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST_F(CumSumTest, AxisTypeInvalid_NEG) +{ + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector axis_data{2}; + Shape input_shape{1, 1, 2, 4}; + Shape axis_shape(0); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + + Tensor axis_tensor = + makeInputTensor(axis_shape, axis_data, _memory_manager.get()); + + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + CumSumParams params{false, false}; + + CumSum kernel(&input_tensor, &axis_tensor, &output_tensor, params); + EXPECT_ANY_THROW(kernel.configure()); +} + +} // namespace + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/DepthToSpace.cpp b/compiler/luci-interpreter/src/kernels/DepthToSpace.cpp index 3a9acd1d4..ec5c918b4 100644 --- a/compiler/luci-interpreter/src/kernels/DepthToSpace.cpp +++ b/compiler/luci-interpreter/src/kernels/DepthToSpace.cpp @@ -72,7 +72,7 @@ void DepthToSpace::execute() const getTensorData(output())); break; default: - throw std::runtime_error("Unsupported Type."); + throw std::runtime_error("luci-intp DepthToSpace Unsupported Type."); } } diff --git a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp index c554c309d..a48416f57 100644 --- a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp +++ b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp @@ -75,7 +75,7 @@ void DepthwiseConv2D::configure() } else { - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp DepthwiseConv2D(1) Unsupported type."); } LUCI_INTERPRETER_CHECK(output()->element_type() == input()->element_type()); @@ -130,7 +130,7 @@ void DepthwiseConv2D::execute() const evalFloat(); break; } - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp DepthwiseConv2D(2) Unsupported type."); case DataType::U8: if (filter()->scales().size() == 1) { @@ -151,7 +151,7 @@ void DepthwiseConv2D::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp DepthwiseConv2D(3) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Dequantize.cpp b/compiler/luci-interpreter/src/kernels/Dequantize.cpp index 96399e5c7..883d9b5fe 100644 --- a/compiler/luci-interpreter/src/kernels/Dequantize.cpp +++ b/compiler/luci-interpreter/src/kernels/Dequantize.cpp @@ -71,7 +71,7 @@ void Dequantize::execute() const break; } default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Dequantize Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Div.cpp b/compiler/luci-interpreter/src/kernels/Div.cpp index dd1532278..190077aed 100644 --- a/compiler/luci-interpreter/src/kernels/Div.cpp +++ b/compiler/luci-interpreter/src/kernels/Div.cpp @@ -56,7 +56,7 @@ void Div::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Div Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Elu.cpp b/compiler/luci-interpreter/src/kernels/Elu.cpp index 697d63be4..a5f366e04 100644 --- a/compiler/luci-interpreter/src/kernels/Elu.cpp +++ b/compiler/luci-interpreter/src/kernels/Elu.cpp @@ -44,7 +44,7 @@ void Elu::execute() const getTensorShape(output()), getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Elu Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Equal.cpp b/compiler/luci-interpreter/src/kernels/Equal.cpp index a57e127b7..c3c11b1ad 100644 --- a/compiler/luci-interpreter/src/kernels/Equal.cpp +++ b/compiler/luci-interpreter/src/kernels/Equal.cpp @@ -59,7 +59,7 @@ void Equal::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Equal Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Exp.cpp b/compiler/luci-interpreter/src/kernels/Exp.cpp index e7c560a88..e276f7fd7 100644 --- a/compiler/luci-interpreter/src/kernels/Exp.cpp +++ b/compiler/luci-interpreter/src/kernels/Exp.cpp @@ -42,7 +42,7 @@ void Exp::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Exp Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/ExpandDims.cpp b/compiler/luci-interpreter/src/kernels/ExpandDims.cpp index ba35c99fa..5cbc4e259 100644 --- a/compiler/luci-interpreter/src/kernels/ExpandDims.cpp +++ b/compiler/luci-interpreter/src/kernels/ExpandDims.cpp @@ -40,7 +40,7 @@ void ExpandDims::configure() axis_value = static_cast(*getTensorData(axis())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp ExpandDims Unsupported type."); } const auto input_shape = input()->shape(); diff --git a/compiler/luci-interpreter/src/kernels/Fill.cpp b/compiler/luci-interpreter/src/kernels/Fill.cpp index e09d6331a..8bea85f61 100644 --- a/compiler/luci-interpreter/src/kernels/Fill.cpp +++ b/compiler/luci-interpreter/src/kernels/Fill.cpp @@ -80,7 +80,7 @@ void Fill::configure() configureShape(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Fill(1) Unsupported type."); } } @@ -109,7 +109,7 @@ void Fill::execute() const getTensorShape(output()), getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Fill(2) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Floor.cpp b/compiler/luci-interpreter/src/kernels/Floor.cpp index e3c4246cc..f15ee2427 100644 --- a/compiler/luci-interpreter/src/kernels/Floor.cpp +++ b/compiler/luci-interpreter/src/kernels/Floor.cpp @@ -43,7 +43,7 @@ void Floor::execute() const break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Floor Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/FloorDiv.cpp b/compiler/luci-interpreter/src/kernels/FloorDiv.cpp index a7a10a336..f2e617cca 100644 --- a/compiler/luci-interpreter/src/kernels/FloorDiv.cpp +++ b/compiler/luci-interpreter/src/kernels/FloorDiv.cpp @@ -48,7 +48,7 @@ void FloorDiv::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp FloorDiv Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/FloorMod.cpp b/compiler/luci-interpreter/src/kernels/FloorMod.cpp index a64fcad3a..8238883bd 100644 --- a/compiler/luci-interpreter/src/kernels/FloorMod.cpp +++ b/compiler/luci-interpreter/src/kernels/FloorMod.cpp @@ -75,7 +75,7 @@ void FloorMod::execute() const evalInteger(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp FloorMod Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/FullyConnected.cpp b/compiler/luci-interpreter/src/kernels/FullyConnected.cpp index bd2bb2f35..ce34655dc 100644 --- a/compiler/luci-interpreter/src/kernels/FullyConnected.cpp +++ b/compiler/luci-interpreter/src/kernels/FullyConnected.cpp @@ -54,9 +54,23 @@ void FullyConnected::configure() LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8); LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32) } + else if (weights()->element_type() == DataType::S4) + { + // TODO support other combinations when needed + LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32); + LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32); + LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32) + } + else if (weights()->element_type() == DataType::U4) + { + // TODO support other combinations when needed + LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32); + LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32); + LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32) + } else { - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp FullyConnected(1) Unsupported type."); } const Shape &input_shape = input()->shape(); @@ -70,9 +84,6 @@ void FullyConnected::configure() const int32_t batch_size = input_shape.num_elements() / weights_shape.dim(1); const int32_t num_units = weights_shape.dim(0); - if (bias()) - LUCI_INTERPRETER_CHECK(bias()->shape().num_elements() == weights()->shape().dim(0)); - if (params().keep_num_dims == false) { output()->resize({batch_size, num_units}); @@ -89,19 +100,41 @@ void FullyConnected::configure() void FullyConnected::execute() const { - switch (input()->element_type()) + const bool is_hybrid = + (input()->element_type() == DataType::FLOAT32 && + (weights()->element_type() == DataType::S4 || weights()->element_type() == DataType::U4) && + output()->element_type() == DataType::FLOAT32 && + (!bias() || bias()->element_type() == DataType::FLOAT32)); + if (is_hybrid) + { + switch (weights()->element_type()) + { + case DataType::S4: + evalHybridWI4AF32(); + break; + case DataType::U4: + evalHybridWU4AF32(); + break; + default: + throw std::runtime_error("luci-intp FullyConnected(3) Unsupported type."); + } + } + else { - case DataType::U8: - evalQuantized(); - break; - case DataType::S8: - evalQuantizedS8(); - break; - case DataType::FLOAT32: - evalFloat(); - break; - default: - throw std::runtime_error("Unsupported type."); + switch (input()->element_type()) + { + case DataType::U8: + evalQuantized(); + break; + case DataType::S8: + evalQuantizedS8(); + break; + case DataType::FLOAT32: + evalFloat(); + break; + default: + throw std::runtime_error("luci-intp FullyConnected(2) Unsupported type."); + } } } @@ -188,5 +221,130 @@ void FullyConnected::evalQuantizedS8() const getTensorShape(output()), getTensorData(output())); } +void FullyConnected::evalHybridWI4AF32() const +{ + float activation_min{}; + float activation_max{}; + calculateActivationRange(_params.activation, &activation_min, &activation_max); + + tflite::FullyConnectedParams params{}; + params.float_activation_min = activation_min; + params.float_activation_max = activation_max; + params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault; + + const int8_t *weights_int4 = getTensorData(weights()); + float *weights_float = getTensorData(scratch()); + const Shape &weights_shape = weights()->shape(); + const auto weights_scales = weights()->scales(); + const auto weights_quantized_dimension = weights()->quantized_dimension(); + // Invariant for per-channel quantization of FC weights. + LUCI_INTERPRETER_CHECK(weights_quantized_dimension == 0); + + if (weights_scales.size() == 1) + { + // Per tensor + const auto scale = weights()->scale(); + for (int32_t i = 0; i < weights_shape.num_elements(); ++i) + { + weights_float[i] = scale * static_cast(weights_int4[i]); + } + } + else + { + // Per channel + const int32_t quant_dim_size = weights_shape.dim(weights_quantized_dimension); + + size_t outer_dims_size = 1; + size_t inner_dims_size = 1; + for (int i = 0; i < weights_quantized_dimension; ++i) + outer_dims_size *= weights_shape.dim(i); + for (int i = weights_quantized_dimension + 1; i < weights_shape.num_dims(); ++i) + inner_dims_size *= weights_shape.dim(i); + + for (size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it) + for (int32_t channel = 0; channel < quant_dim_size; ++channel) + { + float scale = weights_scales[channel]; + size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel); + for (size_t inner_it = 0; inner_it < inner_dims_size; ++inner_it) + { + LUCI_INTERPRETER_CHECK(offset + inner_it < + static_cast(weights_shape.num_elements())); + weights_float[offset + inner_it] = + scale * static_cast(weights_int4[offset + inner_it]); + } + } + } + + tflite::reference_ops::FullyConnected( + params, getTensorShape(input()), getTensorData(input()), getTensorShape(scratch()), + getTensorData(scratch()), getTensorShape(bias()), getTensorData(bias()), + getTensorShape(output()), getTensorData(output())); +} + +void FullyConnected::evalHybridWU4AF32() const +{ + float activation_min{}; + float activation_max{}; + calculateActivationRange(_params.activation, &activation_min, &activation_max); + + tflite::FullyConnectedParams params{}; + params.float_activation_min = activation_min; + params.float_activation_max = activation_max; + params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault; + + const auto *weights_uint4 = getTensorData(weights()); + auto *weights_float = getTensorData(scratch()); + const Shape &weights_shape = weights()->shape(); + const auto weights_scales = weights()->scales(); + const auto weights_zero_points = weights()->zero_points(); + const auto weights_quantized_dimension = weights()->quantized_dimension(); + LUCI_INTERPRETER_CHECK(weights_quantized_dimension == 0); + if (weights_scales.size() == 1) + { + // Per tensor + const auto scale = weights()->scale(); + const auto zero_point = weights()->zero_point(); + LUCI_INTERPRETER_CHECK(zero_point >= 0 and zero_point <= 15); + for (int32_t i = 0; i < weights_shape.num_elements(); ++i) + { + weights_float[i] = + scale * static_cast(static_cast(weights_uint4[i]) - zero_point); + } + } + else + { + // Per channel + const int32_t quant_dim_size = weights_shape.dim(weights_quantized_dimension); + + size_t outer_dims_size = 1; + size_t inner_dims_size = 1; + for (int i = 0; i < weights_quantized_dimension; ++i) + outer_dims_size *= weights_shape.dim(i); + for (int i = weights_quantized_dimension + 1; i < weights_shape.num_dims(); ++i) + inner_dims_size *= weights_shape.dim(i); + + for (size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it) + for (int32_t channel = 0; channel < quant_dim_size; ++channel) + { + int32_t zero_point = weights_zero_points[channel]; + LUCI_INTERPRETER_CHECK(zero_point >= 0 and zero_point <= 15); + float scale = weights_scales[channel]; + size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel); + for (size_t inner_it = 0; inner_it < inner_dims_size; ++inner_it) + { + weights_float[offset + inner_it] = + scale * + static_cast(static_cast(weights_uint4[offset + inner_it]) - zero_point); + } + } + } + + tflite::reference_ops::FullyConnected( + params, getTensorShape(input()), getTensorData(input()), getTensorShape(scratch()), + getTensorData(scratch()), getTensorShape(bias()), getTensorData(bias()), + getTensorShape(output()), getTensorData(output())); +} + } // namespace kernels } // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/FullyConnected.h b/compiler/luci-interpreter/src/kernels/FullyConnected.h index 2a7c068c0..860775e9b 100644 --- a/compiler/luci-interpreter/src/kernels/FullyConnected.h +++ b/compiler/luci-interpreter/src/kernels/FullyConnected.h @@ -30,11 +30,17 @@ class FullyConnected : public KernelWithParams public: FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output, const FullyConnectedParams ¶ms); - + FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output, + Tensor *scratch, const FullyConnectedParams ¶ms) + : FullyConnected(input, weights, bias, output, params) + { + _scratch = scratch; + } const Tensor *input() const { return _inputs[0]; } const Tensor *weights() const { return _inputs[1]; } const Tensor *bias() const { return _inputs[2]; } Tensor *output() const { return _outputs[0]; } + Tensor *scratch() const { return _scratch; } void configure() override; void execute() const override; @@ -43,6 +49,9 @@ private: void evalFloat() const; void evalQuantized() const; void evalQuantizedS8() const; + void evalHybridWI4AF32() const; + void evalHybridWU4AF32() const; + Tensor *_scratch = nullptr; }; } // namespace kernels diff --git a/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp b/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp index 4474cc4fb..ddb96932c 100644 --- a/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp +++ b/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp @@ -154,6 +154,370 @@ TYPED_TEST(FullyConnectedTest, Simple) }); } +TEST(FullyConnectedTest, SimpleS4PerTensor) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list output_shape{1, 4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 0, 1, // unit = 0 + 0, 0, // unit = 1 + -1, -1, // unit = 2 + 0, 0, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + std::initializer_list output_data{ + 1.5, 1, 0, 3, // batch = 0 + }; + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::S4, weights_shape, {{0.5}, {0}}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(int8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = + makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + const float quantized_tolerance = getTolerance(-8, 7, 15); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); + EXPECT_THAT(extractTensorData(output_tensor), + FloatArrayNear(output_data, quantized_tolerance)); +} + +TEST(FullyConnectedTest, SimpleS4PerChannel) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list output_shape{1, 4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 0, 1, // unit = 0 + 0, 0, // unit = 1 + -1, -1, // unit = 2 + 0, 0, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + std::initializer_list output_data{ + 1.5, 1, 0, 3, // batch = 0 + }; + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::S4, weights_shape, {{0.5, 0.5, 0.5, 0.5}, {0, 0, 0, 0}, 0}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(int8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = + makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + const float quantized_tolerance = getTolerance(-8, 7, 15); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); + EXPECT_THAT(extractTensorData(output_tensor), + FloatArrayNear(output_data, quantized_tolerance)); +} + +TEST(FullyConnectedTest, SimpleS4WrongBiasType_NEG) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 0, 1, // unit = 0 + 0, 0, // unit = 1 + -1, -1, // unit = 2 + 0, 0, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::S4, weights_shape, {{0.5}, {8}}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(int8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST(FullyConnectedTest, SimpleS4WrongInputType_NEG) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 0, 1, // unit = 0 + 0, 0, // unit = 1 + -1, -1, // unit = 2 + 0, 0, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::S4, weights_shape, {{0.5}, {8}}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(int8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = + makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST(FullyConnectedTest, SimpleU4PerTensor) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list output_shape{1, 4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 8, 9, // unit = 0 + 8, 8, // unit = 1 + 7, 7, // unit = 2 + 8, 8, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + std::initializer_list output_data{ + 1.5, 1, 0, 3, // batch = 0 + }; + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::U4, weights_shape, {{0.5}, {8}}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(uint8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = + makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + const float quantized_tolerance = getTolerance(0, 15, 15); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); + EXPECT_THAT(extractTensorData(output_tensor), + FloatArrayNear(output_data, quantized_tolerance)); +} + +TEST(FullyConnectedTest, SimpleU4PerChannel) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list output_shape{1, 4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 8, 9, // unit = 0 + 8, 8, // unit = 1 + 7, 7, // unit = 2 + 8, 8, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + std::initializer_list output_data{ + 1.5, 1, 0, 3, // batch = 0 + }; + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::U4, weights_shape, {{0.5, 0.5, 0.5, 0.5}, {8, 8, 8, 8}, 0}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(uint8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = + makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + const float quantized_tolerance = getTolerance(0, 15, 15); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); + EXPECT_THAT(extractTensorData(output_tensor), + FloatArrayNear(output_data, quantized_tolerance)); +} + +TEST(FullyConnectedTest, SimpleU4WrongBiasType_NEG) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list output_shape{1, 4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 8, 9, // unit = 0 + 8, 8, // unit = 1 + 7, 7, // unit = 2 + 8, 8, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + std::initializer_list output_data{ + 1.5, 1, 0, 3, // batch = 0 + }; + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::U4, weights_shape, {{0.5, 0.5}, {8, 8}, 1}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(uint8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + const float quantized_tolerance = getTolerance(0, 15, 15); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST(FullyConnectedTest, SimpleU4WrongInputType_NEG) +{ + std::initializer_list input_shape{1, 2}; + std::initializer_list weights_shape{4, 2}; + std::initializer_list bias_shape{4}; + std::initializer_list output_shape{1, 4}; + std::initializer_list input_data{ + 1, 3, // batch = 0 + }; + std::initializer_list weights_initializer{ + 8, 9, // unit = 0 + 8, 8, // unit = 1 + 7, 7, // unit = 2 + 8, 8, // unit = 3 + }; + std::initializer_list bias_data{0, 1, 2, 3}; + std::initializer_list output_data{ + 1.5, 1, 0, 3, // batch = 0 + }; + std::unique_ptr memory_manager = std::make_unique(); + + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + std::vector quantized_data(weights_initializer); + Tensor weights_tensor(DataType::U4, weights_shape, {{0.5, 0.5}, {8, 8}, 1}, ""); + memory_manager->allocate_memory(weights_tensor); + weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(uint8_t)); + Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, ""); + memory_manager->allocate_memory(weights_scratch); + + Tensor bias_tensor = + makeInputTensor(bias_shape, bias_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + const float quantized_tolerance = getTolerance(0, 15, 15); + + FullyConnectedParams params{}; + params.activation = Activation::RELU; + + FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor, + &weights_scratch, params); + EXPECT_ANY_THROW(kernel.configure()); +} + TEST(FullyConnectedTest, InvalidBiasType_NEG) { Shape input_shape{3, 2, 2, 1}; diff --git a/compiler/luci-interpreter/src/kernels/Gather.cpp b/compiler/luci-interpreter/src/kernels/Gather.cpp index f1256660f..c04e7f622 100644 --- a/compiler/luci-interpreter/src/kernels/Gather.cpp +++ b/compiler/luci-interpreter/src/kernels/Gather.cpp @@ -42,7 +42,7 @@ void Gather::configure() } else { - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Gather(1) Unsupported type."); } LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 || @@ -102,7 +102,7 @@ void Gather::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Gather(2) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Gelu.cpp b/compiler/luci-interpreter/src/kernels/Gelu.cpp index 44e018e0e..998bb5153 100644 --- a/compiler/luci-interpreter/src/kernels/Gelu.cpp +++ b/compiler/luci-interpreter/src/kernels/Gelu.cpp @@ -48,7 +48,7 @@ void Gelu::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Gelu Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Greater.cpp b/compiler/luci-interpreter/src/kernels/Greater.cpp index 5ccae3c38..e1bc32554 100644 --- a/compiler/luci-interpreter/src/kernels/Greater.cpp +++ b/compiler/luci-interpreter/src/kernels/Greater.cpp @@ -59,7 +59,7 @@ void Greater::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Greather Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp b/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp index 27e42c971..f39cbe01f 100644 --- a/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp +++ b/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp @@ -62,7 +62,7 @@ void GreaterEqual::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp GreaterEqual Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/HardSwish.cpp b/compiler/luci-interpreter/src/kernels/HardSwish.cpp index b1008459a..21d7d575f 100644 --- a/compiler/luci-interpreter/src/kernels/HardSwish.cpp +++ b/compiler/luci-interpreter/src/kernels/HardSwish.cpp @@ -44,7 +44,7 @@ void HardSwish::execute() const getTensorShape(output()), getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp HardSwish Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/InstanceNorm.cpp b/compiler/luci-interpreter/src/kernels/InstanceNorm.cpp index 22a329be6..685d4e514 100644 --- a/compiler/luci-interpreter/src/kernels/InstanceNorm.cpp +++ b/compiler/luci-interpreter/src/kernels/InstanceNorm.cpp @@ -55,7 +55,7 @@ void InstanceNorm::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp InstanceNorm Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/L2Normalize.cpp b/compiler/luci-interpreter/src/kernels/L2Normalize.cpp index 64222953f..ab7b27229 100644 --- a/compiler/luci-interpreter/src/kernels/L2Normalize.cpp +++ b/compiler/luci-interpreter/src/kernels/L2Normalize.cpp @@ -58,7 +58,7 @@ void L2Normalize::execute() const eval(input()->zero_point()); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp L2Normalize Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/L2Pool2D.cpp b/compiler/luci-interpreter/src/kernels/L2Pool2D.cpp index 5a88808d5..615381f71 100644 --- a/compiler/luci-interpreter/src/kernels/L2Pool2D.cpp +++ b/compiler/luci-interpreter/src/kernels/L2Pool2D.cpp @@ -80,7 +80,7 @@ void L2Pool2D::execute() const getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp L2Pool2D Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/LeakyRelu.cpp b/compiler/luci-interpreter/src/kernels/LeakyRelu.cpp index 3833a55e8..122b7ef12 100644 --- a/compiler/luci-interpreter/src/kernels/LeakyRelu.cpp +++ b/compiler/luci-interpreter/src/kernels/LeakyRelu.cpp @@ -59,7 +59,7 @@ void LeakyRelu::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp LeakyRelu Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Less.cpp b/compiler/luci-interpreter/src/kernels/Less.cpp index 8d26ff297..707f741db 100644 --- a/compiler/luci-interpreter/src/kernels/Less.cpp +++ b/compiler/luci-interpreter/src/kernels/Less.cpp @@ -59,7 +59,7 @@ void Less::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Less Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/LessEqual.cpp b/compiler/luci-interpreter/src/kernels/LessEqual.cpp index b474bc47a..a1a9ff891 100644 --- a/compiler/luci-interpreter/src/kernels/LessEqual.cpp +++ b/compiler/luci-interpreter/src/kernels/LessEqual.cpp @@ -59,7 +59,7 @@ void LessEqual::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp LessEqual Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/LocalResponseNormalization.cpp b/compiler/luci-interpreter/src/kernels/LocalResponseNormalization.cpp index a2bf442b0..64cf99b2b 100644 --- a/compiler/luci-interpreter/src/kernels/LocalResponseNormalization.cpp +++ b/compiler/luci-interpreter/src/kernels/LocalResponseNormalization.cpp @@ -57,7 +57,7 @@ void LocalResponseNormalization::execute() const getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp LocalResponseNormalizartion Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Log.cpp b/compiler/luci-interpreter/src/kernels/Log.cpp index fa5f90e66..f69caaa65 100644 --- a/compiler/luci-interpreter/src/kernels/Log.cpp +++ b/compiler/luci-interpreter/src/kernels/Log.cpp @@ -37,7 +37,7 @@ void Log::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Log Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/LogSoftmax.cpp b/compiler/luci-interpreter/src/kernels/LogSoftmax.cpp index 79c315338..b577cee05 100644 --- a/compiler/luci-interpreter/src/kernels/LogSoftmax.cpp +++ b/compiler/luci-interpreter/src/kernels/LogSoftmax.cpp @@ -57,7 +57,7 @@ void LogSoftmax::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp LogSoftmax Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/LogicalAnd.cpp b/compiler/luci-interpreter/src/kernels/LogicalAnd.cpp index 8e7263231..2baaafb6f 100644 --- a/compiler/luci-interpreter/src/kernels/LogicalAnd.cpp +++ b/compiler/luci-interpreter/src/kernels/LogicalAnd.cpp @@ -46,7 +46,7 @@ void LogicalAnd::execute() const evalLogicalAnd(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp LogicalAnd Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/LogicalNot.cpp b/compiler/luci-interpreter/src/kernels/LogicalNot.cpp index 65ab961aa..014ce2d56 100644 --- a/compiler/luci-interpreter/src/kernels/LogicalNot.cpp +++ b/compiler/luci-interpreter/src/kernels/LogicalNot.cpp @@ -41,7 +41,7 @@ void LogicalNot::execute() const evalLogicalNot(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp LogicalNot Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Logistic.cpp b/compiler/luci-interpreter/src/kernels/Logistic.cpp index 58e4f185d..fc69aeb6e 100644 --- a/compiler/luci-interpreter/src/kernels/Logistic.cpp +++ b/compiler/luci-interpreter/src/kernels/Logistic.cpp @@ -49,7 +49,7 @@ void Logistic::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Logistic Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/MaxPool2D.cpp b/compiler/luci-interpreter/src/kernels/MaxPool2D.cpp index 8d9760ff2..a105a27b7 100644 --- a/compiler/luci-interpreter/src/kernels/MaxPool2D.cpp +++ b/compiler/luci-interpreter/src/kernels/MaxPool2D.cpp @@ -81,7 +81,7 @@ void MaxPool2D::execute() const evalSInt16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp MaxPool2D Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Maximum.cpp b/compiler/luci-interpreter/src/kernels/Maximum.cpp index b102b5e27..e384ce8a6 100644 --- a/compiler/luci-interpreter/src/kernels/Maximum.cpp +++ b/compiler/luci-interpreter/src/kernels/Maximum.cpp @@ -49,7 +49,7 @@ void Maximum::execute() const evalMaximum(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Maximum Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Mean.cpp b/compiler/luci-interpreter/src/kernels/Mean.cpp index 8e65e0d6d..3d321ccbc 100644 --- a/compiler/luci-interpreter/src/kernels/Mean.cpp +++ b/compiler/luci-interpreter/src/kernels/Mean.cpp @@ -190,7 +190,7 @@ void Mean::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Mean Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Minimum.cpp b/compiler/luci-interpreter/src/kernels/Minimum.cpp index 5d3dcde72..ae8228312 100644 --- a/compiler/luci-interpreter/src/kernels/Minimum.cpp +++ b/compiler/luci-interpreter/src/kernels/Minimum.cpp @@ -49,7 +49,7 @@ void Minimum::execute() const evalMinimum(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Minimum Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/MirrorPad.cpp b/compiler/luci-interpreter/src/kernels/MirrorPad.cpp index bae1eac70..c0d23882d 100644 --- a/compiler/luci-interpreter/src/kernels/MirrorPad.cpp +++ b/compiler/luci-interpreter/src/kernels/MirrorPad.cpp @@ -82,7 +82,7 @@ void MirrorPad::execute() const break; } default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp MirrorPad Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Mul.cpp b/compiler/luci-interpreter/src/kernels/Mul.cpp index 531fb4fa1..43e3ae61c 100644 --- a/compiler/luci-interpreter/src/kernels/Mul.cpp +++ b/compiler/luci-interpreter/src/kernels/Mul.cpp @@ -68,7 +68,7 @@ void Mul::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Mul Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Neg.cpp b/compiler/luci-interpreter/src/kernels/Neg.cpp index c6fe08a9e..fd72f2b98 100644 --- a/compiler/luci-interpreter/src/kernels/Neg.cpp +++ b/compiler/luci-interpreter/src/kernels/Neg.cpp @@ -44,7 +44,7 @@ void Neg::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Neg Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/NotEqual.cpp b/compiler/luci-interpreter/src/kernels/NotEqual.cpp index 54e5eee34..4e0822cb2 100644 --- a/compiler/luci-interpreter/src/kernels/NotEqual.cpp +++ b/compiler/luci-interpreter/src/kernels/NotEqual.cpp @@ -58,8 +58,11 @@ void NotEqual::execute() const case DataType::U8: evalQuantized(); break; + case DataType::BOOL: + evalBool(); + break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp NotEqual Unsupported type."); } } @@ -138,5 +141,28 @@ void NotEqual::evalQuantized() const } } +void NotEqual::evalBool() const +{ + const auto x_data = getTensorData(x()); + const auto y_data = getTensorData(y()); + auto output_data = getTensorData(output()); + + tflite::ComparisonParams op_params; + op_params.is_broadcast = x()->shape() != y()->shape(); + + if (op_params.is_broadcast) + { + tflite::reference_ops::Broadcast4DSlowNotEqualNoScaling(op_params, getTensorShape(x()), x_data, + getTensorShape(y()), y_data, + getTensorShape(output()), output_data); + } + else + { + tflite::reference_ops::NotEqualNoScaling(op_params, getTensorShape(x()), x_data, + getTensorShape(y()), y_data, getTensorShape(output()), + output_data); + } +} + } // namespace kernels } // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/NotEqual.h b/compiler/luci-interpreter/src/kernels/NotEqual.h index d2aafe893..221d2c26b 100644 --- a/compiler/luci-interpreter/src/kernels/NotEqual.h +++ b/compiler/luci-interpreter/src/kernels/NotEqual.h @@ -40,6 +40,7 @@ private: void evalFloat() const; template void evalInteger() const; void evalQuantized() const; + void evalBool() const; private: int32_t _x_multiplier = 0; diff --git a/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp b/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp index 45bf4022a..7a6d0a3ec 100644 --- a/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp +++ b/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp @@ -99,6 +99,36 @@ TEST_F(NotEqualTest, FloatBroardcast) EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3})); } +TEST_F(NotEqualTest, BoolSimple) +{ + std::vector x_data{ + true, false, false, // Row 1 + false, true, true, // Row 2 + }; + + std::vector y_data{ + false, false, true, // Row 1 + true, true, false, // Row 2 + }; + + std::vector ref_output_data{ + true, false, true, // Row 1 + true, false, true, // Row 2 + }; + + Tensor x_tensor = makeInputTensor({2, 3}, x_data, _memory_manager.get()); + Tensor y_tensor = makeInputTensor({2, 3}, y_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::BOOL); + + NotEqual kernel(&x_tensor, &y_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 3})); +} + template void checkIntegerSimple(luci_interpreter::IMemoryManager *memory_manager) { @@ -281,6 +311,17 @@ TEST_F(NotEqualTest, Float_Broadcast_NEG) ASSERT_ANY_THROW(kernel.configure()); } +TEST_F(NotEqualTest, Bool_Broadcast_NEG) +{ + Tensor x_tensor = makeInputTensor({2}, {true, true}, _memory_manager.get()); + Tensor y_tensor = + makeInputTensor({3}, {true, false, false}, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::BOOL); + + NotEqual kernel(&x_tensor, &y_tensor, &output_tensor); + ASSERT_ANY_THROW(kernel.configure()); +} + TEST_F(NotEqualTest, Int32_Broadcast_NEG) { Tensor x_tensor = makeInputTensor({2}, {1, 2}, _memory_manager.get()); diff --git a/compiler/luci-interpreter/src/kernels/PRelu.cpp b/compiler/luci-interpreter/src/kernels/PRelu.cpp index 5a6b05c3a..5a5fba4a4 100644 --- a/compiler/luci-interpreter/src/kernels/PRelu.cpp +++ b/compiler/luci-interpreter/src/kernels/PRelu.cpp @@ -103,7 +103,7 @@ void PRelu::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp PRelu Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Pack.cpp b/compiler/luci-interpreter/src/kernels/Pack.cpp index 42aab330c..8ba6cbac1 100644 --- a/compiler/luci-interpreter/src/kernels/Pack.cpp +++ b/compiler/luci-interpreter/src/kernels/Pack.cpp @@ -48,7 +48,7 @@ void Pack::configure() t0->element_type() != DataType::U8 && t0->element_type() != DataType::S8 && t0->element_type() != DataType::S16 && t0->element_type() != DataType::S64) { - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Pack(1) Unsupported type."); } for (uint32_t i = 1; i < _inputs.size(); ++i) @@ -116,7 +116,7 @@ void Pack::execute() const evalGeneric(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Pack(2) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Pad.cpp b/compiler/luci-interpreter/src/kernels/Pad.cpp index c07f6e310..b8d600b7f 100644 --- a/compiler/luci-interpreter/src/kernels/Pad.cpp +++ b/compiler/luci-interpreter/src/kernels/Pad.cpp @@ -106,7 +106,7 @@ void Pad::execute() const break; } default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Pad Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/PadV2.cpp b/compiler/luci-interpreter/src/kernels/PadV2.cpp index 197cdaa69..d75e572a5 100644 --- a/compiler/luci-interpreter/src/kernels/PadV2.cpp +++ b/compiler/luci-interpreter/src/kernels/PadV2.cpp @@ -100,7 +100,7 @@ void PadV2::execute() const break; } default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp PadV2 Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Pow.cpp b/compiler/luci-interpreter/src/kernels/Pow.cpp index 722c64024..519f1b6ea 100644 --- a/compiler/luci-interpreter/src/kernels/Pow.cpp +++ b/compiler/luci-interpreter/src/kernels/Pow.cpp @@ -50,7 +50,7 @@ void Pow::execute() const eval(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Pow Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Quantize.cpp b/compiler/luci-interpreter/src/kernels/Quantize.cpp index 0c8544a65..70f4f77a2 100644 --- a/compiler/luci-interpreter/src/kernels/Quantize.cpp +++ b/compiler/luci-interpreter/src/kernels/Quantize.cpp @@ -132,7 +132,7 @@ void Quantize::execute() const break; } default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Quantize(1) Unsupported type."); } break; } @@ -152,7 +152,7 @@ void Quantize::execute() const break; } default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Quantize(2) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/ReduceMax.cpp b/compiler/luci-interpreter/src/kernels/ReduceMax.cpp index d58cd1563..8456afd0a 100644 --- a/compiler/luci-interpreter/src/kernels/ReduceMax.cpp +++ b/compiler/luci-interpreter/src/kernels/ReduceMax.cpp @@ -149,9 +149,12 @@ void ReduceMax::execute() const case DataType::FLOAT32: evalFloat(); break; + case DataType::BOOL: + evalBool(); + break; // TODO Support quantized kernels default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp ReduceMax Unsupported type."); } } @@ -177,5 +180,27 @@ void ReduceMax::evalFloat() const [](const float current, const float in) -> float { return (in > current) ? in : current; }); } +void ReduceMax::evalBool() const +{ + const auto *axes_data = getTensorData(axes()); + int num_axes = axes()->shape().num_elements(); + + auto temp_index = getOutputTensors()[1]; + auto resolved_axes = getOutputTensors()[2]; + + int num_resolved_axis = 0; + LUCI_INTERPRETER_CHECK( + tflite::reference_ops::ResolveAxis(input()->shape().num_dims(), axes_data, num_axes, + getTensorData(resolved_axes), &num_resolved_axis)); + + bool init_value = std::numeric_limits::lowest(); + tflite::reference_ops::ReduceGeneric( + getTensorData(input()), getTensorShape(input()).DimsData(), input()->shape().num_dims(), + getTensorData(output()), getTensorShape(output()).DimsData(), + output()->shape().num_dims(), axes_data, num_axes, _params.keep_dims, + getTensorData(temp_index), getTensorData(resolved_axes), init_value, + [](const bool current, const bool in) -> bool { return (in > current) ? in : current; }); +} + } // namespace kernels } // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/ReduceMax.h b/compiler/luci-interpreter/src/kernels/ReduceMax.h index 25a66278a..f512f66bb 100644 --- a/compiler/luci-interpreter/src/kernels/ReduceMax.h +++ b/compiler/luci-interpreter/src/kernels/ReduceMax.h @@ -42,6 +42,7 @@ public: private: void evalFloat() const; + void evalBool() const; }; } // namespace kernels diff --git a/compiler/luci-interpreter/src/kernels/ReduceMax.test.cpp b/compiler/luci-interpreter/src/kernels/ReduceMax.test.cpp index ab688827b..6c41c39db 100644 --- a/compiler/luci-interpreter/src/kernels/ReduceMax.test.cpp +++ b/compiler/luci-interpreter/src/kernels/ReduceMax.test.cpp @@ -98,6 +98,68 @@ TEST_F(ReduceMaxTest, FloatKeepDims) EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape)); } +TEST_F(ReduceMaxTest, BoolNotKeepDims) +{ + std::vector input_data = {true, true, false, false, true, false, false, true, + true, true, false, false, true, true, false, true, + true, false, true, false, true, false, false, true}; + + std::vector axis_data{1, 0, -3, -3}; + Tensor input_tensor = + makeInputTensor({4, 3, 2}, input_data, _memory_manager.get()); + Tensor axis_tensor = makeInputTensor({4}, axis_data, _memory_manager.get()); + Tensor temp_index(DataType::S32, Shape({}), {}, ""); + Tensor resolved_axes(DataType::S32, Shape({}), {}, ""); + Tensor output_tensor = makeOutputTensor(DataType::BOOL); + + ReducerParams params{}; + params.keep_dims = false; + + ReduceMax kernel(&input_tensor, &axis_tensor, &output_tensor, &temp_index, &resolved_axes, + params); + kernel.configure(); + _memory_manager->allocate_memory(temp_index); + _memory_manager->allocate_memory(resolved_axes); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + std::vector ref_output_data{true, true}; + std::initializer_list ref_output_shape{2}; + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape)); +} + +TEST_F(ReduceMaxTest, BoolKeepDims) +{ + std::vector input_data = {true, true, false, false, true, false, false, true, + true, true, false, false, true, true, false, true, + true, false, true, false, true, false, false, true}; + + std::vector axis_data{0, 2}; + Tensor input_tensor = + makeInputTensor({4, 3, 2}, input_data, _memory_manager.get()); + Tensor axis_tensor = makeInputTensor({2}, axis_data, _memory_manager.get()); + Tensor temp_index(DataType::S32, Shape({}), {}, ""); + Tensor resolved_axes(DataType::S32, Shape({}), {}, ""); + Tensor output_tensor = makeOutputTensor(DataType::BOOL); + + ReducerParams params{}; + params.keep_dims = true; + + ReduceMax kernel(&input_tensor, &axis_tensor, &output_tensor, &temp_index, &resolved_axes, + params); + kernel.configure(); + _memory_manager->allocate_memory(temp_index); + _memory_manager->allocate_memory(resolved_axes); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + std::vector ref_output_data{true, true, true}; + std::initializer_list ref_output_shape{1, 3, 1}; + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape)); +} + } // namespace } // namespace kernels } // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/ReduceProd.cpp b/compiler/luci-interpreter/src/kernels/ReduceProd.cpp index f3fc7d3f1..24c4780b8 100644 --- a/compiler/luci-interpreter/src/kernels/ReduceProd.cpp +++ b/compiler/luci-interpreter/src/kernels/ReduceProd.cpp @@ -150,7 +150,7 @@ void ReduceProd::execute() const break; // TODO Support quantized kernels default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp ReduceProd Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Relu.cpp b/compiler/luci-interpreter/src/kernels/Relu.cpp index 747ec6cc8..b88099f4b 100644 --- a/compiler/luci-interpreter/src/kernels/Relu.cpp +++ b/compiler/luci-interpreter/src/kernels/Relu.cpp @@ -59,7 +59,7 @@ void Relu::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Relu Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Relu0To1.cpp b/compiler/luci-interpreter/src/kernels/Relu0To1.cpp new file mode 100644 index 000000000..0488894f0 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Relu0To1.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Relu0To1.h" +#include "kernels/Utils.h" + +#include "PALRelu0To1.h" + +#include + +namespace luci_interpreter +{ + +namespace kernels +{ + +Relu0To1::Relu0To1(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {} + +void Relu0To1::configure() +{ + LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type()); + + if (input()->element_type() == DataType::U8) + { + double multiplier = input()->scale() / output()->scale(); + quantizeMultiplier(multiplier, &_output_multiplier, &_output_shift); + } + output()->resize(input()->shape()); +} + +void Relu0To1::execute() const +{ + switch (input()->element_type()) + { + case DataType::FLOAT32: + evalFloat(); + break; + case DataType::U8: + evalQuantized(); + break; + default: + throw std::runtime_error("luci-intp Relu0To1 Unsupported type."); + } +} + +void Relu0To1::evalFloat() const +{ + const auto input_data = getTensorData(input()); + const auto input_shape = getTensorShape(input()); + auto output_data = getTensorData(output()); + auto output_shape = getTensorShape(output()); + + luci_interpreter_pal::Relu0To1(input_shape, input_data, output_shape, output_data); +} + +void Relu0To1::evalQuantized() const +{ + tflite::ReluParams params; + params.input_offset = input()->zero_point(); + params.output_offset = output()->zero_point(); + params.output_multiplier = _output_multiplier; + params.output_shift = _output_shift; + + params.quantized_activation_min = + std::max(static_cast(std::numeric_limits::min()), params.output_offset); + params.quantized_activation_max = + std::min(static_cast(std::numeric_limits::max()), + params.output_offset + static_cast(roundf(1.f / output()->scale()))); + + luci_interpreter_pal::ReluX(params, getTensorShape(input()), getTensorData(input()), + getTensorShape(output()), getTensorData(output())); +} + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Relu0To1.h b/compiler/luci-interpreter/src/kernels/Relu0To1.h new file mode 100644 index 000000000..ae481a72b --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Relu0To1.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_KERNELS_RELU0TO1_H +#define LUCI_INTERPRETER_KERNELS_RELU0TO1_H + +#include "core/Kernel.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +class Relu0To1 : public Kernel +{ +public: + Relu0To1(const Tensor *input, Tensor *output); + + const Tensor *input() const { return _inputs[0]; } + Tensor *output() const { return _outputs[0]; } + + void configure() override; + void execute() const override; + +private: + void evalFloat() const; + void evalQuantized() const; + +private: + int32_t _output_multiplier{0}; + int32_t _output_shift{0}; +}; + +} // namespace kernels +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_KERNELS_RELU0TO1_H diff --git a/compiler/luci-interpreter/src/kernels/Relu0To1.test.cpp b/compiler/luci-interpreter/src/kernels/Relu0To1.test.cpp new file mode 100644 index 000000000..61dc3b0b2 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Relu0To1.test.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2021 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Relu0To1.h" +#include "kernels/TestUtils.h" +#include "luci_interpreter/TestMemoryManager.h" + +namespace luci_interpreter +{ +namespace kernels +{ +namespace +{ + +using namespace testing; + +class Relu0To1Test : public ::testing::Test +{ +protected: + void SetUp() override { _memory_manager = std::make_unique(); } + + std::unique_ptr _memory_manager; +}; + +TEST_F(Relu0To1Test, FloatSimple) +{ + std::vector input_data{ + 0.0f, 0.5f, 0.1f, // Row 1 + 2.0f, -1.0f, -2.0f, // Row 2 + }; + + std::vector ref_output_data{ + 0.0f, 0.5f, 0.1f, // Row 1 + 1.0f, 0.0f, 0.0f, // Row 2 + }; + + Tensor input_tensor = + makeInputTensor({2, 3}, input_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Relu0To1 kernel(&input_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), + ::testing::ElementsAreArray(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 3})); +} + +TEST_F(Relu0To1Test, Uint8Quantized) +{ + // Choose min / max in such a way that there are exactly 256 units to avoid rounding errors. + const float f_min = (-128.0 / 128.0) * 10; + const float f_max = (127.0 / 128.0) * 10; + const float tolerance = (f_max - f_min) / 255.0; + + std::vector input_data{ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }; + + std::pair quant_param = quantizationParams(f_min, f_max); + Tensor input_tensor = makeInputTensor( + {1, 2, 4, 1}, quant_param.first, quant_param.second, input_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second); + + Relu0To1 kernel(&input_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 4, 1})); + EXPECT_THAT(extractTensorData(output_tensor), + ::testing::ElementsAreArray({128, 128, 131, 128, 132, 128, 141, 128})); + EXPECT_THAT(dequantizeTensorData(output_tensor), + FloatArrayNear({0, 0, 0.2, 0, 0.3, 0, 1.0, 0}, tolerance)); +} + +TEST_F(Relu0To1Test, Uint8Requantized) +{ + // Choose min / max in such a way that there are exactly 256 units to avoid rounding errors. + const float in_min = (-128.0 / 128.0) * 10; + const float in_max = (127.0 / 128.0) * 10; + const float out_min = (0.0 / 256.0) * 0; + const float out_max = (255.0 / 256.0) * 1; + const float tolerance = (in_max - in_min) / 255.0; + + std::vector input_data{ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }; + + std::pair quant_input = quantizationParams(in_min, in_max); + Tensor input_tensor = makeInputTensor( + {1, 2, 4, 1}, quant_input.first, quant_input.second, input_data, _memory_manager.get()); + + std::pair quant_output = quantizationParams(out_min, out_max); + Tensor output_tensor = makeOutputTensor(DataType::U8, quant_output.first, quant_output.second); + + Relu0To1 kernel(&input_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 4, 1})); + EXPECT_THAT(extractTensorData(output_tensor), + ::testing::ElementsAreArray({0, 0, 60, 0, 80, 0, 255, 0})); + EXPECT_THAT(dequantizeTensorData(output_tensor), + FloatArrayNear({0, 0, 0.2, 0, 0.3, 0, 1.0, 0}, tolerance)); +} + +TEST_F(Relu0To1Test, Input_Output_Type_NEG) +{ + Tensor input_tensor = makeInputTensor({1}, {1.f}, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::U8); + + Relu0To1 kernel(&input_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST_F(Relu0To1Test, Invalid_Input_Type_NEG) +{ + Tensor input_tensor = makeInputTensor({1}, {1}, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::S64); + + Relu0To1 kernel(&input_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + EXPECT_ANY_THROW(kernel.execute()); +} + +} // namespace +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Relu6.cpp b/compiler/luci-interpreter/src/kernels/Relu6.cpp index 07205ed3a..06762fe09 100644 --- a/compiler/luci-interpreter/src/kernels/Relu6.cpp +++ b/compiler/luci-interpreter/src/kernels/Relu6.cpp @@ -52,7 +52,7 @@ void Relu6::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Relu6 Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/ResizeBilinear.cpp b/compiler/luci-interpreter/src/kernels/ResizeBilinear.cpp index e2ddd6a7b..8ffc57b30 100644 --- a/compiler/luci-interpreter/src/kernels/ResizeBilinear.cpp +++ b/compiler/luci-interpreter/src/kernels/ResizeBilinear.cpp @@ -66,7 +66,7 @@ void ResizeBilinear::execute() const getTensorData(size()), getTensorShape(output()), getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp ResizeBilinear Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.cpp b/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.cpp index 306cefbc2..90c6f9810 100644 --- a/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.cpp +++ b/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.cpp @@ -66,7 +66,7 @@ void ResizeNearestNeighbor::execute() const getTensorData(size()), getTensorShape(output()), getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp ResizeNearestNeighbor Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Rsqrt.cpp b/compiler/luci-interpreter/src/kernels/Rsqrt.cpp index 6dd92dc98..768fcfffe 100644 --- a/compiler/luci-interpreter/src/kernels/Rsqrt.cpp +++ b/compiler/luci-interpreter/src/kernels/Rsqrt.cpp @@ -46,7 +46,7 @@ void Rsqrt::execute() const break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Rsqrt Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/SVDF.cpp b/compiler/luci-interpreter/src/kernels/SVDF.cpp index b124e242c..62c479333 100644 --- a/compiler/luci-interpreter/src/kernels/SVDF.cpp +++ b/compiler/luci-interpreter/src/kernels/SVDF.cpp @@ -84,7 +84,7 @@ void SVDF::configure() } else { - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp SVDF Unsupported type."); } // Check all the parameters of tensor match within themselves and match the diff --git a/compiler/luci-interpreter/src/kernels/Select.cpp b/compiler/luci-interpreter/src/kernels/Select.cpp index b4ab5f621..0ce24db30 100644 --- a/compiler/luci-interpreter/src/kernels/Select.cpp +++ b/compiler/luci-interpreter/src/kernels/Select.cpp @@ -33,8 +33,6 @@ namespace kernels Select::Select(const Tensor *condition, const Tensor *t, const Tensor *e, Tensor *output) : Kernel({condition, t, e}, {output}) { - // NOTE _requires_broadcast is for SelectV2 - _requires_broadcast = false; _has_low_rank_input_condition = false; } @@ -64,7 +62,7 @@ void Select::execute() const evalFloat(); break; default: - throw std::runtime_error("Select: unsupported type."); + throw std::runtime_error("luci-intp Select unsupported type."); } } @@ -84,11 +82,6 @@ void Select::evalFloat() const tflite::reference_ops::RankOneSelect(condition_shape, condition_data, t_shape, t_data, e_shape, e_data, output_shape, output_data); } - else if (_requires_broadcast) - { - // TODO support broadcast kernel when upgrade to TF2.10.x or above - assert(false); - } else { tflite::reference_ops::Select(condition_shape, condition_data, t_shape, t_data, e_shape, e_data, diff --git a/compiler/luci-interpreter/src/kernels/Select.h b/compiler/luci-interpreter/src/kernels/Select.h index d67b4f5fc..a378ad5ef 100644 --- a/compiler/luci-interpreter/src/kernels/Select.h +++ b/compiler/luci-interpreter/src/kernels/Select.h @@ -42,8 +42,6 @@ private: void evalFloat() const; private: - // for SelectV2 - bool _requires_broadcast = false; // True if input condition is scalar or input condition has rank one and // matches the first dimension of other inputs. bool _has_low_rank_input_condition = false; diff --git a/compiler/luci-interpreter/src/kernels/SelectV2.cpp b/compiler/luci-interpreter/src/kernels/SelectV2.cpp new file mode 100644 index 000000000..f56b3c30a --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/SelectV2.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/SelectV2.h" +#include "kernels/Utils.h" + +#include +// TODO use select.h when version up +// #include + +#include + +namespace luci_interpreter +{ +namespace kernels +{ + +SelectV2::SelectV2(const Tensor *condition, const Tensor *t, const Tensor *e, Tensor *output) + : Kernel({condition, t, e}, {output}) +{ +} + +void SelectV2::configure() +{ + LUCI_INTERPRETER_CHECK(condition()->element_type() == DataType::BOOL); + LUCI_INTERPRETER_CHECK(t()->element_type() == e()->element_type()); + LUCI_INTERPRETER_CHECK(t()->element_type() == output()->element_type()); + + auto cond_shape = condition()->shape(); + auto t_shape = t()->shape(); + auto e_shape = e()->shape(); + + output()->resize( + calculateShapeForBroadcast(cond_shape, calculateShapeForBroadcast(t_shape, e_shape))); +} + +void SelectV2::execute() const +{ + auto t_type = t()->element_type(); + switch (t_type) + { + case DataType::FLOAT32: + evaluate(); + break; + case DataType::S32: + evaluate(); + break; + case DataType::S64: + evaluate(); + break; + default: + throw std::runtime_error("luci-intp SelectV2 unsupported type."); + } +} + +template void SelectV2::evaluate() const +{ + const auto condition_shape = getTensorShape(condition()); + const auto condition_data = getTensorData(condition()); + const auto t_shape = getTensorShape(t()); + const auto t_data = getTensorData(t()); + const auto e_shape = getTensorShape(e()); + const auto e_data = getTensorData(e()); + const auto output_shape = getTensorShape(output()); + auto output_data = getTensorData(output()); + + tflite::reference_ops::BroadcastSelect5DSlow( + condition_shape, condition_data, t_shape, t_data, e_shape, e_data, output_shape, output_data); +} + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/SelectV2.h b/compiler/luci-interpreter/src/kernels/SelectV2.h new file mode 100644 index 000000000..3c73d94aa --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/SelectV2.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_KERNELS_SELECTV2_H +#define LUCI_INTERPRETER_KERNELS_SELECTV2_H + +#include "core/Kernel.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +class SelectV2 : public Kernel +{ +public: + SelectV2(const Tensor *cond, const Tensor *t, const Tensor *e, Tensor *output); + + const Tensor *condition() const { return _inputs[0]; } + const Tensor *t() const { return _inputs[1]; } + const Tensor *e() const { return _inputs[2]; } + Tensor *output() const { return _outputs[0]; } + + void configure() override; + void execute() const override; + +private: + template void evaluate() const; +}; + +} // namespace kernels +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_KERNELS_SELECTV2_H diff --git a/compiler/luci-interpreter/src/kernels/SelectV2.test.cpp b/compiler/luci-interpreter/src/kernels/SelectV2.test.cpp new file mode 100644 index 000000000..66809ceab --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/SelectV2.test.cpp @@ -0,0 +1,204 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/SelectV2.h" +#include "kernels/TestUtils.h" + +#include "luci_interpreter/TestMemoryManager.h" + +namespace luci_interpreter +{ +namespace kernels +{ +namespace +{ + +using namespace testing; + +class SelectV2Test : public ::testing::Test +{ +protected: + void SetUp() override { _memory_manager = std::make_unique(); } + + std::unique_ptr _memory_manager; +}; + +std::vector c_data_single{0}; + +std::vector c_data{ + 1, 1, 1, // Row 1 + 0, 0, 0, // Row 2 +}; + +std::vector f32t_data_single{-0.5}; + +std::vector f32t_data{ + 0.5, 0.7, 0.9, // Row 1 + 1, 0, -1, // Row 2 +}; + +std::vector f32e_data{ + 0.9, 0.7, 0.5, // Row 1 + -1, 0, 1, // Row 2 +}; + +std::vector ref_f32o_data{ + 0.5, 0.7, 0.9, // Row 1 + -1, 0, 1, // Row 2 +}; + +std::vector ref_broadcast_f32o_data{ + -0.5, -0.5, -0.5, // Row 1 + 0.9, 0.7, 0.5, // Row 2 + -0.5, -0.5, -0.5, // Row 3 + -1, 0, 1, // Row 4 +}; + +std::vector i32t_data_single{2}; + +std::vector i32t_data{ + 5, -7, 9, // Row 1 + 1, 0, -1, // Row 2 +}; + +std::vector i32e_data{ + 9, 7, -5, // Row 1 + -1, 0, 1, // Row 2 +}; + +std::vector ref_i32o_data{ + 5, -7, 9, // Row 1 + -1, 0, 1, // Row 2 +}; + +std::vector ref_broadcast_i32o_data{ + 2, 2, 2, // Row 1 + 9, 7, -5, // Row 2 + 2, 2, 2, // Row 3 + -1, 0, 1, // Row 4 +}; + +TEST_F(SelectV2Test, FloatSimple) +{ + Tensor c_tensor = makeInputTensor({2, 3}, c_data, _memory_manager.get()); + Tensor t_tensor = makeInputTensor({2, 3}, f32t_data, _memory_manager.get()); + Tensor e_tensor = makeInputTensor({2, 3}, f32e_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + SelectV2 kernel(&c_tensor, &t_tensor, &e_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), ::testing::ElementsAreArray(ref_f32o_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 3})); +} + +TEST_F(SelectV2Test, FloatBroadcast4D) +{ + Tensor c_tensor = makeInputTensor({1, 2, 3, 1}, c_data, _memory_manager.get()); + Tensor t_tensor = + makeInputTensor({1}, f32t_data_single, _memory_manager.get()); + Tensor e_tensor = + makeInputTensor({2, 1, 3, 1}, f32e_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + SelectV2 kernel(&c_tensor, &t_tensor, &e_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), + ::testing::ElementsAreArray(ref_broadcast_f32o_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 2, 3, 1})); +} + +TEST_F(SelectV2Test, Int32Simple) +{ + Tensor c_tensor = makeInputTensor({2, 3}, c_data, _memory_manager.get()); + Tensor t_tensor = makeInputTensor({2, 3}, i32t_data, _memory_manager.get()); + Tensor e_tensor = makeInputTensor({2, 3}, i32e_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::S32); + + SelectV2 kernel(&c_tensor, &t_tensor, &e_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), + ::testing::ElementsAreArray(ref_i32o_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 3})); +} + +TEST_F(SelectV2Test, Int32Broadcast4D) +{ + Tensor c_tensor = makeInputTensor({1, 2, 3, 1}, c_data, _memory_manager.get()); + Tensor t_tensor = makeInputTensor({1}, i32t_data_single, _memory_manager.get()); + Tensor e_tensor = makeInputTensor({2, 1, 3, 1}, i32e_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::S32); + + SelectV2 kernel(&c_tensor, &t_tensor, &e_tensor, &output_tensor); + kernel.configure(); + _memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + EXPECT_THAT(extractTensorData(output_tensor), + ::testing::ElementsAreArray(ref_broadcast_i32o_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 2, 3, 1})); +} + +TEST_F(SelectV2Test, Invalid_C_Type_NEG) +{ + std::vector i_c_data{ + 1, 1, 1, // Row 1 + 0, 0, 0, // Row 2 + }; + + Tensor c_tensor = makeInputTensor({2, 3}, i_c_data, _memory_manager.get()); + Tensor t_tensor = makeInputTensor({2, 3}, f32t_data, _memory_manager.get()); + Tensor e_tensor = makeInputTensor({2, 3}, f32e_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + SelectV2 kernel(&c_tensor, &t_tensor, &e_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST_F(SelectV2Test, Invalid_O_Type_NEG) +{ + Tensor c_tensor = makeInputTensor({2, 3}, c_data, _memory_manager.get()); + Tensor t_tensor = makeInputTensor({2, 3}, f32t_data, _memory_manager.get()); + Tensor e_tensor = makeInputTensor({2, 3}, f32e_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::BOOL); + + SelectV2 kernel(&c_tensor, &t_tensor, &e_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST_F(SelectV2Test, MixedType_NEG) +{ + Tensor c_tensor = makeInputTensor({2, 3}, c_data, _memory_manager.get()); + Tensor t_tensor = makeInputTensor({2, 3}, i32t_data, _memory_manager.get()); + Tensor e_tensor = makeInputTensor({2, 3}, f32e_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + SelectV2 kernel(&c_tensor, &t_tensor, &e_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +} // namespace +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Shape.cpp b/compiler/luci-interpreter/src/kernels/Shape.cpp index 0429fe1e5..6cc378b04 100644 --- a/compiler/luci-interpreter/src/kernels/Shape.cpp +++ b/compiler/luci-interpreter/src/kernels/Shape.cpp @@ -50,7 +50,7 @@ void ShapeKernel::execute() const evalInt(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Shape Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Sin.cpp b/compiler/luci-interpreter/src/kernels/Sin.cpp new file mode 100644 index 000000000..11db254a0 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Sin.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Sin.h" + +#include "kernels/Utils.h" + +#include + +namespace luci_interpreter +{ +namespace kernels +{ + +namespace +{ + +template +inline void CalcSin(const T *input_data, const size_t num_elements, T *output_data) +{ + for (size_t idx = 0; idx < num_elements; ++idx) + { + output_data[idx] = std::sin(input_data[idx]); + } +} + +} // namespace + +Sin::Sin(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {} + +void Sin::configure() +{ + LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32); + LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type()); + output()->resize(input()->shape()); +} + +void Sin::execute() const +{ + switch (input()->element_type()) + { + case DataType::FLOAT32: + evalFloat(); + break; + default: + throw std::runtime_error("luci-intp Sin Unsupported type."); + } +} + +void Sin::evalFloat() const +{ + const int size = tflite::MatchingFlatSize(getTensorShape(input()), getTensorShape(output())); + CalcSin(getTensorData(input()), size, getTensorData(output())); +} + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Sin.h b/compiler/luci-interpreter/src/kernels/Sin.h new file mode 100644 index 000000000..d42f3daa8 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Sin.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_KERNELS_SIN_H +#define LUCI_INTERPRETER_KERNELS_SIN_H + +#include "core/Kernel.h" +#include "core/KernelParams.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +class Sin : public Kernel +{ +public: + Sin(const Tensor *input, Tensor *output); + + const Tensor *input() const { return _inputs[0]; } + Tensor *output() const { return _outputs[0]; } + + void configure() override; + void execute() const override; + +private: + void evalFloat() const; +}; + +} // namespace kernels +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_KERNELS_SIN_H diff --git a/compiler/luci-interpreter/src/kernels/Sin.test.cpp b/compiler/luci-interpreter/src/kernels/Sin.test.cpp new file mode 100644 index 000000000..1aab252e0 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Sin.test.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Sin.h" +#include "kernels/TestUtils.h" +#include "luci_interpreter/TestMemoryManager.h" + +#include + +namespace luci_interpreter +{ +namespace kernels +{ +namespace +{ + +#define PI 3.14159265358979323846 + +using namespace testing; + +TEST(SinTest, Float) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 1, 3}; + std::vector input_data{0.0f, PI / 3.0f, -PI / 3.0f}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Sin kernel(&input_tensor, &output_tensor); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + std::vector ref_output_shape{1, 1, 3}; + std::vector ref_output_data{std::sin(0.0f), std::sin(PI / 3.0f), std::sin(-PI / 3.0f)}; + EXPECT_THAT(extractTensorData(output_tensor), FloatArrayNear(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape)); +} + +TEST(SinTest, InvalidDType_NEG) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 1, 3}; + std::vector input_data{1l, 2l, 3l}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::S64); + + Sin kernel(&input_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +} // namespace +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Slice.cpp b/compiler/luci-interpreter/src/kernels/Slice.cpp index 2fe2c5471..d7cf46a4b 100644 --- a/compiler/luci-interpreter/src/kernels/Slice.cpp +++ b/compiler/luci-interpreter/src/kernels/Slice.cpp @@ -90,7 +90,7 @@ void Slice::configure() } else { - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Slice Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Softmax.cpp b/compiler/luci-interpreter/src/kernels/Softmax.cpp index c230aaa70..c195fcdac 100644 --- a/compiler/luci-interpreter/src/kernels/Softmax.cpp +++ b/compiler/luci-interpreter/src/kernels/Softmax.cpp @@ -64,7 +64,7 @@ void Softmax::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Softmax Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/SpaceToBatchND.cpp b/compiler/luci-interpreter/src/kernels/SpaceToBatchND.cpp index 630cd38c4..57c6e2e09 100644 --- a/compiler/luci-interpreter/src/kernels/SpaceToBatchND.cpp +++ b/compiler/luci-interpreter/src/kernels/SpaceToBatchND.cpp @@ -95,7 +95,7 @@ void SpaceToBatchND::execute() const getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp ShapeToBatchND Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/SpaceToDepth.cpp b/compiler/luci-interpreter/src/kernels/SpaceToDepth.cpp index 7c29e8cb0..06180ff10 100644 --- a/compiler/luci-interpreter/src/kernels/SpaceToDepth.cpp +++ b/compiler/luci-interpreter/src/kernels/SpaceToDepth.cpp @@ -71,7 +71,7 @@ void SpaceToDepth::execute() const getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp SpaceToDepth Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Split.cpp b/compiler/luci-interpreter/src/kernels/Split.cpp index 1a563f307..c175bbb75 100644 --- a/compiler/luci-interpreter/src/kernels/Split.cpp +++ b/compiler/luci-interpreter/src/kernels/Split.cpp @@ -72,7 +72,7 @@ void Split::execute() const TF_LITE_SPLIT(uint8_t); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Split Unsupported type."); } #undef TF_LITE_SPLIT } diff --git a/compiler/luci-interpreter/src/kernels/SplitV.cpp b/compiler/luci-interpreter/src/kernels/SplitV.cpp index aa6820889..f3cc504b0 100644 --- a/compiler/luci-interpreter/src/kernels/SplitV.cpp +++ b/compiler/luci-interpreter/src/kernels/SplitV.cpp @@ -102,7 +102,7 @@ void SplitV::execute() const TF_LITE_SPLIT(int16_t); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp SplitV Unsupported type."); } #undef TF_LITE_SPLIT } diff --git a/compiler/luci-interpreter/src/kernels/Sqrt.cpp b/compiler/luci-interpreter/src/kernels/Sqrt.cpp index 46e9fc9ad..46fe313f2 100644 --- a/compiler/luci-interpreter/src/kernels/Sqrt.cpp +++ b/compiler/luci-interpreter/src/kernels/Sqrt.cpp @@ -46,7 +46,7 @@ void Sqrt::execute() const break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Sqrt Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Square.cpp b/compiler/luci-interpreter/src/kernels/Square.cpp index bc71905c1..24a85eef8 100644 --- a/compiler/luci-interpreter/src/kernels/Square.cpp +++ b/compiler/luci-interpreter/src/kernels/Square.cpp @@ -46,7 +46,7 @@ void Square::execute() const break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Square Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/SquaredDifference.cpp b/compiler/luci-interpreter/src/kernels/SquaredDifference.cpp index 3bafeba4a..ca7605234 100644 --- a/compiler/luci-interpreter/src/kernels/SquaredDifference.cpp +++ b/compiler/luci-interpreter/src/kernels/SquaredDifference.cpp @@ -46,7 +46,7 @@ void SquaredDifference::execute() const evalSquaredDifference(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp SquaredDifference Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/StridedSlice.cpp b/compiler/luci-interpreter/src/kernels/StridedSlice.cpp index a8730d861..993256b3d 100644 --- a/compiler/luci-interpreter/src/kernels/StridedSlice.cpp +++ b/compiler/luci-interpreter/src/kernels/StridedSlice.cpp @@ -141,8 +141,18 @@ void StridedSlice::execute() const getTensorData(input()), getTensorShape(output()), getTensorData(output())); break; + case DataType::S64: + tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()), + getTensorData(input()), getTensorShape(output()), + getTensorData(output())); + break; + case DataType::BOOL: + tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()), + getTensorData(input()), getTensorShape(output()), + getTensorData(output())); + break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp StridedSlice Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Sub.cpp b/compiler/luci-interpreter/src/kernels/Sub.cpp index 1fd583c62..abceb6fef 100644 --- a/compiler/luci-interpreter/src/kernels/Sub.cpp +++ b/compiler/luci-interpreter/src/kernels/Sub.cpp @@ -58,7 +58,7 @@ void Sub::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Sub Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Sum.cpp b/compiler/luci-interpreter/src/kernels/Sum.cpp index 645f02c36..ef870e45d 100644 --- a/compiler/luci-interpreter/src/kernels/Sum.cpp +++ b/compiler/luci-interpreter/src/kernels/Sum.cpp @@ -149,7 +149,7 @@ void Sum::execute() const evalFloat(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Sum Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Tanh.cpp b/compiler/luci-interpreter/src/kernels/Tanh.cpp index d47a0bde9..7cb59bedb 100644 --- a/compiler/luci-interpreter/src/kernels/Tanh.cpp +++ b/compiler/luci-interpreter/src/kernels/Tanh.cpp @@ -49,7 +49,7 @@ void Tanh::execute() const evalQuantized(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Tanh Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Tile.cpp b/compiler/luci-interpreter/src/kernels/Tile.cpp new file mode 100644 index 000000000..0ab93f52b --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Tile.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Tile.h" + +#include "kernels/Utils.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +Tile::Tile(const Tensor *input, const Tensor *multiples, Tensor *output) + : Kernel({input, multiples}, {output}) +{ +} + +void Tile::configure() +{ + LUCI_INTERPRETER_CHECK(input()->shape().num_dims() >= 1); + LUCI_INTERPRETER_CHECK(multiples()->shape().num_dims() == 1); + LUCI_INTERPRETER_CHECK(multiples()->shape().dim(0) == input()->shape().num_dims()); + LUCI_INTERPRETER_CHECK(multiples()->element_type() == DataType::S32); + + Shape output_shape(input()->shape().num_dims()); + const int32_t *muldata = getTensorData(multiples()); + int32_t num_dim = multiples()->shape().dim(0); + for (int32_t dim = 0; dim < num_dim; ++dim) + { + output_shape.dim(dim) = input()->shape().dim(dim) * muldata[dim]; + } + output()->resize(output_shape); +} + +void Tile::execute() const +{ + switch (output()->element_type()) + { + case DataType::FLOAT32: + evalFloat(); + break; + default: + throw std::runtime_error("luci-intp Tile Unsupported type."); + } +} + +namespace +{ + +template +void CopyMultipleTimes(const T *in_data, int32_t in_size, M multiplier, T *out_data) +{ + for (M i = 0; i < multiplier; ++i) + { + const T *in_end = in_data + in_size; + T *new_out_data = std::copy(in_data, in_end, out_data); + in_data = out_data; + out_data = new_out_data; + } +} + +template +std::pair TileOneDimension(const tflite::RuntimeShape &in_dimensions, const T *in_data, + const M *multiples, T *out_data, int dimension) +{ + if (in_dimensions.DimensionsCount() == 0) + { + // If input tensor is a scalar, then just copy it to output (no need to multiply). + *out_data = *in_data; + return std::make_pair(0, 0); + } + + const int dimension_size = in_dimensions.Dims(dimension); + if (dimension == in_dimensions.DimensionsCount() - 1) + { + CopyMultipleTimes(in_data, dimension_size, multiples[dimension], out_data); + return std::make_pair(dimension_size, dimension_size * static_cast(multiples[dimension])); + } + + int total_stride_size = 0, total_tiled_stride_size = 0; + const T *copy_from_data = in_data; + T *copy_to_data = out_data; + for (int i = 0; i < dimension_size; ++i) + { + int stride_size = 0, tiled_stride_size = 0; + std::tie(stride_size, tiled_stride_size) = + TileOneDimension(in_dimensions, copy_from_data, multiples, copy_to_data, dimension + 1); + copy_from_data += stride_size; + copy_to_data += tiled_stride_size; + total_stride_size += stride_size; + total_tiled_stride_size += tiled_stride_size; + } + CopyMultipleTimes(out_data, total_tiled_stride_size, multiples[dimension] - 1, + out_data + total_tiled_stride_size); + return std::make_pair(total_stride_size, + static_cast(total_tiled_stride_size * multiples[dimension])); +} + +} // namespace + +void Tile::evalFloat() const +{ + TileOneDimension(getTensorShape(input()), getTensorData(input()), + getTensorData(multiples()), getTensorData(output()), 0); +} + +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Tile.h b/compiler/luci-interpreter/src/kernels/Tile.h new file mode 100644 index 000000000..7e3302ce8 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Tile.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_KERNELS_TILE_H +#define LUCI_INTERPRETER_KERNELS_TILE_H + +#include "core/Kernel.h" +#include "core/KernelParams.h" + +namespace luci_interpreter +{ +namespace kernels +{ + +class Tile : public Kernel +{ +public: + Tile(const Tensor *input, const Tensor *multiplies, Tensor *output); + + const Tensor *input() const { return _inputs[0]; } + const Tensor *multiples() const { return _inputs[1]; } + Tensor *output() const { return _outputs[0]; } + + void configure() override; + void execute() const override; + +private: + void evalFloat() const; +}; + +} // namespace kernels +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_KERNELS_TILE_H diff --git a/compiler/luci-interpreter/src/kernels/Tile.test.cpp b/compiler/luci-interpreter/src/kernels/Tile.test.cpp new file mode 100644 index 000000000..2bb5c12e2 --- /dev/null +++ b/compiler/luci-interpreter/src/kernels/Tile.test.cpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/Tile.h" +#include "kernels/TestUtils.h" +#include "luci_interpreter/TestMemoryManager.h" + +namespace luci_interpreter +{ +namespace kernels +{ +namespace +{ + +using namespace testing; + +TEST(TileTest, FloatMul12) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 3}; + std::vector input_data{1.0f, 2.0f, 3.0f}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Shape mul_shape{2}; + std::vector mul_data{1, 2}; + Tensor mul_tensor = makeInputTensor(mul_shape, mul_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Tile kernel(&input_tensor, &mul_tensor, &output_tensor); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + std::vector ref_output_shape{1, 6}; + std::vector ref_output_data{1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f}; + EXPECT_THAT(extractTensorData(output_tensor), FloatArrayNear(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape)); +} + +TEST(TileTest, FloatMul21) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 3}; + std::vector input_data{1.0f, 2.0f, 3.0f}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Shape mul_shape{2}; + std::vector mul_data{2, 1}; + Tensor mul_tensor = makeInputTensor(mul_shape, mul_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Tile kernel(&input_tensor, &mul_tensor, &output_tensor); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + std::vector ref_output_shape{2, 3}; + std::vector ref_output_data{1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f}; + EXPECT_THAT(extractTensorData(output_tensor), FloatArrayNear(ref_output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape)); +} + +TEST(TileTest, MultiplesShapeInvalid_NEG) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 3}; + std::vector input_data{1.0f, 2.0f, 3.0f}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Shape mul_shape{3}; + std::vector mul_data{1, 2, 3}; + Tensor mul_tensor = makeInputTensor(mul_shape, mul_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Tile kernel(&input_tensor, &mul_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST(TileTest, MultiplesDTypeInvalid_NEG) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 3}; + std::vector input_data{1.0f, 2.0f, 3.0f}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Shape mul_shape{2}; + std::vector mul_data{1.0f, 2.0f}; + Tensor mul_tensor = makeInputTensor(mul_shape, mul_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Tile kernel(&input_tensor, &mul_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +TEST(TileTest, MultiplesDimInvalid_NEG) +{ + std::unique_ptr memory_manager = std::make_unique(); + Shape input_shape{1, 3}; + std::vector input_data{1.0f, 2.0f, 3.0f}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, memory_manager.get()); + Shape mul_shape{3}; + std::vector mul_data{1, 2, 3}; + Tensor mul_tensor = makeInputTensor(mul_shape, mul_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::S32); + + Tile kernel(&input_tensor, &mul_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + +} // namespace +} // namespace kernels +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/kernels/Transpose.cpp b/compiler/luci-interpreter/src/kernels/Transpose.cpp index 802d87295..725a9523c 100644 --- a/compiler/luci-interpreter/src/kernels/Transpose.cpp +++ b/compiler/luci-interpreter/src/kernels/Transpose.cpp @@ -70,13 +70,18 @@ void Transpose::execute() const getTensorData(input()), getTensorShape(output()), getTensorData(output())); break; + case DataType::S64: + tflite::reference_ops::Transpose(params, getTensorShape(input()), + getTensorData(input()), getTensorShape(output()), + getTensorData(output())); + break; case DataType::U8: tflite::reference_ops::Transpose(params, getTensorShape(input()), getTensorData(input()), getTensorShape(output()), getTensorData(output())); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Transpose Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Transpose.test.cpp b/compiler/luci-interpreter/src/kernels/Transpose.test.cpp index 43be8f8b9..a41db63b1 100644 --- a/compiler/luci-interpreter/src/kernels/Transpose.test.cpp +++ b/compiler/luci-interpreter/src/kernels/Transpose.test.cpp @@ -51,7 +51,7 @@ template class TransposeTest : public ::testing::Test { }; -using DataTypes = ::testing::Types; +using DataTypes = ::testing::Types; TYPED_TEST_SUITE(TransposeTest, DataTypes); TYPED_TEST(TransposeTest, Small3D) diff --git a/compiler/luci-interpreter/src/kernels/TransposeConv.cpp b/compiler/luci-interpreter/src/kernels/TransposeConv.cpp index 08bfbf319..01bdd80eb 100644 --- a/compiler/luci-interpreter/src/kernels/TransposeConv.cpp +++ b/compiler/luci-interpreter/src/kernels/TransposeConv.cpp @@ -115,18 +115,26 @@ void TransposeConv::execute() const evalQuantizedS16(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp TransposeConv Unsupported type."); } } void TransposeConv::evalFloat() const { + float activation_min{}; + float activation_max{}; + // TODO support activation + assert(_params.activation == Activation::NONE); + calculateActivationRange(Activation::NONE, &activation_min, &activation_max); + tflite::ConvParams op_params{}; op_params.padding_type = tflite::PaddingType::kSame; op_params.padding_values.height = _padding_height; op_params.padding_values.width = _padding_width; op_params.stride_height = params().stride_height; op_params.stride_width = params().stride_width; + op_params.float_activation_min = activation_min; + op_params.float_activation_max = activation_max; tflite::reference_ops::TransposeConv(op_params, // getTensorShape(input()), getTensorData(input()), // getTensorShape(filter()), getTensorData(filter()), // diff --git a/compiler/luci-interpreter/src/kernels/Unpack.cpp b/compiler/luci-interpreter/src/kernels/Unpack.cpp index 9127241c0..bd38b5d8e 100644 --- a/compiler/luci-interpreter/src/kernels/Unpack.cpp +++ b/compiler/luci-interpreter/src/kernels/Unpack.cpp @@ -76,7 +76,7 @@ void Unpack::execute() const case DataType::U8: return executeImpl(); default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp Unpack Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/kernels/Utils.cpp b/compiler/luci-interpreter/src/kernels/Utils.cpp index a04dbcc0f..edd876137 100644 --- a/compiler/luci-interpreter/src/kernels/Utils.cpp +++ b/compiler/luci-interpreter/src/kernels/Utils.cpp @@ -124,10 +124,18 @@ void calculateActivationRangeQuantized(Activation activation, const Tensor *outp int32_t qmax{}; switch (output->element_type()) { + case DataType::U4: + qmin = 0; + qmax = 15; + break; case DataType::U8: qmin = 0; qmax = std::numeric_limits::max(); break; + case DataType::S4: + qmin = -8; + qmax = 7; + break; case DataType::S8: qmin = -std::numeric_limits::max(); qmax = std::numeric_limits::max(); @@ -139,7 +147,7 @@ void calculateActivationRangeQuantized(Activation activation, const Tensor *outp qmax = std::numeric_limits::max(); break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp (calculateActivationRangeQuantized) Unsupported type."); } calculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, activation_min, diff --git a/compiler/luci-interpreter/src/loader/GraphLoader.cpp b/compiler/luci-interpreter/src/loader/GraphLoader.cpp index ba99a579b..cf83713d9 100644 --- a/compiler/luci-interpreter/src/loader/GraphLoader.cpp +++ b/compiler/luci-interpreter/src/loader/GraphLoader.cpp @@ -54,10 +54,14 @@ const void *getNodeData(const luci::CircleConst *node, size_t *data_size) { switch (node->dtype()) { + case DataType::U4: + return getNodeDataImpl(node, data_size); case DataType::U8: return getNodeDataImpl(node, data_size); case DataType::FLOAT32: return getNodeDataImpl(node, data_size); + case DataType::S4: + return getNodeDataImpl(node, data_size); case DataType::S8: return getNodeDataImpl(node, data_size); case DataType::S16: @@ -69,7 +73,7 @@ const void *getNodeData(const luci::CircleConst *node, size_t *data_size) case DataType::BOOL: return getNodeDataImpl(node, data_size); default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error("luci-intp (getNodeData) Unsupported type."); } } diff --git a/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp b/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp index 10a01f418..128d1a43a 100644 --- a/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp +++ b/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp @@ -25,6 +25,8 @@ #include #include #include +#include +#include #include #include #include @@ -67,6 +69,7 @@ #include #include #include +#include #include #include #include @@ -78,6 +81,7 @@ #include #include #include +#include #include #include #include @@ -99,7 +103,7 @@ protected: std::unique_ptr _memory_manager; - template NodeT *createNode(Args &&... args) + template NodeT *createNode(Args &&...args) { auto *node = _graph.nodes()->create(std::forward(args)...); // The actual type does not matter for the purpose of the tests. @@ -299,6 +303,42 @@ TEST_F(KernelBuilderTest, Conv2D) EXPECT_THAT(kernel->params().activation, Eq(op->fusedActivationFunction())); } +TEST_F(KernelBuilderTest, Cos) +{ + auto *input = createInputNode(); + + auto *op = createNode(); + op->x(input); + + auto kernel = buildKernel(op); + ASSERT_THAT(kernel, NotNull()); + + checkTensor(kernel->input(), input); + checkTensor(kernel->output(), op); +} + +TEST_F(KernelBuilderTest, CumSum) +{ + auto *input = createInputNode(); + auto *axis = createInputNode(); + + auto *op = createNode(); + op->input(input); + op->axis(axis); + + op->exclusive(false); + op->reverse(false); + + auto kernel = buildKernel(op); + ASSERT_THAT(kernel, NotNull()); + + checkTensor(kernel->input(), input); + checkTensor(kernel->axis(), axis); + checkTensor(kernel->output(), op); + EXPECT_THAT(kernel->params().exclusive, Eq(op->exclusive())); + EXPECT_THAT(kernel->params().reverse, Eq(op->reverse())); +} + TEST_F(KernelBuilderTest, DepthToSpace) { auto *input = createInputNode(); @@ -1069,6 +1109,20 @@ TEST_F(KernelBuilderTest, Rsqrt) checkTensor(kernel->output(), op); } +TEST_F(KernelBuilderTest, Sin) +{ + auto *input = createInputNode(); + + auto *op = createNode(); + op->x(input); + + auto kernel = buildKernel(op); + ASSERT_THAT(kernel, NotNull()); + + checkTensor(kernel->input(), input); + checkTensor(kernel->output(), op); +} + TEST_F(KernelBuilderTest, Slice) { auto *input = createInputNode(); @@ -1286,6 +1340,23 @@ TEST_F(KernelBuilderTest, Tanh) checkTensor(kernel->output(), op); } +TEST_F(KernelBuilderTest, Tile) +{ + auto *input = createInputNode(); + auto *multiples = createInputNode(); + + auto *op = createNode(); + op->input(input); + op->multiples(multiples); + + auto kernel = buildKernel(op); + ASSERT_THAT(kernel, NotNull()); + + checkTensor(kernel->input(), input); + checkTensor(kernel->multiples(), multiples); + checkTensor(kernel->output(), op); +} + TEST_F(KernelBuilderTest, Transpose) { auto *input = createInputNode(); diff --git a/compiler/luci-interpreter/src/loader/nodes/BroadcastTo.cpp b/compiler/luci-interpreter/src/loader/nodes/BroadcastTo.cpp new file mode 100644 index 000000000..a08bcaf34 --- /dev/null +++ b/compiler/luci-interpreter/src/loader/nodes/BroadcastTo.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" + +#include "kernels/BroadcastTo.h" + +namespace luci_interpreter +{ + +std::unique_ptr build_kernel_CircleBroadcastTo(const luci::CircleNode *circle_node, + KernelBuilderHelper &helper) +{ + const auto *node = loco::must_cast(circle_node); + assert(node->arity() == 2); + + const Tensor *input = helper.getInputTensor(node->input()); + const Tensor *shape = helper.getInputTensor(node->shape()); + Tensor *output = helper.getOutputTensor(node); + + return std::make_unique(input, shape, output); +} + +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/loader/nodes/Cos.cpp b/compiler/luci-interpreter/src/loader/nodes/Cos.cpp new file mode 100644 index 000000000..c1e1b2102 --- /dev/null +++ b/compiler/luci-interpreter/src/loader/nodes/Cos.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" + +#include "kernels/Cos.h" + +namespace luci_interpreter +{ + +std::unique_ptr build_kernel_CircleCos(const luci::CircleNode *circle_node, + KernelBuilderHelper &helper) +{ + const auto *node = loco::must_cast(circle_node); + assert(node->arity() == 1); + + const Tensor *input = helper.getInputTensor(node->x()); + Tensor *output = helper.getOutputTensor(node); + + return std::make_unique(input, output); +} + +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/loader/nodes/CumSum.cpp b/compiler/luci-interpreter/src/loader/nodes/CumSum.cpp new file mode 100644 index 000000000..a12e1b69a --- /dev/null +++ b/compiler/luci-interpreter/src/loader/nodes/CumSum.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" + +#include "kernels/CumSum.h" + +namespace luci_interpreter +{ + +std::unique_ptr build_kernel_CircleCumSum(const luci::CircleNode *circle_node, + KernelBuilderHelper &helper) +{ + const auto *node = loco::must_cast(circle_node); + const Tensor *input = helper.getInputTensor(node->input()); + const Tensor *axis = helper.getInputTensor(node->axis()); + Tensor *output = helper.getOutputTensor(node); + + CumSumParams params{}; + params.exclusive = node->exclusive(); + params.reverse = node->reverse(); + + return std::make_unique(input, axis, output, params); +} + +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp b/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp index b7b742b8a..d0a53ace4 100644 --- a/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp +++ b/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp @@ -35,7 +35,19 @@ std::unique_ptr build_kernel_CircleFullyConnected(const luci::CircleNode FullyConnectedParams params{}; params.activation = node->fusedActivationFunction(); params.keep_num_dims = node->keep_num_dims(); - + if (weights->element_type() == loco::DataType::S4 || + weights->element_type() == loco::DataType::U4) + { + auto scratchpad = + std::make_unique(input->element_type(), weights->shape(), AffineQuantization{}, ""); + scratchpad->set_observable(false); + scratchpad->set_data_buffer(nullptr); + Tensor *scratchpad_tmp = + helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad)); + helper.getRuntimeGraph(node->graph())->configureAllocations(scratchpad_tmp); + return std::make_unique(input, weights, bias, output, scratchpad_tmp, + params); + } return std::make_unique(input, weights, bias, output, params); } diff --git a/compiler/luci-interpreter/src/loader/nodes/Relu0To1.cpp b/compiler/luci-interpreter/src/loader/nodes/Relu0To1.cpp new file mode 100644 index 000000000..10a93a170 --- /dev/null +++ b/compiler/luci-interpreter/src/loader/nodes/Relu0To1.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" + +#include "kernels/Relu0To1.h" + +namespace luci_interpreter +{ + +std::unique_ptr build_kernel_CircleRelu0To1(const luci::CircleNode *circle_node, + KernelBuilderHelper &helper) +{ + const auto *node = loco::must_cast(circle_node); + assert(node->arity() == 1); + + const Tensor *input = helper.getInputTensor(node->features()); + Tensor *output = helper.getOutputTensor(node); + + return std::make_unique(input, output); +} + +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/loader/nodes/SelectV2.cpp b/compiler/luci-interpreter/src/loader/nodes/SelectV2.cpp new file mode 100644 index 000000000..92e5263bf --- /dev/null +++ b/compiler/luci-interpreter/src/loader/nodes/SelectV2.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" + +#include "kernels/SelectV2.h" + +namespace luci_interpreter +{ + +std::unique_ptr build_kernel_CircleSelectV2(const luci::CircleNode *circle_node, + KernelBuilderHelper &helper) +{ + const auto *node = loco::must_cast(circle_node); + assert(node->arity() == 3); + + const Tensor *c = helper.getInputTensor(node->condition()); + const Tensor *t = helper.getInputTensor(node->t()); + const Tensor *e = helper.getInputTensor(node->e()); + Tensor *output = helper.getOutputTensor(node); + + return std::make_unique(c, t, e, output); +} + +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/loader/nodes/Sin.cpp b/compiler/luci-interpreter/src/loader/nodes/Sin.cpp new file mode 100644 index 000000000..b20062a2e --- /dev/null +++ b/compiler/luci-interpreter/src/loader/nodes/Sin.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" + +#include "kernels/Sin.h" + +namespace luci_interpreter +{ + +std::unique_ptr build_kernel_CircleSin(const luci::CircleNode *circle_node, + KernelBuilderHelper &helper) +{ + const auto *node = loco::must_cast(circle_node); + assert(node->arity() == 1); + + const Tensor *input = helper.getInputTensor(node->x()); + Tensor *output = helper.getOutputTensor(node); + + return std::make_unique(input, output); +} + +} // namespace luci_interpreter diff --git a/compiler/luci-interpreter/src/loader/nodes/Tile.cpp b/compiler/luci-interpreter/src/loader/nodes/Tile.cpp new file mode 100644 index 000000000..2e8c17887 --- /dev/null +++ b/compiler/luci-interpreter/src/loader/nodes/Tile.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" + +#include "kernels/Tile.h" + +namespace luci_interpreter +{ + +std::unique_ptr build_kernel_CircleTile(const luci::CircleNode *circle_node, + KernelBuilderHelper &helper) +{ + const auto *node = loco::must_cast(circle_node); + assert(node->arity() == 2); + + const Tensor *input = helper.getInputTensor(node->input()); + const Tensor *multiples = helper.getInputTensor(node->multiples()); + Tensor *output = helper.getOutputTensor(node); + + return std::make_unique(input, multiples, output); +} + +} // namespace luci_interpreter diff --git a/compiler/luci-pass-value-py-test/CMakeLists.txt b/compiler/luci-pass-value-py-test/CMakeLists.txt new file mode 100644 index 000000000..9b59ae6e3 --- /dev/null +++ b/compiler/luci-pass-value-py-test/CMakeLists.txt @@ -0,0 +1,53 @@ +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_12_1") +set(TEST_LIST_FILE "test.lst") + +get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR) + +macro(eval RECIPE PASS_OPTION) + set(CIRCLE_FILE "${RECIPE}.circle") + set(CIRCLE_PATH "${ARTIFACTS_BIN_PATH}/${CIRCLE_FILE}") + + set(PASS_CIRCLE_FILE "${RECIPE}.pass.circle") + set(PASS_CIRCLE_OUTPUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/${PASS_CIRCLE_FILE}") + + set(DASH_PASS_OPTION "--${PASS_OPTION}") + foreach(MORE_OPTIONS ${ARGN}) + list(APPEND DASH_PASS_OPTION "--${MORE_OPTIONS}") + endforeach() + # NOTE if there are two options, 'DASH_PASS_OPTION' will be like '--option_a;--option_b' + # add_custom_command() will translate ';' to two arguments as '--optiona_a --optionb' + # do not use set(DASH_PASS_OPTION "${DASH_PASS_OPTION} --${ARG}")) + # as this will become like '"--optiona_a --optionb"' which is one string argument + + # Generate optimized .circle + add_custom_command(OUTPUT ${PASS_CIRCLE_OUTPUT_PATH} + COMMAND $ ${DASH_PASS_OPTION} ${CIRCLE_PATH} ${PASS_CIRCLE_OUTPUT_PATH} + DEPENDS $ ${CIRCLE_PATH} + COMMENT "Generate ${PASS_CIRCLE_FILE} with ${DASH_PASS_OPTION}" + ) + + # depends + list(APPEND TEST_DEPS ${PASS_CIRCLE_OUTPUT_PATH}) + +endmacro(eval) + +# Read "test.lst" +include("test.lst") +# Read "test.local.lst" if exists +include("test.local.lst" OPTIONAL) + +add_custom_target(luci_pass_value_py_test_files ALL DEPENDS ${TEST_DEPS}) +add_dependencies(luci_pass_value_py_test_files common_artifacts_deps) + +add_test(NAME luci_pass_value_py_test + COMMAND ${VIRTUALENV}/bin/python -m pytest -sv test_luci_eval.py + --test_list ${TEST_LIST_FILE} + --tflite_dir ${ARTIFACTS_BIN_PATH} + --circle_dir ${CMAKE_CURRENT_BINARY_DIR} + --luci_eval_driver $ + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) diff --git a/compiler/luci-pass-value-py-test/README.md b/compiler/luci-pass-value-py-test/README.md new file mode 100644 index 000000000..8d5626e5c --- /dev/null +++ b/compiler/luci-pass-value-py-test/README.md @@ -0,0 +1,20 @@ +# luci-pass-value-py-test + +`luci-pass-value-py-test` validates execution result values of tflite model and +circle model generated with specific optimization. + +The test proceeds as follows: + +Step 0: Use tflite and circle file in 'common-artifacts' folder as the source model. + - tflite file is used as to generate reference execution result + - circle file is used as source of optimization to apply + +Step 1: Run circle2circle with given optimization option to produce transformed circle. + - "modelfile.circle" -> circle2circle -> "modelfile.pass.circle" + +Step 2: Run TFLite interpreter and luci-interpreter for the source tflite and circle, respectively. + (with the same input tensors filled with random values) + - "modelfile.tflite" ------> TFLite interpreter -> Execution result 1 + - "modelfile.pass.circle" -> luci-interpreter ---> Execution result 2 + +Step 3: Compare the execution result 1 and 2. Test is PASSED if results are sames. diff --git a/compiler/luci-pass-value-py-test/conftest.py b/compiler/luci-pass-value-py-test/conftest.py new file mode 100644 index 000000000..be8fa3b79 --- /dev/null +++ b/compiler/luci-pass-value-py-test/conftest.py @@ -0,0 +1,40 @@ +import re + + +def extract_test_args(s): + p = re.compile('eval\\((.*)\\)') + result = p.search(s) + return result.group(1) + + +def pytest_addoption(parser): + parser.addoption("--test_list", action="store", help="Path to test list") + parser.addoption( + "--tflite_dir", action="store", help="Directory including tflite file") + parser.addoption( + "--circle_dir", action="store", help="Directory including circle file") + parser.addoption( + "--luci_eval_driver", action="store", help="Path to luci eval driver") + + +def pytest_generate_tests(metafunc): + list_path = metafunc.config.getoption('test_list') + tflite_dir = metafunc.config.getoption('tflite_dir') + circle_dir = metafunc.config.getoption('circle_dir') + eval_driver_path = metafunc.config.getoption('luci_eval_driver') + if list_path is None: + tests_default_tol = [] + else: + with open(list_path) as f: + contents = [line.rstrip() for line in f] + + comment_removed = [line for line in contents if not line.startswith('#')] + newline_removed = [line for line in comment_removed if line.startswith('eval(')] + test_args = [extract_test_args(line) for line in newline_removed] + # eval(TEST_NAME PASS_1 PASS_2 ..) + tests_default_tol = [(arg.split()[0], tflite_dir, circle_dir, eval_driver_path) + for arg in test_args] + + if 'test_name' in metafunc.fixturenames: + metafunc.parametrize('test_name,tflite_dir,circle_dir,eval_driver_path', + tests_default_tol) diff --git a/compiler/luci-pass-value-py-test/requires.cmake b/compiler/luci-pass-value-py-test/requires.cmake new file mode 100644 index 000000000..19be7219c --- /dev/null +++ b/compiler/luci-pass-value-py-test/requires.cmake @@ -0,0 +1,3 @@ +require("common-artifacts") +require("luci-eval-driver") +require("circle2circle") diff --git a/compiler/luci-pass-value-py-test/test.lst b/compiler/luci-pass-value-py-test/test.lst new file mode 100644 index 000000000..287ddcb94 --- /dev/null +++ b/compiler/luci-pass-value-py-test/test.lst @@ -0,0 +1,87 @@ +# +# Format: +# eval(MODEL PASS) +# MODEL: tflite model file name in build/compiler/common-artifacts folder. +# PASS: Optimization Pass to test. Supports only one Pass for now. +# + +# eval(Net_Preactivation_BN_000 fuse_preactivation_batchnorm) : value diff exist +# --> https://github.com/Samsung/ONE/issues/5782 +eval(FullyConnected_007 replace_non_const_fc_with_batch_matmul) +eval(HardSwish_001 decompose_hardswish) +eval(Net_Add_FloorMod_Gather_000 remove_gather_guard) +eval(Net_Add_FullyConnected_000 fuse_add_to_fullyconnected_bias) +eval(Net_Add_FullyConnected_001 fuse_add_to_fullyconnected_bias) +eval(Net_Add_FullyConnected_002 fuse_add_to_fullyconnected_bias) +eval(Net_Conv_Add_000 fuse_add_with_conv) +eval(Net_Conv_Add_001 fuse_add_with_conv) +# eval(Net_Conv_Add_002 fuse_add_with_conv) --> Conv2D w/o bias fails in tflite interpreter +eval(Net_Conv_Add_Mul_000 fuse_batchnorm_with_conv) +eval(Net_Conv_Add_Mul_000 fuse_batchnorm_with_conv) +eval(Net_Conv_Add_Mul_001 fuse_batchnorm_with_conv) +eval(Net_Conv_Add_Mul_002 fuse_batchnorm_with_conv) +eval(Net_Conv_Min_Max_000 transform_min_max_to_relu6) +eval(Net_Conv_Min_Relu_000 transform_min_relu_to_relu6) +eval(Net_Conv_Mul_000 fuse_mul_with_conv) +eval(Net_Conv_Mul_001 fuse_mul_with_conv) +eval(Net_Conv_Mul_002 fuse_mul_with_conv) +eval(Net_Conv_Mul_003 fuse_mul_with_conv) +eval(Net_Conv_PReluGraph_000 fuse_prelu) +eval(Net_Conv_Relu6_000 fuse_activation_function) +eval(Net_Densify_Add_000 fold_densify) +eval(Net_Dequantize_Add_000 fold_dequantize) +eval(Net_DwConv_BN_000 fuse_batchnorm_with_dwconv) +eval(Net_DwConv_BN_001 fuse_batchnorm_with_dwconv) +eval(Net_FullyConnected_Add_000 fold_fully_connected) +eval(Net_Horizontal_FullyConnected_Add_000 fuse_horizontal_fc_layers) +eval(Net_InstanceNorm_001 fuse_instnorm) +eval(Net_InstanceNorm_002 fuse_instnorm) +eval(Net_InstanceNorm_003 fuse_instnorm) +eval(Net_Mul_Add_000 remove_unnecessary_add) +eval(Net_Mul_Add_001 remove_unnecessary_add) +eval(Net_Mul_Add_002 remove_unnecessary_add) +eval(Net_Mul_Add_003 remove_unnecessary_add) +eval(Net_Mul_Div_000 fuse_mul_with_div) +eval(Net_Mul_Div_001 fuse_mul_with_div) +eval(Net_Mul_FullyConnected_000 fuse_mul_to_fullyconnected_weights) +eval(Net_Mul_FullyConnected_001 fuse_mul_to_fullyconnected_weights) +eval(Net_Mul_FullyConnected_002 fuse_mul_to_fullyconnected_weights) +eval(Net_Reshape_Mean_000 forward_reshape_to_unaryop) +eval(Net_Reshape_Neg_000 forward_reshape_to_unaryop) +eval(Net_Reshape_Reshape_000 remove_redundant_reshape) +eval(Net_Shape_Add_000 fold_shape) +eval(Net_Sqrt_Div_000 transform_sqrt_div_to_rsqrt_mul) +eval(Net_Squeeze_Squeeze_000 substitute_squeeze_to_reshape) +eval(Net_StridedSlice_StridedSlice_000 remove_unnecessary_strided_slice) +eval(Net_TConv_Add_000 fuse_add_with_tconv) +eval(Net_TConv_Add_001 fuse_add_with_tconv) +eval(Net_TConv_Add_002 fuse_add_with_tconv) +eval(Net_TConv_BN_000 fuse_batchnorm_with_tconv) +eval(Net_TConv_BN_001 fuse_batchnorm_with_tconv) +eval(Net_TConv_BN_002 fuse_batchnorm_with_tconv) +eval(Net_TConv_BN_003 fuse_batchnorm_with_tconv) +eval(Net_TConv_BN_004 fuse_batchnorm_with_tconv) +eval(Net_TConv_BN_005 fuse_batchnorm_with_tconv) +eval(Net_TConv_Slice_000 fuse_slice_with_tconv) +eval(Net_TConv_Slice_001 fuse_slice_with_tconv) +eval(Net_TConv_Slice_002 fuse_slice_with_tconv) +eval(Net_TConv_Slice_003 fuse_slice_with_tconv) +eval(Net_Trans_Reshape_Trans_000 remove_unnecessary_transpose) +eval(Net_Transpose_Add_000 forward_transpose_op) +eval(Net_Transpose_Abs_000 forward_transpose_op) +eval(Softmax_001 decompose_softmax) +eval(Softmax_002 decompose_softmax) +eval(UnidirectionalSequenceLSTM_003 unroll_unidirseqlstm) +eval(UnidirectionalSequenceLSTM_004 unroll_unidirseqlstm) + +# test for limited support for FLOAT16 +eval(Net_Densify_Dequantize_Add_000 fold_dequantize fold_densify) +eval(Net_Dequantize_Add_000 fold_dequantize) + +# test SignatureDef, with any optimization +#eval(SignatureDef_MultiOut_000 fuse_instnorm) +#eval(SignatureDef_MultiOut_001 fuse_instnorm) + +# test for common subexpression elimination +eval(CSE_Quantize_000 common_subexpression_elimination) +eval(CSE_Transpose_000 common_subexpression_elimination) diff --git a/compiler/luci-pass-value-py-test/test_luci_eval.py b/compiler/luci-pass-value-py-test/test_luci_eval.py new file mode 100644 index 000000000..4cb59c177 --- /dev/null +++ b/compiler/luci-pass-value-py-test/test_luci_eval.py @@ -0,0 +1,119 @@ +import numpy as np +import tensorflow as tf +import subprocess +import os + + +def luci_eval_verify(test_name, + tflite_dir, + circle_dir, + eval_driver, + rtolf32=1e-5, + atolf32=1e-5): + tflite_model = os.path.join(tflite_dir, test_name + ".tflite") + circle_model = os.path.join(circle_dir, test_name + ".pass.circle") + + # NOTE reuse f32 value as int value too + rtolint = int(rtolf32) + atolint = int(atolf32) + + # Build TFLite interpreter. + interpreter = tf.lite.Interpreter(tflite_model) + interpreter.allocate_tensors() + + # Read SignatureDef and get output tensor id orders for remapping + full_signatures = interpreter._get_full_signature_list() + full_signatures_outputs_remap = None + if full_signatures != None: + signature_serving_default = full_signatures.get('serving_default', None) + if signature_serving_default != None: + signature_outputs = signature_serving_default['outputs'] + + full_signatures_outputs_remap = [] + for index, (key, value) in enumerate(signature_outputs.items()): + full_signatures_outputs_remap.append(value) + + # Generate random input data. + num_inputs = len(interpreter.get_input_details()) + for i in range(num_inputs): + input_details = interpreter.get_input_details()[i] + if input_details["dtype"] == np.float32: + input_data = np.array( + np.random.random_sample(input_details["shape"]), input_details["dtype"]) + elif input_details["dtype"] == np.uint8: + input_data = np.array( + np.random.randint(0, 256, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.int16: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.int32: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.int64: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + elif input_details["dtype"] == np.bool_: + input_data = np.array( + np.random.choice(a=[True, False], size=input_details["shape"]), + input_details["dtype"]) + else: + assert False, "Unsupported input dtype" + + interpreter.set_tensor(input_details["index"], input_data) + input_data.tofile(circle_model + ".input" + str(i)) + + # Do inference + interpreter.invoke() + + # Execute luci interpreter. + subprocess.run( + [ + eval_driver, circle_model, + str(num_inputs), circle_model + ".input", circle_model + ".output" + ], + check=True) + + # Compare the results. + inpt_output_details = interpreter.get_output_details() + for idx in range(len(inpt_output_details)): + output_details = inpt_output_details[idx] + output_data = np.fromfile(circle_model + ".output" + str(idx), + output_details["dtype"]) + shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r') + output_shape = [int(i) for i in shape_file.read().split(',')] + luci_output_data = np.reshape(output_data, output_shape) + output_tensor = output_details["index"] + if full_signatures_outputs_remap != None: + output_tensor = full_signatures_outputs_remap[idx] + intp_output_data = interpreter.get_tensor(output_tensor) + err_msg = "Execution result of " + tflite_model + " does not match with " + circle_model + if output_details["dtype"] == np.uint8: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + elif output_details["dtype"] == np.float32: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32), err_msg + elif output_details["dtype"] == np.int64: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + elif output_details["dtype"] == np.int32: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + elif output_details["dtype"] == np.int16: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + elif output_details["dtype"] == np.bool_: + assert np.allclose( + luci_output_data, intp_output_data, rtol=0, atol=0), err_msg + else: + assert False, "Unsupported data type: " + output_details["dtype"] + + +# arguments must be in sync with `conftest.py` +def test_luci_eval(test_name: str, tflite_dir: str, circle_dir: str, + eval_driver_path: str): + luci_eval_verify(test_name, tflite_dir, circle_dir, eval_driver_path) diff --git a/compiler/luci-pass-value-test/exclude.me b/compiler/luci-pass-value-test/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/luci-pass-value-test/test.lst b/compiler/luci-pass-value-test/test.lst index 93634b2fc..d22464c61 100644 --- a/compiler/luci-pass-value-test/test.lst +++ b/compiler/luci-pass-value-test/test.lst @@ -13,6 +13,10 @@ addeval(Net_Conv_Add_Mul_001 fuse_batchnorm_with_conv) addeval(Net_Conv_Add_Mul_002 fuse_batchnorm_with_conv) addeval(Net_Conv_Min_Max_000 transform_min_max_to_relu6) addeval(Net_Conv_Min_Relu_000 transform_min_relu_to_relu6) +addeval(Net_Conv_Mul_000 fuse_mul_with_conv) +addeval(Net_Conv_Mul_001 fuse_mul_with_conv) +addeval(Net_Conv_Mul_002 fuse_mul_with_conv) +addeval(Net_Conv_Mul_003 fuse_mul_with_conv) addeval(HardSwish_001 decompose_hardswish) addeval(Net_Conv_PReluGraph_000 fuse_prelu) addeval(Net_Conv_Relu6_000 fuse_activation_function) @@ -21,7 +25,12 @@ addeval(Net_Dequantize_Add_000 fold_dequantize) addeval(Net_DwConv_BN_000 fuse_batchnorm_with_dwconv) addeval(Net_DwConv_BN_001 fuse_batchnorm_with_dwconv) addeval(Net_FullyConnected_Add_000 fold_fully_connected) +addeval(Net_Horizontal_FullyConnected_Add_000 fuse_horizontal_fc_layers) addeval(Net_Reshape_Neg_000 forward_reshape_to_unaryop) +addeval(Net_Mul_Add_000 remove_unnecessary_add) +addeval(Net_Mul_Add_001 remove_unnecessary_add) +addeval(Net_Mul_Add_002 remove_unnecessary_add) +addeval(Net_Mul_Add_003 remove_unnecessary_add) addeval(Net_Reshape_Reshape_000 remove_redundant_reshape) addeval(Net_Squeeze_Squeeze_000 substitute_squeeze_to_reshape) addeval(Net_TConv_Add_000 fuse_add_with_tconv) @@ -33,6 +42,11 @@ addeval(Net_TConv_BN_002 fuse_batchnorm_with_tconv) addeval(Net_TConv_BN_003 fuse_batchnorm_with_tconv) addeval(Net_TConv_BN_004 fuse_batchnorm_with_tconv) addeval(Net_TConv_BN_005 fuse_batchnorm_with_tconv) +addeval(Net_TConv_Slice_000 fuse_slice_with_tconv) +addeval(Net_TConv_Slice_001 fuse_slice_with_tconv) +addeval(Net_TConv_Slice_002 fuse_slice_with_tconv) +addeval(Net_TConv_Slice_003 fuse_slice_with_tconv) +addeval(Net_Trans_Reshape_Trans_000 remove_unnecessary_transpose) addeval(Net_InstanceNorm_001 fuse_instnorm) addeval(Net_InstanceNorm_002 fuse_instnorm) addeval(Net_InstanceNorm_003 fuse_instnorm) @@ -40,6 +54,8 @@ addeval(Net_StridedSlice_StridedSlice_000 remove_unnecessary_strided_slice) addeval(FullyConnected_007 replace_non_const_fc_with_batch_matmul) addeval(Net_Transpose_Add_000 forward_transpose_op) addeval(Net_Transpose_Abs_000 forward_transpose_op) +addeval(Softmax_001 decompose_softmax) +addeval(Softmax_002 decompose_softmax) addeval(UnidirectionalSequenceLSTM_003 unroll_unidirseqlstm) addeval(UnidirectionalSequenceLSTM_004 unroll_unidirseqlstm) @@ -50,3 +66,7 @@ addeval(Net_Densify_Dequantize_Add_000 fold_dequantize fold_densify) # test SignatureDef, with any optimization #addeval(SignatureDef_MultiOut_000 fuse_instnorm) #addeval(SignatureDef_MultiOut_001 fuse_instnorm) + +# test for common subexpression elimination +addeval(CSE_Quantize_000 common_subexpression_elimination) +addeval(CSE_Transpose_000 common_subexpression_elimination) diff --git a/compiler/luci-ref-value-py-test/CMakeLists.txt b/compiler/luci-ref-value-py-test/CMakeLists.txt new file mode 100644 index 000000000..e147279b8 --- /dev/null +++ b/compiler/luci-ref-value-py-test/CMakeLists.txt @@ -0,0 +1,24 @@ +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_12_1") +set(TEST_LIST_FILE "test.lst") + +nncc_find_resource(TensorFlowLiteRecipes) +set(TFLITE_RECIPE_REPO "${TensorFlowLiteRecipes_DIR}") + +nncc_find_resource(CircleRecipes) +set(CIRCLE_RECIPE_REPO "${CircleRecipes_DIR}") + +get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR) +add_test(NAME luci_ref_value_py_test + COMMAND ${VIRTUALENV}/bin/python -m pytest -sv test_luci_eval.py + --test_list ${TEST_LIST_FILE} + --tflrecipe ${TFLITE_RECIPE_REPO} + --circlerecipe ${CIRCLE_RECIPE_REPO} + --artifacts ${ARTIFACTS_BIN_PATH} + --binary ${CMAKE_CURRENT_BINARY_DIR} + --luci_eval_driver $ + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) diff --git a/compiler/luci-ref-value-py-test/README.md b/compiler/luci-ref-value-py-test/README.md new file mode 100644 index 000000000..56221a40d --- /dev/null +++ b/compiler/luci-ref-value-py-test/README.md @@ -0,0 +1,28 @@ +# luci-ref-value-py-test + +`luci-ref-value-py-test` validates luci IR graph model file (.circle) that current +`luci-value-py-test` cannot validate, such as tflite models with unsupported +data types like `INT4`. + +The test proceeds as follows: + +Step 1: Use tflite file from common-artifacts and generate circle files from tflite file +(listsed in test.lst). +``` +"tflite file" -> tflite2circle -> "circle file" +``` + +Step 2: Read reference input files and run luci-eval-driver and get output files. +``` +"circle file" -> luci-evel-driver -> "execution result" +``` + +Step 3: Compare with reference output files with the execution result. The result must be the same. + +Reference input/output files are text files with simple fixed format like follows. +- first line is shape like `1,4` +- second line is data type like `float32` +- third line is values like `0.1,0.2,0.3,0.4` + +Place them in same folder where `test.recipe` file exist, with names `ref.inputN` and `ref.outputN`, +where `N` is I/O number from `0`. diff --git a/compiler/luci-ref-value-py-test/conftest.py b/compiler/luci-ref-value-py-test/conftest.py new file mode 100644 index 000000000..b7d46902c --- /dev/null +++ b/compiler/luci-ref-value-py-test/conftest.py @@ -0,0 +1,118 @@ +import re +import os +import shutil + + +def extract_test_args(s): + p = re.compile('eval\\((.*)\\)') + result = p.search(s) + return result.group(1) + + +def pytest_addoption(parser): + parser.addoption("--test_list", action="store", help="Path to test list") + parser.addoption("--artifacts", action="store", help="Path to test artifacts") + parser.addoption("--tflrecipe", action="store", help="Path to tfl recipies") + parser.addoption("--circlerecipe", action="store", help="Path to circle recipies") + parser.addoption("--binary", action="store", help="Path to test binary") + parser.addoption( + "--luci_eval_driver", action="store", help="Path to luci eval driver") + + +def copy_if_changed(src_filepath, dst_filepath): + do_copy = False + if (os.path.isfile(dst_filepath)): + file_diff = os.stat(src_filepath).st_mtime - os.stat(dst_filepath).st_mtime + if file_diff > 1: + print("file:" + src_filepath + " changed, update") + do_copy = True + else: + do_copy = True + + if do_copy: + print("file:" + src_filepath + " copy to: " + dst_filepath) + shutil.copyfile(src_filepath, dst_filepath) + + +# prepare reference input/output files to build folder for luci-eval-driver +# from ref data in res/TensorFlowLiteRecipes/*/ref.input* and ref.output* +# as model_name.ref.input* and model_name.ref.output* +def copy_ref_files(ref_file_src, ref_file_dst): + num_data = 0 + while True: + input_file_src = ref_file_src + str(num_data) + if (not os.path.isfile(input_file_src)): + break + input_file_dst = ref_file_dst + str(num_data) + copy_if_changed(input_file_src, input_file_dst) + # try next file + num_data = num_data + 1 + + +# copy circle mode from common-artifacts to build binary +def copy_circle_model(model_src, model_dst): + copy_if_changed(model_src, model_dst) + + +def prepare_materials(test_name, tflrecipe_path, circlerecipe_path, binary_path, + artifacts_path): + # tfl? or circle? + recipe_path = tflrecipe_path + # check with 'test.recipe' file as 'ref.input?' can be absent for no input model + test_recipe = os.path.join(recipe_path, test_name, 'test.recipe') + if (not os.path.isfile(test_recipe)): + recipe_path = circlerecipe_path + + ref_input_src = os.path.join(recipe_path, test_name, 'ref.input') + ref_input_dst = os.path.join(binary_path, test_name + '.ref.input') + copy_ref_files(ref_input_src, ref_input_dst) + + ref_input_src = os.path.join(recipe_path, test_name, 'ref.output') + ref_input_dst = os.path.join(binary_path, test_name + '.ref.output') + copy_ref_files(ref_input_src, ref_input_dst) + + cirle_model_src = os.path.join(artifacts_path, test_name + '.circle') + cicle_model_dst = os.path.join(binary_path, test_name + '.circle') + copy_circle_model(cirle_model_src, cicle_model_dst) + + +def pytest_generate_tests(metafunc): + list_path = metafunc.config.getoption('test_list') + artifacts_path = metafunc.config.getoption('artifacts') + tflrecipe_path = metafunc.config.getoption('tflrecipe') + circlerecipe_path = metafunc.config.getoption('circlerecipe') + binary_path = metafunc.config.getoption('binary') + eval_driver_path = metafunc.config.getoption('luci_eval_driver') + if list_path is None: + tests_default_tol = [] + tests_with_tol = [] + else: + with open(list_path) as f: + contents = [line.rstrip() for line in f] + + comment_removed = [line for line in contents if not line.startswith('#')] + newline_removed = [line for line in comment_removed if line.startswith('eval(')] + test_args = [extract_test_args(line) for line in newline_removed] + # eval(TEST_NAME) + tests_default_tol = [(arg, binary_path, eval_driver_path) for arg in test_args + if len(arg.split()) == 1] + # eval(TEST_NAME RTOL ATOL) + tests_with_tol = [(arg.split()[0], binary_path, eval_driver_path, arg.split()[1], + arg.split()[2]) for arg in test_args if len(arg.split()) == 3] + + # copy circle file to binary + for test_item in tests_default_tol: + prepare_materials(test_item[0], tflrecipe_path, circlerecipe_path, + binary_path, artifacts_path) + + for test_item in tests_with_tol: + prepare_materials(test_item[0], tflrecipe_path, circlerecipe_path, + binary_path, artifacts_path) + + if 'default_test_name' in metafunc.fixturenames: + metafunc.parametrize('default_test_name,binary_path,eval_driver_path', + tests_default_tol) + + if 'tol_test_name' in metafunc.fixturenames: + metafunc.parametrize('tol_test_name,binary_path,eval_driver_path,rtolf32,atolf32', + tests_with_tol) diff --git a/compiler/luci-ref-value-py-test/requires.cmake b/compiler/luci-ref-value-py-test/requires.cmake new file mode 100644 index 000000000..c4461490a --- /dev/null +++ b/compiler/luci-ref-value-py-test/requires.cmake @@ -0,0 +1,2 @@ +require("common-artifacts") +require("luci-eval-driver") diff --git a/compiler/luci-ref-value-py-test/test.lst b/compiler/luci-ref-value-py-test/test.lst new file mode 100644 index 000000000..478a73ac4 --- /dev/null +++ b/compiler/luci-ref-value-py-test/test.lst @@ -0,0 +1,5 @@ +# test with given reference data +eval(FullyConnected_I4_002 0.1 0.1) + +# circle recipes +eval(CircleFullyConnected_U4_002 0.1 0.1) diff --git a/compiler/luci-ref-value-py-test/test_luci_eval.py b/compiler/luci-ref-value-py-test/test_luci_eval.py new file mode 100644 index 000000000..cbe991d46 --- /dev/null +++ b/compiler/luci-ref-value-py-test/test_luci_eval.py @@ -0,0 +1,130 @@ +import numpy as np +import subprocess +import os + + +# read input/output data files model_name.ref.input* and +# model_name.ref.output* and return the contents +def recover_fromfile(path, test_name, suffix): + # .ref file format + # 1'st line is shape, i.e. "2,4" + # 2'nd line is dtype, i.e. "float32" + # 3'rd line is comma seperated values + ref_filename = test_name + ".ref." + suffix + ref_datapath = os.path.join(path, ref_filename) + + num_data = 0 + parse_shape = [] + parse_dtype = [] + parse_value = [] + + while True: + refnum_filepath = ref_datapath + str(num_data) + if (not os.path.isfile(refnum_filepath)): + break + with open(refnum_filepath, "r") as ref_file: + lines = ref_file.readlines() + assert len(lines) >= 3, "Invalid file: " + ref_filename + str(num_data) + print("load reference data from", test_name) + shape = [int(i) for i in lines[0].split(",")] + dtype = lines[1].strip("\r\n \t") + if dtype == "float32": + value = [float(i) for i in lines[2].split(",")] + else: + assert False, "Unsupported data type: " + dtype + + # validate shape and number of elements + num_elements = 1 + for dim in shape: + num_elements = num_elements * dim + if num_elements != len(value): + assert False, "Number of value elements do not match with shape" + + parse_shape.append(shape) + parse_dtype.append(dtype) + parse_value.append(value) + + num_data = num_data + 1 + + return num_data, parse_shape, parse_dtype, parse_value + + +def recover_inputs(path, test_name): + return recover_fromfile(path, test_name, "input") + + +def recover_outputs(path, test_name): + return recover_fromfile(path, test_name, "output") + + +# save reference data to input files for luci-eval-driver +def save_binary_inputs(path, test_name, num_inputs, input_shape, input_dtype, input_data): + circle_inputpath = os.path.join(path, test_name + ".circle.input") + for index in range(0, num_inputs): + # reference input value + if input_dtype[index] == "float32": + nps = np.asarray(input_data[index], dtype=np.float32) + nps.tofile(circle_inputpath + str(index)) + else: + assert False, "Unsupported data type: " + input_dtype[index] + # reference input shape + nps = np.asarray(input_shape[index], dtype=np.short) + nps.tofile(circle_inputpath + str(index) + ".shape", sep=",") + # reference input dtype + with open(circle_inputpath + str(index) + ".dtype", "w") as dtype_file: + dtype_file.write(input_dtype[index]) + + +def luci_eval_verify(test_name, binary_path, eval_driver, rtolf32=1e-5, atolf32=1e-5): + circle_model = os.path.join(binary_path, test_name + ".circle") + + num_inputs, input_shape, input_dtype, input_data = recover_inputs( + binary_path, test_name) + assert num_inputs > 0, "No valid reference input file" + save_binary_inputs(binary_path, test_name, num_inputs, input_shape, input_dtype, + input_data) + + num_ouputs, output_shape, output_dtype, output_data = recover_outputs( + binary_path, test_name) + assert num_ouputs > 0, "No valid reference output file" + + # Execute luci interpreter. + subprocess.run( + [ + eval_driver, circle_model, + str(num_inputs), circle_model + ".input", circle_model + ".output" + ], + check=True) + + # Compare the results. + for idx in range(num_ouputs): + luci_output_data = np.fromfile(circle_model + ".output" + str(idx), + output_dtype[idx]) + luci_output_data = np.reshape(luci_output_data, output_shape[idx]) + ref_output_data = np.reshape(output_data[idx], output_shape[idx]) + + show_vals_and_stop = False + if output_dtype[idx] == "float32": + if not np.allclose( + luci_output_data, ref_output_data, rtol=rtolf32, atol=atolf32): + show_vals_and_stop = True + else: + assert False, "Unsupported data type: " + output_dtype[idx] + + if show_vals_and_stop: + print("\nreference:\n", ref_output_data) + print("luci:\n", luci_output_data) + message = "Execution result of " + test_name + " does not match with reference" + assert False, message + + +# arguments must be in sync with `conftest.py` +def test_luci_eval(default_test_name: str, binary_path: str, eval_driver_path: str): + luci_eval_verify(default_test_name, binary_path, eval_driver_path) + + +# arguments must be in sync with `conftest.py` +def test_luci_eval_tol(tol_test_name: str, binary_path: str, eval_driver_path: str, + rtolf32: str, atolf32: str): + luci_eval_verify(tol_test_name, binary_path, eval_driver_path, float(rtolf32), + float(atolf32)) diff --git a/compiler/luci-value-py-test/CMakeLists.txt b/compiler/luci-value-py-test/CMakeLists.txt new file mode 100644 index 000000000..1c6208146 --- /dev/null +++ b/compiler/luci-value-py-test/CMakeLists.txt @@ -0,0 +1,37 @@ +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_12_1") +set(TEST_LIST_FILE "test.lst") + +if(NOT CMAKE_CROSSCOMPILING) + get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR) + add_test(NAME luci_value_py_test + COMMAND ${VIRTUALENV}/bin/python -m pytest -sv test_luci_eval.py + --test_list ${TEST_LIST_FILE} + --artifacts ${ARTIFACTS_BIN_PATH} + --luci_eval_driver $ + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) +else(NOT CMAKE_CROSSCOMPILING) + # NOTE target test is carried out using reference input/output data from host + # test results. this is because it would be difficult to prepare + # TensorFlow lite for target device. + # thus, one must run the host test and then run the test in target device + # with the test result files from the host test. + if(NOT DEFINED ENV{BUILD_HOST_EXEC}) + message(STATUS "BUILD_HOST_EXEC not set: Skip luci-value-py-test") + return() + endif(NOT DEFINED ENV{BUILD_HOST_EXEC}) + + set(ARTIFACTS_BIN_PATH $ENV{BUILD_HOST_EXEC}/compiler/common-artifacts) + add_test(NAME luci_value_py_test + COMMAND ${VIRTUALENV}/bin/python -m pytest -sv test_luci_eval_ref.py + --test_list ${TEST_LIST_FILE} + --artifacts ${ARTIFACTS_BIN_PATH} + --target_artifacts ${CMAKE_CURRENT_BINARY_DIR} + --luci_eval_driver $ + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) +endif(NOT CMAKE_CROSSCOMPILING) diff --git a/compiler/luci-value-py-test/README.md b/compiler/luci-value-py-test/README.md new file mode 100644 index 000000000..33d34c07f --- /dev/null +++ b/compiler/luci-value-py-test/README.md @@ -0,0 +1,19 @@ +# luci-value-py-test + +`luci-value-py-test` validates luci IR graph model file (.circle) + +The test proceeds as follows + +Step 1: Generate tflite files and circle files from TFLite recipes (listsed in test.lst). +``` +"TFLite recipe" -> tflchef -> "tflite file" -> tflite2circle -> "circle file" +``` + +Step 2: Run TFLite interpreter and luci-interpreter for the generated tflite and circle, respectively. +(with the same input tensors filled with random values) +``` +circle file -> luci-interpreter -------> Execution result 1 +tflite file -> TFLite interpreter -----> Execution result 2 +``` + +Step 3: Compare the execution result 1 and 2. The result must be the same. diff --git a/compiler/luci-value-py-test/conftest.py b/compiler/luci-value-py-test/conftest.py new file mode 100644 index 000000000..042a265aa --- /dev/null +++ b/compiler/luci-value-py-test/conftest.py @@ -0,0 +1,74 @@ +import re + + +def extract_test_args(s): + p = re.compile('eval\\((.*)\\)') + result = p.search(s) + return result.group(1) + + +def pytest_addoption(parser): + parser.addoption("--test_list", action="store", help="Path to test list") + parser.addoption("--artifacts", action="store", help="Path to test artifacts") + parser.addoption( + "--target_artifacts", action="store", help="Path to test target artifacts") + parser.addoption( + "--luci_eval_driver", action="store", help="Path to luci eval driver") + + +def pytest_generate_tests(metafunc): + list_path = metafunc.config.getoption('test_list') + artifacts_path = metafunc.config.getoption('artifacts') + target_artifacts_path = metafunc.config.getoption('target_artifacts') + eval_driver_path = metafunc.config.getoption('luci_eval_driver') + if list_path is None: + tests_default_tol = [] + tests_with_tol = [] + ref_tests_default_tol = [] + ref_tests_with_tol = [] + else: + with open(list_path) as f: + contents = [line.rstrip() for line in f] + + comment_removed = [line for line in contents if not line.startswith('#')] + newline_removed = [line for line in comment_removed if line.startswith('eval(')] + test_args = [extract_test_args(line) for line in newline_removed] + # eval(TEST_NAME) + tests_default_tol = [(arg, artifacts_path, eval_driver_path) for arg in test_args + if len(arg.split()) == 1] + # eval(TEST_NAME RTOL ATOL) + tests_with_tol = [(arg.split()[0], artifacts_path, eval_driver_path, + arg.split()[1], arg.split()[2]) for arg in test_args + if len(arg.split()) == 3] + + if 'default_test_name' in metafunc.fixturenames: + metafunc.parametrize('default_test_name,artifacts_path,eval_driver_path', + tests_default_tol) + + if 'tol_test_name' in metafunc.fixturenames: + metafunc.parametrize( + 'tol_test_name,artifacts_path,eval_driver_path,rtolf32,atolf32', + tests_with_tol) + + if target_artifacts_path is not None: + # eval(TEST_NAME) + ref_tests_default_tol = [(arg, artifacts_path, target_artifacts_path, + eval_driver_path) for arg in test_args + if len(arg.split()) == 1] + # eval(TEST_NAME RTOL ATOL) + ref_tests_with_tol = [(arg.split()[0], artifacts_path, + target_artifacts_path, eval_driver_path, arg.split()[1], + arg.split()[2]) for arg in test_args + if len(arg.split()) == 3] + # + # for cross platform test + # + if 'default_ref_test_name' in metafunc.fixturenames: + metafunc.parametrize( + 'default_ref_test_name,ref_artifacts_path,target_artifacts_path,eval_driver_path', + ref_tests_default_tol) + + if 'tol_ref_test_name' in metafunc.fixturenames: + metafunc.parametrize( + 'tol_ref_test_name,ref_artifacts_path,target_artifacts_path,eval_driver_path,rtolf32,atolf32', + ref_tests_with_tol) diff --git a/compiler/luci-value-py-test/requires.cmake b/compiler/luci-value-py-test/requires.cmake new file mode 100644 index 000000000..c4461490a --- /dev/null +++ b/compiler/luci-value-py-test/requires.cmake @@ -0,0 +1,2 @@ +require("common-artifacts") +require("luci-eval-driver") diff --git a/compiler/luci-value-py-test/test.lst b/compiler/luci-value-py-test/test.lst new file mode 100644 index 000000000..2c928fb8d --- /dev/null +++ b/compiler/luci-value-py-test/test.lst @@ -0,0 +1,211 @@ +eval(Abs_000) +eval(Add_000) +eval(Add_001) +eval(Add_U8_000) +#eval(AddN_000) +eval(ArgMax_000) +eval(ArgMax_001) +eval(ArgMax_002) +eval(ArgMax_003) +eval(ArgMax_U8_000) +eval(ArgMax_U8_001) +eval(ArgMax_U8_002) +eval(ArgMax_U8_003) +#eval(ArgMin_000) +#eval(ArgMin_001) +#eval(ArgMin_002) +#eval(ArgMin_003) +#eval(ArgMin_U8_000) +#eval(ArgMin_U8_001) +#eval(ArgMin_U8_002) +#eval(ArgMin_U8_003) +eval(AveragePool2D_000) +eval(BatchMatMul_000) +#eval(BatchMatMulV2_000) +#eval(BatchMatMulV2_001) +#eval(BatchToSpaceND_000) +eval(BroadcastTo_001) +eval(Cast_000) +eval(Cast_001) +#eval(Ceil_000) +eval(Concatenation_000) +eval(Concatenation_U8_000) +eval(Conv2D_000) +eval(Conv2D_001) +eval(Conv2D_002) +eval(Conv2D_003) +#eval(Conv2D_U8_000) --> test with tolerance +eval(Conv2D_U8_001) +eval(Cos_000) +eval(CumSum_000) +eval(DepthToSpace_000) +eval(DepthwiseConv2D_000) +eval(DepthwiseConv2D_U8_000) +#eval(DepthwiseConv2D_U8_001) +eval(DepthwiseConv2D_001) +eval(Div_000) +eval(ELU_000) +eval(Equal_000) +eval(Exp_000) +eval(ExpandDims_000) +#eval(ExpandDims_001) +#eval(ExpandDims_002) +eval(ExpandDims_003) +#eval(Fill_000) +#eval(Fill_001) +eval(Floor_000) +eval(FloorDiv_000) +eval(FloorDiv_001) +eval(FloorMod_000) +eval(FloorMod_001) +eval(FullyConnected_000) +eval(FullyConnected_001) +eval(FullyConnected_002) +#eval(FullyConnected_U8_000) +eval(Gather_000) +#eval(GatherNd_000) +eval(Gelu_000) +eval(Greater_000) +eval(GreaterEqual_000) +eval(HardSwish_000) +eval(If_000) +eval(If_001) +eval(L2Normalize_000) +eval(L2Pool2D_000) +#eval(L2Pool2D_U8_000) +eval(LeakyRelu_000) +eval(Less_000) +eval(LessEqual_000) +eval(LocalResponseNormalization_000) +eval(Log_000) +eval(LogicalAnd_000) +eval(LogicalNot_000) +eval(LogicalOr_000) +eval(Logistic_000) +eval(LogSoftmax_000) +#eval(MatMul_000) +#eval(MatrixDiag_000) +#eval(MatrixSetDiag_000) +eval(Maximum_000) +eval(MaxPool2D_000) +eval(MaxPool2D_U8_000) +eval(Mean_000) +eval(Mean_001) +eval(Mean_U8_000) +eval(Minimum_000) +#eval(MirrorPad_000) +eval(Mul_000) +#eval(Mul_U8_000) +eval(Neg_000) +eval(NotEqual_000) +eval(OneHot_000) +eval(OneHot_001) +eval(OneHot_002) +#eval(OneHot_003) +eval(Pack_000) +eval(Pack_U8_000) +eval(Pad_000) +eval(Pad_U8_000) +eval(Pow_000) +eval(PRelu_000) +#eval(Range_000) +#eval(Rank_000) +#eval(ReduceAny_000) +#eval(ReduceAny_001) +#eval(ReduceAny_002) +#eval(ReduceAny_003) +eval(ReduceMax_000) +#eval(ReduceMin_000) +eval(ReduceProd_000) +eval(ReduceProd_001) +eval(ReduceProd_002) +eval(ReduceProd_003) +eval(ReLU_000) +eval(ReLU0To1_000) +eval(ReLU6_000) +#eval(ReLUN1To1_000) +eval(Reshape_000) +eval(Reshape_001) +eval(Reshape_002) +#eval(Reshape_003) +eval(Reshape_U8_000) +eval(ResizeBilinear_000) +eval(ResizeNearestNeighbor_000) +#eval(ReverseSequence_000) +#eval(ReverseV2_000) +#eval(Round_000) +eval(Rsqrt_000) +#eval(ScatterNd_000) +#eval(SegmentSum_000) +eval(Select_000) +eval(Select_001) +eval(Select_002) +eval(SelectV2_000) +eval(SelectV2_001) +eval(SelectV2_002) +eval(Shape_000) +eval(SignatureDef_MultiOut_000) +eval(SignatureDef_MultiOut_001) +eval(Sin_000) +eval(Slice_000) +eval(Softmax_000) +eval(Softmax_U8_000) +eval(SpaceToBatchND_000) +eval(SpaceToBatchND_001) +eval(SpaceToBatchND_002) +eval(SpaceToBatchND_003) +eval(SpaceToDepth_000) +#eval(SparseToDense_000) +eval(Split_000) +eval(SplitV_000) +eval(Sqrt_000) +eval(Square_000) +eval(SquaredDifference_000) +eval(Squeeze_000) +eval(Squeeze_001) +eval(StridedSlice_000) +eval(StridedSlice_001) +eval(StridedSlice_002) +eval(StridedSlice_003) +eval(StridedSlice_004) +eval(Sub_000) +eval(Sub_U8_000) +eval(Sum_000) +eval(Sum_001) +eval(Tanh_000) +eval(Tile_000) +eval(Tile_001) +eval(Tile_002) +#eval(Tile_U8_000) +#eval(TopKV2_000) +#eval(TopKV2_001) +eval(Transpose_000) +eval(TransposeConv_000) +eval(UnidirectionalSequenceLSTM_002) +eval(UnidirectionalSequenceLSTM_003) +eval(UnidirectionalSequenceLSTM_004) +eval(Unpack_000) +eval(Unpack_001) +eval(Unpack_002) +eval(Unpack_003) +eval(UnidirectionalSequenceLSTM_002) +#eval(Where_000) +#eval(Where_001) +eval(While_000) +eval(While_001) +eval(While_002) +#eval(While_003) +#eval(ZerosLike_000) + +# Simple Network test +eval(Part_While_000) +eval(Part_While_001) + +# Tests with tolerance +eval(SVDF_000 8e-3 8e-3) +eval(SVDF_001 8e-3 8e-3) +# TODO fix Conv2D_U8_000 to test without tolerenace +# refer https://github.com/Samsung/ONE/issues/11255#issuecomment-1685424361 +eval(Conv2D_U8_000 5 5) +# refer https://github.com/Samsung/ONE/issues/10438 +eval(YUV_TO_RGB_U8_000 1 1) diff --git a/compiler/luci-value-py-test/test_luci_eval.py b/compiler/luci-value-py-test/test_luci_eval.py new file mode 100644 index 000000000..b3fa4422b --- /dev/null +++ b/compiler/luci-value-py-test/test_luci_eval.py @@ -0,0 +1,140 @@ +import numpy as np +import tensorflow as tf +import subprocess +import os + + +def luci_eval_verify(test_name, artifacts, eval_driver, rtolf32=1e-5, atolf32=1e-5): + tflite_model = os.path.join(artifacts, test_name + ".tflite") + circle_model = os.path.join(artifacts, test_name + ".circle") + + # NOTE reuse f32 value as int value too + rtolint = int(rtolf32) + atolint = int(atolf32) + + # Build TFLite interpreter. + interpreter = tf.lite.Interpreter(tflite_model) + interpreter.allocate_tensors() + + # Read SignatureDef and get output tensor id orders for remapping + full_signatures = interpreter._get_full_signature_list() + full_signatures_outputs_remap = None + if full_signatures != None: + signature_serving_default = full_signatures.get('serving_default', None) + if signature_serving_default != None: + signature_outputs = signature_serving_default['outputs'] + + full_signatures_outputs_remap = [] + for index, (key, value) in enumerate(signature_outputs.items()): + full_signatures_outputs_remap.append(value) + + # Generate random input data. + num_inputs = len(interpreter.get_input_details()) + for i in range(num_inputs): + input_details = interpreter.get_input_details()[i] + if input_details["dtype"] == np.float32: + input_data = np.array( + np.random.random_sample(input_details["shape"]), input_details["dtype"]) + input_dtype = "float32" + elif input_details["dtype"] == np.uint8: + input_data = np.array( + np.random.randint(0, 256, size=input_details["shape"]), + input_details["dtype"]) + input_dtype = "uint8" + elif input_details["dtype"] == np.int16: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + input_dtype = "int16" + elif input_details["dtype"] == np.int32: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + input_dtype = "int32" + elif input_details["dtype"] == np.int64: + input_data = np.array( + np.random.randint(0, 100, size=input_details["shape"]), + input_details["dtype"]) + input_dtype = "int64" + elif input_details["dtype"] == np.bool_: + input_data = np.array( + np.random.choice(a=[True, False], size=input_details["shape"]), + input_details["dtype"]) + input_dtype = "bool" + else: + assert False, "Unsupported input dtype" + + interpreter.set_tensor(input_details["index"], input_data) + input_data.tofile(circle_model + ".input" + str(i)) + input_details["shape"].tofile( + circle_model + ".input" + str(i) + ".shape", sep=',') + with open(circle_model + ".input" + str(i) + ".dtype", 'w') as dtype_file: + dtype_file.write(input_dtype) + + # Do inference + interpreter.invoke() + + # Execute luci interpreter. + subprocess.run( + [ + eval_driver, circle_model, + str(num_inputs), circle_model + ".input", circle_model + ".output" + ], + check=True) + + # Compare the results. + inpt_output_details = interpreter.get_output_details() + for idx in range(len(inpt_output_details)): + output_details = inpt_output_details[idx] + output_data = np.fromfile(circle_model + ".output" + str(idx), + output_details["dtype"]) + shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r') + output_shape = [int(i) for i in shape_file.read().split(',')] + luci_output_data = np.reshape(output_data, output_shape) + output_tensor = output_details["index"] + if full_signatures_outputs_remap != None: + output_tensor = full_signatures_outputs_remap[idx] + intp_output_data = interpreter.get_tensor(output_tensor) + err_msg = "Execution result of " + tflite_model + " does not match with " + circle_model + if output_details["dtype"] == np.uint8: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + output_dtype = "uint8" + elif output_details["dtype"] == np.float32: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32), err_msg + output_dtype = "float32" + elif output_details["dtype"] == np.int64: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + output_dtype = "int64" + elif output_details["dtype"] == np.int32: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + output_dtype = "int32" + elif output_details["dtype"] == np.int16: + assert np.allclose( + luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg + output_dtype = "int16" + elif output_details["dtype"] == np.bool_: + assert np.allclose( + luci_output_data, intp_output_data, rtol=0, atol=0), err_msg + output_dtype = "bool" + else: + assert False, "Unsupported data type: " + output_details["dtype"] + + # save outputN.dtype file + with open(circle_model + ".output" + str(idx) + ".dtype", 'w') as dtype_file: + dtype_file.write(output_dtype) + + +# arguments must be in sync with `conftest.py` +def test_luci_eval(default_test_name: str, artifacts_path: str, eval_driver_path: str): + luci_eval_verify(default_test_name, artifacts_path, eval_driver_path) + + +# arguments must be in sync with `conftest.py` +def test_luci_eval_tol(tol_test_name: str, artifacts_path: str, eval_driver_path: str, + rtolf32: str, atolf32: str): + luci_eval_verify(tol_test_name, artifacts_path, eval_driver_path, float(rtolf32), + float(atolf32)) diff --git a/compiler/luci-value-py-test/test_luci_eval_ref.py b/compiler/luci-value-py-test/test_luci_eval_ref.py new file mode 100644 index 000000000..f476c78fa --- /dev/null +++ b/compiler/luci-value-py-test/test_luci_eval_ref.py @@ -0,0 +1,137 @@ +import numpy as np +import tensorflow as tf +import subprocess +import os + +# +# This script compares the execution result of luci-interpreter with that from ref_model path +# +# Basic usage: +# luci_eval_verifier_ref.py --driver build/compiler/luci-eval-driver/luci_eval_driver +# --ref_model ref_model_path --model this_model_path +# Assumption: +# these file exist with its purpose +# - ref_model_path.circle; circle model +# - ref_model_path.circle.inputN; N'th input numpy data +# - ref_model_path.circle.inputN.dtype; N'th input data type in text +# - ref_model_path.circle.inputN.shape; N'th input data shape in CSV +# - ref_model_path.circle.outputN; N'th output numpy data +# - ref_model_path.circle.outputN.dtype; N'th output data type in text +# - ref_model_path.circle.outputN.shape; N'th output data shape in CSV + + +def dtype_from_file(file_path): + with open(file_path, 'r') as dtype_file: + dtype_str = dtype_file.read() + if dtype_str == "float32": + return np.float32 + if dtype_str == "uint8": + return np.uint8 + if dtype_str == "int16": + return np.int16 + if dtype_str == "int32": + return np.int32 + if dtype_str == "int64": + return np.int64 + if dtype_str == "bool": + return np.bool_ + assert False, "Unsupported dtype from file: " + dtype_str + + +def luci_eval_verify_ref(test_name, + ref_artifacts, + target_artifacts, + eval_driver, + rtolf32=1e-5, + atolf32=1e-5): + circle_model_ref = os.path.join(ref_artifacts, test_name + ".circle") + circle_model = os.path.join(target_artifacts, test_name + ".circle") + + # NOTE reuse f32 value as int value too + rtolint = int(rtolf32) + atolint = int(atolf32) + + # get num of inputs by checking existance of model.inputN + check_input = 0 + while True: + input_file_path = circle_model_ref + ".input" + str(check_input) + if not os.path.isfile(input_file_path): + num_inputs = check_input + break + check_input = check_input + 1 + + assert num_inputs != 0, "input file not exist for " + circle_model_ref + + # get num of outputs by checking existance of model.outputN + check_output = 0 + while True: + output_file_path = circle_model_ref + ".output" + str(check_output) + if not os.path.isfile(output_file_path): + num_outputs = check_output + break + check_output = check_output + 1 + + assert num_outputs != 0, "output file not exist for " + circle_model_ref + + # Execute luci interpreter with reference input + subprocess.run( + [ + eval_driver, circle_model_ref, + str(num_inputs), circle_model_ref + ".input", circle_model + ".output" + ], + check=True) + + # Compare the results. + for idx in range(num_outputs): + output_dtype = dtype_from_file(circle_model_ref + ".output" + str(idx) + ".dtype") + shape_file = open(circle_model_ref + ".output" + str(idx) + ".shape", 'r') + output_shape = [int(i) for i in shape_file.read().split(',')] + + output_data_ref = np.fromfile(circle_model_ref + ".output" + str(idx), + output_dtype) + luci_output_data_ref = np.reshape(output_data_ref, output_shape) + + output_data = np.fromfile(circle_model + ".output" + str(idx), output_dtype) + luci_output_data = np.reshape(output_data, output_shape) + + err_msg = "Execution result of " + circle_model_ref + " does not match with " + circle_model + if output_dtype == np.uint8: + assert np.allclose( + luci_output_data, luci_output_data_ref, rtol=rtolint, + atol=atolint), err_msg + elif output_dtype == np.float32: + assert np.allclose( + luci_output_data, luci_output_data_ref, rtol=rtolf32, + atol=atolf32), err_msg + elif output_dtype == np.int64: + assert np.allclose( + luci_output_data, luci_output_data_ref, rtol=rtolint, + atol=atolint), err_msg + elif output_dtype == np.int32: + assert np.allclose( + luci_output_data, luci_output_data_ref, rtol=rtolint, + atol=atolint), err_msg + elif output_dtype == np.int16: + assert np.allclose( + luci_output_data, luci_output_data_ref, rtol=rtolint, + atol=atolint), err_msg + elif output_dtype == np.bool_: + assert np.allclose( + luci_output_data, luci_output_data_ref, rtol=0, atol=0), err_msg + else: + assert False, "Unsupported data type: " + output_dtype + + +# arguments must be in sync with `conftest.py` +def test_luci_eval_ref(default_ref_test_name: str, ref_artifacts_path: str, + target_artifacts_path: str, eval_driver_path: str): + luci_eval_verify_ref(default_ref_test_name, ref_artifacts_path, target_artifacts_path, + eval_driver_path) + + +# arguments must be in sync with `conftest.py` +def test_luci_eval_tol_ref(tol_ref_test_name: str, ref_artifacts_path: str, + target_artifacts_path: str, eval_driver_path: str, + rtolf32: str, atolf32: str): + luci_eval_verify_ref(tol_ref_test_name, ref_artifacts_path, target_artifacts_path, + eval_driver_path, float(rtolf32), float(atolf32)) diff --git a/compiler/luci-value-test/exclude.me b/compiler/luci-value-test/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/luci-value-test/test.lst b/compiler/luci-value-test/test.lst index 3368b6450..591421eb3 100644 --- a/compiler/luci-value-test/test.lst +++ b/compiler/luci-value-test/test.lst @@ -35,7 +35,8 @@ addeval(Conv2D_002) addeval(Conv2D_003) #addeval(Conv2D_U8_000) --> test with tolerance addeval(Conv2D_U8_001) -#addeval(Cos_000) +addeval(Cos_000) +addeval(CumSum_000) addeval(DepthToSpace_000) addeval(DepthwiseConv2D_000) addeval(DepthwiseConv2D_U8_000) @@ -143,7 +144,7 @@ addeval(Select_002) #addeval(Shape_000) addeval(SignatureDef_MultiOut_000) addeval(SignatureDef_MultiOut_001) -#addeval(Sin_000) +addeval(Sin_000) addeval(Slice_000) addeval(Softmax_000) addeval(Softmax_U8_000) @@ -170,7 +171,9 @@ addeval(Sub_U8_000) addeval(Sum_000) addeval(Sum_001) addeval(Tanh_000) -#addeval(Tile_000) +addeval(Tile_000) +addeval(Tile_001) +addeval(Tile_002) #addeval(Tile_U8_000) #addeval(TopKV2_000) #addeval(TopKV2_001) @@ -201,6 +204,6 @@ addevaltol(SVDF_000 8e-3 8e-3) addevaltol(SVDF_001 8e-3 8e-3) # TODO fix Conv2D_U8_000 to test without tolerenace # refer https://github.com/Samsung/ONE/issues/11255#issuecomment-1685424361 -addeval(Conv2D_U8_000 1 1) +addevaltol(Conv2D_U8_000 5 5) # refer https://github.com/Samsung/ONE/issues/10438 addevaltol(YUV_TO_RGB_U8_000 1 1) diff --git a/compiler/luci/export/CMakeLists.txt b/compiler/luci/export/CMakeLists.txt index fb0e20e00..bc10ad24c 100644 --- a/compiler/luci/export/CMakeLists.txt +++ b/compiler/luci/export/CMakeLists.txt @@ -12,7 +12,7 @@ target_include_directories(luci_export PUBLIC include) target_link_libraries(luci_export PRIVATE luci_lang) target_link_libraries(luci_export PRIVATE luci_service) target_link_libraries(luci_export PRIVATE luci_pass) -target_link_libraries(luci_export PRIVATE mio_circle06) +target_link_libraries(luci_export PRIVATE mio_circle08) target_link_libraries(luci_export PRIVATE luci_env) target_link_libraries(luci_export PRIVATE luci_log) target_link_libraries(luci_export PRIVATE luci_logex) @@ -36,6 +36,6 @@ target_include_directories(luci_export_test PRIVATE src) target_link_libraries(luci_export_test luci_export) target_link_libraries(luci_export_test luci_plan) target_link_libraries(luci_export_test luci_lang) -target_link_libraries(luci_export_test mio_circle06) +target_link_libraries(luci_export_test mio_circle08) target_link_libraries(luci_export_test luci_env) target_link_libraries(luci_export_test oops) diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h index 811373ffe..efc2a5106 100644 --- a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h +++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h @@ -87,6 +87,7 @@ public: node->asymmetric_quantize_inputs()) .Union(); } + flatbuffers::Offset visit(luci::CircleBroadcastTo *) { return _no_option; } flatbuffers::Offset visit(luci::CircleCast *node) { if (node->out_data_type() == loco::DataType::Unknown) @@ -117,6 +118,10 @@ public: { return circle::CreateCosOptions(_builder).Union(); } + flatbuffers::Offset visit(luci::CircleCumSum *node) + { + return circle::CreateCumsumOptions(_builder, node->exclusive(), node->reverse()).Union(); + } flatbuffers::Offset visit(luci::CircleCustom *) { return _no_option; } flatbuffers::Offset visit(luci::CircleDensify *) { @@ -353,6 +358,7 @@ public: return circle::CreateReducerOptions(_builder, node->keep_dims()).Union(); } flatbuffers::Offset visit(luci::CircleRelu *) { return _no_option; } + flatbuffers::Offset visit(luci::CircleRelu0To1 *) { return _no_option; } flatbuffers::Offset visit(luci::CircleRelu6 *) { return _no_option; } flatbuffers::Offset visit(luci::CircleReluN1To1 *) { return _no_option; } flatbuffers::Offset visit(luci::CircleReshape *node) @@ -530,6 +536,12 @@ public: return circle::CreateBCQGatherOptions(_builder, node->input_hidden_size(), node->axis()) .Union(); } + flatbuffers::Offset visit(luci::CircleGRU *node) + { + return circle::CreateGRUOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()), + node->returnSequences(), node->timeMajor()) + .Union(); + } flatbuffers::Offset visit(luci::CircleInstanceNorm *node) { return circle::CreateInstanceNormOptions(_builder, node->epsilon(), diff --git a/compiler/luci/export/src/CircleExportMetadata.cpp b/compiler/luci/export/src/CircleExportMetadata.cpp index 017002f5c..25d0168ec 100644 --- a/compiler/luci/export/src/CircleExportMetadata.cpp +++ b/compiler/luci/export/src/CircleExportMetadata.cpp @@ -56,7 +56,7 @@ const std::vector CircleExportMetadata::encoded_execution_plan_table() const auto id = kv.first; write_u32(data, id); - const auto plan_vector = kv.second; + const auto &plan_vector = kv.second; const auto size = plan_vector.size(); write_u32(data, size); @@ -81,7 +81,7 @@ const std::vector CircleExportMetadata::encoded_source_table(void) const auto id = kv.first; write_u32(data, id); - const auto origin_name = kv.second; + const auto &origin_name = kv.second; const auto length = origin_name.length(); write_u32(data, length + 1); // name + '\0 @@ -107,7 +107,7 @@ const std::vector CircleExportMetadata::encoded_op_table(void) const auto id = kv.first; write_u32(data, id); - const auto origins = kv.second; + const auto &origins = kv.second; const auto node_num = origins.size(); write_u32(data, node_num); diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp index 083add9be..014ef45d7 100644 --- a/compiler/luci/export/src/CircleExporterImpl.cpp +++ b/compiler/luci/export/src/CircleExporterImpl.cpp @@ -76,7 +76,7 @@ Offset>> encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map &opcodes) { std::vector> operator_codes_vec(opcodes.size()); - for (auto it : opcodes) + for (const auto &it : opcodes) { uint32_t idx = it.second; int8_t dep_code = 127; // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES @@ -117,8 +117,7 @@ CircleExporterImpl::exportSubgraph(SerializedGraphData &gd) auto outputs = _builder.CreateVector(gd._outputs); auto operators = _builder.CreateVector(gd._operators); auto name = _builder.CreateString(gd._name); - auto df = gd._data_format; - auto subgraph = CreateSubGraph(_builder, tensors, inputs, outputs, operators, name, df); + auto subgraph = CreateSubGraph(_builder, tensors, inputs, outputs, operators, name); return subgraph; } @@ -202,9 +201,6 @@ void CircleExporterImpl::exportModule(Module *module) // set Subgraph name gd._name = graph->name(); - // TODO set this value properly - gd._data_format = circle::DataFormat::DataFormat_CHANNELS_LAST; - // parse graph into SerializedModelData structure exportOpDefinedTensors(graph, _builder, md, gd); diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp index 9473c2c4e..6678b0dc3 100644 --- a/compiler/luci/export/src/CircleExporterUtils.cpp +++ b/compiler/luci/export/src/CircleExporterUtils.cpp @@ -50,9 +50,13 @@ circle::TensorType to_circle_tensortype(loco::DataType type) { switch (type) { + case loco::DataType::U4: + return circle::TensorType_UINT4; case loco::DataType::U8: return circle::TensorType_UINT8; + case loco::DataType::S4: + return circle::TensorType_INT4; case loco::DataType::S8: return circle::TensorType_INT8; case loco::DataType::S16: diff --git a/compiler/luci/export/src/CircleOps.lst b/compiler/luci/export/src/CircleOps.lst index a047f29d7..8c693baca 100644 --- a/compiler/luci/export/src/CircleOps.lst +++ b/compiler/luci/export/src/CircleOps.lst @@ -26,11 +26,13 @@ CIRCLE_NODE(CircleAveragePool2D, BuiltinOperator_AVERAGE_POOL_2D , BuiltinOption CIRCLE_NODE(CircleBatchToSpaceND, BuiltinOperator_BATCH_TO_SPACE_ND, BuiltinOptions_BatchToSpaceNDOptions) CIRCLE_NODE(CircleBatchMatMul, BuiltinOperator_BATCH_MATMUL, BuiltinOptions_BatchMatMulOptions) CIRCLE_NODE(CircleBidirectionalSequenceLSTM, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_BidirectionalSequenceLSTMOptions) +CIRCLE_NODE(CircleBroadcastTo, BuiltinOperator_BROADCAST_TO, BuiltinOptions_NONE) CIRCLE_NODE(CircleCast, BuiltinOperator_CAST, BuiltinOptions_CastOptions) CIRCLE_NODE(CircleCeil, BuiltinOperator_CEIL, BuiltinOptions_NONE) CIRCLE_NODE(CircleConcatenation, BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions) CIRCLE_NODE(CircleConv2D, BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions) CIRCLE_NODE(CircleCos, BuiltinOperator_COS, BuiltinOptions_CosOptions) +CIRCLE_NODE(CircleCumSum, BuiltinOperator_CUMSUM, BuiltinOptions_CumsumOptions) CIRCLE_NODE(CircleCustom, BuiltinOperator_CUSTOM, BuiltinOptions_NONE) CIRCLE_NODE(CircleDensify, BuiltinOperator_DENSIFY, BuiltinOptions_DensifyOptions) CIRCLE_NODE(CircleDepthToSpace, BuiltinOperator_DEPTH_TO_SPACE, BuiltinOptions_DepthToSpaceOptions) @@ -92,6 +94,7 @@ CIRCLE_NODE(CircleReduceMax, BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerO CIRCLE_NODE(CircleReduceMin, BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions) CIRCLE_NODE(CircleReduceProd, BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions) CIRCLE_NODE(CircleRelu, BuiltinOperator_RELU, BuiltinOptions_NONE) +CIRCLE_NODE(CircleRelu0To1, BuiltinOperator_RELU_0_TO_1, BuiltinOptions_NONE) CIRCLE_NODE(CircleRelu6, BuiltinOperator_RELU6, BuiltinOptions_NONE) CIRCLE_NODE(CircleReluN1To1, BuiltinOperator_RELU_N1_TO_1, BuiltinOptions_NONE) CIRCLE_NODE(CircleReshape, BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions) @@ -136,6 +139,7 @@ CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLik // Circle Only CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions) CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions) +CIRCLE_NODE(CircleGRU, BuiltinOperator_GRU, BuiltinOptions_GRUOptions) CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions) // Virtual node(s) CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut) diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp index 97e81076b..0022a0e57 100644 --- a/compiler/luci/export/src/CircleTensorExporter.cpp +++ b/compiler/luci/export/src/CircleTensorExporter.cpp @@ -409,6 +409,31 @@ encodeOpBufferByDType(FlatBufferBuilder &builder, luci:: return CreateBuffer(builder, array_offset); } +template +flatbuffers::Offset encodeOpBufferPack4bit(FlatBufferBuilder &builder, + luci::CircleConst *c) +{ + const uint32_t size = c->size
(); + const uint32_t raw_size = (size + 1) / 2; + std::vector raw_data(raw_size); + + for (uint32_t i = 0; i < raw_size; ++i) + { + uint32_t sidx = i * 2; + uint8_t data = static_cast(c->at
(sidx)); + raw_data[i] = data & 0x0f; + sidx++; + if (sidx < size) + { + data = static_cast(c->at
(sidx)); + raw_data[i] |= data << 4; + } + } + + auto array_offset = builder.CreateVector(raw_data.data(), raw_size); + return CreateBuffer(builder, array_offset); +} + template <> flatbuffers::Offset encodeOpBuffer(FlatBufferBuilder &builder, luci::CircleConst *c) { @@ -416,6 +441,8 @@ flatbuffers::Offset encodeOpBuffer(FlatBufferBuilder &builder, l { case loco::DataType::FLOAT32: return encodeOpBufferByDType(builder, c); + case loco::DataType::S4: + return encodeOpBufferPack4bit(builder, c); case loco::DataType::S8: return encodeOpBufferByDType(builder, c); case loco::DataType::S16: @@ -424,6 +451,8 @@ flatbuffers::Offset encodeOpBuffer(FlatBufferBuilder &builder, l return encodeOpBufferByDType(builder, c); case loco::DataType::S64: return encodeOpBufferByDType(builder, c); + case loco::DataType::U4: + return encodeOpBufferPack4bit(builder, c); case loco::DataType::U8: return encodeOpBufferByDType(builder, c); case loco::DataType::BOOL: @@ -477,7 +506,7 @@ encodeSparsityParameters(FlatBufferBuilder &builder, luci::SparsityParam *sparsi std::vector> dim_metadata_vec; auto luci_dim_metadata = sparsityparam->dim_metadata; - for (auto it : luci_dim_metadata) + for (const auto &it : luci_dim_metadata) { // array_segments auto circle_array_segments = to_circle_sparse_index_vector(builder, it.array_segments()); @@ -526,6 +555,9 @@ bool has_same_values(luci::CircleConst *lhs, luci::CircleConst *rhs) case loco::DataType::FLOAT32: return has_same_elements(lhs, rhs); + case loco::DataType::S4: + return has_same_elements(lhs, rhs); + case loco::DataType::S8: return has_same_elements(lhs, rhs); @@ -538,6 +570,9 @@ bool has_same_values(luci::CircleConst *lhs, luci::CircleConst *rhs) case loco::DataType::S64: return has_same_elements(lhs, rhs); + case loco::DataType::U4: + return has_same_elements(lhs, rhs); + case loco::DataType::U8: return has_same_elements(lhs, rhs); diff --git a/compiler/luci/import/CMakeLists.txt b/compiler/luci/import/CMakeLists.txt index 2e7e88118..8c1da0e77 100644 --- a/compiler/luci/import/CMakeLists.txt +++ b/compiler/luci/import/CMakeLists.txt @@ -12,7 +12,7 @@ target_include_directories(luci_import PUBLIC include) target_link_libraries(luci_import PUBLIC luci_lang) target_link_libraries(luci_import PUBLIC luci_profile) target_link_libraries(luci_import PUBLIC luci_plan) -target_link_libraries(luci_import PUBLIC mio_circle06) +target_link_libraries(luci_import PUBLIC mio_circle08) target_link_libraries(luci_import PRIVATE luci_env) target_link_libraries(luci_import PRIVATE luci_log) target_link_libraries(luci_import PRIVATE luci_logex) @@ -20,7 +20,7 @@ target_link_libraries(luci_import PRIVATE nncc_common) target_link_libraries(luci_import PRIVATE locop) target_link_libraries(luci_import PRIVATE foder) target_link_libraries(luci_import PRIVATE oops) -target_link_libraries(luci_import PRIVATE mio_circle06_helper) +target_link_libraries(luci_import PRIVATE mio_circle08_helper) install(TARGETS luci_import DESTINATION lib) install(DIRECTORY include/ DESTINATION include FILES_MATCHING PATTERN "*.h") diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h index a0519f661..36e3cdf3c 100644 --- a/compiler/luci/import/include/luci/Import/CircleReader.h +++ b/compiler/luci/import/include/luci/Import/CircleReader.h @@ -106,7 +106,6 @@ public: // direct API VectorWrapper inputs() const { return wrap(_current_subgraph->inputs()); } VectorWrapper outputs() const { return wrap(_current_subgraph->outputs()); } std::string name() const { return fb_string2std_string(_current_subgraph->name()); } - circle::DataFormat data_format() const { return _current_subgraph->data_format(); } CircleMetadataSet metadata() const { return wrap(_model->metadata()); } uint32_t num_subgraph() const { return wrap(_model->subgraphs()).size(); } diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h index e8c8d0aae..f3f4871b4 100644 --- a/compiler/luci/import/include/luci/Import/Nodes.h +++ b/compiler/luci/import/include/luci/Import/Nodes.h @@ -28,12 +28,14 @@ #include "Nodes/CircleBCQFullyConnected.h" #include "Nodes/CircleBCQGather.h" #include "Nodes/CircleBidirectionalSequenceLSTM.h" +#include "Nodes/CircleBroadcastTo.h" #include "Nodes/CircleCast.h" #include "Nodes/CircleCeil.h" #include "Nodes/CircleConcatenation.h" #include "Nodes/CircleConst.h" #include "Nodes/CircleConv2D.h" #include "Nodes/CircleCos.h" +#include "Nodes/CircleCumSum.h" #include "Nodes/CircleCustom.h" #include "Nodes/CircleDensify.h" #include "Nodes/CircleDepthToSpace.h" @@ -55,6 +57,7 @@ #include "Nodes/CircleGelu.h" #include "Nodes/CircleGreater.h" #include "Nodes/CircleGreaterEqual.h" +#include "Nodes/CircleGRU.h" #include "Nodes/CircleHardSwish.h" #include "Nodes/CircleIf.h" #include "Nodes/CircleInstanceNorm.h" @@ -96,6 +99,7 @@ #include "Nodes/CircleReduceMin.h" #include "Nodes/CircleReduceProd.h" #include "Nodes/CircleRelu.h" +#include "Nodes/CircleRelu0To1.h" #include "Nodes/CircleRelu6.h" #include "Nodes/CircleReluN1To1.h" #include "Nodes/CircleReshape.h" diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleBroadcastTo.h b/compiler/luci/import/include/luci/Import/Nodes/CircleBroadcastTo.h new file mode 100644 index 000000000..9c6f6f49c --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleBroadcastTo.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_BROADCAST_TO_H__ +#define __LUCI_IMPORT_OP_CIRCLE_BROADCAST_TO_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleBroadcastToGraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_BROADCAST_TO_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleCumSum.h b/compiler/luci/import/include/luci/Import/Nodes/CircleCumSum.h new file mode 100644 index 000000000..adf7c6cf2 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleCumSum.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_CUMSUM_H__ +#define __LUCI_IMPORT_OP_CIRCLE_CUMSUM_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleCumSumGraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_CUMSUM_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h b/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h new file mode 100644 index 000000000..ed7935e1f --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_GRU_H__ +#define __LUCI_IMPORT_OP_CIRCLE_GRU_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleGRUGraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_GRU_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleRelu0To1.h b/compiler/luci/import/include/luci/Import/Nodes/CircleRelu0To1.h new file mode 100644 index 000000000..3254d77a3 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleRelu0To1.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_RELU_0_TO_1_H__ +#define __LUCI_IMPORT_OP_CIRCLE_RELU_0_TO_1_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleRelu0To1GraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_RELU_0_TO_1_H__ diff --git a/compiler/luci/import/src/CircleImportMetadata.cpp b/compiler/luci/import/src/CircleImportMetadata.cpp index 9c1fe7356..fbdea8a7c 100644 --- a/compiler/luci/import/src/CircleImportMetadata.cpp +++ b/compiler/luci/import/src/CircleImportMetadata.cpp @@ -236,7 +236,7 @@ const OriginTable CircleImportMetadata::origin_table(void) std::vector> origins; for (auto source_id : source_ids) { - const auto source_name = _source_table.at(source_id); + const auto &source_name = _source_table.at(source_id); origins.push_back(single_origin(source_id, source_name)); } diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp index a42c3f913..9d05a1a53 100644 --- a/compiler/luci/import/src/CircleReader.cpp +++ b/compiler/luci/import/src/CircleReader.cpp @@ -54,6 +54,8 @@ loco::DataType luci_datatype(const circle::TensorType type) return loco::DataType::S32; case circle::TensorType_UINT8: return loco::DataType::U8; + case circle::TensorType_UINT4: + return loco::DataType::U4; case circle::TensorType_INT64: return loco::DataType::S64; case circle::TensorType_STRING: @@ -66,6 +68,8 @@ loco::DataType luci_datatype(const circle::TensorType type) break; case circle::TensorType_INT8: return loco::DataType::S8; + case circle::TensorType_INT4: + return loco::DataType::S4; default: break; } diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp index 9c868320d..29edf8348 100644 --- a/compiler/luci/import/src/GraphBuilderRegistry.cpp +++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp @@ -38,12 +38,14 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedGraphBuilder); // 253 CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherGraphBuilder); // 252 CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMGraphBuilder); // 52 + CIRCLE_NODE(BROADCAST_TO, CircleBroadcastToGraphBuilder); // 130 CIRCLE_NODE(CAST, CircleCastGraphBuilder); // 53 CIRCLE_NODE(CEIL, CircleCeilGraphBuilder); // 104 CIRCLE_NODE(CUSTOM, CircleCustomGraphBuilder); // 32 CIRCLE_NODE(CONCATENATION, CircleConcatenationGraphBuilder); // 2 CIRCLE_NODE(CONV_2D, CircleConv2DGraphBuilder); // 3 CIRCLE_NODE(COS, CircleCosGraphBuilder); // 108 + CIRCLE_NODE(CUMSUM, CircleCumSumGraphBuilder); // 128 CIRCLE_NODE(DENSIFY, CircleDensifyGraphBuilder); // 124 CIRCLE_NODE(DEPTH_TO_SPACE, CircleDepthToSpaceGraphBuilder); // 5 CIRCLE_NODE(DEPTHWISE_CONV_2D, CircleDepthwiseConv2DGraphBuilder); // 4 @@ -64,6 +66,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(GELU, CircleGeluGraphBuilder); // 150 CIRCLE_NODE(GREATER, CircleGreaterGraphBuilder); // 61 CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualGraphBuilder); // 62 + CIRCLE_NODE(GRU, CircleGRUGraphBuilder); // 251 CIRCLE_NODE(HARD_SWISH, CircleHardSwishGraphBuilder); // 117 CIRCLE_NODE(IF, CircleIfGraphBuilder); // 118 CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormGraphBuilder); // 254 @@ -105,6 +108,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(REDUCE_MIN, CircleReduceMinGraphBuilder); // 89 CIRCLE_NODE(REDUCE_PROD, CircleReduceProdGraphBuilder); // 81 CIRCLE_NODE(RELU, CircleReluGraphBuilder); // 19 + CIRCLE_NODE(RELU_0_TO_1, CircleRelu0To1GraphBuilder); // 152 CIRCLE_NODE(RELU6, CircleRelu6GraphBuilder); // 21 CIRCLE_NODE(RELU_N1_TO_1, CircleReluN1To1GraphBuilder); // 20 CIRCLE_NODE(RESHAPE, CircleReshapeGraphBuilder); // 22 diff --git a/compiler/luci/import/src/Importer.test.cpp b/compiler/luci/import/src/Importer.test.cpp index 91e4860ea..6967128c9 100644 --- a/compiler/luci/import/src/Importer.test.cpp +++ b/compiler/luci/import/src/Importer.test.cpp @@ -50,7 +50,6 @@ struct BasicCircleModel { model->subgraphs.push_back(std::make_unique()); model->subgraphs.back()->name = ""; - model->subgraphs.back()->data_format = circle::DataFormat_CHANNELS_LAST; return model->subgraphs.size() - 1; } diff --git a/compiler/luci/import/src/Nodes/CircleBroadcastTo.cpp b/compiler/luci/import/src/Nodes/CircleBroadcastTo.cpp new file mode 100644 index 000000000..2ccd13fb7 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleBroadcastTo.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleBroadcastTo.h" + +#include + +#include + +namespace luci +{ + +bool CircleBroadcastToGraphBuilder::validate(const ValidateArgs &args) const +{ + // TODO Support type check + return GraphBuilder::validate(args, 2); +} + +CircleNode *CircleBroadcastToGraphBuilder::build_node(const circle::OperatorT &, + const std::vector &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create(); + node->input(inputs.at(0)); + node->shape(inputs.at(1)); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleConst.cpp b/compiler/luci/import/src/Nodes/CircleConst.cpp index 88f2ae3d0..189f4d897 100644 --- a/compiler/luci/import/src/Nodes/CircleConst.cpp +++ b/compiler/luci/import/src/Nodes/CircleConst.cpp @@ -102,6 +102,38 @@ void copy_data(const VectorWrapper &raw_data, } } +// NOTE copy_data for S4, U4. +// this method will unpack two 4bit elements, packed in 8bit, +// to two 8bit elements, having values -8~7, for S4 and 0~15 for U4. +template +void copy_data_4(const VectorWrapper &raw_data, uint32_t num_elements, + CircleConst *const_node) +{ + using T = typename loco::DataTypeImpl
::Type; + + // TODO support sparse? + assert(const_node->sparsityparam() == nullptr); + if (const_node->sparsityparam()) + return; + + uint32_t raw_size = (num_elements + 1) / 2; + assert(raw_data.size() == raw_size); + + const uint8_t *data = raw_data.data(); + const_node->size
(num_elements); + for (uint32_t i = 0; i < raw_size; ++i) + { + uint32_t idx = i * 2; + // for S4, 1bit for sign, 3bit for value + const_node->at
(idx) = static_cast(data[i] << 4) >> 4; + if (idx < num_elements) + { + idx++; + const_node->at
(idx) = static_cast(data[i]) >> 4; + } + } +} + } // namespace namespace luci @@ -170,10 +202,18 @@ CircleNode *CircleConstNodeBuilder::build(TensorIndex tensor_index, copy_data(buffer, num_elements, const_node); break; + case loco::DataType::U4: + copy_data_4(buffer, num_elements, const_node); + break; + case loco::DataType::U8: copy_data(buffer, num_elements, const_node); break; + case loco::DataType::S4: + copy_data_4(buffer, num_elements, const_node); + break; + case loco::DataType::S8: copy_data(buffer, num_elements, const_node); break; diff --git a/compiler/luci/import/src/Nodes/CircleCumSum.cpp b/compiler/luci/import/src/Nodes/CircleCumSum.cpp new file mode 100644 index 000000000..b757fc743 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleCumSum.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleCumSum.h" + +#include + +#include + +namespace luci +{ + +bool CircleCumSumGraphBuilder::validate(const ValidateArgs &args) const +{ + return GraphBuilder::validate(args, 2); +} + +CircleNode *CircleCumSumGraphBuilder::build_node(const circle::OperatorT &op, + const std::vector &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create(); + node->input(inputs.at(0)); + node->axis(inputs.at(1)); + + const auto *options = op.builtin_options.AsCumsumOptions(); + node->exclusive(options->exclusive); + node->reverse(options->reverse); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleGRU.cpp b/compiler/luci/import/src/Nodes/CircleGRU.cpp new file mode 100644 index 000000000..51fe5f55c --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleGRU.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleGRU.h" + +#include + +#include + +namespace luci +{ + +bool CircleGRUGraphBuilder::validate(const ValidateArgs &args) const +{ + return GraphBuilder::validate(args, 6); +} + +CircleNode *CircleGRUGraphBuilder::build_node(const circle::OperatorT &, + const std::vector &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create(); + node->input(inputs.at(0)); + node->hidden_hidden(inputs.at(1)); + node->hidden_hidden_bias(inputs.at(2)); + node->hidden_input(inputs.at(3)); + node->hidden_input_bias(inputs.at(4)); + node->state(inputs.at(5)); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleRelu0To1.cpp b/compiler/luci/import/src/Nodes/CircleRelu0To1.cpp new file mode 100644 index 000000000..a6957847c --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleRelu0To1.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleRelu0To1.h" + +#include + +#include + +namespace luci +{ + +bool CircleRelu0To1GraphBuilder::validate(const ValidateArgs &args) const +{ + return GraphBuilder::validate(args, 1); +} + +CircleNode *CircleRelu0To1GraphBuilder::build_node(const circle::OperatorT &, + const std::vector &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create(); + node->features(inputs.at(0)); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/ValidateHelpers.cpp b/compiler/luci/import/src/ValidateHelpers.cpp index fc027704b..9943b80f6 100644 --- a/compiler/luci/import/src/ValidateHelpers.cpp +++ b/compiler/luci/import/src/ValidateHelpers.cpp @@ -79,6 +79,7 @@ bool validate_minmax(const GraphBuilderBase::ValidateArgs &args) case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: case circle::TensorType_FLOAT64: + case circle::TensorType_INT16: case circle::TensorType_INT32: case circle::TensorType_INT64: case circle::TensorType_UINT8: diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h index d643b0893..49036537a 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h @@ -26,11 +26,13 @@ #include "Nodes/CircleBatchMatMul.h" #include "Nodes/CircleBatchToSpaceND.h" #include "Nodes/CircleBidirectionalSequenceLSTM.h" +#include "Nodes/CircleBroadcastTo.h" #include "Nodes/CircleCast.h" #include "Nodes/CircleCeil.h" #include "Nodes/CircleConcatenation.h" #include "Nodes/CircleConv2D.h" #include "Nodes/CircleCos.h" +#include "Nodes/CircleCumSum.h" #include "Nodes/CircleCustom.h" #include "Nodes/CircleDensify.h" #include "Nodes/CircleDepthToSpace.h" @@ -92,6 +94,7 @@ #include "Nodes/CircleReduceMin.h" #include "Nodes/CircleReduceProd.h" #include "Nodes/CircleRelu.h" +#include "Nodes/CircleRelu0To1.h" #include "Nodes/CircleRelu6.h" #include "Nodes/CircleReluN1To1.h" #include "Nodes/CircleReshape.h" @@ -136,6 +139,7 @@ // Circle only #include "Nodes/CircleBCQFullyConnected.h" #include "Nodes/CircleBCQGather.h" +#include "Nodes/CircleGRU.h" #include "Nodes/CircleInstanceNorm.h" // Virtual nodes #include "Nodes/CircleConst.h" diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst index 1646909e8..a97f6b60b 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst @@ -24,11 +24,13 @@ CIRCLE_NODE(AVERAGE_POOL_2D, CircleAveragePool2D) CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceND) CIRCLE_NODE(BATCH_MATMUL, CircleBatchMatMul) CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTM) +CIRCLE_NODE(BROADCAST_TO, CircleBroadcastTo) CIRCLE_NODE(CAST, CircleCast) CIRCLE_NODE(CEIL, CircleCeil) CIRCLE_NODE(CONCATENATION, CircleConcatenation) CIRCLE_NODE(CONV_2D, CircleConv2D) CIRCLE_NODE(COS, CircleCos) +CIRCLE_NODE(CUMSUM, CircleCumSum) CIRCLE_NODE(CUSTOM, CircleCustom) CIRCLE_NODE(DENSIFY, CircleDensify) CIRCLE_NODE(DEPTH_TO_SPACE, CircleDepthToSpace) @@ -50,6 +52,7 @@ CIRCLE_NODE(GATHER_ND, CircleGatherNd) CIRCLE_NODE(GELU, CircleGelu) CIRCLE_NODE(GREATER, CircleGreater) CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqual) +CIRCLE_NODE(GRU, CircleGRU) CIRCLE_NODE(HARD_SWISH, CircleHardSwish) CIRCLE_NODE(IF, CircleIf) CIRCLE_NODE(L2_NORMALIZATION, CircleL2Normalize) @@ -90,6 +93,7 @@ CIRCLE_NODE(REDUCE_MAX, CircleReduceMax) CIRCLE_NODE(REDUCE_MIN, CircleReduceMin) CIRCLE_NODE(REDUCE_PROD, CircleReduceProd) CIRCLE_NODE(RELU, CircleRelu) +CIRCLE_NODE(RELU_0_TO_1, CircleRelu0To1) CIRCLE_NODE(RELU6, CircleRelu6) CIRCLE_NODE(RELU_N1_TO_1, CircleReluN1To1) CIRCLE_NODE(RESHAPE, CircleReshape) diff --git a/compiler/luci/lang/include/luci/IR/DataTypeHelper.h b/compiler/luci/lang/include/luci/IR/DataTypeHelper.h new file mode 100644 index 000000000..fdc97f5ac --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/DataTypeHelper.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_LANG_IR_DATA_TYPE_HELPER_H__ +#define __LUCI_LANG_IR_DATA_TYPE_HELPER_H__ + +#include +#include + +namespace luci +{ + +/** + * @brief Returns the size of the data type. + * @note luci saves S4, U4 in a single byte. + * The extra 4 bits in the MSB side are filled with sign bits. + */ +inline uint32_t size(loco::DataType data_type) +{ + switch (data_type) + { + case loco::DataType::S4: + return sizeof(loco::DataTypeImpl::Type); + case loco::DataType::U4: + return sizeof(loco::DataTypeImpl::Type); + default: + return loco::size(data_type); + } +} + +} // namespace luci + +#endif // __LUCI_LANG_IR_DATA_TYPE_HELPER_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBroadcastTo.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBroadcastTo.h new file mode 100644 index 000000000..8d0591bda --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBroadcastTo.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLEBROADCASTTO_H__ +#define __LUCI_IR_CIRCLEBROADCASTTO_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief BroadcastTo in Circle + */ +class CircleBroadcastTo final : public FixedArityNode<2, CircleNodeImpl> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + +public: + loco::Node *shape(void) const { return at(1)->node(); } + void shape(loco::Node *node) { at(1)->node(node); } +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLEBROADCASTTO_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCumSum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCumSum.h new file mode 100644 index 000000000..bbeb8a442 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCumSum.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLE_CUMSUM_H__ +#define __LUCI_IR_CIRCLE_CUMSUM_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleNodeMixins.h" +#include "luci/IR/CircleOpcode.h" + +namespace luci +{ + +class CircleCumSum final : public FixedArityNode<2, CircleNodeImpl> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *axis(void) const { return at(1)->node(); } + void axis(loco::Node *node) { at(1)->node(node); } + +public: + bool exclusive(void) const { return _exclusive; } + void exclusive(bool exclusive) { _exclusive = exclusive; } + + bool reverse(void) const { return _reverse; } + void reverse(bool reverse) { _reverse = reverse; } + +private: + bool _exclusive{false}; + bool _reverse{false}; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLE_CUMSUM_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h new file mode 100644 index 000000000..897729910 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLE_GRU_H__ +#define __LUCI_IR_CIRCLE_GRU_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief GRU in Circle + */ +class CircleGRU final : public FixedArityNode<6, CircleNodeImpl> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *hidden_hidden(void) const { return at(1)->node(); } + void hidden_hidden(loco::Node *node) { at(1)->node(node); } + + loco::Node *hidden_hidden_bias(void) const { return at(2)->node(); } + void hidden_hidden_bias(loco::Node *node) { at(2)->node(node); } + + loco::Node *hidden_input(void) const { return at(3)->node(); } + void hidden_input(loco::Node *node) { at(3)->node(node); } + + loco::Node *hidden_input_bias(void) const { return at(4)->node(); } + void hidden_input_bias(loco::Node *node) { at(4)->node(node); } + + loco::Node *state(void) const { return at(5)->node(); } + void state(loco::Node *node) { at(5)->node(node); } + +public: + FusedActFunc fusedActivationFunction() const { return _fused_act_fun; } + void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; } + + bool returnSequences() const { return _return_sequences; } + void returnSequences(bool return_sequences) { _return_sequences = return_sequences; } + + bool timeMajor() const { return _time_major; } + void timeMajor(bool time_major) { _time_major = time_major; } + +private: + FusedActFunc _fused_act_fun = FusedActFunc::NONE; + bool _return_sequences = false; + bool _time_major = false; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLE_GRU_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu0To1.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu0To1.h new file mode 100644 index 000000000..7c648df73 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu0To1.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLERELU_0_TO_1_H__ +#define __LUCI_IR_CIRCLERELU_0_TO_1_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief RELU_0_TO_1 in Circle + */ +class CircleRelu0To1 final : public FixedArityNode<1, CircleNodeImpl> +{ +public: + loco::Node *features(void) const { return at(0)->node(); } + void features(loco::Node *node) { at(0)->node(node); } +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLERELU_0_TO_1_H__ diff --git a/compiler/luci/lang/include/luci/IR/SparsityParam.h b/compiler/luci/lang/include/luci/IR/SparsityParam.h index 6cfff67e1..15ddeb17c 100644 --- a/compiler/luci/lang/include/luci/IR/SparsityParam.h +++ b/compiler/luci/lang/include/luci/IR/SparsityParam.h @@ -202,8 +202,8 @@ public: } DimMetaData(DimensionType format, int32_t dense_size, const SparseIndexVector &array_segments, const SparseIndexVector &array_indices) - : _format{format}, _dense_size{dense_size}, _array_segments{array_segments}, _array_indices{ - array_indices} + : _format{format}, _dense_size{dense_size}, _array_segments{array_segments}, + _array_indices{array_indices} { // DO NOTHING } diff --git a/compiler/luci/lang/src/DataTypeHelper.cpp b/compiler/luci/lang/src/DataTypeHelper.cpp new file mode 100644 index 000000000..f2ea50338 --- /dev/null +++ b/compiler/luci/lang/src/DataTypeHelper.cpp @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This is to validate DataTypeHelper.h +#include "luci/IR/DataTypeHelper.h" diff --git a/compiler/luci/lang/src/Nodes/CircleBroadcastTo.test.cpp b/compiler/luci/lang/src/Nodes/CircleBroadcastTo.test.cpp new file mode 100644 index 000000000..f24f73981 --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleBroadcastTo.test.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleBroadcastTo.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include + +TEST(CircleBroadcastToTest, constructor) +{ + luci::CircleBroadcastTo broadcast_to_node; + + ASSERT_EQ(luci::CircleDialect::get(), broadcast_to_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::BROADCAST_TO, broadcast_to_node.opcode()); + + ASSERT_EQ(nullptr, broadcast_to_node.input()); + + ASSERT_EQ(nullptr, broadcast_to_node.shape()); +} + +TEST(CircleBroadcastToTest, input_NEG) +{ + luci::CircleBroadcastTo broadcast_to_node; + luci::CircleBroadcastTo node; + + broadcast_to_node.input(&node); + broadcast_to_node.shape(&node); + ASSERT_NE(nullptr, broadcast_to_node.input()); + ASSERT_NE(nullptr, broadcast_to_node.shape()); + + broadcast_to_node.input(nullptr); + broadcast_to_node.shape(nullptr); + ASSERT_EQ(nullptr, broadcast_to_node.input()); + ASSERT_EQ(nullptr, broadcast_to_node.shape()); +} + +TEST(CircleBroadcastToTest, arity_NEG) +{ + luci::CircleBroadcastTo broadcast_to_node; + + ASSERT_NO_THROW(broadcast_to_node.arg(1)); + ASSERT_THROW(broadcast_to_node.arg(2), std::out_of_range); +} + +TEST(CircleBroadcastToTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor + { + }; + + luci::CircleBroadcastTo broadcast_to_node; + + TestVisitor tv; + ASSERT_THROW(broadcast_to_node.accept(&tv), std::exception); +} + +TEST(CircleBroadcastToTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor + { + }; + + luci::CircleBroadcastTo broadcast_to_node; + + TestVisitor tv; + ASSERT_THROW(broadcast_to_node.accept(&tv), std::exception); +} diff --git a/compiler/luci/lang/src/Nodes/CircleConst.cpp b/compiler/luci/lang/src/Nodes/CircleConst.cpp index a4854ec59..c17a4e2c3 100644 --- a/compiler/luci/lang/src/Nodes/CircleConst.cpp +++ b/compiler/luci/lang/src/Nodes/CircleConst.cpp @@ -74,8 +74,10 @@ INSTANTIATE(loco::DataType::S64); INSTANTIATE(loco::DataType::S32); INSTANTIATE(loco::DataType::S16); INSTANTIATE(loco::DataType::S8); +INSTANTIATE(loco::DataType::S4); INSTANTIATE(loco::DataType::FLOAT32); INSTANTIATE(loco::DataType::U8); +INSTANTIATE(loco::DataType::U4); INSTANTIATE(loco::DataType::BOOL); INSTANTIATE(loco::DataType::FLOAT16); diff --git a/compiler/luci/lang/src/Nodes/CircleCumSum.test.cpp b/compiler/luci/lang/src/Nodes/CircleCumSum.test.cpp new file mode 100644 index 000000000..424d9a16c --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleCumSum.test.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleCumSum.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include + +TEST(CircleCumSumTest, constructor_P) +{ + luci::CircleCumSum node; + + ASSERT_EQ(luci::CircleDialect::get(), node.dialect()); + ASSERT_EQ(luci::CircleOpcode::CUMSUM, node.opcode()); + + ASSERT_EQ(nullptr, node.input()); +} + +TEST(CircleCumSumTest, input_NEG) +{ + luci::CircleCumSum node; + luci::CircleCumSum input; + + node.input(&input); + ASSERT_NE(nullptr, node.input()); + + node.input(nullptr); + ASSERT_EQ(nullptr, node.input()); +} + +// FIXME +TEST(CircleCumSumTest, arity_NEG) +{ + luci::CircleCumSum node; + + ASSERT_NO_THROW(node.arg(0)); + ASSERT_NO_THROW(node.arg(1)); + ASSERT_THROW(node.arg(2), std::out_of_range); +} + +TEST(CircleCumSumTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor + { + }; + + luci::CircleCumSum node; + + TestVisitor tv; + ASSERT_THROW(node.accept(&tv), std::exception); +} + +TEST(CircleCumSumTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor + { + }; + + luci::CircleCumSum node; + + TestVisitor tv; + ASSERT_THROW(node.accept(&tv), std::exception); +} diff --git a/compiler/luci/lang/src/Nodes/CircleGRU.test.cpp b/compiler/luci/lang/src/Nodes/CircleGRU.test.cpp new file mode 100644 index 000000000..f7e6f9594 --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleGRU.test.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleGRU.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include + +TEST(CircleGRUTest, constructor) +{ + luci::CircleGRU gru_node; + + ASSERT_EQ(luci::CircleDialect::get(), gru_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::GRU, gru_node.opcode()); + + ASSERT_EQ(nullptr, gru_node.input()); + ASSERT_EQ(nullptr, gru_node.hidden_hidden()); + ASSERT_EQ(nullptr, gru_node.hidden_hidden_bias()); + ASSERT_EQ(nullptr, gru_node.hidden_input()); + ASSERT_EQ(nullptr, gru_node.hidden_input_bias()); + ASSERT_EQ(nullptr, gru_node.state()); +} + +TEST(CircleGRUTest, input_NEG) +{ + luci::CircleGRU gru_node; + luci::CircleGRU node; + + gru_node.input(&node); + ASSERT_NE(nullptr, gru_node.input()); + + gru_node.input(nullptr); + ASSERT_EQ(nullptr, gru_node.input()); +} + +TEST(CircleGRUTest, arity_NEG) +{ + luci::CircleGRU gru_node; + + ASSERT_NO_THROW(gru_node.arg(0)); + ASSERT_NO_THROW(gru_node.arg(1)); + ASSERT_NO_THROW(gru_node.arg(2)); + ASSERT_NO_THROW(gru_node.arg(3)); + ASSERT_NO_THROW(gru_node.arg(4)); + ASSERT_NO_THROW(gru_node.arg(5)); + ASSERT_THROW(gru_node.arg(6), std::out_of_range); +} + +TEST(CircleGRUTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor + { + }; + + luci::CircleGRU gru_node; + + TestVisitor tv; + ASSERT_THROW(gru_node.accept(&tv), std::exception); +} + +TEST(CircleGRUTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor + { + }; + + luci::CircleGRU gru_node; + + TestVisitor tv; + ASSERT_THROW(gru_node.accept(&tv), std::exception); +} diff --git a/compiler/luci/lang/src/Nodes/CircleRelu0To1.test.cpp b/compiler/luci/lang/src/Nodes/CircleRelu0To1.test.cpp new file mode 100644 index 000000000..2ac89a7ad --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleRelu0To1.test.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleRelu0To1.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include + +TEST(CircleRelu0ToTest, constructor) +{ + luci::CircleRelu0To1 relu0to1_node; + + ASSERT_EQ(luci::CircleDialect::get(), relu0to1_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::RELU_0_TO_1, relu0to1_node.opcode()); + + ASSERT_EQ(nullptr, relu0to1_node.features()); +} + +TEST(CircleRelu0ToTest, input_NEG) +{ + luci::CircleRelu0To1 relu0to1_node; + luci::CircleRelu0To1 node; + + relu0to1_node.features(&node); + ASSERT_NE(nullptr, relu0to1_node.features()); + + relu0to1_node.features(nullptr); + ASSERT_EQ(nullptr, relu0to1_node.features()); +} + +TEST(CircleRelu0ToTest, arity_NEG) +{ + luci::CircleRelu0To1 relu0to1_node; + + ASSERT_NO_THROW(relu0to1_node.arg(0)); + ASSERT_THROW(relu0to1_node.arg(1), std::out_of_range); +} + +TEST(CircleRelu0ToTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor + { + }; + + luci::CircleRelu0To1 relu0to1_node; + + TestVisitor tv; + ASSERT_THROW(relu0to1_node.accept(&tv), std::exception); +} + +TEST(CircleRelu0ToTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor + { + }; + + luci::CircleRelu0To1 relu0to1_node; + + TestVisitor tv; + ASSERT_THROW(relu0to1_node.accept(&tv), std::exception); +} diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp index e7f38d07b..2ff37afe1 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp @@ -144,12 +144,14 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node) CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedSummaryBuilder) CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherSummaryBuilder) CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMSummaryBuilder) + CIRCLE_NODE(BROADCAST_TO, CircleBroadcastToSummaryBuilder) CIRCLE_NODE(CAST, CircleCastSummaryBuilder) CIRCLE_NODE(CEIL, CircleCeilSummaryBuilder) CIRCLE_NODE(CONCATENATION, CircleConcatenationSummaryBuilder) CIRCLE_NODE(CIRCLECONST, CircleConstSummaryBuilder) CIRCLE_NODE(CONV_2D, CircleConv2DSummaryBuilder) CIRCLE_NODE(COS, CircleCosSummaryBuilder) + CIRCLE_NODE(CUMSUM, CircleCumsumSummaryBuilder) CIRCLE_NODE(CUSTOM, CircleCustomSummaryBuilder) CIRCLE_NODE(DENSIFY, CircleDensifySummaryBuilder) CIRCLE_NODE(DEPTH_TO_SPACE, CircleDepthToSpaceSummaryBuilder) @@ -171,6 +173,7 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node) CIRCLE_NODE(GELU, CircleGeluSummaryBuilder) CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder) CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder) + CIRCLE_NODE(GRU, CircleGRUSummaryBuilder) CIRCLE_NODE(HARD_SWISH, CircleHardSwishSummaryBuilder) CIRCLE_NODE(IF, CircleIfSummaryBuilder) CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder) @@ -212,6 +215,7 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node) CIRCLE_NODE(REDUCE_MIN, CircleReduceMinSummaryBuilder) CIRCLE_NODE(REDUCE_PROD, CircleReduceProdSummaryBuilder) CIRCLE_NODE(RELU, CircleReluSummaryBuilder) + CIRCLE_NODE(RELU_0_TO_1, CircleRelu0To1SummaryBuilder) CIRCLE_NODE(RELU6, CircleRelu6SummaryBuilder) CIRCLE_NODE(RELU_N1_TO_1, CircleReluN1To1SummaryBuilder) CIRCLE_NODE(RESHAPE, CircleReshapeSummaryBuilder) diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp index aba6a8681..f0a92ef91 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp @@ -30,6 +30,8 @@ std::string to_str(loco::DataType type) { switch (type) { + case loco::DataType::U4: + return "UINT4"; case loco::DataType::U8: return "UINT8"; case loco::DataType::U16: @@ -39,6 +41,8 @@ std::string to_str(loco::DataType type) case loco::DataType::U64: return "UINT64"; + case loco::DataType::S4: + return "INT4"; case loco::DataType::S8: return "INT8"; case loco::DataType::S16: @@ -108,6 +112,11 @@ std::string to_str(const luci::Stride *stride) return std::to_string(stride->h()) + "," + std::to_string(stride->w()); } +std::string to_str(const luci::Dilation *dilation) +{ + return std::to_string(dilation->h()) + "," + std::to_string(dilation->w()); +} + std::string to_str(const luci::Filter *filter) { return std::to_string(filter->h()) + "," + std::to_string(filter->w()); @@ -343,6 +352,25 @@ void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci: s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs())); } +std::vector CircleGRUSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "hidden_hidden", "hidden_hidden_bias", + "hidden_input", "hidden_input_bias", "state"}; +} + +void CircleGRUSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto gru = loco::must_cast(node); + s.args().append("fused_act_function", to_str(gru->fusedActivationFunction())); + s.args().append("return_sequence", to_str(gru->returnSequences())); + s.args().append("time_major", to_str(gru->timeMajor())); +} + +std::vector CircleBroadcastToSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "shape"}; +} + std::vector CircleCastSummaryBuilder::get_input_names(const luci::CircleNode *) { return {"x"}; @@ -425,6 +453,19 @@ void CircleConv2DSummaryBuilder::build_attributes(const luci::CircleNode *node, s.args().append("fused_activation_function", to_str(conv2d->fusedActivationFunction())); } +std::vector CircleCumsumSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "axis"}; +} + +void CircleCumsumSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto cumsum = loco::must_cast(node); + s.args().append("exclusive", to_str(cumsum->exclusive())); + s.args().append("reverse", to_str(cumsum->reverse())); +} + std::vector CircleCustomSummaryBuilder::get_input_names(const luci::CircleNode *node) { auto input_names = std::vector(); diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h index 0bdb05d8d..f489e9b6e 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h @@ -145,6 +145,12 @@ private: void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); }; +class CircleBroadcastToSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector get_input_names(const luci::CircleNode *); +}; + class CircleCastSummaryBuilder final : public CircleNodeSummaryBuilder { private: @@ -183,6 +189,13 @@ class CircleCosSummaryBuilder final : public CircleNodeWithXSummaryBuilder { }; +class CircleCumsumSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + class CircleCustomSummaryBuilder final : public CircleNodeSummaryBuilder { private: @@ -294,6 +307,13 @@ class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBui { }; +class CircleGRUSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + class CircleHardSwishSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder { }; @@ -517,6 +537,10 @@ class CircleReluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuild { }; +class CircleRelu0To1SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder +{ +}; + class CircleRelu6SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder { }; diff --git a/compiler/luci/partition/CMakeLists.txt b/compiler/luci/partition/CMakeLists.txt index 304ef6307..001194eb6 100644 --- a/compiler/luci/partition/CMakeLists.txt +++ b/compiler/luci/partition/CMakeLists.txt @@ -13,7 +13,7 @@ target_link_libraries(luci_partition PUBLIC luci_lang) target_link_libraries(luci_partition PRIVATE luci_service) target_link_libraries(luci_partition PRIVATE luci_log) target_link_libraries(luci_partition PRIVATE luci_logex) -target_link_libraries(luci_partition PRIVATE mio_circle06) +target_link_libraries(luci_partition PRIVATE mio_circle08) target_link_libraries(luci_partition PRIVATE nncc_common) target_link_libraries(luci_partition PRIVATE pepper_csv2vec) target_link_libraries(luci_partition PRIVATE oops) diff --git a/compiler/luci/partition/include/luci/ConnectNode.h b/compiler/luci/partition/include/luci/ConnectNode.h index d8cbfc6c4..7539aaf6b 100644 --- a/compiler/luci/partition/include/luci/ConnectNode.h +++ b/compiler/luci/partition/include/luci/ConnectNode.h @@ -70,12 +70,14 @@ public: void visit(const luci::CircleAveragePool2D *) final; void visit(const luci::CircleBatchMatMul *) final; void visit(const luci::CircleBatchToSpaceND *) final; + void visit(const luci::CircleBroadcastTo *) final; void visit(const luci::CircleCast *) final; void visit(const luci::CircleCeil *) final; void visit(const luci::CircleConcatenation *) final; void visit(const luci::CircleConst *) final; void visit(const luci::CircleConv2D *) final; void visit(const luci::CircleCos *) final; + void visit(const luci::CircleCumSum *) final; void visit(const luci::CircleCustom *) final; void visit(const luci::CircleDensify *) final; void visit(const luci::CircleDepthToSpace *) final; @@ -137,6 +139,7 @@ public: void visit(const luci::CircleReduceMin *) final; void visit(const luci::CircleReduceProd *) final; void visit(const luci::CircleRelu *) final; + void visit(const luci::CircleRelu0To1 *) final; void visit(const luci::CircleRelu6 *) final; void visit(const luci::CircleReluN1To1 *) final; void visit(const luci::CircleReshape *) final; @@ -182,6 +185,7 @@ public: // Circle Only void visit(const luci::CircleBCQFullyConnected *) final; void visit(const luci::CircleBCQGather *) final; + void visit(const luci::CircleGRU *) final; void visit(const luci::CircleInstanceNorm *) final; // NOTE CircleInput and CircleOutput are not handled here as these need diff --git a/compiler/luci/partition/src/Nodes/CircleBroadcastTo.cpp b/compiler/luci/partition/src/Nodes/CircleBroadcastTo.cpp new file mode 100644 index 000000000..5c1e2481d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBroadcastTo.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleBroadcastTo *node) +{ + auto *cloned = loco::must_cast(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast(node->input()); + luci::CircleNode *shape = loco::must_cast(node->shape()); + + cloned->input(cn->find_clone(input)); + cloned->shape(cn->find_clone(shape)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleBroadcastTo *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleBroadcastTo.test.cpp b/compiler/luci/partition/src/Nodes/CircleBroadcastTo.test.cpp new file mode 100644 index 000000000..90f24c5ad --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBroadcastTo.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +#include "ConnectNode.test.h" + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->shape(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_BroadcastTo) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_BroadcastTo_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleCumSum.cpp b/compiler/luci/partition/src/Nodes/CircleCumSum.cpp new file mode 100644 index 000000000..1110c42c0 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCumSum.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleCumSum *node) +{ + auto *cloned = loco::must_cast(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast(node->input()); + luci::CircleNode *axis = loco::must_cast(node->axis()); + + cloned->input(cn->find_clone(input)); + cloned->axis(cn->find_clone(axis)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleCumSum *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleCumSum.test.cpp b/compiler/luci/partition/src/Nodes/CircleCumSum.test.cpp new file mode 100644 index 000000000..9e9cd828b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCumSum.test.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +#include "ConnectNode.test.h" + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT::init(g); + + _node->exclusive(false); + _node->reverse(false); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, {0}}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->axis(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_CumSum) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_CumSum_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleGRU.cpp b/compiler/luci/partition/src/Nodes/CircleGRU.cpp new file mode 100644 index 000000000..c0995f469 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGRU.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleGRU *node) +{ + auto *cloned = loco::must_cast(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast(node->input()); + luci::CircleNode *hidden_input = loco::must_cast(node->hidden_input()); + luci::CircleNode *hidden_input_bias = + loco::must_cast(node->hidden_input_bias()); + luci::CircleNode *hidden_hidden = loco::must_cast(node->hidden_hidden()); + luci::CircleNode *hidden_hidden_bias = + loco::must_cast(node->hidden_hidden_bias()); + luci::CircleNode *state = loco::must_cast(node->state()); + + cloned->input(cn->find_clone(input)); + cloned->hidden_input(cn->find_clone(hidden_input)); + cloned->hidden_input_bias(cn->find_clone(hidden_input_bias)); + cloned->hidden_hidden(cn->find_clone(hidden_hidden)); + cloned->hidden_hidden_bias(cn->find_clone(hidden_hidden_bias)); + cloned->state(cn->find_clone(state)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleGRU *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp b/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp new file mode 100644 index 000000000..c0720a178 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +#include "ConnectNode.test.h" + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::NONE); + } +}; + +class TestNodeGraph : public TestIsOGraph<6>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<6>::init({shape, shape, shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->hidden_hidden(input(1)); + node()->hidden_hidden_bias(input(2)); + node()->hidden_input(input(3)); + node()->hidden_input_bias(input(4)); + node()->state(input(5)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_CIRCLE_GRU) +{ + TestNodeGraph tng; + tng.init({10, 1, 4}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(6, clone->arity()); + // 24 separate checks is too much + for (uint32_t i = 0; i < 6; ++i) + ASSERT_EQ(cth.inputs(i), clone->arg(i)); +} + +TEST(ConnectNodeTest, connect_CIRCLE_GRU_NEG) +{ + TestNodeGraph tng; + tng.init({10, 1, 4}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleRelu0To1.cpp b/compiler/luci/partition/src/Nodes/CircleRelu0To1.cpp new file mode 100644 index 000000000..8f9a07610 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRelu0To1.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRelu0To1 *node) +{ + auto *cloned = loco::must_cast(cn->find_clone(node)); + + luci::CircleNode *features = loco::must_cast(node->features()); + + cloned->features(cn->find_clone(features)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRelu0To1 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRelu0To1.test.cpp b/compiler/luci/partition/src/Nodes/CircleRelu0To1.test.cpp new file mode 100644 index 000000000..ba7290f40 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRelu0To1.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +#include "ConnectNode.test.h" + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->features(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Relu0To1) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Relu0To1_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp index 0fabfc416..5a78d99c0 100644 --- a/compiler/luci/partition/src/PartitionIRDump.cpp +++ b/compiler/luci/partition/src/PartitionIRDump.cpp @@ -56,7 +56,7 @@ void dump(std::ostream &os, const PGroups *pgroups) for (auto it = pgroups->node2group.begin(); it != pgroups->node2group.end(); ++it) { auto node = it->first; - auto group = it->second; + auto &group = it->second; os << " Node: " << node << "(" << luci::opcode_name(node) << "," << node->name() << "): " << group << std::endl; } diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt index ac18a5f8d..0dd884a20 100644 --- a/compiler/luci/pass/CMakeLists.txt +++ b/compiler/luci/pass/CMakeLists.txt @@ -1,4 +1,4 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) nnas_find_package(Fp16Source QUIET) if(NOT FlatBuffers_FOUND) @@ -35,7 +35,7 @@ target_link_libraries(luci_pass PRIVATE luci_compute) target_link_libraries(luci_pass PRIVATE nncc_common) target_link_libraries(luci_pass PRIVATE pepper_csv2vec) target_link_libraries(luci_pass PRIVATE oops) -target_link_libraries(luci_pass PRIVATE flatbuffers-2.0) +target_link_libraries(luci_pass PRIVATE flatbuffers-23.5.26) install(TARGETS luci_pass DESTINATION lib) install(DIRECTORY include/ DESTINATION include FILES_MATCHING PATTERN "*.h") @@ -51,5 +51,5 @@ target_include_directories(luci_pass_test PRIVATE src) target_link_libraries(luci_pass_test luci_pass) target_link_libraries(luci_pass_test luci_lang) target_link_libraries(luci_pass_test luci_testhelper) -target_link_libraries(luci_pass_test flatbuffers-2.0) +target_link_libraries(luci_pass_test flatbuffers-23.5.26) #target_link_libraries(luci_pass_test oops) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 6ebacee39..bdae7d57e 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -34,20 +34,28 @@ public: { enum Algorithm { + FuseAddToFullyConnectedBias, + FuseAddWithConv, FuseAddWithFullyConnected, FuseAddWithTConv, FuseBatchNormWithConv, FuseBatchNormWithDwConv, FuseBatchNormWithTConv, + FuseMulToFullyConnectedWeights, + FuseSliceWithTConv, FuseBCQ, + FuseHorizontalFullyConnected, FuseInstanceNorm, FuseMeanWithMean, + FuseMulWithConv, + FuseMulWithDiv, FuseTransposeWithMean, ResolveCustomOpAdd, ResolveCustomOpBatchMatMul, ResolveCustomOpMatMul, ResolveCustomOpMaxPoolWithArgmax, ResolveCustomOpSplitV, + ResolveFormerCustomOp, FoldAddV2, FoldCast, FoldDensify, @@ -55,7 +63,11 @@ public: FoldFullyConnected, FoldDequantize, FoldGather, + FoldMul, + FoldReshape, + FoldShape, FoldSparseToDense, + FoldSqueeze, ForwardReshapeToUnaryOp, ForwardTransposeOp, SparsifyTensorPass, @@ -64,32 +76,42 @@ public: FuseActivationFunction, FusePRelu, FuseGelu, + FuseRsqrt, ShuffleWeightTo16x1Float32, RemoveRedundantTranspose, ReplaceMulAddWithDepthwiseConv, ReplaceNonConstFCWithBatchMatMul, ReplaceSubWithAdd, + ReplaceWithFCGeluFC, SubstitutePackToReshape, SubstitutePadV2ToPad, SubstituteSplitVToSplit, SubstituteSqueezeToReshape, ExpandBroadcastConst, ConvertNCHWToNHWC, + CommonSubExpressionElimination, + RemoveUnnecessaryAdd, RemoveUnnecessarySlice, RemoveUnnecessaryStridedSlice, RemoveUnnecessarySplit, RemoveUnnecessaryReshape, + RemoveUnnecessaryTranspose, TransformMinMaxToRelu6Pass, TransformMinReluToRelu6Pass, + TransformSqrtDivToRsqrtMul, DecomposeHardSwishPass, + DecomposeSoftmaxPass, SubstituteStridedSliceToReshape, SubstituteTransposeToReshape, RemoveRedundantQuantize, RemoveRedundantReshape, RemoveFakeQuant, + RemoveQDQForMixedPrecisionOp, RemoveQuantDequantSeq, RemoveDuplicateConst, UnrollUnidirSeqLSTM, + XpSepActFromTransposeConv, + RemoveGatherGuard, }; enum AlgorithmParameters diff --git a/compiler/luci/pass/include/luci/CircleQuantizer.h b/compiler/luci/pass/include/luci/CircleQuantizer.h index 463f31790..8fd19cc8f 100644 --- a/compiler/luci/pass/include/luci/CircleQuantizer.h +++ b/compiler/luci/pass/include/luci/CircleQuantizer.h @@ -19,6 +19,7 @@ #include +#include #include #include @@ -37,6 +38,24 @@ public: std::string granularity; }; + using LayerParams = std::vector>; + + // NOTE ...Set is not related with std::set but used as to denote + // multple 'set' of LayerParams. + class LayerParamsSet + { + public: + // some helper methods + size_t size(void) const { return items.size(); } + template void emplace_back(Args &&...args) { items.emplace_back(args...); } + std::vector::iterator begin() { return items.begin(); }; + std::vector::iterator end() { return items.end(); }; + + private: + // store multiple set of LayerParams + std::vector items; + }; + enum Algorithm { QuantizeDequantizeWeights, @@ -46,6 +65,7 @@ public: ForceQuantParam, ConvertToFakeQuantizedModel, QuantizeWeights, + QuantizeOnnxFakeQuantizedModel, }; enum AlgorithmParameters @@ -66,6 +86,7 @@ public: Quantize_input_type, Quantize_output_type, Quantize_TF_style_maxpool, + Quantize_save_min_max, }; virtual ~Options() = default; @@ -78,8 +99,10 @@ public: virtual std::vector params(AlgorithmParameters) const = 0; // Quantization parameters for multiple layers - virtual void layer_params(AlgorithmParameters, std::vector> &) = 0; - virtual std::vector> layer_params(AlgorithmParameters) const = 0; + virtual void layer_params(AlgorithmParameters, LayerParams &) = 0; + virtual LayerParams layer_params(AlgorithmParameters) const = 0; + virtual void layer_params_set(LayerParamsSet &) = 0; + virtual LayerParamsSet layer_params_set(void) const = 0; }; public: diff --git a/compiler/luci/pass/include/luci/Pass/CommonSubExpressionEliminationPass.h b/compiler/luci/pass/include/luci/Pass/CommonSubExpressionEliminationPass.h new file mode 100644 index 000000000..30b2167fd --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/CommonSubExpressionEliminationPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_COMMON_SUB_EXPRESSION_ELIMINATION_PASS_H__ +#define __LUCI_COMMON_SUB_EXPRESSION_ELIMINATION_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Pass to perform CSE (Common Sub-expression Elimination) + */ +class CommonSubExpressionEliminationPass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "luci::CommonSubExpressionEliminationPass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace luci + +#endif //__LUCI_COMMON_SUB_EXPRESSION_ELIMINATION_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/DecomposeSoftmaxPass.h b/compiler/luci/pass/include/luci/Pass/DecomposeSoftmaxPass.h new file mode 100644 index 000000000..ecf95b34d --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/DecomposeSoftmaxPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_DECOMPOSE_SOFTMAX_PASS_H__ +#define __LUCI_DECOMPOSE_SOFTMAX_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to decompose Softmax into backend friendly structures + */ +struct DecomposeSoftmaxPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::DecomposeSoftmaxPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_DECOMPOSE_SOFTMAX_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldMulPass.h b/compiler/luci/pass/include/luci/Pass/FoldMulPass.h new file mode 100644 index 000000000..69b661fbe --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldMulPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_MUL_PASS_H__ +#define __LUCI_FOLD_MUL_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fold Mul to a constant tensor + * + */ +struct FoldMulPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldMulPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_MUL_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldReshapePass.h b/compiler/luci/pass/include/luci/Pass/FoldReshapePass.h new file mode 100644 index 000000000..214b2bffc --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldReshapePass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_RESHAPE_PASS_H__ +#define __LUCI_FOLD_RESHAPE_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fold Reshape to a constant tensor + * + */ +struct FoldReshapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldReshapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_RESHAPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldShapePass.h b/compiler/luci/pass/include/luci/Pass/FoldShapePass.h new file mode 100644 index 000000000..0e5dee164 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldShapePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_SHAPE_PASS_H__ +#define __LUCI_FOLD_SHAPE_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fold Shape to a constant tensor + */ +struct FoldShapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldShapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_SHAPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldSqueezePass.h b/compiler/luci/pass/include/luci/Pass/FoldSqueezePass.h new file mode 100644 index 000000000..c7e7f19e5 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldSqueezePass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_SQUEEZE_PASS_H__ +#define __LUCI_FOLD_SQUEEZE_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fold Squeeze to a constant tensor + * + */ +struct FoldSqueezePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldSqueezePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_SQUEEZE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseAddToFullyConnectedBiasPass.h b/compiler/luci/pass/include/luci/Pass/FuseAddToFullyConnectedBiasPass.h new file mode 100644 index 000000000..9aef47845 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseAddToFullyConnectedBiasPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_ADD_TO_FULLY_CONNECTED_BIAS_PASS_H__ +#define __LUCI_FUSE_ADD_TO_FULLY_CONNECTED_BIAS_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Add to following FC bias + */ +struct FuseAddToFullyConnectedBiasPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseAddToFullyConnectedBiasPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_ADD_TO_FULLY_CONNECTED_BIAS_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseAddWithConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseAddWithConvPass.h new file mode 100644 index 000000000..c6c5981be --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseAddWithConvPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_ADD_WITH_CONV_PASS_H__ +#define __LUCI_FUSE_ADD_WITH_CONV_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse CircleAdd into CircleConv2D + */ +struct FuseAddWithConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseAddWithConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_ADD_WITH_CONV_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseHorizontalFullyConnectedPass.h b/compiler/luci/pass/include/luci/Pass/FuseHorizontalFullyConnectedPass.h new file mode 100644 index 000000000..49729c081 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseHorizontalFullyConnectedPass.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_HORIZONTAL_FULLY_CONNECTED_PASS_H__ +#define __LUCI_FUSE_HORIZONTAL_FULLY_CONNECTED_PASS_H__ + +#include + +namespace luci +{ + +struct FuseHorizontalFullyConnectedPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseHorizontalFullyConnectedPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_HORIZONTAL_FULLY_CONNECTED_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseMulToFullyConnectedWeightsPass.h b/compiler/luci/pass/include/luci/Pass/FuseMulToFullyConnectedWeightsPass.h new file mode 100644 index 000000000..583f21ef8 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseMulToFullyConnectedWeightsPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ +#define __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Mul into following FullyConnected + */ +struct FuseMulToFullyConnectedWeightsPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseMulToFullyConnectedWeightsPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseMulWithConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseMulWithConvPass.h new file mode 100644 index 000000000..08977f236 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseMulWithConvPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_MUL_WITH_CONV_H__ +#define __LUCI_FUSE_MUL_WITH_CONV_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Mul operation with a preceding Conv + */ +struct FuseMulWithConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseMulWithConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_MUL_WITH_CONV_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseMulWithDivPass.h b/compiler/luci/pass/include/luci/Pass/FuseMulWithDivPass.h new file mode 100644 index 000000000..fa9086c59 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseMulWithDivPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_MUL_WITH_DIV_PASS_H__ +#define __LUCI_FUSE_MUL_WITH_DIV_PASS_H__ + +#include + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Mul operation with a Div operation + */ +struct FuseMulWithDivPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseMulWithDivPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif //__LUCI_FUSE_MUL_WITH_DIV_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseRsqrtPass.h b/compiler/luci/pass/include/luci/Pass/FuseRsqrtPass.h new file mode 100644 index 000000000..b6a26bc51 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseRsqrtPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_RSQRT_PASS_H__ +#define __LUCI_FUSE_RSQRT_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleRsqrt + */ +struct FuseRsqrtPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseRsqrtPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_RSQRT_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseSliceWithTConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseSliceWithTConvPass.h new file mode 100644 index 000000000..2863076c0 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseSliceWithTConvPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_SLICE_WITH_TCONV_PASS_H__ +#define __LUCI_FUSE_SLICE_WITH_TCONV_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Slice operation with a preceding TConv + */ +struct FuseSliceWithTConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseSliceWithTConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_SLICE_WITH_TCONV_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeOnnxFakeQuantModelPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeOnnxFakeQuantModelPass.h new file mode 100644 index 000000000..d9a9569b0 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/QuantizeOnnxFakeQuantModelPass.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_ONNX_FAKE_QUANT_MODEL_PASS_H__ +#define __LUCI_QUANTIZE_ONNX_FAKE_QUANT_MODEL_PASS_H__ + +#include + +#include + +#include + +namespace luci +{ + +/** + * @brief Pass to create a quantized graph from a graph fake-quantized on onnx + */ +class QuantizeOnnxFakeQuantModelPass : public logo::Pass +{ +public: + struct Context + { + loco::DataType default_activation_dtype = loco::DataType::Unknown; + }; + +public: + QuantizeOnnxFakeQuantModelPass(std::unique_ptr &&ctx) : _ctx{std::move(ctx)} + { + assert(_ctx); // FIX_CALLER_UNLESS + assert(_ctx->default_activation_dtype); // FIX_CALLER_UNLESS + } + + virtual const char *name(void) const { return "luci::QuantizeOnnxFakeQuantModelPass"; } + +public: + bool run(loco::Graph *graph); + +private: + std::unique_ptr _ctx; +}; + +} // namespace luci + +#endif //__LUCI_QUANTIZE_ONNX_FAKE_QUANT_MODEL_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h index 6874046f0..560cc9025 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h @@ -42,6 +42,7 @@ public: std::vector input_types; std::vector output_types; bool TF_style_maxpool = false; + bool save_min_max = false; std::vector layers_info; }; diff --git a/compiler/luci/pass/include/luci/Pass/RemoveGatherGuardPass.h b/compiler/luci/pass/include/luci/Pass/RemoveGatherGuardPass.h new file mode 100644 index 000000000..d514dc382 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveGatherGuardPass.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_GATHER_GUARD_PASS_H__ +#define __LUCI_REMOVE_GATHER_GUARD_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to remove Add+FloorMod guard ops of Gather + * @note If the indices of Gather is guarenteed to be positive by the user, + * Add/FloorMod guard ops can be removed. + * This pass is to remove Add+FloorMod having INT32/INT64 dtypes + * for some backends cannot process this in quantized models. + */ +struct RemoveGatherGuardPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveGatherGuardPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_GATHER_GUARD_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveQDQForMixedPrecisionOpPass.h b/compiler/luci/pass/include/luci/Pass/RemoveQDQForMixedPrecisionOpPass.h new file mode 100644 index 000000000..2da2be8b6 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveQDQForMixedPrecisionOpPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_QDQ_FOR_MIXED_PRECISION_OP_PASS_H__ +#define __LUCI_REMOVE_QDQ_FOR_MIXED_PRECISION_OP_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to remove QDQ pattern for mixed-precision Ops + */ +struct RemoveQDQForMixedPrecisionOpPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveQDQForMixedPrecisionOpPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_QDQ_FOR_MIXED_PRECISION_OP_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryAddPass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryAddPass.h new file mode 100644 index 000000000..432e834e1 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryAddPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_UNNECESSARY_ADD_PASS_H__ +#define __LUCI_REMOVE_UNNECESSARY_ADD_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to temove unnecessary(input and output are same) Add node. + */ +struct RemoveUnnecessaryAddPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessaryAddPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_UNNECESSARY_ADD_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryTransposeNetPass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryTransposeNetPass.h new file mode 100644 index 000000000..c7f746500 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryTransposeNetPass.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_UNNECESSARY_TRANSPOSE_NET_PASS_H__ +#define __LUCI_REMOVE_UNNECESSARY_TRANSPOSE_NET_PASS_H__ + +#include + +namespace luci +{ + +struct RemoveUnnecessaryTransposeNetPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessaryTransposeNetPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_UNNECESSARY_TRANSPOSE_NET_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ReplaceWithFCGeluFCPass.h b/compiler/luci/pass/include/luci/Pass/ReplaceWithFCGeluFCPass.h new file mode 100644 index 000000000..091cd9e39 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ReplaceWithFCGeluFCPass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REPLACE_WITH_FC_GELU_FC_PASS_H__ +#define __LUCI_REPLACE_WITH_FC_GELU_FC_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to generate FC-Gelu-FC from a certain Op pattern + * + * To see the target Op pattern, please visit implementation. + * NOTE: The target pattern includes FC fused with div/mul Ops. + */ +struct ReplaceWithFCGeluFCPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ReplaceWithFCGeluFCPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REPLACE_WITH_FC_GELU_FC_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ResolveFormerCustomOpPass.h b/compiler/luci/pass/include/luci/Pass/ResolveFormerCustomOpPass.h new file mode 100644 index 000000000..cb8a05030 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ResolveFormerCustomOpPass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_RESOLVE_FORMER_CUSTOM_OP_PASS_H__ +#define __LUCI_RESOLVE_FORMER_CUSTOM_OP_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to convert a custom operator to a built-in operator. + * + * @details This pass changes a op formerly used as a custom op to builtin op + * from schema version upgrade. + */ +struct ResolveFormerCustomOpPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ResolveFormerCustomOpPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_RESOLVE_FORMER_CUSTOM_OP_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/TransformSqrtDivToRsqrtMulPass.h b/compiler/luci/pass/include/luci/Pass/TransformSqrtDivToRsqrtMulPass.h new file mode 100644 index 000000000..da065246e --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/TransformSqrtDivToRsqrtMulPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_TRANSFORM_SQRT_DIV_TO_RSQRT_MUL_PASS_H__ +#define __LUCI_TRANSFORM_SQRT_DIV_TO_RSQRT_MUL_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to transform Div(X,Sqrt(y)) to Mul(X,Rsqrt(y)) + */ +struct TransformSqrtDivToRsqrtMulPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::TransformSqrtDivToRsqrtMulPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_TRANSFORM_SQRT_DIV_TO_RSQRT_MUL_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/XpSepActFromTransposeConvPass.h b/compiler/luci/pass/include/luci/Pass/XpSepActFromTransposeConvPass.h new file mode 100644 index 000000000..ed7833644 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/XpSepActFromTransposeConvPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_XP_SEP_ACT_FROM_TRANSPOSE_CONV_PASS_H__ +#define __LUCI_XP_SEP_ACT_FROM_TRANSPOSE_CONV_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Experimental Class to separate activation functions from TransposeConv + */ +struct XpSepActFromTransposeConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::XpSepActFromTransposeConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_XP_SEP_ACT_FROM_TRANSPOSE_CONV_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index b011581af..aa98fb386 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -17,6 +17,7 @@ #include "luci/CircleOptimizer.h" #include "luci/Pass/ConvertNCHWToNHWCPass.h" +#include "luci/Pass/CommonSubExpressionEliminationPass.h" #include "luci/Pass/ExpandBroadcastConstPass.h" #include "luci/Pass/FoldAddV2Pass.h" #include "luci/Pass/FoldCastPass.h" @@ -25,42 +26,60 @@ #include "luci/Pass/FoldDequantizePass.h" #include "luci/Pass/FoldFullyConnectedPass.h" #include "luci/Pass/FoldGatherPass.h" +#include "luci/Pass/FoldMulPass.h" +#include "luci/Pass/FoldReshapePass.h" +#include "luci/Pass/FoldShapePass.h" #include "luci/Pass/FoldSparseToDensePass.h" +#include "luci/Pass/FoldSqueezePass.h" #include "luci/Pass/ForwardReshapeToUnaryOpPass.h" #include "luci/Pass/ForwardTransposeOpPass.h" #include "luci/Pass/FuseActivationFunctionPass.h" +#include "luci/Pass/FuseAddToFullyConnectedBiasPass.h" +#include "luci/Pass/FuseAddWithConvPass.h" #include "luci/Pass/FuseAddWithFullyConnectedPass.h" #include "luci/Pass/FuseAddWithTConvPass.h" #include "luci/Pass/FuseBatchNormWithConvPass.h" #include "luci/Pass/FuseBatchNormWithDwConvPass.h" #include "luci/Pass/FuseBatchNormWithTConvPass.h" #include "luci/Pass/FuseBCQPass.h" +#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" #include "luci/Pass/FuseInstanceNormPass.h" #include "luci/Pass/FuseMeanWithMeanPass.h" +#include "luci/Pass/FuseMulWithConvPass.h" +#include "luci/Pass/FuseMulWithDivPass.h" #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" +#include "luci/Pass/FuseRsqrtPass.h" +#include "luci/Pass/FuseSliceWithTConvPass.h" +#include "luci/Pass/FuseHorizontalFullyConnectedPass.h" #include "luci/Pass/FuseTransposeWithMeanPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" #include "luci/Pass/RemoveDuplicateConstPass.h" #include "luci/Pass/RemoveFakeQuantPass.h" +#include "luci/Pass/RemoveGatherGuardPass.h" +#include "luci/Pass/RemoveQDQForMixedPrecisionOpPass.h" #include "luci/Pass/RemoveQuantDequantSeqPass.h" #include "luci/Pass/RemoveRedundantReshapePass.h" #include "luci/Pass/RemoveRedundantTransposePass.h" #include "luci/Pass/RemoveRedundantQuantizePass.h" +#include "luci/Pass/RemoveUnnecessaryAddPass.h" #include "luci/Pass/RemoveUnnecessaryReshapePass.h" #include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h" #include "luci/Pass/RemoveUnnecessarySlicePass.h" #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h" #include "luci/Pass/RemoveUnnecessarySplitPass.h" +#include "luci/Pass/RemoveUnnecessaryTransposeNetPass.h" #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h" #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h" #include "luci/Pass/ReplaceSubWithAddPass.h" +#include "luci/Pass/ReplaceWithFCGeluFCPass.h" #include "luci/Pass/ResolveCustomOpAddPass.h" #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h" #include "luci/Pass/ResolveCustomOpMatMulPass.h" #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h" #include "luci/Pass/ResolveCustomOpSplitVPass.h" +#include "luci/Pass/ResolveFormerCustomOpPass.h" #include "luci/Pass/SparsifyTensorPass.h" #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" #include "luci/Pass/SubstitutePackToReshapePass.h" @@ -71,8 +90,11 @@ #include "luci/Pass/SubstituteTransposeToReshapePass.h" #include "luci/Pass/TransformMinMaxToRelu6Pass.h" #include "luci/Pass/TransformMinReluToRelu6Pass.h" +#include "luci/Pass/TransformSqrtDivToRsqrtMulPass.h" #include "luci/Pass/DecomposeHardSwishPass.h" +#include "luci/Pass/DecomposeSoftmaxPass.h" #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h" +#include "luci/Pass/XpSepActFromTransposeConvPass.h" // TODO add more passes #include "luci/Pass/CircleShapeInferencePass.h" @@ -154,6 +176,7 @@ void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_out phase.emplace_back(std::make_unique()); phase.emplace_back(std::make_unique()); phase.emplace_back(std::make_unique()); + phase.emplace_back(std::make_unique()); // Fuse FullyConnected with Add // Why we perform FuseAddWithFullyConnectedPass before ConvertNCHWToNHWCPass? @@ -237,6 +260,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const phase.emplace_back(std::make_unique()); phase.emplace_back(std::make_unique()); + if (_options->query(Options::Algorithm::CommonSubExpressionElimination)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::ResolveCustomOpAdd)) { phase.emplace_back(std::make_unique()); @@ -249,10 +276,22 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::ResolveFormerCustomOp)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FuseMeanWithMean)) { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseMulWithConv)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::FuseMulWithDiv)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax)) { phase.emplace_back(std::make_unique()); @@ -277,6 +316,18 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseSliceWithTConv)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::FuseAddToFullyConnectedBias)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::FuseAddWithConv)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FuseAddWithFullyConnected)) { phase.emplace_back(std::make_unique()); @@ -289,6 +340,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseMulToFullyConnectedWeights)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FusePRelu)) { phase.emplace_back(std::make_unique()); @@ -297,6 +352,14 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseRsqrt)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::FuseHorizontalFullyConnected)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FuseTransposeWithMean)) { phase.emplace_back(std::make_unique()); @@ -329,10 +392,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FoldMul)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::FoldReshape)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::FoldShape)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FoldSparseToDense)) { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FoldSqueeze)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FusePreActivationBatchNorm)) { phase.emplace_back(std::make_unique()); @@ -357,10 +436,22 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::RemoveGatherGuard)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::RemoveQDQForMixedPrecisionOp)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::RemoveQuantDequantSeq)) { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::RemoveUnnecessaryAdd)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape)) { phase.emplace_back(std::make_unique()); @@ -378,6 +469,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::RemoveUnnecessaryTranspose)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::RemoveRedundantReshape)) { phase.emplace_back(std::make_unique()); @@ -402,6 +497,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::ReplaceWithFCGeluFC)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::SubstitutePackToReshape)) { phase.emplace_back(std::make_unique()); @@ -434,14 +533,30 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::TransformSqrtDivToRsqrtMul)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::DecomposeHardSwishPass)) { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::DecomposeSoftmaxPass)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM)) { phase.emplace_back(std::make_unique()); } + + // NOTE Experimental options; these will be removed someday + // Add experimental options here + if (_options->query(Options::Algorithm::XpSepActFromTransposeConv)) + { + phase.emplace_back(std::make_unique()); + } + // Forward Reshape/Transpose is done after // 1. SubstituteXXXToReshape // 2. RemoveRedundantReshape/Transpose diff --git a/compiler/luci/pass/src/CircleOptimizerUtils.cpp b/compiler/luci/pass/src/CircleOptimizerUtils.cpp deleted file mode 100644 index 127573db4..000000000 --- a/compiler/luci/pass/src/CircleOptimizerUtils.cpp +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "CircleOptimizerUtils.h" - -#include - -namespace luci -{ - -bool has_dynamic_shape(const loco::Node *node) -{ - const auto circle_node = loco::must_cast(node); - for (uint32_t i = 0; i < circle_node->rank(); ++i) - if (!circle_node->dim(i).known()) - return true; - return false; -} - -} // namespace luci diff --git a/compiler/luci/pass/src/CircleOptimizerUtils.h b/compiler/luci/pass/src/CircleOptimizerUtils.h deleted file mode 100644 index e04942bfa..000000000 --- a/compiler/luci/pass/src/CircleOptimizerUtils.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_CIRCLE_OPTIMIZER_UTILS_H__ -#define __LUCI_CIRCLE_OPTIMIZER_UTILS_H__ - -#include - -namespace luci -{ - -bool has_dynamic_shape(const loco::Node *node); - -} // namespace luci - -#endif // __LUCI_CIRCLE_OPTIMIZER_UTILS_H__ diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp index 9039a839f..86ada1f18 100644 --- a/compiler/luci/pass/src/CircleQuantizer.cpp +++ b/compiler/luci/pass/src/CircleQuantizer.cpp @@ -27,6 +27,7 @@ #include "luci/Pass/QuantizeWithMinMaxPass.h" #include "luci/Pass/QuantizeDequantizeWeightsPass.h" #include "luci/Pass/QuantizeWeightsPass.h" +#include "luci/Pass/QuantizeOnnxFakeQuantModelPass.h" #include "luci/Pass/CircleShapeInferencePass.h" #include "luci/Pass/CircleTypeInferencePass.h" @@ -43,13 +44,18 @@ #include #include +#include #include +#include +#include namespace { using namespace luci; using LayerParam = luci::CircleQuantizer::Options::LayerParam; +using LayerParams = luci::CircleQuantizer::Options::LayerParams; +using LayerParamsSet = luci::CircleQuantizer::Options::LayerParamsSet; // This function updates user-given input_type to match with the input signature of graph // If user gives only one input_type, it will be expanded to the number of graph inputs @@ -224,15 +230,18 @@ public: const std::string param(AlgorithmParameters) const final; void params(AlgorithmParameters, std::vector &) final; std::vector params(AlgorithmParameters) const final; - void layer_params(AlgorithmParameters, std::vector> &) final; - std::vector> layer_params(AlgorithmParameters) const final; + void layer_params(AlgorithmParameters, LayerParams &) final; + LayerParams layer_params(AlgorithmParameters) const final; + void layer_params_set(LayerParamsSet &) final; + LayerParamsSet layer_params_set(void) const final; bool query(Algorithm) final; private: std::vector _algorithms; std::map _algorithm_params; std::map> _multiple_params; - std::map>> _layer_params; + std::map _layer_params; + LayerParamsSet _layer_params_set; }; void QuantizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); } @@ -273,14 +282,12 @@ std::vector QuantizeOptionsImpl::params(AlgorithmParameters param) } } -void QuantizeOptionsImpl::layer_params(AlgorithmParameters param, - std::vector> &vec) +void QuantizeOptionsImpl::layer_params(AlgorithmParameters param, LayerParams &vec) { _layer_params[param] = vec; } -std::vector> -QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const +LayerParams QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const { auto param_vec = _layer_params.find(param); if (param_vec != _layer_params.end()) @@ -289,10 +296,14 @@ QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const } else { - return std::vector>(); + return LayerParams(); } } +void QuantizeOptionsImpl::layer_params_set(LayerParamsSet &vec) { _layer_params_set = vec; } + +LayerParamsSet QuantizeOptionsImpl::layer_params_set(void) const { return _layer_params_set; } + bool QuantizeOptionsImpl::query(Algorithm algo) { std::vector::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo); @@ -304,6 +315,66 @@ bool QuantizeOptionsImpl::query(Algorithm algo) } // namespace +namespace +{ + +bool is_valid_params(loco::Graph *g, LayerParams &lps) +{ + // no same name in lps + std::unordered_set us; + for (auto &lp : lps) + { + if (us.find(lp->name) != us.end()) + throw std::runtime_error("Duplicate name found in configuration: " + lp->name); + us.emplace(lp->name); + } + + // all name should be found in graph + for (auto &lp : lps) + { + auto &name = lp->name; + bool found = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto cnode = loco::must_cast(node); + if (cnode->opcode() == luci::CircleOpcode::CIRCLEOUTPUT) + continue; + + if (cnode->name() == name) + { + found = true; + break; + } + } + if (not found) + return false; + } + return true; +} + +LayerParams find_valid_params(loco::Graph *g, LayerParamsSet &lpss) +{ + // valid condition: there should be only one LayerParams that is OK + uint32_t valid_count = 0; + LayerParams params; + for (auto &lps : lpss) + { + if (is_valid_params(g, lps)) + { + valid_count++; + params = lps; + } + } + if (valid_count != 1) + throw std::runtime_error( + "Configuration file has layer names (and alternates) that can be mapped in multiple or no " + "ways. Please update configuration file to have only one valid name mapping."); + + return params; +} + +} // namespace + namespace luci { @@ -332,6 +403,7 @@ void CircleQuantizer::quantize(loco::Graph *g) const _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity); auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params); + auto layer_params_set = _options->layer_params_set(); if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype)) throw std::runtime_error("Unsupported input type. List of supported input type: " + @@ -349,10 +421,15 @@ void CircleQuantizer::quantize(loco::Graph *g) const str_to_dtype(output_model_dtype) != loco::DataType::U8) throw std::runtime_error("Layer-wise quantization only supports uint8 dtype."); + if (layer_params_set.size() > 1u) + { + layer_params = find_valid_params(g, layer_params_set); + } + // Check dtype/granularity of layer params for (auto layer_param : layer_params) { - auto name = layer_param->name; + const auto &name = layer_param->name; if (!in_array(to_lower_case(layer_param->dtype), fakeq_supported_output_model_dtype)) { throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " + @@ -367,7 +444,7 @@ void CircleQuantizer::quantize(loco::Graph *g) const } // Clear existing quantparams before doing fake quantization - for (auto node : loco::active_nodes(loco::output_nodes(g))) + for (auto &node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast(node); if (circle_node->quantparam() != nullptr) @@ -426,7 +503,11 @@ void CircleQuantizer::quantize(loco::Graph *g) const bool TF_style_maxpool = _options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True"; + bool save_min_max = + _options->param(Options::AlgorithmParameters::Quantize_save_min_max) == "True"; + auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params); + auto layer_params_set = _options->layer_params_set(); if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype)) throw std::runtime_error("Unsupported input type. List of supported input types: " + @@ -458,10 +539,15 @@ void CircleQuantizer::quantize(loco::Graph *g) const str_to_dtype(output_model_dtype) != loco::DataType::U8) throw std::runtime_error("Layer-wise quantization only supports uint8 dtype."); + if (layer_params_set.size() > 1u) + { + layer_params = find_valid_params(g, layer_params_set); + } + // Check dtype/granularity of layer params for (auto layer_param : layer_params) { - auto name = layer_param->name; + const auto &name = layer_param->name; if (!in_array(to_lower_case(layer_param->dtype), qwmm_supported_output_model_dtype)) { throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " + @@ -494,6 +580,7 @@ void CircleQuantizer::quantize(loco::Graph *g) const ctx->input_types = input_types; ctx->output_types = output_types; ctx->TF_style_maxpool = TF_style_maxpool; + ctx->save_min_max = save_min_max; for (auto layer_param : layer_params) { @@ -540,7 +627,7 @@ void CircleQuantizer::quantize(loco::Graph *g) const if (_options->query(Options::Algorithm::QuantizeWeights)) { static const std::vector qw_supported_input_model_dtype{"float32"}; - static const std::vector qw_supported_output_model_dtype{"int8", "int16"}; + static const std::vector qw_supported_output_model_dtype{"int4", "int8", "int16"}; static const std::vector qw_supported_granularity{"channel"}; auto input_model_dtype = @@ -571,6 +658,30 @@ void CircleQuantizer::quantize(loco::Graph *g) const weights_quantizer.run(g); } + if (_options->query(Options::Algorithm::QuantizeOnnxFakeQuantizedModel)) + { + auto ctx = std::make_unique(); + { + ctx->default_activation_dtype = loco::DataType::S16; + } + + luci::QuantizeOnnxFakeQuantModelPass quantizer(std::move(ctx)); + + quantizer.run(g); + + logo::Phase phase; + + // Default passes + phase.emplace_back(std::make_unique()); + phase.emplace_back(std::make_unique()); + phase.emplace_back(std::make_unique()); + + ProgressReporter prog(g, logo::PhaseStrategy::Restart); + logo::PhaseRunner phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); + } + // Requantize if (_options->query(Options::Algorithm::Requantize)) { diff --git a/compiler/luci/pass/src/CircleQuantizer.test.cpp b/compiler/luci/pass/src/CircleQuantizer.test.cpp index 5766d5fe5..a16132646 100644 --- a/compiler/luci/pass/src/CircleQuantizer.test.cpp +++ b/compiler/luci/pass/src/CircleQuantizer.test.cpp @@ -16,8 +16,12 @@ #include "luci/CircleQuantizer.h" +#include + #include +#include + using namespace luci; using Algorithms = luci::CircleQuantizer::Options::Algorithm; using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters; @@ -189,3 +193,241 @@ TEST(CircleQuantizerTest, quantize_requant_output_NEG) EXPECT_THROW(o.quantize(&g), std::runtime_error); } + +struct SimpleQuantGraph +{ + void init(void); + + loco::Graph g; + + luci::CircleInput *input = nullptr; + luci::CircleOutput *output = nullptr; + luci::CircleConv2D *conv2d1 = nullptr; + luci::CircleConv2D *conv2d2 = nullptr; + luci::CircleConst *filter = nullptr; + luci::CircleConst *bias = nullptr; +}; + +// Have two conv layers named "c1" and "c2". +void SimpleQuantGraph::init() +{ + auto graph_input = g.inputs()->create(); + graph_input->shape({1, 1, 1, 1}); + graph_input->dtype(loco::DataType::FLOAT32); + + auto graph_output = g.outputs()->create(); + graph_output->shape({1, 1, 1, 1}); + graph_output->dtype(loco::DataType::FLOAT32); + + input = g.nodes()->create(); + input->dtype(loco::DataType::FLOAT32); + input->shape({1, 1, 1, 1}); + input->shape_status(luci::ShapeStatus::VALID); + input->index(graph_input->index()); + + filter = g.nodes()->create(); + filter->dtype(loco::DataType::FLOAT32); + filter->size(1 * 1 * 1 * 1); + filter->shape({1, 1, 1, 1}); + filter->shape_status(luci::ShapeStatus::VALID); + + bias = g.nodes()->create(); + bias->dtype(loco::DataType::FLOAT32); + bias->size(1); + bias->shape({1}); + bias->shape_status(luci::ShapeStatus::VALID); + + conv2d1 = g.nodes()->create(); + conv2d1->dtype(loco::DataType::FLOAT32); + conv2d1->fusedActivationFunction(luci::FusedActFunc::NONE); + conv2d1->input(input); + conv2d1->filter(filter); + conv2d1->bias(bias); + conv2d1->padding(luci::Padding::VALID); + conv2d1->name("c1"); + + conv2d2 = g.nodes()->create(); + conv2d2->dtype(loco::DataType::FLOAT32); + conv2d2->fusedActivationFunction(luci::FusedActFunc::NONE); + conv2d2->input(input); + conv2d2->filter(filter); + conv2d2->bias(conv2d1); + conv2d2->padding(luci::Padding::VALID); + conv2d2->name("c2"); + + output = g.nodes()->create(); + output->dtype(loco::DataType::FLOAT32); + output->from(conv2d2); + output->index(graph_output->index()); +} + +struct SimpleCircleQuantizer +{ + CircleQuantizer::Options *init(); + void quantize(loco::Graph *g) { cq.quantize(g); } + + luci::CircleQuantizer cq; +}; + +CircleQuantizer::Options *SimpleCircleQuantizer::init(void) +{ + auto options = cq.options(); + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + return options; +} + +using LayerParam = luci::CircleQuantizer::Options::LayerParam; +using LayerParams = luci::CircleQuantizer::Options::LayerParams; +using LayerParamsSet = luci::CircleQuantizer::Options::LayerParamsSet; + +TEST(CircleQuantizerTest, quantize_layer_param_set) +{ + SimpleQuantGraph sqg; + sqg.init(); + + LayerParamsSet lpss; + { + LayerParams lps1; + { + auto lp1 = std::make_shared(); + lp1->name = "x1"; + lp1->dtype = "int16"; + lp1->granularity = "channel"; + lps1.emplace_back(lp1); + } + lpss.emplace_back(lps1); + + LayerParams lps2; + { + auto lp2 = std::make_shared(); + lp2->name = "c1"; + lp2->dtype = "int16"; + lp2->granularity = "channel"; + lps2.emplace_back(lp2); + } + lpss.emplace_back(lps2); + } + + SimpleCircleQuantizer scq; + auto options = scq.init(); + options->layer_params_set(lpss); + + EXPECT_NO_THROW(scq.quantize(&sqg.g)); +} + +TEST(CircleQuantizerTest, invalid_layer_params_NEG) +{ + SimpleQuantGraph sqg; + sqg.init(); + + LayerParamsSet lpss; + { + // there is no LayerParam with "c1" nor "c2" + LayerParams lps1; + { + auto lp1 = std::make_shared(); + lp1->name = "x1"; + lp1->dtype = "int16"; + lp1->granularity = "channel"; + lps1.emplace_back(lp1); + } + lpss.emplace_back(lps1); + + LayerParams lps2; + { + auto lp2 = std::make_shared(); + lp2->name = "x2"; + lp2->dtype = "int16"; + lp2->granularity = "channel"; + lps2.emplace_back(lp2); + } + lpss.emplace_back(lps2); + } + + SimpleCircleQuantizer scq; + auto options = scq.init(); + options->layer_params_set(lpss); + + EXPECT_THROW(scq.quantize(&sqg.g), std::runtime_error); +} + +TEST(CircleQuantizerTest, duplicate_name_in_layer_params_NEG) +{ + SimpleQuantGraph sqg; + sqg.init(); + + LayerParamsSet lpss; + { + LayerParams lps1; + { + // duplicate c1 name in a LayerParams + auto lp11 = std::make_shared(); + lp11->name = "c1"; + lp11->dtype = "int16"; + lp11->granularity = "channel"; + lps1.emplace_back(lp11); + + auto lp12 = std::make_shared(); + lp12->name = "c1"; + lp12->dtype = "int16"; + lp12->granularity = "channel"; + lps1.emplace_back(lp12); + } + lpss.emplace_back(lps1); + + LayerParams lps2; + { + auto lp2 = std::make_shared(); + lp2->name = "x1"; + lp2->dtype = "int16"; + lp2->granularity = "channel"; + lps2.emplace_back(lp2); + } + lpss.emplace_back(lps2); + } + + SimpleCircleQuantizer scq; + auto options = scq.init(); + options->layer_params_set(lpss); + + EXPECT_THROW(scq.quantize(&sqg.g), std::runtime_error); +} + +TEST(CircleQuantizerTest, duplicate_layer_params_NEG) +{ + SimpleQuantGraph sqg; + sqg.init(); + + LayerParamsSet lpss; + { + // duplicate "c1" name in a LayerParamsSet + LayerParams lps1; + { + auto lp1 = std::make_shared(); + lp1->name = "c1"; + lp1->dtype = "int16"; + lp1->granularity = "channel"; + lps1.emplace_back(lp1); + } + lpss.emplace_back(lps1); + + LayerParams lps2; + { + auto lp2 = std::make_shared(); + lp2->name = "c1"; + lp2->dtype = "int16"; + lp2->granularity = "channel"; + lps2.emplace_back(lp2); + } + lpss.emplace_back(lps2); + } + + SimpleCircleQuantizer scq; + auto options = scq.init(); + options->layer_params_set(lpss); + + EXPECT_THROW(scq.quantize(&sqg.g), std::runtime_error); +} diff --git a/compiler/luci/pass/src/CircleShapeInferencePass.cpp b/compiler/luci/pass/src/CircleShapeInferencePass.cpp index ddab22421..bcfb6d0b5 100644 --- a/compiler/luci/pass/src/CircleShapeInferencePass.cpp +++ b/compiler/luci/pass/src/CircleShapeInferencePass.cpp @@ -15,6 +15,7 @@ */ #include "helpers/InferenceCandidates.h" +#include "helpers/Shape.h" #include "luci/Pass/CircleShapeInferencePass.h" @@ -22,31 +23,6 @@ #include -namespace -{ - -bool is_same_shape(luci::CircleNode *node, loco::TensorShape shape) -{ - if (node->shape_status() != luci::ShapeStatus::VALID) - return false; - - if (node->rank() != shape.rank()) - return false; - - for (uint32_t i = 0; i < node->rank(); ++i) - { - if (node->dim(i).known() != shape.dim(i).known()) - return false; - - if (node->dim(i).value() != shape.dim(i).value()) - return false; - } - - return true; -} - -} // namespace - namespace luci { diff --git a/compiler/luci/pass/src/CommonSubExpressionEliminationPass.cpp b/compiler/luci/pass/src/CommonSubExpressionEliminationPass.cpp new file mode 100644 index 000000000..e90e385de --- /dev/null +++ b/compiler/luci/pass/src/CommonSubExpressionEliminationPass.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/CommonSubExpressionEliminationPass.h" +#include "helpers/ExpressionCache.h" + +#include + +using namespace luci::pass; + +namespace +{ + +// Return true if node is a virtual node +// TODO Extract this helper to somewhere else +bool virtual_op(const luci::CircleOpcode opcode) +{ + switch (opcode) + { +#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \ + case luci::CircleOpcode::OPCODE: \ + return false; +#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \ + case luci::CircleOpcode::OPCODE: \ + return true; +#include +#undef CIRCLE_NODE +#undef CIRCLE_VNODE + default: + throw std::runtime_error("Unknown opcode detected"); + } +} + +} // namespace + +namespace luci +{ + +bool CommonSubExpressionEliminationPass::run(loco::Graph *g) +{ + // Build cache + ExpressionCache cache; + + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto cnode = loco::must_cast(node); + + // Skip virtual Ops + // Why? virtual Ops do not perform actual computations + // NOTE Fix this if the assumption is not true + if (virtual_op(cnode->opcode())) + continue; + + // Build expression + auto expr = Expression::build(cnode); + + // Invalid (NYI) expression + if (expr.op == nullptr) + continue; + + // Cache hit + if (auto saved_node = cache.get(expr)) + { + loco::replace(cnode).with(saved_node); + changed = true; + } + // Cache miss + else + { + cache.put(expr, cnode); + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/CommonSubExpressionEliminationPass.test.cpp b/compiler/luci/pass/src/CommonSubExpressionEliminationPass.test.cpp new file mode 100644 index 000000000..fd350a7fc --- /dev/null +++ b/compiler/luci/pass/src/CommonSubExpressionEliminationPass.test.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/CommonSubExpressionEliminationPass.h" + +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +using namespace luci; +using namespace luci::test; + +class QuantizeGraphlet +{ +public: + QuantizeGraphlet() = default; + + virtual ~QuantizeGraphlet() = default; + + void init(loco::Graph *g) + { + _quantize = g->nodes()->create(); + + auto qparam = std::make_unique(); + { + qparam->scale.emplace_back(1.0); + qparam->zerop.emplace_back(128); + } + _quantize->name("quantize"); + _quantize->quantparam(std::move(qparam)); + _quantize->dtype(loco::DataType::S16); + } + +protected: + luci::CircleQuantize *_quantize = nullptr; +}; + +class CSE_QuantizeTestGraph : public CommonSubExpressionEliminationTestGraph, + public QuantizeGraphlet +{ +public: + std::vector ops; + +protected: + virtual loco::Node *createExpression(luci::CircleNode *ifm, const std::string &name) override + { + auto expr = g()->nodes()->create(); + + auto qparam = std::make_unique(); + { + qparam->scale.emplace_back(1.0); + qparam->zerop.emplace_back(128); + } + expr->name(name + "_quantize"); + expr->quantparam(std::move(qparam)); + expr->dtype(loco::DataType::S16); + expr->shape({1, 8, 8, 32}); + + expr->input(ifm); + + ops.emplace_back(expr); + + // Set ifm dtype as uint8 + ifm->dtype(loco::DataType::U8); + + return expr; + }; + +public: + void init(void) + { + CommonSubExpressionEliminationTestGraph::init({{1, 8, 8, 32}}, {{1, 8, 8, 32}, {1, 8, 8, 32}}); + } +}; + +class CSE_TransposeTestGraph : public CommonSubExpressionEliminationTestGraph +{ +public: + std::vector ops; + +protected: + virtual loco::Node *createExpression(luci::CircleNode *ifm, const std::string &name) override + { + auto perm = g()->nodes()->create(); + perm->name(name + "_perm"); + perm->dtype(loco::DataType::S32); + perm->shape({4}); + perm->size(4); + perm->at(0) = 0; + perm->at(1) = 3; + perm->at(2) = 1; + perm->at(3) = 2; + + auto expr = g()->nodes()->create(); + expr->name(name + "_transpose"); + expr->dtype(loco::DataType::FLOAT32); + expr->shape({1, 32, 8, 8}); + expr->a(ifm); + expr->perm(perm); + + ops.emplace_back(expr); + + return expr; + }; + +public: + void init(void) + { + CommonSubExpressionEliminationTestGraph::init({{1, 8, 8, 32}}, {{1, 32, 8, 8}, {1, 32, 8, 8}}); + } +}; + +} // namespace + +TEST(CommonSubExpressionEliminationTest, Quantize) +{ + CSE_QuantizeTestGraph g; + luci::CommonSubExpressionEliminationPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(CommonSubExpressionEliminationTest, Quantize_NEG) +{ + CSE_QuantizeTestGraph g; + + luci::CommonSubExpressionEliminationPass pass; + + g.init(); + + // Different pattern + g.ops.at(0)->input(g.ops.at(1)); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(CommonSubExpressionEliminationTest, Transpose) +{ + CSE_TransposeTestGraph g; + luci::CommonSubExpressionEliminationPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(CommonSubExpressionEliminationTest, Transpose_NEG) +{ + CSE_TransposeTestGraph g; + + luci::CommonSubExpressionEliminationPass pass; + + g.init(); + + // Different pattern + g.ops.at(0)->a(g.ops.at(1)); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp index ac4320246..8d782b36e 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp @@ -15,7 +15,7 @@ */ #include "luci/Pass/ConvertNCHWToNHWCPass.h" -#include "CircleOptimizerUtils.h" +#include "helpers/Shape.h" #include #include @@ -98,22 +98,6 @@ bool is_output(const loco::Node *node) return false; } -bool is_same_shape(const luci::CircleNode *node, const std::vector &shape) -{ - if (not node) - return false; - - if (shape.size() != node->rank()) - return false; - - for (uint32_t i = 0; i < shape.size(); i++) - { - if (not(node->dim(i) == shape[i])) - return false; - } - return true; -} - enum class DataFormat { NCHW, @@ -273,6 +257,27 @@ int32_t nchw_axis_to_nhwc(int32_t axis) return to_nhwc[pos_axis]; } +// Return a new CircleConst with NHWC value +luci::CircleConst *create_nhwc_axis(luci::CircleConst *axis) +{ + assert(axis); // FIX_CALLER_UNLESS + assert(axis->dtype() == loco::DataType::S32); // FIX_CALLER_UNLESS + assert(axis->size() == 1); // FIX_CALLER_UNLESS + + auto new_axis = axis->graph()->nodes()->create(); + new_axis->dtype(loco::DataType::S32); + new_axis->size(1); + new_axis->rank(1); + new_axis->dim(0) = 1; + new_axis->at(0) = nchw_axis_to_nhwc(axis->at(0)); + new_axis->shape_status(luci::ShapeStatus::VALID); + new_axis->name(axis->name() + "_NHWC"); + + luci::add_origin(new_axis, luci::get_origin(axis)); + + return new_axis; +} + luci::CircleTranspose *create_post_transpose(luci::CircleNode *node) { return create_4d_transpose(node, {0, 3, 1, 2}); @@ -1270,13 +1275,15 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor if (axis->size() != 1) return false; - axis->at(0) = nchw_axis_to_nhwc(axis->at(0)); + auto new_axis = create_nhwc_axis(axis); + assert(new_axis); // FIX_ME_UNLESS // Insert pre-transpose const auto pred_node = loco::must_cast(node->input()); auto pre_trans = create_pre_transpose(node); pre_trans->a(pred_node); node->input(pre_trans); + node->split_dim(new_axis); // Do shape inference for this node again. node->shape_status(luci::ShapeStatus::UNDEFINED); diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp index ae5ab1519..4a0bc6633 100644 --- a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp +++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp @@ -87,7 +87,8 @@ bool is_quant_act(const luci::CircleNode *node) // 1. dtype is not fp32 // 2. node has qparam // NOTE Quantized const can have the following types -// u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias) +// s4 (weights), u4 (weights), u8 (weights, activation), +// s16 (weights, activation), s32 (bias), s64 (bias) bool is_quant_const(const luci::CircleConst *node) { if (node->dtype() == loco::DataType::FLOAT32) @@ -197,6 +198,7 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor void visit(luci::CircleConv2D *node) { fq_activation(node); } void visit(luci::CircleDepthwiseConv2D *node) { fq_activation(node); } void visit(luci::CircleDiv *node) { fq_activation(node); } + void visit(luci::CircleExp *node) { fq_activation(node); } void visit(luci::CircleFullyConnected *node) { fq_activation(node); } void visit(luci::CircleGelu *node) { fq_activation(node); } void visit(luci::CircleInstanceNorm *node) { fq_activation(node); } @@ -204,6 +206,8 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor void visit(luci::CircleLogistic *node) { fq_activation(node); } void visit(luci::CircleLogSoftmax *node) { fq_activation(node); } void visit(luci::CircleMaxPool2D *node) { fq_activation(node); } + void visit(luci::CircleMaximum *node) { fq_activation(node); } + void visit(luci::CircleMinimum *node) { fq_activation(node); } void visit(luci::CircleMul *node) { fq_activation(node); } void visit(luci::CircleNeg *node) { fq_activation(node); } void visit(luci::CirclePad *node) { fq_activation(node); } diff --git a/compiler/luci/pass/src/CopyQuantParamPass.cpp b/compiler/luci/pass/src/CopyQuantParamPass.cpp index 9b1bb0ea9..fd6557a25 100644 --- a/compiler/luci/pass/src/CopyQuantParamPass.cpp +++ b/compiler/luci/pass/src/CopyQuantParamPass.cpp @@ -59,8 +59,8 @@ bool CopyQuantParamPass::run(loco::Graph *g) for (uint32_t i = 0; i < _src_tensors.size(); i++) { - auto src = _src_tensors[i]; - auto dst = _dst_tensors[i]; + auto &src = _src_tensors[i]; + auto &dst = _dst_tensors[i]; auto nodes = get_src_dst(src, dst); if (not nodes.src) @@ -71,6 +71,12 @@ bool CopyQuantParamPass::run(loco::Graph *g) copy_quantparam(nodes.src, nodes.dst); + if (auto output = dynamic_cast(nodes.dst)) + { + auto from_node = loco::must_cast(output->from()); + copy_quantparam(output, from_node); + } + INFO(l) << "Quantparam of " << src << " is copied to " << dst << std::endl; } diff --git a/compiler/luci/pass/src/DecomposeHardSwishPass.cpp b/compiler/luci/pass/src/DecomposeHardSwishPass.cpp index bd99d2de0..2963cd8a8 100644 --- a/compiler/luci/pass/src/DecomposeHardSwishPass.cpp +++ b/compiler/luci/pass/src/DecomposeHardSwishPass.cpp @@ -16,9 +16,6 @@ #include "luci/Pass/DecomposeHardSwishPass.h" -#include "helpers/NodeFiller.h" -#include "helpers/TypeMapper.h" - #include #include diff --git a/compiler/luci/pass/src/DecomposeSoftmaxPass.cpp b/compiler/luci/pass/src/DecomposeSoftmaxPass.cpp new file mode 100644 index 000000000..0bfd85224 --- /dev/null +++ b/compiler/luci/pass/src/DecomposeSoftmaxPass.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/DecomposeSoftmaxPass.h" + +#include +#include + +namespace +{ +/** + * BEFORE + * [CircleNode] + * | + * | + * [CircleSoftmax] + * | + * | + * [CircleNode] + * + * + * AFTER + * + * [CircleNode] [CircleConst(=-1)] + * | \ / | + * | \ / | + * | [CircleReduceMax] | + * | / | + * | / | + * | / | + * [Sub] | + * | | + * | [CircleConst(=beta)] | + * | / | + * | / | + * [Mul] (if beta != 1) | + * | | + * [Exp] | + * | \ | + * | \ | + * | [CircleSum]-----------+ + * | / + * | / + * [Div] + * | + * | + * [CircleNode] + */ +bool decompose_softmax(luci::CircleSoftmax *softmax) +{ + if (!softmax) + return false; + + if (softmax->dtype() != loco::DataType::FLOAT32) + return false; + + auto const input = loco::must_cast(softmax->logits()); + auto g = softmax->graph(); + + auto const beta = softmax->beta(); + auto const name = softmax->name(); + assert(name.length() > 0); + + // fill reduction index (-1) for CircleReduceMax and CircleSum + auto index_const = g->nodes()->create(); + index_const->shape({}); // scalar + index_const->dtype(loco::DataType::S32); + index_const->rank(0); + index_const->size(1); + index_const->at(0) = -1; + index_const->name(name + "/Softmax/reduction_index"); + luci::add_origin(index_const, luci::get_origin(softmax)); + + // Create CircleReduceMax operation + auto max = g->nodes()->create(); + max->input(input); + max->reduction_indices(index_const); + max->keep_dims(true); + max->name(name + "/Softmax/max"); + luci::add_origin(max, luci::get_origin(softmax)); + + // Create CircleSub operation + auto sub = g->nodes()->create(); + sub->x(input); + sub->y(max); + sub->fusedActivationFunction(luci::FusedActFunc::NONE); + sub->name(name + "/Softmax/sub"); + luci::add_origin(sub, luci::get_origin(softmax)); + + // input for exp can be either sub or mul (in case beta != 1) + loco::Node *exp_input = sub; + + // multiply sub by beta in case it is nonunit + if (std::abs(beta - 1.f) > 1.e-05f) + { + // Create constant for beta + auto beta_const = g->nodes()->create(); + beta_const->shape({}); // scalar + beta_const->dtype(loco::DataType::FLOAT32); + beta_const->rank(0); + beta_const->size(1); + beta_const->at(0) = beta; + beta_const->name(name + "/Softmax/beta_const"); + luci::add_origin(beta_const, luci::get_origin(softmax)); + + // Create CircleMul + auto mul = g->nodes()->create(); + mul->x(sub); + mul->y(beta_const); + mul->fusedActivationFunction(luci::FusedActFunc::NONE); + mul->name(name + "/Softmax/beta_mul"); + luci::add_origin(mul, luci::get_origin(softmax)); + + exp_input = mul; + } + + // Create CircleExp operation + auto exp = g->nodes()->create(); + exp->x(exp_input); + exp->name(name + "/Softmax/exp"); + luci::add_origin(exp, luci::get_origin(softmax)); + + // Create CircleSum operation + auto sum = g->nodes()->create(); + sum->input(exp); + sum->reduction_indices(index_const); + sum->keep_dims(true); + sum->name(name + "/Softmax/sum"); + luci::add_origin(sum, luci::get_origin(softmax)); + + // Create CircleDiv operation + auto div = g->nodes()->create(); + div->x(exp); + div->y(sum); + div->fusedActivationFunction(luci::FusedActFunc::NONE); + div->name(name + "/Softmax/div"); + luci::add_origin(div, luci::get_origin(softmax)); + + replace(softmax).with(div); + + return true; +} + +} // namespace + +namespace luci +{ + +bool DecomposeSoftmaxPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto softmax = dynamic_cast(node)) + { + if (decompose_softmax(softmax)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/DecomposeSoftmaxPass.test.cpp b/compiler/luci/pass/src/DecomposeSoftmaxPass.test.cpp new file mode 100644 index 000000000..481ed45ad --- /dev/null +++ b/compiler/luci/pass/src/DecomposeSoftmaxPass.test.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/DecomposeSoftmaxPass.h" + +#include + +#include + +namespace +{ + +/** + * Softmax graph + * + * [CircleInput] + * | + * | + * [CircleSoftMax] + * | + * | + * [CircleOutput] + */ +template struct SoftmaxGraph +{ + loco::Graph _g; + luci::CircleInput *_input = nullptr; + luci::CircleSoftmax *_softmax = nullptr; + luci::CircleOutput *_output = nullptr; + + SoftmaxGraph() + { + const int N = 1; + const int H = 4; + const int W = 4; + const int C = 3; + + // graph input and output + auto graph_input = _g.inputs()->create(); + auto graph_output = _g.outputs()->create(); + + // CircleInput + _input = _g.nodes()->create(); + _input->index(graph_input->index()); + _input->shape({N, H, W, C}); + _input->dtype(T); + _input->name("input"); + + // CircleSoftmax + _softmax = _g.nodes()->create(); + _softmax->logits(_input); + _softmax->shape({N, H, W, C}); + _softmax->dtype(T); + _softmax->name("softmax"); + _softmax->beta(0.5f); + + // CircleOutput + _output = _g.nodes()->create(); + _output->index(graph_output->index()); + _output->from(_softmax); + _output->shape({N, H, W, C}); + _output->dtype(T); + _output->name("output"); + } +}; + +} // namespace + +TEST(DecomposeSoftmaxPass, simple_test) +{ + /** + * tests: + * 1) decomposition pass has nonnull name + * 2) decomposition runs successfully for float32 softmax graph + * 3) resulting graph has the following structure: + * + * [CircleNode] [CircleConst(=-1)] + * | \ / | + * | \ / | + * | [CircleReduceMax] | + * | / | + * | / | + * | / | + * [Sub] | + * | | + * | [CircleConst(=0.5)] | + * | / | + * | / | + * [Mul] | + * | | + * [Exp] | + * | \ | + * | \ | + * | [CircleSum]-----------+ + * | / + * | / + * [Div] + * | + * | + * [CircleNode] + */ + luci::DecomposeSoftmaxPass pass; + SoftmaxGraph softmax_g; + + auto const name = pass.name(); + ASSERT_NE(nullptr, name); + + auto ret = pass.run(&softmax_g._g); + EXPECT_TRUE(ret); + + auto div = dynamic_cast(softmax_g._output->from()); + EXPECT_NE(nullptr, div); + + auto exp = dynamic_cast(div->x()); + EXPECT_NE(nullptr, exp); + + auto sum = dynamic_cast(div->y()); + EXPECT_NE(nullptr, sum); + + auto exp_1 = dynamic_cast(sum->input()); + EXPECT_EQ(exp, exp_1); + + auto indices = dynamic_cast(sum->reduction_indices()); + EXPECT_NE(nullptr, indices); + EXPECT_EQ(indices->dtype(), loco::DataType::S32); + EXPECT_EQ(indices->size(), 1); + EXPECT_EQ(indices->scalar(), -1); + + auto mul = dynamic_cast(exp->x()); + EXPECT_NE(nullptr, mul); + + auto sub = dynamic_cast(mul->x()); + EXPECT_NE(nullptr, sub); + + auto beta = dynamic_cast(mul->y()); + EXPECT_NE(nullptr, beta); + EXPECT_EQ(beta->dtype(), loco::DataType::FLOAT32); + EXPECT_EQ(beta->size(), 1); + EXPECT_FLOAT_EQ(beta->scalar(), 0.5f); + + auto input = dynamic_cast(sub->x()); + EXPECT_NE(nullptr, input); + + auto max = dynamic_cast(sub->y()); + EXPECT_NE(nullptr, max); + + auto indices_1 = dynamic_cast(max->reduction_indices()); + EXPECT_NE(nullptr, indices_1); + EXPECT_EQ(indices, indices_1); + + auto input_1 = dynamic_cast(max->input()); + EXPECT_NE(nullptr, input_1); + EXPECT_EQ(input, input_1); +} + +TEST(DecomposeSoftmaxPass, wrong_condition_NEG) +{ + luci::DecomposeSoftmaxPass pass; + SoftmaxGraph softmax_g; + + auto ret = pass.run(&softmax_g._g); + EXPECT_FALSE(ret); + + auto softmax = dynamic_cast(softmax_g._output->from()); + EXPECT_NE(nullptr, softmax); +} diff --git a/compiler/luci/pass/src/FoldAddV2Pass.cpp b/compiler/luci/pass/src/FoldAddV2Pass.cpp index 20c1022f8..7952b6b11 100644 --- a/compiler/luci/pass/src/FoldAddV2Pass.cpp +++ b/compiler/luci/pass/src/FoldAddV2Pass.cpp @@ -18,8 +18,6 @@ #include -#include - namespace { diff --git a/compiler/luci/pass/src/FoldCastPass.cpp b/compiler/luci/pass/src/FoldCastPass.cpp index 00b86fe48..3ecbe2b4c 100644 --- a/compiler/luci/pass/src/FoldCastPass.cpp +++ b/compiler/luci/pass/src/FoldCastPass.cpp @@ -26,24 +26,41 @@ luci::CircleConst *cast_const(luci::CircleConst *node, loco::DataType from_dtype { assert(node->dtype() == from_dtype); + enum CAST_TYPES + { + CAST_NONE = 0, + CAST_S64_S32, + }; + + CAST_TYPES cast_type = CAST_NONE; + if (from_dtype == loco::DataType::S64) + { + if (to_dtype == loco::DataType::S32) + cast_type = CAST_S64_S32; + } + // TODO: Support more data types + if (cast_type == CAST_NONE) + return nullptr; + auto name = node->name(); assert(name.length() > 0); + auto constant = node->graph()->nodes()->create(); constant->dtype(to_dtype); constant->rank(node->rank()); uint32_t num_elems = 1; + for (uint32_t i = 0; i < node->rank(); i++) { constant->dim(i).set(node->dim(i).value()); num_elems *= node->dim(i).value(); } - constant->shape_status(luci::ShapeStatus::VALID); // TODO: Support more data types - if (from_dtype == loco::DataType::S64) + switch (cast_type) { - if (to_dtype == loco::DataType::S32) + case CAST_S64_S32: { constant->size(num_elems); for (uint32_t i = 0; i < num_elems; i++) @@ -53,7 +70,8 @@ luci::CircleConst *cast_const(luci::CircleConst *node, loco::DataType from_dtype constant->name(name + "_S32"); return constant; } - return nullptr; + default: + break; } return nullptr; diff --git a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp index 33f9f1d77..455a69586 100644 --- a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp +++ b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp @@ -16,89 +16,51 @@ #include "luci/Pass/FoldDepthwiseConv2DPass.h" -#include +#include "helpers/Compute.h" +#include "helpers/Shape.h" #include #include #include -#include // std::numeric_limits +#include -namespace +#include + +namespace luci { -// TODO Share activation mix/max and compute_input/output code with luci-interpreter +namespace +{ -bool compute_output(uint32_t *output_size, luci::Padding padding, int32_t image_size, - int32_t filter_size, int32_t stride, int32_t dilation_rate) +bool set_params(const luci::CircleDepthwiseConv2D *node, compute::DepthwiseConv2D &cdc) { - auto const effective_filter_size = (filter_size - 1) * dilation_rate + 1; - switch (padding) - { - case luci::Padding::SAME: - *output_size = (image_size + stride - 1) / stride; - return true; + assert(node); - case luci::Padding::VALID: - *output_size = (image_size + stride - effective_filter_size) / stride; - return true; + LOGGER(l); - default: - { - LOGGER(l); - WARN(l) << "Unsupported padding: " << uint32_t(padding); - return false; - } + auto ¶ms = cdc.params(); + if (!to_compute(node->padding(), params.padding_type)) + { + WARN(l) << "FoldDepthwiseConv2DPass unsupported padding: " << uint32_t(node->padding()); + return false; } -} -uint32_t compute_padding(int32_t stride, int32_t dilation_rate, int32_t in_size, - int32_t filter_size, int32_t out_size) -{ - auto const effective_filter_size = (filter_size - 1) * dilation_rate + 1; - auto const padding = ((out_size - 1) * stride + effective_filter_size - in_size) / 2; - return padding > 0 ? padding : 0; -} + params.stride_height = node->stride()->h(); + params.stride_width = node->stride()->w(); + params.dilation_height_factor = node->dilation()->h(); + params.dilation_width_factor = node->dilation()->w(); + params.depth_multiplier = node->depthMultiplier(); -bool set_kernel_parameters(tflite::DepthwiseParams *params, luci::CircleDepthwiseConv2D *node, - uint32_t padding_height, uint32_t padding_width) -{ - switch (node->fusedActivationFunction()) + compute::FusedActFunc fac; + if (!to_compute(node->fusedActivationFunction(), fac)) { - case luci::FusedActFunc::NONE: - case luci::FusedActFunc::TANH: - params->float_activation_min = std::numeric_limits::lowest(); - params->float_activation_max = std::numeric_limits::max(); - break; - case luci::FusedActFunc::RELU: - params->float_activation_min = 0; - params->float_activation_max = std::numeric_limits::max(); - break; - case luci::FusedActFunc::RELU_N1_TO_1: - params->float_activation_min = -1; - params->float_activation_max = 1; - break; - case luci::FusedActFunc::RELU6: - params->float_activation_min = 0; - params->float_activation_max = 6; - break; - default: - { - LOGGER(l); - WARN(l) << "Unsupported activation: " << uint32_t(node->fusedActivationFunction()); - return false; - } + WARN(l) << "FoldDepthwiseConv2DPass unsupported activation: " + << uint32_t(node->fusedActivationFunction()); + return false; } - - params->stride_height = node->stride()->h(); - params->stride_width = node->stride()->w(); - params->dilation_height_factor = node->dilation()->h(); - params->dilation_width_factor = node->dilation()->w(); - params->depth_multiplier = node->depthMultiplier(); - - params->padding_values.height = padding_height; - params->padding_values.width = padding_width; + cdc.fused_act_func(fac); return true; } @@ -118,91 +80,59 @@ bool set_kernel_parameters(tflite::DepthwiseParams *params, luci::CircleDepthwis */ bool fold_depthwise_conv_2d(luci::CircleDepthwiseConv2D *node) { - LOGGER(l); - auto const input = dynamic_cast(node->input()); - if (input == nullptr) return false; // Constant input is required for folding auto const filter = dynamic_cast(node->filter()); - if (filter == nullptr) return false; // Constant filter is required for folding - if (filter->dim(0).value() != 1) return false; // Unsupported batch size auto const bias = dynamic_cast(node->bias()); - if (bias == nullptr) return false; // Constant bias is required for folding - auto const input_batches = input->dim(0).value(); - auto const input_height = input->dim(1).value(); - auto const input_width = input->dim(2).value(); - auto const input_depth = input->dim(3).value(); - - auto const filter_height = filter->dim(1).value(); - auto const filter_width = filter->dim(2).value(); - auto const filter_channels_out = filter->dim(3).value(); - - if (filter_channels_out % input_depth != 0) - return false; // Wrong input/output depth ratio - - if (node->depthMultiplier() != static_cast(filter_channels_out / input_depth)) - return false; // Wrong depth multiplier value - - if (bias->rank() != 1 || bias->dim(0).value() != filter_channels_out) - return false; // Unsupported bias value - - uint32_t output_height = 0; - uint32_t output_width = 0; - - if (!compute_output(&output_height, node->padding(), input_height, filter_height, - node->stride()->h(), node->dilation()->h())) - return false; // Unsupported output parameters + auto static_shape = [](luci::CircleNode *node) { + loco::TensorShape shape; + shape.rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); ++i) + shape.dim(i) = node->dim(i); + return shape; + }; - if (!compute_output(&output_width, node->padding(), input_width, filter_width, - node->stride()->w(), node->dilation()->w())) - return false; // Unsupported output parameters + auto const input_data = &input->at(0); + auto const filter_data = &filter->at(0); + auto const bias_data = &bias->at(0); - auto const padding_height = compute_padding(node->stride()->h(), node->dilation()->h(), - input_height, filter_height, output_height); - auto const padding_width = compute_padding(node->stride()->w(), node->dilation()->w(), - input_width, filter_width, output_width); + compute::DepthwiseConv2D comp_dwconv2d{}; + if (!set_params(node, comp_dwconv2d)) + return false; + comp_dwconv2d.input(static_shape(input), input_data); + comp_dwconv2d.filter(static_shape(filter), filter_data); + comp_dwconv2d.bias(static_shape(bias), bias_data); - tflite::DepthwiseParams params{}; + if (!comp_dwconv2d.prepare()) + return false; - if (!set_kernel_parameters(¶ms, node, padding_height, padding_width)) - return false; // Unsupported kernel parameter values + auto output_shape = comp_dwconv2d.output_shape(); + assert(is_same_shape(node, output_shape)); + auto output_size = loco::element_count(&output_shape); + // result folded constant node auto constant = node->graph()->nodes()->create(); - constant->name(node->name()); constant->dtype(node->dtype()); constant->rank(node->rank()); + for (uint32_t i = 0; i < output_shape.rank(); ++i) + constant->dim(i).set(output_shape.dim(i).value()); constant->shape_status(luci::ShapeStatus::VALID); - for (uint32_t i = 0; i < node->rank(); ++i) - constant->dim(i).set(node->dim(i).value()); - - constant->size(input_batches * output_height * output_width * - filter_channels_out); - - auto const input_data = &input->at(0); - auto const filter_data = &filter->at(0); - auto const bias_data = &bias->at(0); - auto const constant_data = &constant->at(0); - - auto tensor_shape = [](luci::CircleNode *node) { - tflite::RuntimeShape runtime_shape(node->rank()); - for (uint32_t i = 0; i < node->rank(); ++i) - runtime_shape.SetDim(i, node->dim(i).value()); - return runtime_shape; - }; + constant->size(output_size); + constant->name(node->name()); - tflite::reference_ops::DepthwiseConv(params, tensor_shape(input), input_data, - tensor_shape(filter), filter_data, tensor_shape(bias), - bias_data, tensor_shape(constant), constant_data); + auto constant_data = &constant->at(0); + comp_dwconv2d.output(constant_data); + comp_dwconv2d.compute(); loco::replace(node).with(constant); @@ -211,9 +141,6 @@ bool fold_depthwise_conv_2d(luci::CircleDepthwiseConv2D *node) } // namespace -namespace luci -{ - /** * Constant Folding for DepthwiseConv2D Op **/ diff --git a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp index 36cae0437..3a1b54492 100644 --- a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp +++ b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp @@ -56,6 +56,7 @@ public: _dconv->filter(_dconv_filter); _dconv->bias(_dconv_bias); _dconv->shape({1, 4, 4, 1}); + _dconv->shape_status(luci::ShapeStatus::VALID); _dconv->stride()->h(1); _dconv->stride()->w(1); _dconv->depthMultiplier(1); diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp index b6526deb0..39652a5aa 100644 --- a/compiler/luci/pass/src/FoldDequantizePass.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.cpp @@ -40,6 +40,10 @@ bool is_foldable_const(luci::CircleConst *node) if (node->quantparam() == nullptr) return false; + if (node->dtype() == loco::DataType::S4) + return true; + if (node->dtype() == loco::DataType::U4) + return true; if (node->dtype() == loco::DataType::S8) return true; if (node->dtype() == loco::DataType::U8) @@ -105,6 +109,18 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) switch (const_node->dtype()) { + case loco::DataType::U4: + new_const_node->at(i) = + static_cast(const_node->at(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::S4: + new_const_node->at(i) = + static_cast(const_node->at(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; case loco::DataType::S8: new_const_node->at(i) = static_cast(const_node->at(i) - diff --git a/compiler/luci/pass/src/FoldDequantizePass.test.cpp b/compiler/luci/pass/src/FoldDequantizePass.test.cpp index 87dff5dc0..8873fc345 100644 --- a/compiler/luci/pass/src/FoldDequantizePass.test.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.test.cpp @@ -86,6 +86,14 @@ protected: luci::CircleConst *_input = nullptr; }; +class U4FoldDequantizeTest : public FoldDequantizeTest +{ +}; + +class S4FoldDequantizeTest : public FoldDequantizeTest +{ +}; + class S8FoldDequantizeTest : public FoldDequantizeTest { }; @@ -205,6 +213,80 @@ TEST_F(U8FoldDequantizeTest, fold_dequant_basic_NEG) EXPECT_EQ(nullptr, folded_const); } +TEST_F(U4FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(-5.0, folded_const->at(0)); + EXPECT_EQ(0.0, folded_const->at(1)); + EXPECT_EQ(0.0, folded_const->at(2)); + EXPECT_EQ(10.0, folded_const->at(3)); + EXPECT_EQ(15.0, folded_const->at(4)); + EXPECT_EQ(20.0, folded_const->at(5)); + EXPECT_EQ(40.0, folded_const->at(6)); + EXPECT_EQ(50.0, folded_const->at(7)); +} + +TEST_F(U4FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(S4FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(-5.0, folded_const->at(0)); + EXPECT_EQ(0.0, folded_const->at(1)); + EXPECT_EQ(0.0, folded_const->at(2)); + EXPECT_EQ(10.0, folded_const->at(3)); + EXPECT_EQ(15.0, folded_const->at(4)); + EXPECT_EQ(20.0, folded_const->at(5)); + EXPECT_EQ(40.0, folded_const->at(6)); + EXPECT_EQ(50.0, folded_const->at(7)); +} + +TEST_F(S4FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + TEST_F(S8FoldDequantizeTest, fold_dequant_basic) { luci::FoldDequantizePass pass; diff --git a/compiler/luci/pass/src/FoldFullyConnectedPass.cpp b/compiler/luci/pass/src/FoldFullyConnectedPass.cpp index a3bca7eda..8fa0720e6 100644 --- a/compiler/luci/pass/src/FoldFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FoldFullyConnectedPass.cpp @@ -16,50 +16,47 @@ #include "luci/Pass/FoldFullyConnectedPass.h" -#include +#include "helpers/Compute.h" +#include "helpers/Shape.h" #include #include #include -#include // std::numeric_limits +#include + +#include + +namespace luci +{ namespace { -bool set_kernel_parameters(tflite::FullyConnectedParams *params, luci::CircleFullyConnected *node) +bool set_params(const luci::CircleFullyConnected *node, compute::FullyConnected &cfc) { - switch (node->fusedActivationFunction()) + assert(node); + + LOGGER(l); + + // NOTE only support default for now + if (node->weights_format() != luci::CircleFullyConnected::WeightsFormat::DEFAULT) { - case luci::FusedActFunc::NONE: - case luci::FusedActFunc::TANH: - params->float_activation_min = std::numeric_limits::lowest(); - params->float_activation_max = std::numeric_limits::max(); - break; - case luci::FusedActFunc::RELU: - params->float_activation_min = 0; - params->float_activation_max = std::numeric_limits::max(); - break; - case luci::FusedActFunc::RELU_N1_TO_1: - params->float_activation_min = -1; - params->float_activation_max = 1; - break; - case luci::FusedActFunc::RELU6: - params->float_activation_min = 0; - params->float_activation_max = 6; - break; - default: - { - LOGGER(l); - WARN(l) << "Unsupported activation: " << uint32_t(node->fusedActivationFunction()); - return false; - } + WARN(l) << "FoldFullyConnectedPass unsupported weights_format: " + << uint32_t(node->weights_format()); + return false; } + cfc.params().weights_format = compute::FullyConnectedWeightsFormat::kDefault; - assert(node->weights_format() == - luci::CircleFullyConnected::WeightsFormat::DEFAULT); // FIX_CALLER_UNLESS - params->weights_format = tflite::FullyConnectedWeightsFormat::kDefault; + compute::FusedActFunc fac; + if (!to_compute(node->fusedActivationFunction(), fac)) + { + WARN(l) << "FoldFullyConnectedPass unsupported activation: " + << uint32_t(node->fusedActivationFunction()); + return false; + } + cfc.fused_act_func(fac); return true; } @@ -85,8 +82,6 @@ bool fold_fully_connected(luci::CircleFullyConnected *node) { RETURN_FALSE_UNLESS(node != nullptr); - LOGGER(l); - auto const input = dynamic_cast(node->input()); auto const weights = dynamic_cast(node->weights()); auto const bias = dynamic_cast(node->bias()); @@ -94,77 +89,59 @@ bool fold_fully_connected(luci::CircleFullyConnected *node) RETURN_FALSE_UNLESS(input != nullptr); RETURN_FALSE_UNLESS(weights != nullptr); - RETURN_FALSE_UNLESS(node->weights_format() == luci::CircleFullyConnected::WeightsFormat::DEFAULT); RETURN_FALSE_UNLESS(bias != nullptr or no_bias != nullptr); RETURN_FALSE_UNLESS(input->dtype() == loco::DataType::FLOAT32); RETURN_FALSE_UNLESS(weights->dtype() == loco::DataType::FLOAT32); + + auto const input_data = &input->at(0); + auto const weights_data = &weights->at(0); + float *bias_data = nullptr; if (bias) + { RETURN_FALSE_UNLESS(bias->dtype() == loco::DataType::FLOAT32); + bias_data = &bias->at(0); + } - auto const input_elems = input->size(); + auto static_shape = [](luci::CircleNode *node) { + loco::TensorShape shape; + if (not node) + return shape; + shape.rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); ++i) + shape.dim(i) = node->dim(i); + return shape; + }; - RETURN_FALSE_UNLESS(weights->rank() == 2); - RETURN_FALSE_UNLESS(input_elems % weights->dim(1).value() == 0); - auto const batch_size = input_elems / weights->dim(1).value(); - auto const num_units = weights->dim(0).value(); + compute::FullyConnected comp_fc{}; + if (!set_params(node, comp_fc)) + return false; + comp_fc.input(static_shape(input), input_data); + comp_fc.weights(static_shape(weights), weights_data); + comp_fc.bias(static_shape(bias), bias_data); - if (bias) - RETURN_FALSE_UNLESS(bias->size() == num_units); + comp_fc.keep_num_dims(node->keep_num_dims()); - tflite::FullyConnectedParams params{}; - if (!set_kernel_parameters(¶ms, node)) - return false; // Unsupported kernel parameter values + if (!comp_fc.prepare()) + return false; - std::vector output_shape; - if (node->keep_num_dims() == false) - { - output_shape.push_back(batch_size); - output_shape.push_back(num_units); - } - else - { - output_shape.resize(input->rank()); - for (uint32_t i = 0; i < input->rank(); i++) - output_shape[i] = input->dim(i).value(); - output_shape[input->rank() - 1] = num_units; - } + const auto &output_shape = comp_fc.output_shape(); + assert(is_same_shape(node, output_shape)); + auto output_size = loco::element_count(&output_shape); auto constant = node->graph()->nodes()->create(); { - constant->name(node->name()); constant->dtype(node->dtype()); constant->rank(node->rank()); - constant->shape_status(luci::ShapeStatus::VALID); - uint32_t num_elem = 1; for (uint32_t i = 0; i < node->rank(); ++i) - { constant->dim(i).set(node->dim(i).value()); - num_elem *= node->dim(i).value(); - } - constant->size(num_elem); + constant->shape_status(luci::ShapeStatus::VALID); + constant->size(output_size); + constant->name(node->name()); } - - auto tensor_shape = [](luci::CircleNode *node) { - if (node == nullptr) - return tflite::RuntimeShape(); - - tflite::RuntimeShape runtime_shape(node->rank()); - for (uint32_t i = 0; i < node->rank(); ++i) - runtime_shape.SetDim(i, node->dim(i).value()); - return runtime_shape; - }; - - auto tensor_data = [](luci::CircleConst *node) -> float * { - if (node == nullptr) - return nullptr; - - return &node->at(0); - }; - - tflite::reference_ops::FullyConnected( - params, tensor_shape(input), tensor_data(input), tensor_shape(weights), tensor_data(weights), - tensor_shape(bias), tensor_data(bias), tensor_shape(constant), tensor_data(constant)); + auto constant_data = &constant->at(0); + comp_fc.output(constant_data); + comp_fc.compute(); loco::replace(node).with(constant); @@ -173,9 +150,6 @@ bool fold_fully_connected(luci::CircleFullyConnected *node) } // namespace -namespace luci -{ - /** * Constant Folding for FullyConnected Op **/ diff --git a/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp index a8e64a24b..4a1dd8cb7 100644 --- a/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp @@ -58,6 +58,7 @@ public: _fc->weights(_fc_weights); _fc->bias(_fc_bias); _fc->shape({NUM_UNITS}); + _fc->shape_status(luci::ShapeStatus::VALID); _fc->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT); _fc->keep_num_dims(true); diff --git a/compiler/luci/pass/src/FoldGatherPass.cpp b/compiler/luci/pass/src/FoldGatherPass.cpp index f179d74bd..43422cb87 100644 --- a/compiler/luci/pass/src/FoldGatherPass.cpp +++ b/compiler/luci/pass/src/FoldGatherPass.cpp @@ -15,7 +15,6 @@ */ #include "luci/Pass/FoldGatherPass.h" -#include "CircleOptimizerUtils.h" #include diff --git a/compiler/luci/pass/src/FoldMulPass.cpp b/compiler/luci/pass/src/FoldMulPass.cpp new file mode 100644 index 000000000..65112911e --- /dev/null +++ b/compiler/luci/pass/src/FoldMulPass.cpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldMulPass.h" + +#include + +#include + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +namespace +{ + +/** + * @return higher rank of x, y or nullptr if not compatible + */ +const luci::CircleConst *compatible_shape(const luci::CircleConst *x, const luci::CircleConst *y) +{ + if (x->rank() >= y->rank()) + { + uint32_t d = x->rank() - y->rank(); + for (uint32_t i = 0; i < y->rank(); i++) + { + // NOTE dim() has only '==' operator + if (!(x->dim(i + d) == y->dim(i))) + return nullptr; + } + return x; + } + else + { + uint32_t d = y->rank() - x->rank(); + for (uint32_t i = 0; i < x->rank(); i++) + { + if (!(x->dim(i) == y->dim(i + d))) + return nullptr; + } + return y; + } +} + +/** + * Fold Mul to const if both inputs are const + **/ +bool fold_mul(luci::CircleMul *mul) +{ + CHECK_OR_FALSE(mul); + CHECK_OR_FALSE(mul->dtype() == loco::DataType::FLOAT32); + + // Check inputs are const and compatible + auto x = dynamic_cast(mul->x()); + auto y = dynamic_cast(mul->y()); + CHECK_OR_FALSE(x); + CHECK_OR_FALSE(y); + CHECK_OR_FALSE(x->dtype() == y->dtype()); + const auto xy = compatible_shape(x, y); + CHECK_OR_FALSE(xy); + + auto name_x = x->name(); + auto name_y = y->name(); + assert(name_x.length() > 0); + assert(name_y.length() > 0); + auto folded_const = mul->graph()->nodes()->create(); + folded_const->dtype(xy->dtype()); + folded_const->rank(xy->rank()); + for (uint32_t i = 0; i < xy->rank(); i++) + folded_const->dim(i).set(xy->dim(i).value()); + + const auto size_x = x->size(); + const auto size_y = y->size(); + const auto size_xy = xy->size(); + folded_const->size(size_xy); + for (uint32_t i = 0; i < size_xy; i++) + { + auto xv = x->at(i % size_x); + auto yv = y->at(i % size_y); + folded_const->at(i) = xv * yv; + } + + folded_const->shape_status(luci::ShapeStatus::VALID); + folded_const->name(name_x + "_" + name_y); + + loco::replace(mul).with(folded_const); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for Mul Op + **/ +bool FoldMulPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto mul = dynamic_cast(node)) + { + if (fold_mul(mul)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldMulPass.test.cpp b/compiler/luci/pass/src/FoldMulPass.test.cpp new file mode 100644 index 000000000..0c6de971f --- /dev/null +++ b/compiler/luci/pass/src/FoldMulPass.test.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldMulPass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +/** + * Graph has an Mul Op with constant inputs + * + * BEFORE + * + * [CircleConst] [CircleConst] + * | | + * [CircleMul] + * | + * [CircleNode] + * AFTER + * [CircleConst] [CircleConst] + * | | + * [CircleConst] [CircleMul] + * | + * [CircleNode] + */ + +template class FoldMulTest : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldMulTest(std::initializer_list shape) : luci::ConstantFoldingAddTestGraph(shape, T) + { + _mul = _g.nodes()->template create(); + _x = _g.nodes()->template create(); + _y = _g.nodes()->template create(); + + _mul->dtype(T); + _x->dtype(T); + _y->dtype(T); + + _mul->shape(shape); + _x->shape(shape); + _y->shape(shape); + + uint32_t num_elems = 1; + for (auto dim = shape.begin(); dim != shape.end(); dim++) + num_elems *= *dim; + + _x->size(num_elems); + _y->size(num_elems); + + for (uint32_t i = 0; i < num_elems; i++) + { + _x->at(i) = i + 1; + _y->at(i) = i + 1; + } + + _mul->x(_x); + _mul->y(_y); + _mul->name("mul"); + _x->name("x"); + _y->name("y"); + } + + loco::Node *createFoldedPattern() override { return _mul; } + + virtual ~FoldMulTest() = default; + +protected: + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_x = nullptr; + luci::CircleConst *_y = nullptr; +}; + +class FoldF32MulTest : public FoldMulTest, public ::testing::Test +{ +public: + FoldF32MulTest() : FoldMulTest({3}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST_F(FoldF32MulTest, name) +{ + luci::FoldMulPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FoldF32MulTest, fold_mul) +{ + luci::FoldMulPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(1, folded_const->at(0)); + EXPECT_EQ(4, folded_const->at(1)); + EXPECT_EQ(9, folded_const->at(2)); +} + +TEST_F(FoldF32MulTest, input_type_mismatch_NEG) +{ + _x->dtype(loco::DataType::U4); + + luci::FoldMulPass pass; + EXPECT_FALSE(pass.run(graph())); +} diff --git a/compiler/luci/pass/src/FoldReshapePass.cpp b/compiler/luci/pass/src/FoldReshapePass.cpp new file mode 100644 index 000000000..56b19fa69 --- /dev/null +++ b/compiler/luci/pass/src/FoldReshapePass.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldReshapePass.h" + +#include +#include +#include + +namespace +{ + +/** + * Fold Reshape to const if it has const input + **/ +bool fold_reshape(luci::CircleReshape *reshape) +{ + // Check const input + auto const_input = dynamic_cast(reshape->tensor()); + if (not const_input) + return false; + + // Check const shape + auto const_shape = dynamic_cast(reshape->shape()); + if (not const_shape) + return false; + + // Check all dimensions are known + const auto input_rank = const_input->rank(); + for (uint32_t i = 0; i < input_rank; i++) + { + if (not const_input->dim(i).known()) + return false; + } + + // Check all dimensions are known + const auto shape_rank = const_shape->rank(); + if (shape_rank != 1) + return false; + + if (not const_shape->dim(0).known()) + return false; + + std::vector new_shape; + switch (const_shape->dtype()) + { + case loco::DataType::S32: + for (uint32_t i = 0; i < const_shape->size(); i++) + { + const auto val = const_shape->at(i); + if (val < 0) + return false; + + new_shape.push_back(static_cast(val)); + } + break; + // TODO Support S64 + default: + return false; + } + + if (auto input_qparam = const_input->quantparam()) + { + // Only support per-tensor quantization + if (input_qparam->scale.size() != 1) + return false; + + if (input_qparam->zerop.size() != 1) + return false; + } + + auto new_const = luci::clone(const_input); + new_const->rank(new_shape.size()); + for (uint32_t i = 0; i < new_shape.size(); i++) + { + new_const->dim(i).set(new_shape[i]); + } + + new_const->shape_status(luci::ShapeStatus::VALID); + + new_const->name(const_input->name() + "_reshaped"); + luci::add_origin( + new_const, luci::composite_origin({luci::get_origin(reshape), luci::get_origin(const_input)})); + + loco::replace(reshape).with(new_const); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for Reshape Op + **/ +bool FoldReshapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto reshape = dynamic_cast(node)) + { + if (fold_reshape(reshape)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldReshapePass.test.cpp b/compiler/luci/pass/src/FoldReshapePass.test.cpp new file mode 100644 index 000000000..0ae2bce54 --- /dev/null +++ b/compiler/luci/pass/src/FoldReshapePass.test.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldReshapePass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +std::unique_ptr gen_qparam(const std::vector &s, + const std::vector &zp) +{ + auto qparam = std::make_unique(); + { + for (auto scale : s) + qparam->scale.push_back(scale); + + for (auto zerop : zp) + qparam->zerop.push_back(zerop); + } + + return std::move(qparam); +} + +template class FoldReshapeTest : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldReshapeTest(std::initializer_list input_shape, + std::initializer_list output_shape) + : luci::ConstantFoldingAddTestGraph(output_shape, DT) + { + _reshape = _g.nodes()->template create(); + _x = _g.nodes()->template create(); + _shape = _g.nodes()->template create(); + + _reshape->dtype(DT); + _x->dtype(DT); + _shape->dtype(loco::DataType::S32); + + _reshape->shape(_shape); + _x->shape(input_shape); + _shape->shape({static_cast(output_shape.size())}); + + uint32_t num_elems = 1; + for (auto dim : input_shape) + num_elems *= dim; + + _x->size
(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + _x->at
(i) = i; + + _shape->size(output_shape.size()); + uint32_t i = 0; + for (auto dim : output_shape) + { + _shape->at(i++) = static_cast(dim); + } + + if (DT == loco::DataType::S16) + { + _x->quantparam(gen_qparam({1.0}, {0})); + _reshape->quantparam(gen_qparam({1.0}, {0})); + } + + _reshape->tensor(_x); + _reshape->shape(_shape); + + _reshape->name("reshape"); + _shape->name("shape"); + _x->name("x"); + } + + loco::Node *createFoldedPattern() override { return _reshape; } + +public: + void set_unknown_dim() { _x->dim(0).unset(); } + void set_non_per_tensor() { _x->quantparam(gen_qparam({1.0, 2.0}, {0, 0})); } + +protected: + luci::CircleReshape *_reshape = nullptr; + luci::CircleConst *_x = nullptr; + luci::CircleConst *_shape = nullptr; +}; + +/** + * Graph that has a Reshape Op with constant input + * + * BEFORE + * + * [CircleConst] + * | + * [Reshape] + * + * AFTER + * + * [CircleConst] + * + */ +class FoldFP32ReshapeTest : public FoldReshapeTest, public ::testing::Test +{ +public: + FoldFP32ReshapeTest() : FoldReshapeTest({1, 3}, {3}) {} + + virtual void SetUp() { init(); } +}; + +class FoldS16ReshapeTest : public FoldReshapeTest, public ::testing::Test +{ +public: + FoldS16ReshapeTest() : FoldReshapeTest({1, 3}, {3}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST_F(FoldFP32ReshapeTest, fold_reshape_fp32) +{ + luci::FoldReshapePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(0, folded_const->at(0)); + EXPECT_EQ(1, folded_const->at(1)); + EXPECT_EQ(2, folded_const->at(2)); +} + +TEST_F(FoldFP32ReshapeTest, fold_reshape_unkown_dim_NEG) +{ + set_unknown_dim(); + + luci::FoldReshapePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(FoldS16ReshapeTest, fold_reshape_s16) +{ + luci::FoldReshapePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::S16, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(0, folded_const->at(0)); + EXPECT_EQ(1, folded_const->at(1)); + EXPECT_EQ(2, folded_const->at(2)); + + auto qparam = folded_const->quantparam(); + EXPECT_NE(nullptr, qparam); + EXPECT_EQ(1.0, qparam->scale.at(0)); + EXPECT_EQ(0, qparam->zerop.at(0)); +} + +TEST_F(FoldS16ReshapeTest, fold_non_per_tensor_quant_NEG) +{ + set_non_per_tensor(); + + luci::FoldReshapePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} diff --git a/compiler/luci/pass/src/FoldShapePass.cpp b/compiler/luci/pass/src/FoldShapePass.cpp new file mode 100644 index 000000000..bcf538ae3 --- /dev/null +++ b/compiler/luci/pass/src/FoldShapePass.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldShapePass.h" + +#include +#include + +namespace +{ + +template luci::CircleConst *folding(luci::CircleShape *shape) +{ + auto input_node = loco::must_cast(shape->input()); + auto name = input_node->name(); + assert(name.length() > 0); + auto shape_status = input_node->shape_status(); + if (shape_status != luci::ShapeStatus::VALID) + return nullptr; + auto rank = input_node->rank(); + // TODO support rank == 0 when possible + if (rank == 0) + return nullptr; + for (uint32_t i = 0; i < rank; i++) + { + auto dim = input_node->dim(i); + if (!dim.known()) + return nullptr; + } + + auto folded_shape = input_node->graph()->nodes()->create(); + folded_shape->name(name + "_ConstShape"); + folded_shape->dtype(OutType); + folded_shape->rank(1); + folded_shape->dim(0).set(rank); + luci::add_origin(folded_shape, luci::get_origin(shape)); + + folded_shape->size(rank); + for (uint32_t i = 0; i < rank; i++) + folded_shape->at(i) = input_node->dim(i).value(); + + return folded_shape; +} + +// Fold Shape to const if the input shape is static +template bool fold_shape(luci::CircleShape *shape) +{ + auto folded_shape = folding(shape); + if (not folded_shape) + return false; + + loco::replace(shape).with(folded_shape); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * + * [CircleNode] + * | + * [CircleShape] + * | + * [CircleNode] + * + * AFTER + * + * [CircleConst] [CircleNode] + * | + * [CircleNode] + * + */ +bool FoldShapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto shape = dynamic_cast(node)) + { + auto out_type = shape->out_type(); + switch (out_type) + { + case loco::DataType::S32: + if (fold_shape(shape)) + changed = true; + break; + case loco::DataType::S64: + if (fold_shape(shape)) + changed = true; + break; + default: + break; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldShapePass.test.cpp b/compiler/luci/pass/src/FoldShapePass.test.cpp new file mode 100644 index 000000000..cece597cf --- /dev/null +++ b/compiler/luci/pass/src/FoldShapePass.test.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldShapePass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +template class FoldShapeGraph : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldShapeGraph(std::vector input_shape) + : luci::ConstantFoldingAddTestGraph(input_shape, OutType) + { + _x = _g.nodes()->template create(); + _x->name("x"); + _x->dtype(loco::DataType::S32); + _x->rank(input_shape.size()); + for (uint32_t i = 0; i < input_shape.size(); i++) + _x->dim(i).set(input_shape.at(i)); + _x->shape_status(luci::ShapeStatus::VALID); + + _shape = _g.nodes()->template create(); + _shape->name("shape"); + _shape->out_type(OutType); + _shape->input(_x); + _shape->shape({4}); + _shape->rank(1); + _shape->dim(0).set(4); + } + + loco::Node *createFoldedPattern() override { return _shape; } + +protected: + luci::CircleShape *_shape = nullptr; + luci::CircleConst *_x = nullptr; +}; + +/** + * Graph that has a Shape Op + * + * BEFORE + * + * [CircleConst] + * | + * [CircleInput] [CircleShape] + * \ / + * [CircleAdd] + * | + * [CircleOutput] + * + * AFTER + * + * [CircleInput] [CircleConst] + * \ / + * [CircleAdd] + * | + * [CircleOutput] + * + */ +class FoldShapePassGraphTest : public FoldShapeGraph, public ::testing::Test +{ +public: + FoldShapePassGraphTest() : FoldShapeGraph({1, 8, 8, 64}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST(FoldShapePassTest, name) +{ + luci::FoldShapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FoldShapePassGraphTest, fold_shape) +{ + luci::FoldShapePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded shape + EXPECT_EQ(loco::DataType::S32, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(4, folded_const->dim(0).value()); + EXPECT_EQ(1, folded_const->at(0)); + EXPECT_EQ(8, folded_const->at(1)); + EXPECT_EQ(8, folded_const->at(2)); + EXPECT_EQ(64, folded_const->at(3)); +} + +TEST_F(FoldShapePassGraphTest, undefined_shape_NEG) +{ + _x->shape_status(luci::ShapeStatus::UNDEFINED); + + luci::FoldShapePass pass; + EXPECT_FALSE(pass.run(graph())); +} + +TEST_F(FoldShapePassGraphTest, unallowed_rank_NEG) +{ + _x->rank(0); + + luci::FoldShapePass pass; + EXPECT_FALSE(pass.run(graph())); +} + +TEST_F(FoldShapePassGraphTest, unknown_dimension_NEG) +{ + _x->dim(0).unset(); + + luci::FoldShapePass pass; + EXPECT_FALSE(pass.run(graph())); +} diff --git a/compiler/luci/pass/src/FoldSparseToDensePass.cpp b/compiler/luci/pass/src/FoldSparseToDensePass.cpp index ed60d8899..0dbc09a3f 100644 --- a/compiler/luci/pass/src/FoldSparseToDensePass.cpp +++ b/compiler/luci/pass/src/FoldSparseToDensePass.cpp @@ -15,7 +15,6 @@ */ #include "luci/Pass/FoldSparseToDensePass.h" -#include "CircleOptimizerUtils.h" #include diff --git a/compiler/luci/pass/src/FoldSqueezePass.cpp b/compiler/luci/pass/src/FoldSqueezePass.cpp new file mode 100644 index 000000000..1ec2f836a --- /dev/null +++ b/compiler/luci/pass/src/FoldSqueezePass.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldSqueezePass.h" + +#include +#include +#include + +namespace +{ + +/** + * Fold Squeeze to const if it has const input + **/ +bool fold_squeeze(luci::CircleSqueeze *squeeze) +{ + // Check squeeze has const input + auto const_input = dynamic_cast(squeeze->input()); + if (not const_input) + return false; + + // Check all dimensions are known + const auto input_rank = const_input->rank(); + for (uint32_t i = 0; i < input_rank; i++) + { + if (not const_input->dim(i).known()) + return false; + } + + const auto squeeze_dims = squeeze->squeeze_dims(); + uint32_t num_squeeze_dims = squeeze_dims.size(); + std::vector should_squeeze(input_rank, false); + uint32_t num_squeezed_dims = 0; + + // Squeeze all dimensions whose value is 1 + if (num_squeeze_dims == 0) + { + for (uint32_t idx = 0; idx < input_rank; ++idx) + { + if (const_input->dim(idx).value() == 1) + { + should_squeeze.at(idx) = true; + ++num_squeezed_dims; + } + } + } + else + { + for (uint32_t idx = 0; idx < num_squeeze_dims; ++idx) + { + const int32_t current = + squeeze_dims.at(idx) < 0 ? squeeze_dims.at(idx) + input_rank : squeeze_dims.at(idx); + assert(current >= 0); + assert(current < static_cast(input_rank)); + assert(const_input->dim(current).value() == 1); + + if (not should_squeeze[current]) + ++num_squeezed_dims; + should_squeeze[current] = true; + } + } + + auto new_const = luci::clone(const_input); + new_const->rank(input_rank - num_squeezed_dims); + for (uint32_t in_idx = 0, out_idx = 0; in_idx < input_rank; ++in_idx) + { + if (should_squeeze.at(in_idx)) + continue; + + new_const->dim(out_idx++) = const_input->dim(in_idx); + } + + new_const->shape_status(luci::ShapeStatus::VALID); + + new_const->name(const_input->name() + "_squeezed"); + luci::add_origin( + new_const, luci::composite_origin({luci::get_origin(squeeze), luci::get_origin(const_input)})); + + loco::replace(squeeze).with(new_const); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for Squeeze Op + **/ +bool FoldSqueezePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto squeeze = dynamic_cast(node)) + { + if (fold_squeeze(squeeze)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldSqueezePass.test.cpp b/compiler/luci/pass/src/FoldSqueezePass.test.cpp new file mode 100644 index 000000000..b7c6efe68 --- /dev/null +++ b/compiler/luci/pass/src/FoldSqueezePass.test.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldSqueezePass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +template class FoldSqueezeTest : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldSqueezeTest(std::initializer_list input_shape, + std::initializer_list output_shape) + : luci::ConstantFoldingAddTestGraph(output_shape, DT) + { + _squeeze = _g.nodes()->template create(); + _x = _g.nodes()->template create(); + + _squeeze->dtype(DT); + _x->dtype(DT); + + _squeeze->shape(output_shape); + _x->shape(input_shape); + + _squeeze->squeeze_dims({0}); + + uint32_t num_elems = 1; + for (auto dim = input_shape.begin(); dim != input_shape.end(); dim++) + num_elems *= *dim; + + _x->size
(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + _x->at
(i) = i; + + _squeeze->input(_x); + + _squeeze->name("squeeze"); + _x->name("x"); + } + + loco::Node *createFoldedPattern() override { return _squeeze; } + +public: + void set_unknown_dim() { _x->dim(0).unset(); } + +protected: + luci::CircleSqueeze *_squeeze = nullptr; + luci::CircleConst *_x = nullptr; +}; + +/** + * Graph that has a Squeeze Op with constant input + * + * BEFORE + * + * [CircleConst] + * | + * [Squeeze] + * + * AFTER + * + * [CircleConst] + * + */ +class FoldFP32SqueezeTest : public FoldSqueezeTest, public ::testing::Test +{ +public: + FoldFP32SqueezeTest() : FoldSqueezeTest({1, 3}, {3}) {} + + virtual void SetUp() { init(); } +}; + +class FoldS16SqueezeTest : public FoldSqueezeTest, public ::testing::Test +{ +public: + FoldS16SqueezeTest() : FoldSqueezeTest({1, 3}, {3}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST_F(FoldFP32SqueezeTest, fold_squeeze_fp32) +{ + luci::FoldSqueezePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(0, folded_const->at(0)); + EXPECT_EQ(1, folded_const->at(1)); + EXPECT_EQ(2, folded_const->at(2)); +} + +TEST_F(FoldFP32SqueezeTest, fold_squeeze_unkown_dim_NEG) +{ + set_unknown_dim(); + + luci::FoldSqueezePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(FoldS16SqueezeTest, fold_squeeze_s16) +{ + luci::FoldSqueezePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::S16, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(3, folded_const->dim(0).value()); + EXPECT_EQ(0, folded_const->at(0)); + EXPECT_EQ(1, folded_const->at(1)); + EXPECT_EQ(2, folded_const->at(2)); +} diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp index 3494a6e60..21ac7adfb 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/ForwardReshapeToUnaryOpPass.h" +#include "helpers/NodeFiller.h" + #include #include #include @@ -76,6 +78,34 @@ luci::CircleReshape *create_cloned_reshape(luci::CircleReshape *reshape) return new_reshape; } +bool forward_reshape(luci::CircleReshape *reshape, luci::CircleMean *mean, uint32_t axis) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + assert(mean != nullptr); // FIX_CALLER_UNLESS + + auto new_reshape = create_cloned_reshape(reshape); + if (not new_reshape) + return false; + + // reconnect network + loco::replace(mean).with(new_reshape); + mean->input(reshape->tensor()); + new_reshape->tensor(mean); + + // Change const shape axis value + auto *shape_reshape = loco::must_cast(new_reshape->shape()); + assert(shape_reshape->dtype() == loco::DataType::S32); // FIX_CALLER_UNLESS + assert(axis < shape_reshape->size()); // FIX_CALLER_UNLESS + // Mean reduction will make value to '1' + shape_reshape->at(axis) = 1; + + // Do shape inference for this node again. + mean->shape_status(luci::ShapeStatus::UNDEFINED); + reshape->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; +} + bool forward_reshape(luci::CircleReshape *reshape, luci::CircleAbs *abs) { assert(reshape != nullptr); // FIX_CALLER_UNLESS @@ -146,6 +176,64 @@ bool forward_reshape(luci::CircleReshape *reshape, luci::CircleLogistic *logit) return true; } +bool forward_reshape(luci::CircleReshape *reshape, luci::CircleMul *div, + luci::CircleConst *const_value) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + assert(div != nullptr); // FIX_CALLER_UNLESS + + auto new_reshape = create_cloned_reshape(reshape); + if (not new_reshape) + return false; + + // reconnect network + loco::replace(div).with(new_reshape); + if (div->x() == const_value) + { + div->y(reshape->tensor()); + } + else + { + assert(div->y() == const_value); + div->x(reshape->tensor()); + } + new_reshape->tensor(div); + + // Do shape inference for this node again. + div->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; +} + +bool forward_reshape(luci::CircleReshape *reshape, luci::CircleDiv *div, + luci::CircleConst *const_value) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + assert(div != nullptr); // FIX_CALLER_UNLESS + + auto new_reshape = create_cloned_reshape(reshape); + if (not new_reshape) + return false; + + // reconnect network + loco::replace(div).with(new_reshape); + if (div->x() == const_value) + { + div->y(reshape->tensor()); + } + else + { + assert(div->y() == const_value); + div->x(reshape->tensor()); + } + new_reshape->tensor(div); + + // Do shape inference for this node again. + div->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; +} + class ForwardReshape final : public luci::CircleNodeMutableVisitor { protected: @@ -156,6 +244,81 @@ protected: return false; } + /** + * Graph example: + * + * BEFORE + * [Input] + * (3, 4, 4) [Shape_Const = (1, -1, 4)] + * | | + * [Reshape] ---------------- + * (1, 12, 4) + * | + * [Mean, keep_dims = true] + * (1, 12, 1) + * | + * [Output] + * + * AFTER + * [Input] + * (3, 4, 4) + * | + * [Mean, keep_dims = true] + * (3, 4, 1) [Shape_Const = (1, -1, 1)] + * | | + * [Reshape]----------------- + * (1, 12, 1) + * | + * [Output] + * + */ + bool visit(luci::CircleMean *node) + { + luci::CircleReshape *reshape = nullptr; + luci::CircleConst *axis = nullptr; + + reshape = dynamic_cast(node->input()); + axis = dynamic_cast(node->reduction_indices()); + + if (reshape == nullptr or axis == nullptr) + return false; + + if (axis->dtype() != loco::DataType::S32) + return false; + + // Should be scalar + if (axis->size() != 1) + return false; + + // axis value + auto axis_value = axis->at(0); + + if (axis_value < 0) + axis_value += static_cast(reshape->rank()); + + assert(axis_value >= 0); + + if (node->keep_dims() != true) + return false; + + auto reshape_input = loco::must_cast(reshape->tensor()); + + // reshape shouldn't change rank + if (reshape_input->rank() != reshape->rank()) + return false; + + assert(reshape_input->rank() > static_cast(axis_value)); + + for (int32_t i = 0; i <= axis_value; ++i) + { + if (not reshape_input->dim(i).known() or + reshape_input->dim(i).value() != reshape->dim(i).value()) + return false; + } + + return forward_reshape(reshape, node, axis_value); + } + bool visit(luci::CircleAbs *node) { auto reshape = as_reshape(node->x()); @@ -180,6 +343,43 @@ protected: return forward_reshape(reshape, node); } + + bool visit(luci::CircleDiv *node) + { + luci::CircleReshape *reshape = nullptr; + luci::CircleConst *const_value = nullptr; + + if (not luci::fill(&reshape, &const_value).with_commutative_args_of(node)) + return false; + + if (const_value->dtype() != loco::DataType::FLOAT32) + return false; + + // Should be scalar + if (const_value->size() != 1) + return false; + + return forward_reshape(reshape, node, const_value); + } + + bool visit(luci::CircleMul *node) + { + luci::CircleReshape *reshape = nullptr; + luci::CircleConst *const_value = nullptr; + + if (not luci::fill(&reshape, &const_value).with_commutative_args_of(node)) + return false; + + if (const_value->dtype() != loco::DataType::FLOAT32) + return false; + + // Should be scalar + if (const_value->size() != 1) + return false; + + return forward_reshape(reshape, node, const_value); + } + // TODO add more unary operators }; @@ -201,6 +401,11 @@ namespace luci * | | | * * UnaryOp: CircleNeg, ... + * Note: Binary Op (Div, Mul) can also be considered as a unary operation + * if one of its inputs is a constant. + * For CircleMean in which the axis is a scalar + * constant and reshape Op does not change the axis on which the mean is + * taken, the Reshape Op can be forwarded. * * AFTER * | diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp index 373513270..ae89e4fad 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp @@ -173,6 +173,114 @@ protected: luci::ForwardReshapeToUnaryOpPass _pass; }; +/** + * Simple graph for test + * + * BEFORE + * [Input] + * (3, 4, 4) [Shape_Const = (1, -1, 4)] + * | | + * [Reshape] ---------------- + * (1, 12, 4) + * | + * [Mean, keep_dims = true] + * (1, 12, 1) + * | + * [Output] + * + * AFTER + * [Input] + * (3, 4, 4) + * | + * [Mean, keep_dims = true] + * (3, 4, 1) [Shape_Const = (1, -1, 1)] + * | | + * [Reshape]----------------- + * (1, 12, 1) + * | + * [Output] + * + */ +class PatternReshapeMeanGraphlet +{ +public: + PatternReshapeMeanGraphlet() = default; + + void init(loco::Graph *g) + { + _mean = g->nodes()->create(); + _mean_const = g->nodes()->create(); + _reshape = g->nodes()->create(); + _reshape_const = g->nodes()->create(); + + _mean->name("_mean"); + _mean_const->name("_mean_const"); + _reshape->name("_reshape"); + _reshape_const->name("_reshape_const"); + } + +public: + luci::CircleMean *mean() { return _mean; } + luci::CircleConst *mean_const() { return _mean_const; } + luci::CircleReshape *reshape() { return _reshape; } + luci::CircleConst *reshape_const() { return _reshape_const; } + +protected: + luci::CircleMean *_mean = nullptr; + luci::CircleConst *_mean_const = nullptr; + luci::CircleReshape *_reshape = nullptr; + luci::CircleConst *_reshape_const = nullptr; +}; + +class ForwardReshapeToMeanPatternTestGraph : public TestIOGraph, public PatternReshapeMeanGraphlet +{ +public: + ForwardReshapeToMeanPatternTestGraph() = default; + + void init(void) + { + TestIOGraph::init({3, 4, 4}, {3, 4, 4}); + PatternReshapeMeanGraphlet::init(g()); + + _reshape_const->rank(1); + _reshape_const->dtype(loco::DataType::S32); + _reshape_const->size(3); + _reshape_const->at(0) = 1; + _reshape_const->at(1) = -1; + _reshape_const->at(2) = 4; + _reshape_const->shape_status(luci::ShapeStatus::VALID); + + _reshape->rank(3); + _reshape->dim(0).set(3); + _reshape->dim(1).set(4); + _reshape->dim(2).set(4); + _reshape->dtype(loco::DataType::FLOAT32); + _reshape->shape_status(luci::ShapeStatus::VALID); + _reshape->tensor(input()); + _reshape->shape(_reshape_const); + + _mean_const->rank(1); + _mean_const->dtype(loco::DataType::S32); + _mean_const->size(1); + _mean_const->at(0) = -1; + _mean_const->shape_status(luci::ShapeStatus::VALID); + + _mean->rank(3); + _mean->dim(0).set(1); + _mean->dim(1).set(12); + _mean->dim(2).set(1); + _mean->dtype(loco::DataType::FLOAT32); + _mean->shape_status(luci::ShapeStatus::VALID); + _mean->input(_reshape); + _mean->reduction_indices(_mean_const); + _mean->keep_dims(true); + + output()->from(_mean); + } + + void invalid_type() { _mean_const->dtype(loco::DataType::FLOAT32); } +}; + } // namespace TEST(ForwardReshapeToUnaryOpPassTest, name) @@ -209,3 +317,25 @@ TEST_F(ForwardReshapeToLogisticGraphTest, forward) log = dynamic_cast(reshape->tensor()); ASSERT_NE(nullptr, log); } + +TEST(FuseMulWithDivPassTest, forward_reshape_to_mean_pattern) +{ + ForwardReshapeToMeanPatternTestGraph g; + luci::ForwardReshapeToUnaryOpPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseMulWithDivPassTest, forward_reshape_to_mean_pattern_NEG) +{ + ForwardReshapeToMeanPatternTestGraph g; + luci::ForwardReshapeToUnaryOpPass pass; + + g.init(); + + g.invalid_type(); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp index c76d73344..b9f7ae5a8 100644 --- a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp @@ -164,6 +164,7 @@ private: { if (auto y = dynamic_cast(node->y())) { + RETURN_FALSE_UNLESS(node->rank() == y->rank()); RETURN_FALSE_UNLESS(check_rank_four(x)); RETURN_FALSE_UNLESS(check_perm(y)); @@ -190,6 +191,7 @@ private: { if (auto x = dynamic_cast(node->x())) { + RETURN_FALSE_UNLESS(node->rank() == x->rank()); RETURN_FALSE_UNLESS(check_rank_four(y)); RETURN_FALSE_UNLESS(check_perm(x)); @@ -289,6 +291,8 @@ public: bool visit(luci::CircleNode *) { return false; } bool visit(luci::CircleAbs *node) { return has_pattern_x(node); } + + bool visit(luci::CircleLogistic *node) { return has_pattern_x(node); } }; } // namespace diff --git a/compiler/luci/pass/src/FuseAddToFullyConnectedBiasPass.cpp b/compiler/luci/pass/src/FuseAddToFullyConnectedBiasPass.cpp new file mode 100644 index 000000000..297a345c0 --- /dev/null +++ b/compiler/luci/pass/src/FuseAddToFullyConnectedBiasPass.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseAddToFullyConnectedBiasPass.h" + +#include +#include + +#include "helpers/NodeFiller.h" + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +namespace +{ + +/** + * Fuse Add to following FullyConnected bias if possible + * + * BEFORE + * | + * [CircleAdd] [CircleConst] [CircleConst] + * | | | + * [CircleFullyConnected] ----------+ + * | + * + * AFTER + * | + * | [CircleConst] [CircleConst] [CircleConst] + * | | | | + * | [CircleConst] [CircleFullyConnected] [CircleAdd] + * | | | + * [CircleFullyConnected] ------+ + * | + * + */ +bool fuse_add_to_fc_bias(luci::CircleFullyConnected *fc) +{ + CHECK_OR_FALSE(fc); + + // check input is Add + auto add = dynamic_cast(fc->input()); + CHECK_OR_FALSE(add); + // conditions of Add, FC: to expect constant folding, support only F32 + CHECK_OR_FALSE(add->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(add->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(fc->dtype() == loco::DataType::FLOAT32); + // support weight with constant + auto weights = dynamic_cast(fc->weights()); + CHECK_OR_FALSE(weights); + // bias can be constant or outputexclude + auto bias = dynamic_cast(fc->bias()); + CHECK_OR_FALSE(bias); + + // Check addition of Add is constant + luci::CircleNode *add_input = nullptr; + luci::CircleConst *add_shift = nullptr; + CHECK_OR_FALSE(luci::fill(&add_input, &add_shift).with_commutative_args_of(add)); + // support only 1D constant + CHECK_OR_FALSE(add_shift->rank() == 1); + + auto graph = fc->graph(); + + auto fc_bias = graph->nodes()->create(); + fc_bias->input(add_shift); + fc_bias->weights(weights); + fc_bias->bias(bias); + fc_bias->keep_num_dims(true); + fc_bias->fusedActivationFunction(luci::FusedActFunc::NONE); + fc_bias->name(fc->name() + "_" + add->name() + "_bias"); + luci::add_origin(fc_bias, + luci::composite_origin( + {luci::get_origin(add), luci::get_origin(add_shift), luci::get_origin(bias)})); + + auto fc_new = graph->nodes()->create(); + fc_new->input(add_input); + fc_new->weights(weights); + fc_new->bias(fc_bias); + fc_new->weights_format(fc->weights_format()); + fc_new->keep_num_dims(fc->keep_num_dims()); + fc_new->fusedActivationFunction(fc->fusedActivationFunction()); + fc_new->name(fc->name()); + luci::add_origin(fc_new, luci::get_origin(fc)); + + replace(fc).with(fc_new); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseAddToFullyConnectedBiasPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast(node); + if (not fc) + continue; + + if (fuse_add_to_fc_bias(fc)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseAddToFullyConnectedBiasPass.test.cpp b/compiler/luci/pass/src/FuseAddToFullyConnectedBiasPass.test.cpp new file mode 100644 index 000000000..445ba4b0b --- /dev/null +++ b/compiler/luci/pass/src/FuseAddToFullyConnectedBiasPass.test.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseAddToFullyConnectedBiasPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +template class FuseAddToFullyConnectedBiasPassTestGraph : public TestIOGraph +{ +public: + FuseAddToFullyConnectedBiasPassTestGraph() = default; + + void init(void) + { + TestIOGraph::init({3, 4}, {3, 6}); + + _add = g()->nodes()->create(); + _add_s = g()->nodes()->create(); + _fc = g()->nodes()->create(); + _fc_w = g()->nodes()->create(); + _fc_b = g()->nodes()->create(); + + _add->name("add"); + _add_s->name("add_s"); + _fc->name("fc"); + _fc_w->name("fc_w"); + _fc_b->name("fc_b"); + + _add->dtype(DT); + _fc->dtype(DT); + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + + _add_s->rank(1); + _add_s->dim(0) = 3; + _add_s->dtype(DT); + _add_s->size
(3); + for (uint32_t i = 0; i < 3; ++i) + { + _add_s->at
(0) = 1.0f; + } + + _fc_w->rank(2); + _fc_w->dim(0) = 3; + _fc_w->dim(1) = 4; + _fc_w->dtype(DT); + _fc_w->size
(4 * 6); + for (uint32_t i = 0; i < 4 * 6; ++i) + { + _fc_w->at
(0) = 1.0f; + } + + _fc_b->rank(1); + _fc_b->dim(0) = 6; + _fc_b->dtype(DT); + _fc_b->size
(6); + for (uint32_t i = 0; i < 6; ++i) + { + _fc_b->at
(0) = 1.0f; + } + + _add->x(input()); + _add->y(_add_s); + _fc->input(_add); + _fc->weights(_fc_b); + _fc->bias(_fc_b); + + output()->from(_fc); + } + + luci::CircleAdd *_add = nullptr; + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleConst *_add_s = nullptr; + luci::CircleConst *_fc_w = nullptr; + luci::CircleConst *_fc_b = nullptr; +}; + +class FuseAddToFullyConnectedBiasPassTest : public ::testing::Test +{ +public: + FuseAddToFullyConnectedBiasPassTest() = default; + +protected: + FuseAddToFullyConnectedBiasPassTestGraph _graph; + luci::FuseAddToFullyConnectedBiasPass _pass; +}; + +class FuseAddToFullyConnectedBiasPassS32Test : public ::testing::Test +{ +public: + FuseAddToFullyConnectedBiasPassS32Test() = default; + +protected: + FuseAddToFullyConnectedBiasPassTestGraph _graph; + luci::FuseAddToFullyConnectedBiasPass _pass; +}; + +} // namespace + +TEST_F(FuseAddToFullyConnectedBiasPassTest, name) +{ + auto const name = _pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FuseAddToFullyConnectedBiasPassTest, fuse_add_to_fc_bias) +{ + _graph.init(); + + EXPECT_TRUE(_pass.run(_graph.g())); +} + +TEST_F(FuseAddToFullyConnectedBiasPassTest, add_fused_act_NEG) +{ + _graph.init(); + + _graph._add->fusedActivationFunction(luci::FusedActFunc::RELU); + + EXPECT_FALSE(_pass.run(_graph.g())); +} + +TEST_F(FuseAddToFullyConnectedBiasPassTest, add_d2_NEG) +{ + _graph.init(); + + _graph._add_s->rank(2); + _graph._add_s->dim(0) = 1; + _graph._add_s->dim(1) = 3; + + EXPECT_FALSE(_pass.run(_graph.g())); +} + +TEST_F(FuseAddToFullyConnectedBiasPassS32Test, dtype_s32_NEG) +{ + _graph.init(); + + EXPECT_FALSE(_pass.run(_graph.g())); +} diff --git a/compiler/luci/pass/src/FuseAddWithConvPass.cpp b/compiler/luci/pass/src/FuseAddWithConvPass.cpp new file mode 100644 index 000000000..f6c1e574c --- /dev/null +++ b/compiler/luci/pass/src/FuseAddWithConvPass.cpp @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseAddWithConvPass.h" + +#include "helpers/NodeFiller.h" + +#include +#include + +namespace +{ +/** + * Fuse Add to Conv2D if possible. + * + * BEFORE + * | [CircleConst] + * | / [CircleConst] + * | / / + * [CircleConv2D] [CircleConst] + * | / + * [CircleAdd] + * | + * + * AFTER + * | [CircleConst] + * +--------------+ / [CircleConst] + * | | / / + * | [CircleConv2D] [CircleConst] + * [CircleConst] | | / + * [CircleConst] \ | [CircleAdd] + * \ \ | + * [CircleConv2D] + * | + */ +bool fused_add_with_conv(luci::CircleAdd *add) +{ + // find the pattern of CircleAdd(CircleConv2D, CircleConst) + luci::CircleConst *shift = nullptr; + luci::CircleConv2D *conv2d = nullptr; + if (not luci::fill(&conv2d, &shift).with_commutative_args_of(add)) + return false; + + // check conditions for conv2d + if (conv2d->rank() != 4) + return false; + if (conv2d->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleConst *filter = dynamic_cast(conv2d->filter()); + luci::CircleConst *bias = dynamic_cast(conv2d->bias()); + luci::CircleOutputExclude *biasex = dynamic_cast(conv2d->bias()); + + // filter should exist, bias should be const or none(output exclude) + if (filter == nullptr || (bias == nullptr && biasex == nullptr)) + return false; + if (filter->rank() != 4) + return false; + if (filter->dtype() != shift->dtype()) + return false; + // TODO support more data type + if (filter->dtype() != loco::DataType::FLOAT32) + return false; + + // filter is OHWI + uint32_t out_channel = filter->dim(0).value(); + + // shape of shift should be [1, 1, 1, out_channel] or [out_channel] + if (shift->rank() == 4) + { + for (uint32_t i = 0; i < 3; ++i) + if (shift->dim(i).value() != 1) + return false; + if (shift->dim(3).value() != out_channel) + return false; + } + else if (shift->rank() == 1) + { + if (shift->dim(0).value() != out_channel) + return false; + } + else + return false; + + auto conv2d_name = conv2d->name(); + auto shift_name = shift->name(); + assert(conv2d_name.length() > 0); + assert(shift_name.length() > 0); + auto bias_name = (bias ? bias->name() : conv2d_name) + ";" + shift_name; + + luci::CircleConv2D *fused_conv2d = add->graph()->nodes()->create(); + luci::CircleConst *fused_bias = add->graph()->nodes()->create(); + + fused_bias->dtype(conv2d->dtype()); + fused_bias->rank(1); + fused_bias->dim(0).set(out_channel); + fused_bias->shape_status(luci::ShapeStatus::VALID); + fused_bias->name(bias_name); + fused_bias->size(out_channel); + // fuse shift to bias + for (uint32_t b = 0; b < out_channel; ++b) + { + auto bias_val = shift->at(b); + if (bias) + bias_val += bias->at(b); + fused_bias->at(b) = bias_val; + } + + // Set attributes of fused_conv2d + fused_conv2d->input(conv2d->input()); + fused_conv2d->filter(conv2d->filter()); + fused_conv2d->bias(fused_bias); + fused_conv2d->fusedActivationFunction(add->fusedActivationFunction()); + fused_conv2d->padding(conv2d->padding()); + fused_conv2d->stride()->h(conv2d->stride()->h()); + fused_conv2d->stride()->w(conv2d->stride()->w()); + fused_conv2d->dilation()->h(conv2d->dilation()->h()); + fused_conv2d->dilation()->w(conv2d->dilation()->w()); + fused_conv2d->name(conv2d_name); + luci::add_origin(fused_conv2d, + luci::composite_origin({luci::get_origin(add), luci::get_origin(conv2d)})); + + replace(add).with(fused_conv2d); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseAddWithConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto add = dynamic_cast(node)) + { + if (fused_add_with_conv(add)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseAddWithConvPass.test.cpp b/compiler/luci/pass/src/FuseAddWithConvPass.test.cpp new file mode 100644 index 000000000..c64b457fa --- /dev/null +++ b/compiler/luci/pass/src/FuseAddWithConvPass.test.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseAddWithConvPass.h" + +#include "helpers/CreateCircleConst.h" + +#include +#include + +#include + +using namespace luci::test; + +namespace +{ + +#define FILTER_O 4 +#define FILTER_H 1 +#define FILTER_W 1 +#define FILTER_I 6 + +class Conv2DAddGraphlet +{ +public: + Conv2DAddGraphlet() = default; + + void init(loco::Graph *g) + { + const ShapeU32 filter_shape = {FILTER_O, FILTER_H, FILTER_W, FILTER_I}; + const ShapeU32 bias_shape = {FILTER_O}; + + _conv_f = luci::create_const_node(g, loco::DataType::FLOAT32, filter_shape, 0.5f); + _conv_b = luci::create_const_node(g, loco::DataType::FLOAT32, bias_shape, 0.5f); + _conv_f->name("conv_f"); + _conv_b->name("conv_b"); + + _conv = g->nodes()->create(); + _conv->filter(_conv_f); + _conv->bias(_conv_b); + _conv->fusedActivationFunction(luci::FusedActFunc::NONE); + _conv->dtype(loco::DataType::FLOAT32); + _conv->shape({1, 3, 3, FILTER_O}); + _conv->name("conv"); + + const ShapeU32 add_shape = {1, 1, 1, FILTER_O}; + _add_y = luci::create_const_node(g, loco::DataType::FLOAT32, add_shape, 0.5f); + _add_y->name("add_y"); + + _add = g->nodes()->create(); + _add->x(_conv); + _add->y(_add_y); + _add->fusedActivationFunction(luci::FusedActFunc::RELU); + _add->dtype(loco::DataType::FLOAT32); + _add->shape({1, 3, 3, FILTER_O}); + _add->name("add"); + + // for negative test + const ShapeU32 add_shape_2 = {FILTER_O, FILTER_I}; + _add_y_2 = luci::create_const_node(g, loco::DataType::FLOAT32, add_shape_2, 0.5f); + _add_y_2->name("add_y_2"); + } + +protected: + luci::CircleConv2D *_conv = nullptr; + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_conv_f = nullptr; + luci::CircleConst *_conv_b = nullptr; + luci::CircleConst *_add_y = nullptr; + luci::CircleConst *_add_y_2 = nullptr; +}; + +class FuseAddWithConvTestGraph : public TestIOGraph, public Conv2DAddGraphlet +{ +public: + FuseAddWithConvTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 3, 3, FILTER_I}, {1, 3, 3, FILTER_O}); + Conv2DAddGraphlet::init(g()); + + _conv->input(input()); + output()->from(_add); + } + + void add_use_2() + { + // set to not compatible shape + _add->y(_add_y_2); + } +}; + +class FuseAddWithConvPassTest : public ::testing::Test, public FuseAddWithConvTestGraph +{ +public: + luci::FuseAddWithConvPass pass; +}; + +} // namespace + +TEST_F(FuseAddWithConvPassTest, simple_test) +{ + init(); + + // Add should exist + auto add = dynamic_cast(output()->from()); + EXPECT_NE(nullptr, add); + + EXPECT_TRUE(pass.run(g())); + + // expect Add is fused into Conv + auto conv = dynamic_cast(output()->from()); + EXPECT_NE(nullptr, conv); +} + +TEST_F(FuseAddWithConvPassTest, wrong_add_shape_NEG) +{ + init(); + add_use_2(); + + // Add const shape is not compatible + EXPECT_FALSE(pass.run(g())); +} diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp index 1d4a2e3bf..45b7ef648 100644 --- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp @@ -20,6 +20,8 @@ #include #include +#include + namespace { /** @@ -124,6 +126,184 @@ bool fuse_add_with_fc(luci::CircleFullyConnected *fc) return true; } +// Return qparam if it exists and its scale/zp's size is the same with len +// Return nullptr otherwise +luci::CircleQuantParam *get_qparam(luci::CircleNode *node, uint32_t len) +{ + if (node->quantparam() == nullptr) + return nullptr; + + if (node->quantparam()->scale.size() != len) + return nullptr; + + if (node->quantparam()->zerop.size() != len) + return nullptr; + + return node->quantparam(); +} + +bool fuse_add_with_s16_fc(luci::CircleFullyConnected *fc) +{ + assert(fc); // FIX_CALLER_UNLESS + assert(fc->dtype() == loco::DataType::S16); // FIX_CALLER_UNLESS + + if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + auto weights = dynamic_cast(fc->weights()); + if (not weights) + return false; + + auto fc_output = loco::succs(fc); + // Fuse only when FC has a single successor (to avoid weight increase) + if (fc_output.size() != 1) + return false; + + auto add = dynamic_cast(*fc_output.begin()); + if (not add) + return false; + + // Only support the same dtype with fc + if (add->dtype() != loco::DataType::S16) + return false; + + // Get addition + auto addition = add->x() == fc ? dynamic_cast(add->y()) + : dynamic_cast(add->x()); + + // Non-const addition + if (not addition) + return false; + + // Check addition dtype + if (addition->dtype() != loco::DataType::S16) + return false; + + auto rank = addition->rank(); + // TODO Support scalar addition + if (rank == 0) + return false; + + for (uint32_t i = 0; i < rank - 1; i++) + { + if (addition->dim(i).value() != 1) + return false; + } + + // Check the last dim of addition is the same with the output dim of weight + const auto last_dim = addition->dim(rank - 1).value(); + if (last_dim != weights->dim(0).value()) + return false; + + auto bias = loco::must_cast(fc->bias()); + + // Only support (1) constant bias, or (2) no bias + if (bias->opcode() != luci::CircleOpcode::CIRCLECONST and + bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + return false; + + // If bias is const, its dtype must be s64 + if (bias->opcode() == luci::CircleOpcode::CIRCLECONST and bias->dtype() != loco::DataType::S64) + return false; + + const auto addition_qparam = get_qparam(addition, last_dim); + if (addition_qparam == nullptr) + return false; + + std::vector fp32_bias(last_dim); + for (uint32_t i = 0; i < last_dim; i++) + { + auto scale = addition_qparam->scale.at(i); + if (addition_qparam->zerop.at(i) != 0) + return false; // FIX_ME_UNLESS + + auto val = addition->at(i); + fp32_bias[i] = val * scale; + } + + // Add existing bias values + if (auto const_bias = dynamic_cast(bias)) + { + const auto bias_qparam = get_qparam(const_bias, last_dim); + if (bias_qparam == nullptr) + return false; + + for (uint32_t i = 0; i < last_dim; i++) + { + auto scale = bias_qparam->scale.at(i); + if (bias_qparam->zerop.at(i) != 0) + return false; // FIX_ME_UNLESS + + auto val = const_bias->at(i); + fp32_bias[i] += val * scale; + } + } + + const auto add_qparam = get_qparam(add, 1); + if (add_qparam == nullptr) + return false; + + auto input = loco::must_cast(fc->input()); + const auto input_qparam = get_qparam(input, 1); + if (input_qparam == nullptr) + return false; + + const auto weights_qparam = get_qparam(weights, last_dim); + if (weights_qparam == nullptr) + return false; + + auto fused_bias = luci::clone(addition); + fused_bias->dtype(loco::DataType::S64); + fused_bias->size(last_dim); + + // The shape is normalized to [N] to become the bias of FC + fused_bias->rank(1); + fused_bias->dim(0) = last_dim; + + std::vector new_bias_scale; + for (uint32_t i = 0; i < last_dim; i++) + { + const auto input_scale = input_qparam->scale.at(0); + const auto weight_scale = weights_qparam->scale.at(i); + + const float scale = input_scale * weight_scale; + const float scale_inv = (scale == 0) ? 0 : 1.0 / scale; + + fused_bias->at(i) = + static_cast(std::round(fp32_bias.at(i) * scale_inv)); + + new_bias_scale.push_back(scale); + } + std::vector new_bias_zerop(new_bias_scale.size(), 0); + + auto bias_qparam = std::make_unique(); + { + bias_qparam->scale = new_bias_scale; + bias_qparam->zerop = new_bias_zerop; + } + + fused_bias->quantparam(std::move(bias_qparam)); + + // In-place update. This works because fc is guaranteed to have a single successor + fc->bias(fused_bias); + fc->fusedActivationFunction(add->fusedActivationFunction()); + + auto qparam = std::make_unique(); + { + qparam->scale.push_back(add_qparam->scale.at(0)); + qparam->zerop.push_back(add_qparam->scale.at(0)); + } + + fc->quantparam(std::move(qparam)); + + // set origin + luci::add_origin(fc, luci::get_origin(add)); + + replace(add).with(fc); + + return true; +} + } // namespace namespace luci @@ -138,8 +318,19 @@ bool FuseAddWithFullyConnectedPass::run(loco::Graph *g) if (not fc) continue; - if (fuse_add_with_fc(fc)) - changed = true; + switch (fc->dtype()) + { + case loco::DataType::FLOAT32: + if (fuse_add_with_fc(fc)) + changed = true; + break; + case loco::DataType::S16: + if (fuse_add_with_s16_fc(fc)) + changed = true; + break; + default: + break; + } } return changed; diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp index b132c6bd9..c96846f54 100644 --- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp @@ -121,6 +121,142 @@ public: luci::FuseAddWithFullyConnectedPass pass; }; +std::unique_ptr gen_qparam(const std::vector &s, + const std::vector &zp) +{ + auto qparam = std::make_unique(); + { + for (auto scale : s) + qparam->scale.push_back(scale); + + for (auto zerop : zp) + qparam->zerop.push_back(zerop); + } + + return std::move(qparam); +} + +/** + * Simple graph for test + * + * BEFORE + * + * [FC] + * | + * [Add w/ Relu] + * + * AFTER + * + * [FC w/ Relu] (bias updated) + * + */ +class S16FCAddGraphlet +{ +public: + void init(loco::Graph *g) + { + std::vector weights_val(16 * 4); + _fc_f = luci::create_const_node(g, loco::DataType::S16, {16, 4}, weights_val); + { + auto qparam = std::make_unique(); + { + for (uint32_t i = 0; i < 16; i++) + { + qparam->scale.push_back(1.0); + qparam->zerop.push_back(0); + } + } + _fc_f->quantparam(std::move(qparam)); + } + + std::vector bias_val(16); + for (uint32_t i = 0; i < 16; i++) + bias_val.at(i) = i; + + _fc_b = luci::create_const_node(g, loco::DataType::S64, {1, 16}, bias_val); + { + std::vector scale(16, 1.0); + std::vector zerop(16, 0); + + auto qparam = gen_qparam(scale, zerop); + _fc_b->quantparam(std::move(qparam)); + } + + _fc = g->nodes()->create(); + _fc->weights(_fc_f); + _fc->bias(_fc_b); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->dtype(loco::DataType::S16); + _fc->shape({1, 16}); + _fc->name("fc"); + + std::vector addition_val; + for (uint32_t i = 0; i < 16; i++) + addition_val.push_back(static_cast(i)); + + _add_c = luci::create_const_node(g, loco::DataType::S16, {1, 16}, addition_val); + { + std::vector scale(16, 1.0); + std::vector zerop(16, 0); + + auto qparam = gen_qparam(scale, zerop); + _add_c->quantparam(std::move(qparam)); + } + + _add = g->nodes()->create(); + { + auto qparam = gen_qparam({2.0}, {0}); + _add->quantparam(std::move(qparam)); + } + + _add->x(_fc); + _add->y(_add_c); + _add->fusedActivationFunction(luci::FusedActFunc::RELU); + _add->dtype(loco::DataType::S16); + _add->shape({1, 16}); + _add->name("add"); + } + +public: + luci::CircleFullyConnected *fc() { return _fc; } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_fc_f = nullptr; + luci::CircleConst *_fc_b = nullptr; + luci::CircleConst *_add_c = nullptr; +}; + +class S16FuseAddWithFCTestGraph : public TestIOGraph, public S16FCAddGraphlet +{ +public: + void init(void) + { + TestIOGraph::init({1, 4}, {1, 16}); + input()->dtype(loco::DataType::S16); + { + auto qparam = gen_qparam({1.0}, {0}); + input()->quantparam(std::move(qparam)); + } + + output()->dtype(loco::DataType::S16); + + S16FCAddGraphlet::init(g()); + + _fc->input(input()); + + output()->from(_add); + } +}; + +class S16FuseAddWithFullyConnectedPassTest : public ::testing::Test +{ +public: + S16FuseAddWithFCTestGraph g; + luci::FuseAddWithFullyConnectedPass pass; +}; + } // namespace TEST_F(FuseAddWithFullyConnectedPassTest, simple_test) @@ -150,3 +286,39 @@ TEST_F(FuseAddWithFullyConnectedPassTest, fm_bias_NEG) auto ret = pass.run(g.g()); EXPECT_EQ(false, ret); } + +TEST_F(S16FuseAddWithFullyConnectedPassTest, fuse_s16) +{ + g.init(); + + auto ret = pass.run(g.g()); + EXPECT_EQ(true, ret); + + auto fc = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, fc); + EXPECT_NE(nullptr, fc->quantparam()); + EXPECT_EQ(1, fc->quantparam()->scale.size()); + EXPECT_EQ(2.0, fc->quantparam()->scale.at(0)); + EXPECT_EQ(luci::FusedActFunc::RELU, fc->fusedActivationFunction()); + + auto bias = loco::must_cast(g.fc()->bias()); + EXPECT_EQ(loco::DataType::S64, bias->dtype()); + for (uint32_t i = 0; i < bias->size(); i++) + { + EXPECT_EQ(2 * i, bias->at(i)); + } + + auto qparam = bias->quantparam(); + EXPECT_NE(nullptr, qparam); + EXPECT_EQ(1.0, qparam->scale.at(0)); + EXPECT_EQ(0, qparam->zerop.at(0)); +} + +TEST_F(S16FuseAddWithFullyConnectedPassTest, fc_with_null_weights_NEG) +{ + g.init(); + g.fc()->weights(nullptr); + + auto ret = pass.run(g.g()); + EXPECT_EQ(false, ret); +} diff --git a/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp new file mode 100644 index 000000000..3aa37256a --- /dev/null +++ b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseHorizontalFullyConnectedPass.h" + +#include +#include +#include + +namespace luci +{ + +namespace +{ + +bool check_type_and_shape_equality(const CircleNode *left, const CircleNode *right) +{ + if (left->dtype() != right->dtype()) + return false; + + if (left->rank() != right->rank()) + return false; + + for (uint32_t i = 0; i < left->rank(); ++i) + { + if (left->dim(i).value() != right->dim(i).value()) + return false; + } + + return true; +} + +// Add right const to left const (left is updated) +template void sum_const_values(CircleConst *left, const CircleConst *right) +{ + assert(check_type_and_shape_equality(left, right)); // FIX CALLER UNLESS + const auto size = left->template size(); + + for (uint32_t i = 0; i < size; ++i) + { + left->template at(i) += right->template at(i); + } +} + +bool fuse_horizontal_fc_nodes(CircleAdd *add_node) +{ + // Let's check left and right FC nodes + auto left_fc_node = dynamic_cast(add_node->x()); + auto right_fc_node = dynamic_cast(add_node->y()); + + if (left_fc_node == nullptr or right_fc_node == nullptr) + return false; + + if (not check_type_and_shape_equality(left_fc_node, right_fc_node)) + return false; + + if (left_fc_node->fusedActivationFunction() != FusedActFunc::NONE) + return false; + + if (right_fc_node->fusedActivationFunction() != FusedActFunc::NONE) + return false; + + // Let's check that FC nodes have the same input + if (left_fc_node->input() != right_fc_node->input()) + return false; + + // Lets check left and right FC weights: type and shape + auto left_fc_weights = dynamic_cast(left_fc_node->weights()); + auto right_fc_weights = dynamic_cast(right_fc_node->weights()); + + if (left_fc_weights == nullptr or right_fc_weights == nullptr) + return false; + + if (not check_type_and_shape_equality(left_fc_weights, right_fc_weights)) + return false; + + // Lets check left and right FC bias: type and shape + auto left_fc_bias = dynamic_cast(left_fc_node->bias()); + auto right_fc_bias = dynamic_cast(right_fc_node->bias()); + + // Support only if both biases are const, or both are non-const + // TODO Support the case that one FC has a const bias and another FC has no bias. + if ((left_fc_bias == nullptr and right_fc_bias != nullptr) or + (left_fc_bias != nullptr and right_fc_bias == nullptr)) + { + return false; + } + + // Both left/right bias are const. Check dtype/shape. + if (left_fc_bias != nullptr and not check_type_and_shape_equality(left_fc_bias, right_fc_bias)) + return false; + + // Both left/right bias are non-const. Check left/right fc has no bias. + if (left_fc_bias == nullptr) + { + auto left_no_bias = dynamic_cast(left_fc_node->bias()); + auto right_no_bias = dynamic_cast(right_fc_node->bias()); + if (not left_no_bias or not right_no_bias) + return false; + } + + if (left_fc_weights->dtype() != loco::DataType::FLOAT32) + return false; + + if (left_fc_bias != nullptr) + { + if (left_fc_bias->dtype() != loco::DataType::FLOAT32) + return false; + } + + // Lets create fused FC weights and bias + auto fused_fc_weights = clone(left_fc_weights); + add_origin(fused_fc_weights, + composite_origin({get_origin(left_fc_weights), get_origin(right_fc_weights)})); + + CircleConst *fused_fc_bias = nullptr; + if (left_fc_bias != nullptr) + { + fused_fc_bias = clone(left_fc_bias); + add_origin(fused_fc_bias, + composite_origin({get_origin(left_fc_bias), get_origin(right_fc_bias)})); + } + + assert(fused_fc_weights->dtype() == loco::DataType::FLOAT32); + sum_const_values(fused_fc_weights, right_fc_weights); + + if (fused_fc_bias != nullptr) + { + assert(fused_fc_bias->dtype() == loco::DataType::FLOAT32); + sum_const_values(fused_fc_bias, right_fc_bias); + } + + // Create fused FC node + auto graph = left_fc_node->graph(); + auto fused_fc_node = graph->nodes()->create(); + fused_fc_node->input(left_fc_node->input()); + fused_fc_node->weights(fused_fc_weights); + if (fused_fc_bias) + { + fused_fc_node->bias(fused_fc_bias); + } + else + { + assert(nullptr != dynamic_cast(left_fc_node->bias())); // FIX ME UNLESS + fused_fc_node->bias(left_fc_node->bias()); + } + + fused_fc_node->fusedActivationFunction(add_node->fusedActivationFunction()); + fused_fc_node->name(left_fc_node->name() + "_" + right_fc_node->name() + "_fused"); + + add_origin(fused_fc_node, composite_origin({get_origin(left_fc_node), get_origin(right_fc_node), + get_origin(add_node)})); + + replace(add_node).with(fused_fc_node); + + return true; +} + +} // namespace + +/** + * @brief Class to fuse horizontal FC layers + * + * Before + * + * +---- [In] ----+ + * | | + * V V + * fc1 (w1, b1) fc2 (w2, b2) + * | | + * | | + * +---> add <----+ + * | + * V + * [Out] + * + * After + * + * [In] + * | + * V + * fc3 (w1+w2, b1+b2) + * | + * V + * [Out] + * + * Shape/dtype of fc1, fc2, and fc3 should be the same. + */ +bool FuseHorizontalFullyConnectedPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto add_node = dynamic_cast(node); + if (not add_node) + continue; + + if (fuse_horizontal_fc_nodes(add_node)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.test.cpp new file mode 100644 index 000000000..3dba7f89a --- /dev/null +++ b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.test.cpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseHorizontalFullyConnectedPass.h" +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +/* + * Before + * + * +---- [In] ----+ + * | | + * V V + * fc1 (w1, b1) fc2 (w2, b2) + * | | + * | | + * +---> add <----+ + * | + * V + * [Out] + * + * After + * + * [In] + * | + * V + * fc3 (w1+w2, b1+b2) + * | + * V + * [Out] + */ +class FuseHorizontalFCLayersTestGraph : public TestIOGraph +{ +public: + FuseHorizontalFCLayersTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 10}, {1, 10}); + + _left_fc = g()->nodes()->create(); + _right_fc = g()->nodes()->create(); + _left_weight = g()->nodes()->create(); + _right_weight = g()->nodes()->create(); + + _left_fc->name("left FC"); + _right_fc->name("right FC"); + _left_weight->name("left weight"); + _right_weight->name("right weight"); + + _left_fc->dtype(loco::DataType::FLOAT32); + _right_fc->dtype(loco::DataType::FLOAT32); + + _left_fc->shape_status(luci::ShapeStatus::VALID); + _right_fc->shape_status(luci::ShapeStatus::VALID); + + _left_fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _right_fc->fusedActivationFunction(luci::FusedActFunc::NONE); + + _left_fc->rank(2); + _right_fc->rank(2); + + _right_fc->dim(0) = 1; + _right_fc->dim(1) = 10; + + _left_fc->dim(0) = 1; + _left_fc->dim(1) = 10; + + _left_weight->rank(2); + _left_weight->dtype(loco::DataType::FLOAT32); + _left_weight->size(5 * 10); + for (uint32_t i = 0; i < 5 * 10; ++i) + { + _left_weight->at(0) = 1.0f; + } + _left_weight->dim(0) = 5; + _left_weight->dim(1) = 10; + _left_weight->shape_status(luci::ShapeStatus::VALID); + + _right_weight->rank(2); + _right_weight->dtype(loco::DataType::FLOAT32); + _right_weight->size(5 * 10); + for (uint32_t i = 0; i < 5 * 10; ++i) + { + _right_weight->at(0) = 2.0f; + } + _right_weight->dim(0) = 5; + _right_weight->dim(1) = 10; + _right_weight->shape_status(luci::ShapeStatus::VALID); + + const auto left_output_exclude = g()->nodes()->create(); + const auto right_output_exclude = g()->nodes()->create(); + + _left_fc->input(input()); + _left_fc->weights(_left_weight); + _left_fc->bias(left_output_exclude); + _right_fc->input(input()); + _right_fc->weights(_right_weight); + _right_fc->bias(right_output_exclude); + + _add = g()->nodes()->create(); + _add->dtype(loco::DataType::FLOAT32); + _add->rank(2); + _add->dim(0) = 1; + _add->dim(1) = 5; + _add->x(_left_fc); + _add->y(_right_fc); + _add->shape_status(luci::ShapeStatus::VALID); + + output()->from(_add); + } + + luci::CircleFullyConnected *get_left_fc() { return _left_fc; } + +private: + luci::CircleFullyConnected *_left_fc = nullptr; + luci::CircleConst *_left_weight = nullptr; + luci::CircleFullyConnected *_right_fc = nullptr; + luci::CircleConst *_right_weight = nullptr; + luci::CircleAdd *_add = nullptr; +}; + +} // namespace + +TEST(FuseHorizontalFCLayersPassTest, name) +{ + luci::FuseHorizontalFullyConnectedPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(FuseHorizontalFCLayersPassTest, fuse_horizontal_nodes) +{ + FuseHorizontalFCLayersTestGraph g; + luci::FuseHorizontalFullyConnectedPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseHorizontalFCLayersPassTest, fuse_horizontal_nodes_NEG) +{ + FuseHorizontalFCLayersTestGraph g; + luci::FuseHorizontalFullyConnectedPass pass; + + g.init(); + + g.get_left_fc()->fusedActivationFunction(luci::FusedActFunc::RELU6); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FuseHorizontalFCLayersPassTest, wrong_dtype_NEG) +{ + FuseHorizontalFCLayersTestGraph g; + luci::FuseHorizontalFullyConnectedPass pass; + + g.init(); + + g.get_left_fc()->dtype(loco::DataType::S32); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.cpp b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.cpp new file mode 100644 index 000000000..1f41d16f0 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" + +#include +#include + +#include "helpers/NodeFiller.h" + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +namespace +{ + +/** + * Fuse Mul to following FullyConnected if possible + * + * BEFORE + * | + * [CircleMul] [CircleConst] [CircleConst] + * | | | + * [CircleFullyConnected] ----------+ + * | + * + * AFTER + * | + * | [CircleConst] [CircleConst] + * | | | + * | [CircleMul] [CircleConst] [CircleMul] + * | | | + * [CircleFullyConnected] ------------+ + * | + * + */ +bool fuse_fc_with_mul(luci::CircleFullyConnected *fc) +{ + CHECK_OR_FALSE(fc); + + // check input is Mul + auto mul = dynamic_cast(fc->input()); + CHECK_OR_FALSE(mul); + // conditions of Mul, FC: to expect constant folding, support only F32 + CHECK_OR_FALSE(mul->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(mul->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(fc->dtype() == loco::DataType::FLOAT32); + // support weight with constant + auto weights = dynamic_cast(fc->weights()); + CHECK_OR_FALSE(weights); + + // Check multiplication of Mul is constant + luci::CircleNode *mul_input = nullptr; + luci::CircleConst *mul_scale = nullptr; + CHECK_OR_FALSE(luci::fill(&mul_input, &mul_scale).with_commutative_args_of(mul)); + // support only 1D constant + CHECK_OR_FALSE(mul_scale->rank() == 1); + + auto graph = fc->graph(); + + auto fc_weights = graph->nodes()->create(); + fc_weights->x(weights); + fc_weights->y(mul_scale); + fc_weights->fusedActivationFunction(luci::FusedActFunc::NONE); + fc_weights->name(mul->name() + "_" + fc->name() + "_weight"); + luci::add_origin(fc_weights, + luci::composite_origin({luci::get_origin(mul), luci::get_origin(weights), + luci::get_origin(mul_scale)})); + + auto fc_new = graph->nodes()->create(); + fc_new->input(mul_input); + fc_new->weights(fc_weights); + fc_new->bias(fc->bias()); + fc_new->weights_format(fc->weights_format()); + fc_new->keep_num_dims(fc->keep_num_dims()); + fc_new->fusedActivationFunction(fc->fusedActivationFunction()); + fc_new->name(fc->name()); + luci::add_origin(fc_new, luci::get_origin(fc)); + + replace(fc).with(fc_new); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseMulToFullyConnectedWeightsPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast(node); + if (not fc) + continue; + + if (fuse_fc_with_mul(fc)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.test.cpp b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.test.cpp new file mode 100644 index 000000000..2cb7a4e9f --- /dev/null +++ b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.test.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +template class FuseMulToFullyConnectedWeightsPassTestGraph : public TestIOGraph +{ +public: + FuseMulToFullyConnectedWeightsPassTestGraph() = default; + + void init(void) + { + TestIOGraph::init({3, 4}, {3, 6}); + + _mul = g()->nodes()->create(); + _mul_s = g()->nodes()->create(); + _fc = g()->nodes()->create(); + _fc_w = g()->nodes()->create(); + _fc_b = g()->nodes()->create(); + + _mul->name("mul"); + _mul_s->name("mul_s"); + _fc->name("fc"); + _fc_w->name("fc_w"); + _fc_b->name("fc_b"); + + _mul->dtype(DT); + _fc->dtype(DT); + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + + _mul_s->rank(1); + _mul_s->dim(0) = 3; + _mul_s->dtype(DT); + _mul_s->size
(3); + for (uint32_t i = 0; i < 3; ++i) + { + _mul_s->at
(0) = 1.0f; + } + + _fc_w->rank(2); + _fc_w->dim(0) = 3; + _fc_w->dim(1) = 4; + _fc_w->dtype(DT); + _fc_w->size
(4 * 6); + for (uint32_t i = 0; i < 4 * 6; ++i) + { + _fc_w->at
(0) = 1.0f; + } + + _fc_b->rank(1); + _fc_b->dim(0) = 6; + _fc_b->dtype(DT); + _fc_b->size
(6); + for (uint32_t i = 0; i < 6; ++i) + { + _fc_b->at
(0) = 1.0f; + } + + _mul->x(input()); + _mul->y(_mul_s); + _fc->input(_mul); + _fc->weights(_fc_b); + _fc->bias(_fc_b); + + output()->from(_fc); + } + + luci::CircleMul *_mul = nullptr; + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleConst *_mul_s = nullptr; + luci::CircleConst *_fc_w = nullptr; + luci::CircleConst *_fc_b = nullptr; +}; + +class FuseMulToFullyConnectedWeightsPassTest : public ::testing::Test +{ +public: + FuseMulToFullyConnectedWeightsPassTest() = default; + +protected: + FuseMulToFullyConnectedWeightsPassTestGraph _graph; + luci::FuseMulToFullyConnectedWeightsPass _pass; +}; + +class FuseMulToFullyConnectedWeightsPassS32Test : public ::testing::Test +{ +public: + FuseMulToFullyConnectedWeightsPassS32Test() = default; + +protected: + FuseMulToFullyConnectedWeightsPassTestGraph _graph; + luci::FuseMulToFullyConnectedWeightsPass _pass; +}; + +} // namespace + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, name) +{ + auto const name = _pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, fuse_mul_to_fc_weights) +{ + _graph.init(); + + EXPECT_TRUE(_pass.run(_graph.g())); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, mul_fused_act_NEG) +{ + _graph.init(); + + _graph._mul->fusedActivationFunction(luci::FusedActFunc::RELU); + + EXPECT_FALSE(_pass.run(_graph.g())); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, mul_d2_NEG) +{ + _graph.init(); + + _graph._mul_s->rank(2); + _graph._mul_s->dim(0) = 1; + _graph._mul_s->dim(1) = 3; + + EXPECT_FALSE(_pass.run(_graph.g())); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassS32Test, dtype_s32_NEG) +{ + _graph.init(); + + EXPECT_FALSE(_pass.run(_graph.g())); +} diff --git a/compiler/luci/pass/src/FuseMulWithConvPass.cpp b/compiler/luci/pass/src/FuseMulWithConvPass.cpp new file mode 100644 index 000000000..94ad0863d --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithConvPass.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithConvPass.h" + +#include "helpers/NodeFiller.h" + +#include +#include +#include +#include + +namespace +{ + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + +inline uint32_t cal_offset(const luci::CircleConst *node, const std::vector &indices) +{ + // sanity check for node's rank + assert(node != nullptr && node->rank() == 4); + + // sanity check for indices + assert(indices.size() == 4); + + return indices[0] * node->dim(1).value() * node->dim(2).value() * node->dim(3).value() + + indices[1] * node->dim(2).value() * node->dim(3).value() + + indices[2] * node->dim(3).value() + indices[3]; +} + +/** + * Fuse Mul with Conv if possible + * + * NOTE: In case mul is channewise constant, we can try to merge mul with nconv, + * + * BEFORE + * | + * [CircleConv2D] (no activation) + * | + * [Mul] (channel-wise/scalar constant) + * | + * + * AFTER + * | + * [CircleConv2D] (with updated kernels, bias, and activation) + * | + * + */ + +bool fuse_mul_with_conv(luci::CircleMul *mul) +{ + // sanity check + RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); + + luci::CircleConst *const_mul_operand = nullptr; + luci::CircleConv2D *conv = nullptr; + RETURN_FALSE_UNLESS(luci::fill(&const_mul_operand, &conv).with_commutative_args_of(mul)); + + // sanity check + RETURN_FALSE_UNLESS(conv->dtype() == loco::DataType::FLOAT32 && + const_mul_operand->dtype() == loco::DataType::FLOAT32); + + // NOTE for general activation function: S * Act(A * B) != Act(A*(SB)) + if (conv->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + // check that const_mul_operand is channel-wise constant or just a scalar + RETURN_FALSE_UNLESS(const_mul_operand->rank() == 4 || const_mul_operand->rank() == 1 || + const_mul_operand->rank() == 0); + + std::vector mul_values; + if (const_mul_operand->rank() == 4) + { + // check channel-wise broadcasting + RETURN_FALSE_UNLESS(const_mul_operand->dim(0).value() == 1 && + const_mul_operand->dim(1).value() == 1 && + const_mul_operand->dim(2).value() == 1); + } + else if (const_mul_operand->rank() == 1 || const_mul_operand->rank() == 0) + { + // sanity check + RETURN_FALSE_UNLESS(const_mul_operand->size() != 0); + } + + mul_values.resize(const_mul_operand->size()); + for (uint32_t idx = 0; idx < mul_values.size(); idx++) + { + mul_values[idx] = const_mul_operand->at(idx); + } + + // filter + auto const conv_filter = dynamic_cast(conv->filter()); + + // sanity check + RETURN_FALSE_UNLESS(conv_filter != nullptr && conv_filter->rank() == 4 && + conv_filter->dtype() == loco::DataType::FLOAT32); + + auto const out_channels = conv_filter->dim(0).value(); + + // multiplier is either channelwise constant or scalar + RETURN_FALSE_UNLESS(out_channels == mul_values.size() || mul_values.size() == 1); + + // bias + auto const conv_bias = dynamic_cast(conv->bias()); + + RETURN_FALSE_UNLESS(conv_bias == nullptr || + (conv_bias->rank() == 1 && conv_bias->dim(0) == out_channels && + conv_bias->dtype() == loco::DataType::FLOAT32)); + + luci::CircleConst *fused_conv_filter = nullptr; + { + // fused filter + fused_conv_filter = luci::clone(conv_filter); + // set values of conv filter multiplied by constant channel-wise + for (uint32_t out_chan = 0; out_chan < out_channels; out_chan++) + { + // for scalar - first element, otherwise - channelwise + float mult = mul_values[out_chan % mul_values.size()]; + for (uint32_t out_height = 0; out_height < fused_conv_filter->dim(1).value(); out_height++) + { + for (uint32_t out_width = 0; out_width < fused_conv_filter->dim(2).value(); out_width++) + { + for (uint32_t in_chan = 0; in_chan < fused_conv_filter->dim(3).value(); in_chan++) + { + std::vector indices = {out_chan, out_height, out_width, in_chan}; + auto const data = + conv_filter->at(cal_offset(conv_filter, indices)); + fused_conv_filter->at(cal_offset(fused_conv_filter, indices)) = + mult * data; + } + } + } + } + fused_conv_filter->name(conv_filter->name() + "/FusedMul"); + luci::add_origin(fused_conv_filter, luci::get_origin(conv_filter)); + } + + luci::CircleConst *fused_conv_bias = nullptr; + if (conv_bias != nullptr) + { + // fused bias + fused_conv_bias = luci::clone(conv_bias); + // update bias values + for (uint32_t c = 0; c < conv_bias->size(); c++) + { + // for scalar - first element, otherwise - channelwise + float mult = mul_values[c % mul_values.size()]; + auto const data = conv_bias->at(c); + fused_conv_bias->at(c) = mult * data; + } + + fused_conv_bias->name(conv_bias->name() + "/FusedMul"); + luci::add_origin(fused_conv_bias, luci::get_origin(conv_bias)); + } + + // Configure new CircleConv2D operation. + auto *fused_conv = loco::must_cast(luci::clone_node(conv, mul->graph())); + fused_conv->input(conv->input()); + fused_conv->filter(fused_conv_filter); + fused_conv->bias(fused_conv_bias); + fused_conv->name(conv->name() + "/FusedMul"); + fused_conv->fusedActivationFunction(mul->fusedActivationFunction()); + luci::add_origin(fused_conv, + luci::composite_origin({luci::get_origin(conv), luci::get_origin(mul)})); + + // Replace old mul operation with new fused_conv with updated kernel/bias + replace(mul).with(fused_conv); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseMulWithConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto mul = dynamic_cast(node); + if (not mul) + continue; + + if (fuse_mul_with_conv(mul)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseMulWithConvPass.test.cpp b/compiler/luci/pass/src/FuseMulWithConvPass.test.cpp new file mode 100644 index 000000000..9f215ded7 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithConvPass.test.cpp @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/FuseMulWithConvPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class ConvMulGraphlet +{ +public: + ConvMulGraphlet() = default; + +public: + void init(loco::Graph *g, bool activation, bool mul_const_shape) + { + _conv = g->nodes()->create(); + _conv_filter = g->nodes()->create(); + _conv_bias = g->nodes()->create(); + _mul = g->nodes()->create(); + _mul_const = g->nodes()->create(); + + if (activation) + { + _conv->fusedActivationFunction(luci::FusedActFunc::RELU); + } + else + { + _conv->fusedActivationFunction(luci::FusedActFunc::NONE); + } + + _conv->dtype(loco::DataType::FLOAT32); + _conv_filter->dtype(loco::DataType::FLOAT32); + _conv_bias->dtype(loco::DataType::FLOAT32); + _mul->dtype(loco::DataType::FLOAT32); + _mul_const->dtype(loco::DataType::FLOAT32); + + _conv->name("conv"); + _conv_filter->name("conv_filter"); + _conv_bias->name("conv_bias"); + _mul->name("mul"); + _mul_const->name("mul_const"); + + _conv_filter->shape({_output_channels, 1, 1, _input_channels}); + _conv_bias->shape({_output_channels}); + if (mul_const_shape) + { + _mul_const->shape({1, 1, _input_dim, _output_channels}); + } + else + { + _mul_const->shape({1, 1, 1, _output_channels}); + // initialize _mul_const for positive test + _mul_const->size(_output_channels); + for (uint32_t i = 0; i < _output_channels; i++) + { + _mul_const->at(i) = 1.f; + } + } + + { + // initialize bias + _conv_bias->size(_output_channels); + for (uint32_t i = 0; i < _output_channels; i++) + { + _conv_bias->at(i) = 0.f; + } + } + + { + // initialize filter + _conv_filter->size(_output_channels * _input_channels); + for (uint32_t i = 0; i < _conv_filter->size(); i++) + { + _conv_filter->at(i) = 1.f; + } + } + + _conv->filter(_conv_filter); + _conv->bias(_conv_bias); + _conv->padding(luci::Padding::VALID); + _conv->stride()->h(1); + _conv->stride()->w(1); + _mul->x(_mul_const); + _mul->y(_conv); + } + +protected: + luci::CircleConv2D *_conv = nullptr; + luci::CircleConst *_conv_filter = nullptr; + luci::CircleConst *_conv_bias = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_mul_const = nullptr; + + const uint32_t _input_channels = 32; + const uint32_t _output_channels = 64; + const uint32_t _input_dim = 64; + const ShapeU32 _input_shape = {1, _input_dim, _input_dim, _input_channels}; + const ShapeU32 _output_shape = {1, _input_dim, _input_dim, _output_channels}; +}; + +class ConvMulGraph : public TestIOGraph, public ConvMulGraphlet +{ +public: + ConvMulGraph() = default; + +public: + void init(bool activation, bool mul_const_shape) + { + TestIOGraph::init(_input_shape, _output_shape); + ConvMulGraphlet::init(g(), activation, mul_const_shape); + + _conv->input(input()); + output()->from(_mul); + } +}; + +} // namespace + +TEST(FuseMulWithConvPass, name_test) +{ + luci::FuseMulWithConvPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(FuseMulWithConvPass, simple_test) +{ + luci::FuseMulWithConvPass pass; + + ConvMulGraph g; + g.init(false, false); + + ASSERT_TRUE(pass.run(g.g())); + + // check Mul is removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto mul = dynamic_cast(node)) + count++; + } + ASSERT_EQ(0, count); +} + +TEST(FuseMulWithConvPass, not_removed_NEG) +{ + luci::FuseMulWithConvPass pass; + ConvMulGraph g; + g.init(false, true); + + ASSERT_FALSE(pass.run(g.g())); + + // check Mul is not removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto mul = dynamic_cast(node)) + count++; + } + ASSERT_EQ(1, count); +} + +TEST(FuseMulWithConvPass, activation_blocks_removal_NEG) +{ + luci::FuseMulWithConvPass pass; + ConvMulGraph g; + g.init(true, false); + + ASSERT_FALSE(pass.run(g.g())); + + // check Mul is not removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto mul = dynamic_cast(node)) + count++; + } + ASSERT_EQ(1, count); +} diff --git a/compiler/luci/pass/src/FuseMulWithDivPass.cpp b/compiler/luci/pass/src/FuseMulWithDivPass.cpp new file mode 100644 index 000000000..c4c7d8170 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithDivPass.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithDivPass.h" + +#include "helpers/NodeFiller.h" + +#include +#include + +namespace luci +{ + +namespace +{ + +// Return a new CircleConst with a new value +luci::CircleConst *create_div_const_with_new_value(luci::CircleConst *div_const, + luci::CircleConst *mul_const, float new_value) +{ + assert(div_const); // FIX_CALLER_UNLESS + assert(div_const->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(div_const->size() == 1); // FIX_CALLER_UNLESS + + auto new_div_const = div_const->graph()->nodes()->create(); + new_div_const->dtype(loco::DataType::FLOAT32); + new_div_const->size(1); + new_div_const->rank(1); + new_div_const->dim(0) = 1; + new_div_const->at(0) = new_value; + new_div_const->shape_status(luci::ShapeStatus::VALID); + new_div_const->name(div_const->name() + ";" + mul_const->name()); + + luci::add_origin(new_div_const, luci::composite_origin( + {luci::get_origin(div_const), luci::get_origin(mul_const)})); + + return new_div_const; +} + +// Return a new CircleConst with a new value +luci::CircleConst *create_mul_const_with_new_value(luci::CircleConst *mul_const, + luci::CircleConst *div_const, float new_value) +{ + assert(mul_const); // FIX_CALLER_UNLESS + assert(mul_const->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(mul_const->size() == 1); // FIX_CALLER_UNLESS + + auto new_mul_const = mul_const->graph()->nodes()->create(); + new_mul_const->dtype(loco::DataType::FLOAT32); + new_mul_const->rank(0); + new_mul_const->size(1); + new_mul_const->scalar() = new_value; + new_mul_const->shape_status(luci::ShapeStatus::VALID); + new_mul_const->name(mul_const->name() + ";" + div_const->name()); + + luci::add_origin(new_mul_const, luci::composite_origin( + {luci::get_origin(mul_const), luci::get_origin(div_const)})); + + return new_mul_const; +} + +/** + * Pass to fuse mul(one of the input is const scalar) and + * div(numerator is const scalar) as div + * + * BEFORE + * [CircleNode] [Scalar_Mul_Const] + * | | + * [CirlceMul, (x=CircleNode, y=Scalar_Mul_Const)] -------- + * | + * | [Scalar_Div_Const] + * | | + * [CircleDiv, (x=Scalar_Div_Const, y=CirlceMul)] ------ + * | + * [CircleNode] + * + * AFTER + * [CircleNode] + * | [Scalar_new_Div_Const] + * | | + * [CircleDiv, (x=Scalar_new_Div_Const, y=CircleNode)] ------- + * | + * [CircleNode] + * + * where Scalar_new_Div_Const = Scalar_Div_Const / Scalar_Mul_Const + * + **/ +bool fuse_mul_with_div_to_div(luci::CircleDiv *div) +{ + if (div->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleConst *div_const = nullptr; + luci::CircleMul *mul = nullptr; + if (not luci::fill(&div_const, &mul).with_args_of(div)) + return false; + + if (div_const->dtype() != loco::DataType::FLOAT32) + return false; + + if (div_const->size() != 1) + return false; + + if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleNode *mul_input = nullptr; + luci::CircleConst *mul_const = nullptr; + if (not luci::fill(&mul_input, &mul_const).with_commutative_args_of(mul)) + return false; + + if (mul_const->dtype() != loco::DataType::FLOAT32) + return false; + + if (mul_const->size() != 1) + return false; + + const auto div_value = div_const->at(0); + const auto mul_value = mul_const->at(0); + + if (mul_value == 0) + return false; + + const auto new_value = div_value / mul_value; + + auto new_div_const = create_div_const_with_new_value(div_const, mul_const, new_value); + auto new_div = div->graph()->nodes()->create(); + new_div->fusedActivationFunction(luci::FusedActFunc::NONE); + new_div->x(new_div_const); + new_div->y(mul_input); + new_div->name(div->name()); + luci::add_origin(new_div, luci::composite_origin({luci::get_origin(div), luci::get_origin(mul)})); + + replace(div).with(new_div); + + return true; +} + +/** + * Pass to fuse mul(one of the input is const scalar) and + * div(numerator is const scalar) as mul + * + * BEFORE + * [CircleNode] [Scalar_Mul_Const] + * | | + * [CirlceMul, (x=CircleNode, y=Scalar_Mul_Const)] -------- + * | + * | [Scalar_Div_Const] + * | | + * [CircleDiv, (x=CirlceMul, y=Scalar_Div_Const)] ------ + * | + * [CircleNode] + * + * AFTER + * [CircleNode] + * | [Scalar_new_Mul_Const] + * | | + * [CircleMul, (x=CircleNode, y=Scalar_new_Mul_Const)] ------- + * | + * [CircleNode] + * + * where Scalar_new_Mul_Const = Scalar_Mul_Const / Scalar_Div_Const + * + **/ +bool fuse_mul_with_div_to_mul(luci::CircleDiv *div) +{ + if (div->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleMul *mul = nullptr; + luci::CircleConst *div_const = nullptr; + if (not luci::fill(&mul, &div_const).with_args_of(div)) + return false; + + if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + if (div_const->dtype() != loco::DataType::FLOAT32) + return false; + // TODO support other shape + if (div_const->size() != 1) + return false; + + luci::CircleNode *mul_input = nullptr; + luci::CircleConst *mul_const = nullptr; + if (not luci::fill(&mul_input, &mul_const).with_commutative_args_of(mul)) + return false; + + if (mul_const->dtype() != loco::DataType::FLOAT32) + return false; + // TODO support other shape + if (mul_const->size() != 1) + return false; + + const auto mul_value = mul_const->at(0); + const auto div_value = div_const->at(0); + const auto new_value = mul_value / div_value; + auto new_mul_const = create_mul_const_with_new_value(mul_const, div_const, new_value); + + auto new_mul = div->graph()->nodes()->create(); + new_mul->fusedActivationFunction(luci::FusedActFunc::NONE); + new_mul->x(mul_input); + new_mul->y(new_mul_const); + new_mul->name(mul->name()); + luci::add_origin(new_mul, luci::composite_origin({luci::get_origin(div), luci::get_origin(mul)})); + + replace(div).with(new_mul); + + return true; +} + +} // namespace + +bool FuseMulWithDivPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto div = dynamic_cast(node); + if (not div) + continue; + + if (fuse_mul_with_div_to_div(div)) + changed = true; + + if (fuse_mul_with_div_to_mul(div)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseMulWithDivPass.test.cpp b/compiler/luci/pass/src/FuseMulWithDivPass.test.cpp new file mode 100644 index 000000000..67ad48e1d --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithDivPass.test.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithDivPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class PatternMulDivGraphlet +{ +public: + PatternMulDivGraphlet() = default; + + void init(loco::Graph *g) + { + _mul = g->nodes()->create(); + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + _mul_const = g->nodes()->create(); + _mul_const->rank(1); + _mul_const->dtype(loco::DataType::FLOAT32); + _mul_const->size(1); + _mul_const->at(0) = 1.1f; + _mul_const->shape_status(luci::ShapeStatus::VALID); + + _div = g->nodes()->create(); + _div->fusedActivationFunction(luci::FusedActFunc::NONE); + _div_const = g->nodes()->create(); + _div_const->rank(1); + _div_const->dtype(loco::DataType::FLOAT32); + _div_const->size(1); + _div_const->at(0) = 2.2f; + _div_const->shape_status(luci::ShapeStatus::VALID); + + _mul->name("_mul"); + _mul_const->name("_mul_const"); + + _div->name("_div"); + _div_const->name("_div_const"); + } + +public: + luci::CircleMul *mul() { return _mul; } + luci::CircleConst *mul_const() { return _mul_const; } + luci::CircleDiv *div() { return _div; } + luci::CircleConst *div_const() { return _div_const; } + +protected: + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_mul_const = nullptr; + luci::CircleDiv *_div = nullptr; + luci::CircleConst *_div_const = nullptr; +}; + +/** + * Simple graph for test + * + * BEFORE + * [Input] + * | + * [Mul, MUL_Scalar_Const] + * | + * [Div, DIV_Scalar_Const] + * | + * [Output] + * + * AFTER + * [Input] + * | + * [Div, Scalar_Const_new] + * | + * [Output] + * + * WHERE: Scalar_Const_new = DIV_Scalar_Const / MUL_Scalar_Const + */ +class FuseMulDivPatternTestGraph : public TestIOGraph, public PatternMulDivGraphlet +{ +public: + FuseMulDivPatternTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 2, 3}, {1, 2, 3}); + PatternMulDivGraphlet::init(g()); + + _mul->x(input()); + _mul->y(_mul_const); + + _div->x(_div_const); + _div->y(_mul); + + output()->from(_div); + } +}; + +class FuseMulDivToMulPatternTestGraph : public TestIOGraph, public PatternMulDivGraphlet +{ +public: + FuseMulDivToMulPatternTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 2, 3}, {1, 2, 3}); + PatternMulDivGraphlet::init(g()); + + _mul->x(input()); + _mul->y(_mul_const); + + _div->x(_mul); + _div->y(_div_const); + + output()->from(_div); + } +}; + +} // namespace + +TEST(FuseMulWithDivPassTest, fus_mul_div_pattern) +{ + FuseMulDivPatternTestGraph g; + luci::FuseMulWithDivPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseMulWithDivPassTest, fuse_mul_div_NEG) +{ + FuseMulDivPatternTestGraph g; + luci::FuseMulWithDivPass pass; + + g.init(); + + // Add CircleRelu operation between CircleMean and Mul operations + auto relu = g.g()->nodes()->create(); + relu->name("relu"); + relu->features(g.mul()); + g.div()->y(relu); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FuseMulWithDivPassTest, fuse_mul_div_to_mul_pattern) +{ + FuseMulDivToMulPatternTestGraph g; + luci::FuseMulWithDivPass pass; + + g.init(); + + auto div = dynamic_cast(g.output()->from()); + EXPECT_NE(div, nullptr); + + EXPECT_TRUE(pass.run(g.g())); + + auto mul = dynamic_cast(g.output()->from()); + EXPECT_NE(mul, nullptr); +} + +TEST(FuseMulWithDivPassTest, fuse_mul_div_to_mul_NEG) +{ + FuseMulDivToMulPatternTestGraph g; + luci::FuseMulWithDivPass pass; + + g.init(); + + // Add CircleRelu operation between CircleMul and Div operations + auto relu = g.g()->nodes()->create(); + relu->name("relu"); + relu->features(g.mul()); + g.div()->x(relu); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/FuseRsqrtPass.cpp b/compiler/luci/pass/src/FuseRsqrtPass.cpp new file mode 100644 index 000000000..eb3b2c67e --- /dev/null +++ b/compiler/luci/pass/src/FuseRsqrtPass.cpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseRsqrtPass.h" +#include "helpers/NodeFiller.h" + +#include +#include + +#include +#include + +namespace +{ + +/** + * Fuse Sqrt + Div to Rsqrt (1/Sqrt -> Rsqrt) + * + * BEFORE + * | + * [CircleSqrt] + * [CircleConst] | where Const has value 1.0 + * | | + * [CircleDiv] + * | + * + * AFTER + * | + * | [CircleSqrt] + * [CircleRsqrt] [CircleConst] | + * | | | + * | [CircleDiv] + */ + +// Float comparison +bool same(float a, float b) { return fabs(a - b) < 1e-5; } + +class RsqrtPattern +{ +public: + RsqrtPattern(luci::CircleDiv *candidate) + { + assert(candidate); // FIX_CALLER_UNLESS + _div = candidate; + } + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +public: + bool matched() + { + // Check pattern + CHECK_OR_FALSE(luci::fill(&_div_const, &_sqrt).with_args_of(_div)); + _ifm = loco::must_cast(_sqrt->x()); + + CHECK_OR_FALSE(_div->fusedActivationFunction() == luci::FusedActFunc::NONE); + + // Check div_const = 1 + switch (_div->dtype()) + { + case loco::DataType::S16: + CHECK_OR_FALSE(_div_const->quantparam() != nullptr); + CHECK_OR_FALSE(_div_const->quantparam()->scale.size() == 1); + CHECK_OR_FALSE(_div_const->quantparam()->zerop.size() == 1); + CHECK_OR_FALSE(_div_const->quantparam()->zerop.at(0) == 0); + CHECK_OR_FALSE(_div_const->size() == 1); + CHECK_OR_FALSE(same(1.0, _div_const->at(0) * + _div_const->quantparam()->scale.at(0))); + break; + // TODO Support more dtypes + default: + return false; + } + + return true; + } +#undef CHECK_OR_FALSE + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleSqrt *_sqrt = nullptr; + luci::CircleDiv *_div = nullptr; + luci::CircleConst *_div_const = nullptr; +}; + +class FuseRsqrt final +{ +public: + FuseRsqrt(const RsqrtPattern *p) : _p(p) {} + +public: + void apply(void); + +private: + luci::CircleRsqrt *create_rsqrt(loco::Graph *graph); + +private: + const RsqrtPattern *_p = nullptr; +}; + +luci::CircleRsqrt *FuseRsqrt::create_rsqrt(loco::Graph *graph) +{ + assert(graph); + + auto rsqrt = graph->nodes()->create(); + rsqrt->x(_p->_ifm); + rsqrt->name(_p->_div->name() + "_rsqrt"); + + luci::copy_quantparam(_p->_div, rsqrt); + + return rsqrt; +} + +void FuseRsqrt::apply() +{ + auto graph = _p->_div->graph(); + + auto rsqrt = create_rsqrt(graph); + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p->_sqrt), luci::get_origin(_p->_div), luci::get_origin(_p->_div_const)}; + + luci::add_origin(rsqrt, luci::composite_origin(origin_vec)); + + replace(_p->_div).with(rsqrt); +} + +} // namespace + +namespace +{ + +bool fuse_rsqrt(luci::CircleDiv *div) +{ + assert(div); + + RsqrtPattern pattern(div); + if (pattern.matched()) + { + FuseRsqrt fuse(&pattern); + fuse.apply(); + return true; + } + + return false; +} + +} // namespace + +namespace luci +{ + +bool FuseRsqrtPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto div = dynamic_cast(node); + if (not div) + continue; + + if (fuse_rsqrt(div)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseRsqrtPass.test.cpp b/compiler/luci/pass/src/FuseRsqrtPass.test.cpp new file mode 100644 index 000000000..425179a5a --- /dev/null +++ b/compiler/luci/pass/src/FuseRsqrtPass.test.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseRsqrtPass.h" + +#include + +#include + +#include +#include + +namespace +{ + +using namespace luci::test; + +std::unique_ptr gen_qparam(const std::vector &s, + const std::vector &zp) +{ + auto qparam = std::make_unique(); + { + for (auto scale : s) + qparam->scale.push_back(scale); + + for (auto zerop : zp) + qparam->zerop.push_back(zerop); + } + + return std::move(qparam); +} + +class S16RsqrtGraphlet +{ +public: + S16RsqrtGraphlet() = default; + + void init(loco::Graph *g) + { + _sqrt = g->nodes()->create(); + _div = g->nodes()->create(); + _div_const = g->nodes()->create(); + + _div->fusedActivationFunction(luci::FusedActFunc::NONE); + + _sqrt->dtype(loco::DataType::S16); + _div->dtype(loco::DataType::S16); + _div_const->dtype(loco::DataType::S16); + + _div_const->size(1); + _div_const->shape({1}); + _div_const->at(0) = 1; + _div_const->shape_status(luci::ShapeStatus::VALID); + + _sqrt->quantparam(gen_qparam({1.0}, {0})); + _div->quantparam(gen_qparam({2.0}, {0})); + _div_const->quantparam(gen_qparam({1.0}, {0})); + } + + void invalid_act() { _div->fusedActivationFunction(luci::FusedActFunc::RELU); } + +protected: + luci::CircleSqrt *_sqrt = nullptr; + luci::CircleDiv *_div = nullptr; + luci::CircleConst *_div_const = nullptr; +}; + +class FuseS16RsqrtTestGraph : public TestIOGraph, public S16RsqrtGraphlet +{ +public: + FuseS16RsqrtTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + S16RsqrtGraphlet::init(g()); + + _sqrt->x(input()); + _div->x(_div_const); + _div->y(_sqrt); + + output()->from(_div); + } +}; + +} // namespace + +TEST(FuseRsqrtPassTest, s16) +{ + FuseS16RsqrtTestGraph g; + luci::FuseRsqrtPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseRsqrtPassTest, fuse_invalid_act_NEG) +{ + FuseS16RsqrtTestGraph g; + luci::FuseRsqrtPass pass; + + g.init(); + g.invalid_act(); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/FuseSliceWithTConvPass.cpp b/compiler/luci/pass/src/FuseSliceWithTConvPass.cpp new file mode 100644 index 000000000..981ec1b44 --- /dev/null +++ b/compiler/luci/pass/src/FuseSliceWithTConvPass.cpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseSliceWithTConvPass.h" + +#include +#include +#include +#include + +namespace +{ + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + +inline int32_t compute_input_size(luci::Padding padding, int32_t image_size, int32_t filter_size, + int32_t stride) +{ + switch (padding) + { + case luci::Padding::SAME: + return (image_size + stride - 1) / stride; + case luci::Padding::VALID: + return (image_size + stride - filter_size) / stride; + default: + throw std::runtime_error("Unsupported padding"); + } +} + +inline int32_t extract_pad_value(int32_t stride, int32_t in_size, int32_t filter_size, + int32_t out_size) +{ + const int32_t padding = ((in_size - 1) * stride + filter_size - out_size) / 2; + return padding > 0 ? padding : 0; +} + +inline uint32_t cal_offset(const luci::CircleConst *node, const uint32_t *indices) +{ + return indices[0] * node->dim(1).value() * node->dim(2).value() * node->dim(3).value() + + indices[1] * node->dim(2).value() * node->dim(3).value() + + indices[2] * node->dim(3).value() + indices[3]; +} + +luci::Padding compute_padding(const luci::CircleTransposeConv *tconv, int32_t out_height, + int32_t out_width, int32_t pad_top, int32_t pad_left) +{ + auto const filter = dynamic_cast(tconv->filter()); + if (!filter) + return luci::Padding::UNDEFINED; + + auto tconv_shape = dynamic_cast(tconv->inputSizes()); + if (!tconv_shape) + return luci::Padding::UNDEFINED; + + luci::Padding padding = luci::Padding::UNDEFINED; + std::initializer_list paddings_to_check = {luci::Padding::VALID, + luci::Padding::SAME}; + + auto const filter_height = filter->dim(1).value(); + auto const filter_width = filter->dim(2).value(); + auto const stride_height = tconv->stride()->h(); + auto const stride_width = tconv->stride()->w(); + + for (auto padding_to_check : paddings_to_check) + { + auto const in_height = + compute_input_size(padding_to_check, out_height, filter_height, stride_height); + auto const pad_top_virtual = + extract_pad_value(stride_height, in_height, filter_height, out_height); + if (pad_top_virtual != pad_top) + continue; + + auto const in_width = + compute_input_size(padding_to_check, out_width, filter_width, stride_width); + auto const pad_left_virtual = + extract_pad_value(stride_width, in_width, filter_width, out_width); + if (pad_left_virtual == pad_left) + { + padding = padding_to_check; // correct padding is found + break; + } + } + + return padding; +} + +/** + * Fuse Slice with CircleTransposeConv if possible + * + * NOTE: In case predecessor of slice is tconv, we can try to merge slice with tconv, + * because spatial slice is reduction so as padding for tconv, + * while channels slice reduction can be directly modeled in tconv. + * For now there is no option to set explicitely pad values for + * CircleTransposeConv. Only using VALID/SAME and output shape is the only way + * to set pad values. That is why not all numerical values of pad are legal for such + * transform. + * + * BEFORE + * | + * [CircleTransposeConv] + * | + * [CircleSlice] + * | + * + * AFTER + * | + * [CircleTransposeConv] (with m.b. changed padding, output shape, and filter/bias) + * | + * + */ + +bool fuse_slice_with_tconv(luci::CircleSlice *slice) +{ + // NOTE: assume NHWC layout + auto tconv = dynamic_cast(slice->input()); + RETURN_FALSE_UNLESS(tconv != nullptr); + + // offset + auto begin = dynamic_cast(slice->begin()); + // sanity check + RETURN_FALSE_UNLESS(begin != nullptr && begin->dtype() == loco::DataType::S32 && + begin->rank() == 1); + + // output shape + auto out_shape = dynamic_cast(slice->size()); + // sanity check + RETURN_FALSE_UNLESS(out_shape != nullptr && out_shape->dtype() == loco::DataType::S32 && + out_shape->rank() == 1); + + // output shape of tconv + auto tconv_shape = dynamic_cast(tconv->inputSizes()); + // sanity check + RETURN_FALSE_UNLESS(tconv_shape != nullptr && tconv_shape->dtype() == loco::DataType::S32 && + tconv_shape->rank() == 1); + + // no update if batch dimension is processed in slice + RETURN_FALSE_UNLESS(begin->at(0) == 0 && + out_shape->at(0) == + tconv_shape->at(0)); + + // filter + auto const tconv_filter = dynamic_cast(tconv->filter()); + // sanity check + RETURN_FALSE_UNLESS(tconv_filter != nullptr && tconv_filter->rank() == 4 && + tconv_filter->dtype() == loco::DataType::FLOAT32); + + // bias + auto const tconv_bias = dynamic_cast(tconv->bias()); + // Only support const bias + // TODO Support non-const bias + RETURN_FALSE_UNLESS(tconv_bias != nullptr && tconv_bias->rank() == 1 && + tconv_bias->dtype() == loco::DataType::FLOAT32); + + auto const out_height = out_shape->at(1); + auto const out_width = out_shape->at(2); + + auto const pad_top = begin->at(1); + auto const pad_left = begin->at(2); + + // As there is no option to set numerical values of pad explicitly for CircleTransposeConv + // we need to be sure that interpretation of PADDING + OUTPUT_SHAPE will produce + // the pad values, defined by slice. If possible compute_padding will return correct + // padding value, otherwise it will return UNDEFINED + auto const padding = compute_padding(tconv, out_height, out_width, pad_top, pad_left); + if (padding == luci::Padding::UNDEFINED) + return false; // impossible to fuse + + auto const out_channels = out_shape->at(3); + // update filter and bias in case it's needed + loco::Node *fused_filter = tconv->filter(); + loco::Node *fused_bias = tconv->bias(); + // Channel-direction slice + // Corresponding weights/bias of TConv is sliced. + if (begin->at(3) != 0 || + out_channels != tconv_shape->at(3)) + { + // fused filter + auto const in_channels = tconv_filter->dim(3).value(); + + luci::CircleConst *fused_tconv_filter = luci::clone(tconv_filter); + fused_tconv_filter->dim(0).set(out_channels); // out_channels + // update size due to channels change + fused_tconv_filter->size(out_channels * tconv_filter->dim(1).value() * + tconv_filter->dim(2).value() * in_channels); + auto const ch_offset = begin->at(3); + // set reduced filter values + for (uint32_t out_chan = 0; out_chan < fused_tconv_filter->dim(0).value(); out_chan++) + { + for (uint32_t out_height = 0; out_height < fused_tconv_filter->dim(1).value(); out_height++) + { + for (uint32_t out_width = 0; out_width < fused_tconv_filter->dim(2).value(); out_width++) + { + for (uint32_t in_chan = 0; in_chan < fused_tconv_filter->dim(3).value(); in_chan++) + { + uint32_t indices[4] = {out_chan, out_height, out_width, in_chan}; + uint32_t old_indices[4] = {out_chan + ch_offset, out_height, out_width, in_chan}; + auto const data = + tconv_filter->at(cal_offset(tconv_filter, old_indices)); + fused_tconv_filter->at( + cal_offset(fused_tconv_filter, indices)) = data; + } + } + } + } + fused_tconv_filter->name(tconv_filter->name() + "/FusedSlice"); + luci::add_origin(fused_tconv_filter, luci::get_origin(tconv_shape)); + fused_filter = fused_tconv_filter; + + // fused bias + luci::CircleConst *fused_tconv_bias = luci::clone(tconv_bias); + fused_tconv_bias->size(out_channels); + fused_tconv_bias->dim(0).set(out_channels); // out_channels + // set reduced bias values + for (int32_t c = 0; c < out_channels; c++) + { + auto const data = tconv_bias->at(c + ch_offset); + fused_tconv_bias->at(c) = data; + } + + fused_tconv_bias->name(tconv_bias->name() + "/FusedSlice"); + luci::add_origin(fused_tconv_bias, luci::get_origin(tconv_bias)); + fused_bias = fused_tconv_bias; + } + + auto *fused_tconv_shape = luci::clone(tconv_shape); + // spatial dimensions + fused_tconv_shape->at(1) = out_height; + fused_tconv_shape->at(2) = out_width; + // channels + fused_tconv_shape->at(3) = out_channels; + fused_tconv_shape->name(tconv_shape->name() + "/FusedSlice"); + luci::add_origin(fused_tconv_shape, luci::get_origin(tconv_shape)); + + // Configure new CircleTransposeConv operation. + auto *fused_tconv = + loco::must_cast(luci::clone_node(tconv, slice->graph())); + fused_tconv->inputSizes(fused_tconv_shape); + fused_tconv->outBackprop(tconv->outBackprop()); + fused_tconv->filter(fused_filter); + fused_tconv->bias(fused_bias); + fused_tconv->padding(padding); + fused_tconv->name(tconv->name() + "/FusedSlice"); + luci::add_origin(fused_tconv, + luci::composite_origin({luci::get_origin(tconv), luci::get_origin(slice)})); + + // Replace old slice operation with new fused_tconv with merged pad values + replace(slice).with(fused_tconv); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseSliceWithTConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto slice = dynamic_cast(node); + if (not slice) + continue; + + if (fuse_slice_with_tconv(slice)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseSliceWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseSliceWithTConvPass.test.cpp new file mode 100644 index 000000000..71372d8ac --- /dev/null +++ b/compiler/luci/pass/src/FuseSliceWithTConvPass.test.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseSliceWithTConvPass.h" + +#include + +#include + +#include + +namespace +{ + +/** + * TConv->Slice graph for test + * + * [CircleInput] + * | + * | + * [CircleTransposeConv] + * | + * | + * [CircleSlice] + * | + * | + * [CircleOutput] + */ +struct TConvSliceGraph : public luci::test::TestIOGraph +{ + luci::CircleTransposeConv *_tconv = nullptr; + luci::CircleSlice *_slice = nullptr; + luci::CircleConst *_filter = nullptr; + luci::CircleConst *_bias = nullptr; + luci::CircleConst *_tconv_shape = nullptr; + luci::CircleConst *_slice_offset = nullptr; + luci::CircleConst *_slice_size = nullptr; + + TConvSliceGraph(uint32_t h, uint32_t w, uint32_t pads[4]) + { + // pads={pad_top, pad_bottom, pad_left, pad_right} + uint32_t channels = 32; + uint32_t k_h = 3, k_w = 3; + auto const tconv_h = (h - 1) * 2 + k_h; + auto const tconv_w = (w - 1) * 2 + k_w; + auto const out_h = tconv_h - pads[0] - pads[1]; + auto const out_w = tconv_w - pads[2] - pads[3]; + + // graph input and output + TestIOGraph::init({1, h, w, channels}, {1, out_h, out_w, channels}); + + _filter = g()->nodes()->create(); + _filter->dtype(loco::DataType::FLOAT32); + _filter->rank(4); + _filter->shape({channels, k_h, k_w, channels}); + _filter->shape_status(luci::ShapeStatus::VALID); + _filter->size(channels * k_h * k_w * channels); + _filter->name("filter"); + + _bias = g()->nodes()->create(); + _bias->dtype(loco::DataType::FLOAT32); + _bias->rank(1); + _bias->shape({channels}); + _bias->shape_status(luci::ShapeStatus::VALID); + _bias->size(channels); + _bias->name("bias"); + + _tconv_shape = g()->nodes()->create(); + _tconv_shape->dtype(loco::DataType::S32); + _tconv_shape->rank(1); + _tconv_shape->shape({4}); + _tconv_shape->shape_status(luci::ShapeStatus::VALID); + _tconv_shape->size(4); + _tconv_shape->at(0) = 1; + _tconv_shape->at(3) = channels; + _tconv_shape->at(1) = tconv_h; + _tconv_shape->at(2) = tconv_w; + _tconv_shape->name("tconv_shape"); + + _tconv = g()->nodes()->create(); + _tconv->filter(_filter); + _tconv->bias(_bias); + _tconv->inputSizes(_tconv_shape); + _tconv->outBackprop(input()); + _tconv->fusedActivationFunction(luci::FusedActFunc::NONE); + _tconv->dtype(loco::DataType::FLOAT32); + _tconv->padding(luci::Padding::VALID); + _tconv->stride()->h(2); + _tconv->stride()->w(2); + _tconv->name("tconv"); + + // offset to be used in slice + _slice_offset = g()->nodes()->create(); + _slice_offset->dtype(loco::DataType::S32); + _slice_offset->rank(1); + _slice_offset->shape({4}); + _slice_offset->shape_status(luci::ShapeStatus::VALID); + _slice_offset->size(4); + _slice_offset->at(0) = 0; + _slice_offset->at(3) = 0; + _slice_offset->at(1) = pads[0]; + _slice_offset->at(2) = pads[2]; + _slice_offset->name("slice_offset"); + + _slice_size = g()->nodes()->create(); + _slice_size->dtype(loco::DataType::S32); + _slice_size->rank(1); + _slice_size->shape({4}); + _slice_size->shape_status(luci::ShapeStatus::VALID); + _slice_size->size(4); + _slice_size->at(0) = 1; + _slice_size->at(3) = channels; + _slice_size->at(1) = out_h; + _slice_size->at(2) = out_w; + _slice_size->name("slice_size"); + + _slice = g()->nodes()->create(); + _slice->begin(_slice_offset); + _slice->size(_slice_size); + _slice->input(_tconv); + _slice->name("slice"); + + output()->from(_slice); + } +}; + +} // namespace + +TEST(FuseSliceWithTConvPassTest, simple_test) +{ + /** + * tests: + * 1) fusion pass has nonnull name + * 2) fusion runs successfully for float32 TConvSlice graph + * 3) resulting graph has the following structure: + * + * [CircleTransposeConv] (with output_shape = shape_of_the_slice) + * | + * | + * [Output] + */ + luci::FuseSliceWithTConvPass pass; + uint32_t pads[4] = {0, 2, 0, 2}; + uint32_t h = 8, w = 8; + TConvSliceGraph graph(h, w, pads); + auto const out_h = graph._slice_size->at(1); + auto const out_w = graph._slice_size->at(2); + + auto const name = pass.name(); + ASSERT_NE(nullptr, name); + + auto ret = pass.run(graph.g()); + EXPECT_TRUE(ret); + + auto const fused_tconv = dynamic_cast(graph.output()->from()); + EXPECT_NE(nullptr, fused_tconv); + + EXPECT_EQ(luci::Padding::VALID, fused_tconv->padding()); + + auto const out_size = dynamic_cast(fused_tconv->inputSizes()); + EXPECT_NE(nullptr, out_size); + EXPECT_EQ(out_h, out_size->at(1)); // h + EXPECT_EQ(out_w, out_size->at(2)); // 2 +} + +TEST(FuseSliceWithTConvPassTest, wrong_condition_NEG) +{ + luci::FuseSliceWithTConvPass pass; + uint32_t pads[4] = {3, 3, 3, 3}; // no fusion is possible with these pads + TConvSliceGraph graph(8, 8, pads); + + auto ret = pass.run(graph.g()); + EXPECT_FALSE(ret); +} diff --git a/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.cpp b/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.cpp new file mode 100644 index 000000000..b10071b08 --- /dev/null +++ b/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.cpp @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "InsertQuantizeOpOnDTypeMismatch.h" +#include "QuantizationUtils.h" + +#include +#include + +#include // std::numeric_limits + +using namespace luci; + +namespace +{ + +// Update u8 node to i16 +// Qparam of i16 is inferred from the qparam of u8 +void update_u8_to_i16(luci::CircleNode *node) +{ + assert(node->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + + node->dtype(loco::DataType::S16); + + auto qparam = node->quantparam(); + assert(qparam); + assert(qparam->scale.size() == 1); + assert(qparam->zerop.size() == 1); + + auto u8_scale = qparam->scale[0]; + auto u8_zerop = qparam->zerop[0]; + + auto min = u8_scale * (-u8_zerop); + auto max = u8_scale * (255 - u8_zerop); + + float s16_scale{0}; + int64_t s16_zerop{0}; + float nudged_min{0}; + float nudged_max{0}; + + compute_sym_scale(min, max, s16_scale, nudged_min, nudged_max); + + auto quantparam = std::make_unique(); + quantparam->scale.push_back(s16_scale); + quantparam->zerop.push_back(s16_zerop); + + node->quantparam(std::move(quantparam)); +} + +// Update i16 node to u8 node +// Qparam of u8 is inferred from the qparam of i16 +void update_i16_to_u8(luci::CircleNode *node) +{ + assert(node->dtype() == loco::DataType::S16); // FIX_CALLER_UNLESS + + node->dtype(loco::DataType::U8); + + auto qparam = node->quantparam(); + assert(qparam); + assert(qparam->scale.size() == 1); + assert(qparam->zerop.size() == 1); + + auto s16_scale = qparam->scale[0]; + assert(qparam->zerop[0] == 0); + + auto max = s16_scale * std::numeric_limits::max(); + auto min = -max; + + float u8_scale{0}; + int64_t u8_zerop{0}; + float nudged_min{0}; + float nudged_max{0}; + + compute_asym_scale_zp(min, max, u8_scale, u8_zerop, nudged_min, nudged_max); + + auto quantparam = std::make_unique(); + quantparam->scale.push_back(u8_scale); + quantparam->zerop.push_back(u8_zerop); + + node->quantparam(std::move(quantparam)); +} + +// Create a Quantize Op which has the same +// dtype, shape, and qparam with node +luci::CircleQuantize *create_quantize_op(luci::CircleNode *node) +{ + auto quantize = node->graph()->nodes()->create(); + quantize->name(node->name() + "_Quantize"); + quantize->dtype(node->dtype()); + quantize->rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + quantize->dim(i).set(node->dim(i).value()); + + quantize->shape_status(luci::ShapeStatus::VALID); + + assert(node->quantparam()); // FIX_CALLER_UNLESS + copy_quantparam(node, quantize); + + luci::add_origin(quantize, luci::get_origin(node)); + + return quantize; +} + +} // namespace + +namespace luci +{ + +void InsertQuantizeOpOnDTypeMismatch::visit(luci::CircleFullyConnected *node) +{ + auto input = loco::must_cast(node->input()); + + // Input dtype == Output dtype. No problem + if (input->dtype() == node->dtype()) + return; + + // Skip if node has bias + if (dynamic_cast(node->bias()) == nullptr) + return; + + if (node->fusedActivationFunction() != luci::FusedActFunc::NONE) + return; + + // Only cares quantized case + if (not is_quantized(input)) + return; + + if (not is_quantized(node)) + return; + + // Let's support limited case + // TODO Extend this to another dtype + if (input->dtype() != loco::DataType::U8) + return; + + if (node->dtype() != loco::DataType::S16) + return; + + // Create Quantize Op + auto quant_op = create_quantize_op(node); + + // Insert Quantize Op after node + loco::replace(node).with(quant_op); + quant_op->input(node); + + // Update node's dtype and qparam from i16 to u8 + // NOTE This would severely degrade accuracy. It is + // important to mitigate this accuracy drop in backend. + update_i16_to_u8(node); +} + +void InsertQuantizeOpOnDTypeMismatch::visit(luci::CircleMul *node) +{ + auto x = loco::must_cast(node->x()); + auto y = loco::must_cast(node->y()); + + assert(x->dtype() == y->dtype()); // FIX_CALLER_UNLESS + + // Ignore invalid dtype + if (x->dtype() != y->dtype()) + return; + + if (node->fusedActivationFunction() != luci::FusedActFunc::NONE) + return; + + // Input dtype == Output dtype. No problem + if (x->dtype() == node->dtype()) + return; + + // Only cares quantized case + if (not is_quantized(x)) + return; + + if (not is_quantized(y)) + return; + + if (not is_quantized(node)) + return; + + // Let's support limited case + // TODO Extend this to another dtype + if (x->dtype() != loco::DataType::S16) + return; + + if (node->dtype() != loco::DataType::U8) + return; + + // Create Quantize Op + auto quant_op = create_quantize_op(node); + + // Insert Quantize Op after node + loco::replace(node).with(quant_op); + quant_op->input(node); + + // Update node's dtype and qparam from u8 to i16 + update_u8_to_i16(node); +} + +void InsertQuantizeOpOnDTypeMismatch::visit(luci::CircleBatchMatMul *node) +{ + auto x = loco::must_cast(node->x()); + auto y = loco::must_cast(node->y()); + + assert(x->dtype() == y->dtype()); // FIX_CALLER_UNLESS + + // Ignore invalid dtype + if (x->dtype() != y->dtype()) + return; + + if (node->adj_x() or node->adj_y()) + return; + + // Input dtype == Output dtype. No problem + if (x->dtype() == node->dtype()) + return; + + // Only cares quantized case + if (not is_quantized(x)) + return; + + if (not is_quantized(y)) + return; + + if (not is_quantized(node)) + return; + + // Let's support limited case + // TODO Extend this to another dtype + if (x->dtype() != loco::DataType::S16) + return; + + if (node->dtype() != loco::DataType::U8) + return; + + // Create Quantize Op + auto quant_op = create_quantize_op(node); + + // Insert Quantize Op after node + loco::replace(node).with(quant_op); + quant_op->input(node); + + // Update node's dtype and qparam from i16 to u8 + update_u8_to_i16(node); +} + +} // namespace luci diff --git a/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.h b/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.h new file mode 100644 index 000000000..e4fe231b4 --- /dev/null +++ b/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_INSERT_QUANTIZE_OP_ON_DTYPE_MISMATCH_H__ +#define __LUCI_INSERT_QUANTIZE_OP_ON_DTYPE_MISMATCH_H__ + +#include + +namespace luci +{ + +struct InsertQuantizeOpOnDTypeMismatch final : public luci::CircleNodeMutableVisitor +{ + InsertQuantizeOpOnDTypeMismatch() = default; + +private: + void visit(luci::CircleNode *) {} + + void visit(luci::CircleFullyConnected *node); + void visit(luci::CircleMul *node); + void visit(luci::CircleBatchMatMul *node); + + // TODO Support more operators +}; + +} // namespace luci + +#endif // __LUCI_INSERT_QUANTIZE_OP_ON_DTYPE_MISMATCH_H__ diff --git a/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.test.cpp b/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.test.cpp new file mode 100644 index 000000000..0d2ec5244 --- /dev/null +++ b/compiler/luci/pass/src/InsertQuantizeOpOnDTypeMismatch.test.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "InsertQuantizeOpOnDTypeMismatch.h" +#include "PassTestGraphs.h" + +#include + +namespace +{ + +using namespace luci::test; + +std::unique_ptr gen_qparam(float s, int64_t zp) +{ + auto qparam = std::make_unique(); + { + qparam->scale.push_back(s); + qparam->zerop.push_back(zp); + } + + return std::move(qparam); +} + +/** + * Mul graph for test + * + * BEFORE + * + * [Input(s16)] [Const(s16)] + * \ / + * [Mul(u8)] + * | + * [Output(u8)] + * + * AFTER + * + * [Input(s16)] [Const(s16)] + * \ / + * [Mul(s16)] + * | + * [Quantize(u8)] + * | + * [Output(u8)] + */ +class MulGraphlet +{ +public: + MulGraphlet() = default; + + void init(loco::Graph *g) + { + _mul = g->nodes()->create(); + _const = g->nodes()->create(); + + _mul->dtype(loco::DataType::U8); + _const->dtype(loco::DataType::S16); + + _mul->quantparam(std::move(gen_qparam(1, 0))); + _const->quantparam(std::move(gen_qparam(1, 0))); + + _mul->shape({2, 2, 2}); + + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + + _mul->name("mul"); + _const->name("const"); + } + +public: + luci::CircleMul *mul(void) { return _mul; } + +protected: + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_const = nullptr; +}; + +class DtypeMisMatchMulTestGraph : public TestIOGraph, public MulGraphlet +{ +public: + void init(void) + { + TestIOGraph::init({2, 2, 2}, {2, 2, 2}); + + input()->dtype(loco::DataType::S16); + output()->dtype(loco::DataType::U8); + + input()->quantparam(std::move(gen_qparam(1, 0))); + output()->quantparam(std::move(gen_qparam(1, 0))); + + MulGraphlet::init(g()); + + _mul->x(input()); + _mul->y(_const); + + output()->from(_mul); + } +}; + +} // namespace + +TEST(InsertQuantizeOpOnDTypeMismatchTest, mul) +{ + DtypeMisMatchMulTestGraph g; + + luci::InsertQuantizeOpOnDTypeMismatch visitor; + + g.init(); + + auto node = g.mul(); + node->accept(&visitor); + + // Quantize Op is created + EXPECT_NE(nullptr, dynamic_cast(g.output()->from())); + + // Mul's dtype is changed from U8 to S16 + EXPECT_EQ(loco::DataType::S16, g.mul()->dtype()); +} + +TEST(InsertQuantizeOpOnDTypeMismatchTest, mul_dtype_match_NEG) +{ + DtypeMisMatchMulTestGraph g; + + luci::InsertQuantizeOpOnDTypeMismatch visitor; + + g.init(); + + auto node = g.mul(); + node->dtype(loco::DataType::S16); + + node->accept(&visitor); + + // Quantize Op is not created + EXPECT_EQ(nullptr, dynamic_cast(g.output()->from())); +} diff --git a/compiler/luci/pass/src/PassTestGraphs.h b/compiler/luci/pass/src/PassTestGraphs.h index f5ae24f0b..f0b0f4ec9 100644 --- a/compiler/luci/pass/src/PassTestGraphs.h +++ b/compiler/luci/pass/src/PassTestGraphs.h @@ -20,6 +20,9 @@ #include #include +#include +#include + namespace luci { @@ -137,6 +140,44 @@ protected: luci::CircleAdd *_add = nullptr; }; +/** + * CommonSubExpressionEliminationTestGraph is a base class for testing + * common subexpression elimination pass. It creates Input and Output + * in the below graph. Child classes must implement Expression. + * + * [Input] + * / \ + * [Expression] [Expression] + * | | + * [Output 1] [Output 2] + * + * Expression should satisfy the below conditions + * - Input type == Output type + * - Input shape == Output shape + * - Expression 1 and 2 are semantically equal + */ +class CommonSubExpressionEliminationTestGraph : public test::TestIsGraphlet<1>, + public test::TestOsGraphlet<2> +{ +public: + virtual void init(const std::initializer_list shape_in, + const std::initializer_list shape_out) + { + test::TestIsGraphlet<1>::init(g(), shape_in); + test::TestOsGraphlet<2>::init(g(), shape_out); + + auto expr1 = createExpression(input(0), "expr1"); + auto expr2 = createExpression(input(0), "expr2"); + + output(0)->from(expr1); + output(1)->from(expr2); + } + + virtual ~CommonSubExpressionEliminationTestGraph() = default; + + virtual loco::Node *createExpression(luci::CircleNode *ifm, const std::string &name) = 0; +}; + } // namespace luci #endif // __LUCI_PASS_TEST_GRAPHS_H__ diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp index aaadb2864..b7c5ee231 100644 --- a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp +++ b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp @@ -141,7 +141,7 @@ struct PropagateQParamForward final : public luci::CircleNodeMutableVisitorquantparam(luci::make_predefined_qparam(qtype, node->dtype())); + node->quantparam(luci::make_predefined_qparam(qtype, node->dtype(), node->quantparam())); break; case luci::ActivationQType::IntScale: luci::set_int_scale(node); diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index 3e3cdde34..33d75ce04 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -101,14 +101,26 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float } } +int32_t max_for_sym_quant(const loco::DataType &type) +{ + if (type == loco::DataType::S4) + return std::numeric_limits::max() >> 4; + else if (type == loco::DataType::S8) + return std::numeric_limits::max(); + else if (type == loco::DataType::S16) + return std::numeric_limits::max(); + else + throw std::runtime_error("Unsupported dtype for symmetric quantization"); +}; + void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min, float &nudged_max, loco::DataType out_type) { assert(min <= max); - assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16); + assert(out_type == loco::DataType::S4 || out_type == loco::DataType::S8 || + out_type == loco::DataType::S16); - const int32_t kMaxScale = (out_type == loco::DataType::S16) ? std::numeric_limits::max() - : std::numeric_limits::max(); + const int32_t kMaxScale = max_for_sym_quant(out_type); const int32_t kMinScale = -kMaxScale; const double qmin_double = kMinScale; const double qmax_double = kMaxScale; @@ -343,13 +355,16 @@ ActivationQType activation_qtype(const CircleNode *node) } std::unique_ptr make_predefined_qparam(ActivationQType qtype, - loco::DataType dtype) + loco::DataType dtype, + CircleQuantParam *old_quant_param) { auto qparam = std::make_unique(); - auto set_qparam = [&qparam](float scale, int64_t zp) { + auto set_qparam = [&qparam, old_quant_param](float scale, int64_t zp) { qparam->scale.emplace_back(scale); qparam->zerop.emplace_back(zp); + qparam->min = old_quant_param->min; + qparam->max = old_quant_param->max; }; switch (qtype) @@ -435,6 +450,12 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type) auto quantparam = std::make_unique(); quantparam->scale.push_back(scaling_factor); quantparam->zerop.push_back(zp); + // Copy min and max values if it exists + if (node->quantparam()) + { + quantparam->min = node->quantparam()->min; + quantparam->max = node->quantparam()->max; + } node->quantparam(std::move(quantparam)); } @@ -508,4 +529,32 @@ void warn_accuracy_with_range(luci::CircleNode *n) } } +bool is_onnx_dequantize_linear(const luci::CircleCustom *node) +{ + if (node->numInputs() != 3) + return false; + + if (node->numOutputs() != 1) + return false; + + if (node->custom_code() != "ONNXDequantizeLinear") + return false; + + return true; +} + +bool is_onnx_quantize_linear(const luci::CircleCustom *node) +{ + if (node->numInputs() != 3) + return false; + + if (node->numOutputs() != 1) + return false; + + if (node->custom_code() != "ONNXQuantizeLinear") + return false; + + return true; +} + } // namespace luci diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h index 93c4045b5..0bf3270d5 100644 --- a/compiler/luci/pass/src/QuantizationUtils.h +++ b/compiler/luci/pass/src/QuantizationUtils.h @@ -23,6 +23,9 @@ namespace luci { +// Return the max value of dtype for symmetric quantization (int4/int8/int16) +int32_t max_for_sym_quant(const loco::DataType &type); + // Compute scale using given min/max for symmetric quantization (int8/int16) void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min, float &nudged_max, loco::DataType out_type = loco::DataType::S16); @@ -77,7 +80,8 @@ ActivationQType activation_qtype(const CircleNode *node); // Create qparam with pre-defined values for speical operators std::unique_ptr make_predefined_qparam(CircleNode *node, loco::DataType dtype); std::unique_ptr make_predefined_qparam(ActivationQType qtype, - loco::DataType dtype); + loco::DataType dtype, + CircleQuantParam *old_quant_param); // Update node's scale to a positive integer (for special Ops e.g., Floor, Ceil) void set_int_scale(luci::CircleNode *node); @@ -89,6 +93,12 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type); // Emits warnings to log with WARN void warn_accuracy_with_range(luci::CircleNode *n); +// Return true if the node is OnnxDequantizeLinear +bool is_onnx_dequantize_linear(const luci::CircleCustom *node); + +// Return true if the node is OnnxQuantizeLinear +bool is_onnx_quantize_linear(const luci::CircleCustom *node); + } // namespace luci #endif // __LUCI_QUANTIZATION_UTILS_H__ diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp index 913450083..8d953989e 100644 --- a/compiler/luci/pass/src/QuantizeActivation.cpp +++ b/compiler/luci/pass/src/QuantizeActivation.cpp @@ -110,26 +110,30 @@ void QuantizeSpecialActivation::visit(luci::CircleNode *node) auto fused_act_node = dynamic_cast *>(node); if (fused_act_node != nullptr && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH) { - auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type); + auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type, + node->quantparam()); node->quantparam(std::move(qparam)); } } void QuantizeSpecialActivation::visit(luci::CircleLogistic *node) { - auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedLogistic, output_type); + auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedLogistic, output_type, + node->quantparam()); node->quantparam(std::move(qparam)); } void QuantizeSpecialActivation::visit(luci::CircleTanh *node) { - auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type); + auto qparam = + make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type, node->quantparam()); node->quantparam(std::move(qparam)); } void QuantizeSpecialActivation::visit(luci::CircleSoftmax *node) { - auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedSoftmax, output_type); + auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedSoftmax, output_type, + node->quantparam()); node->quantparam(std::move(qparam)); } @@ -271,6 +275,7 @@ QUANTIZE_TWO_CONST_INPUTS(luci::CircleMinimum, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleMul, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleNotEqual, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CirclePow, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleSelectV2, t, e) QUANTIZE_TWO_CONST_INPUTS(luci::CircleSub, x, y) // AddN has arbitrary number of inputs diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h index ba3bc59f2..162ec2c66 100644 --- a/compiler/luci/pass/src/QuantizeActivation.h +++ b/compiler/luci/pass/src/QuantizeActivation.h @@ -27,12 +27,8 @@ namespace luci */ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor { - QuantizeActivation(loco::DataType input, loco::DataType output) - : input_type(input), output_type(output) - { - } + QuantizeActivation(loco::DataType output) : output_type(output) {} - loco::DataType input_type; loco::DataType output_type; // Quantize each node using recorded min/max @@ -44,12 +40,8 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor */ struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor { - QuantizeSpecialActivation(loco::DataType input, loco::DataType output) - : input_type(input), output_type(output) - { - } + QuantizeSpecialActivation(loco::DataType output) : output_type(output) {} - loco::DataType input_type; loco::DataType output_type; void visit(luci::CircleNode *node); @@ -158,6 +150,7 @@ private: void visit(luci::CircleMul *node); void visit(luci::CircleNotEqual *node); void visit(luci::CirclePow *node); + void visit(luci::CircleSelectV2 *node); void visit(luci::CircleSub *node); // AddN has arbitrary number of inputs diff --git a/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.cpp b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.cpp new file mode 100644 index 000000000..9194af031 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.cpp @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeOnnxDequantizeLinearPass.h" +#include "QuantizationUtils.h" + +#include +#include + +#include + +namespace +{ + +using namespace luci; + +// Return true if all values of node are within value_range +// value_range: [min, max] +template +bool value_range(const luci::CircleConst *node, const std::pair &value_range) +{ + const auto min = value_range.first; + const auto max = value_range.second; + + auto size = node->size
(); + for (uint32_t i = 0; i < size; i++) + { + const auto val = static_cast(node->at
(i)); + if (val < min or val > max) + return false; + } + + return true; +} + +std::vector get_scales(const luci::CircleConst *node) +{ + assert(node); // FIX_CALLER_UNLESS + + const auto num_scales = node->size(); + std::vector scales(num_scales); + for (uint32_t i = 0; i < num_scales; ++i) + { + scales[i] = node->at(i); + } + + return scales; +} + +template std::vector get_zerops(const luci::CircleConst *node) +{ + assert(node); // FIX_CALLER_UNLESS + + const auto num_zerops = node->size
(); + std::vector zerops(num_zerops); + for (uint32_t i = 0; i < num_zerops; ++i) + { + zerops[i] = node->at
(i); + } + + return zerops; +} + +int32_t get_axis(const luci::CircleCustom *node) +{ + assert(node); // FIX_CALLER_UNLESS + + const auto custom_options = node->custom_options(); + const auto map = flexbuffers::GetRoot(custom_options).AsMap(); + + return map["axis"].IsNull() ? 0 : map["axis"].AsInt32(); +} + +class OnnxDequantizeLinearPattern final +{ +public: + OnnxDequantizeLinearPattern(luci::CircleCustomOut *candidate) { custom_out = candidate; } + +public: + bool matched() + { + if (not custom_out) + return false; + + dequantize = loco::must_cast(custom_out->input()); + if (not is_onnx_dequantize_linear(dequantize)) + return false; + + input = dynamic_cast(dequantize->inputs(0)); + if (not input) + return false; + + scale = dynamic_cast(dequantize->inputs(1)); + if (not scale) + return false; + + zerop = dynamic_cast(dequantize->inputs(2)); + if (not zerop) + return false; + + const auto input_dtype = input->dtype(); + const auto scale_dtype = scale->dtype(); + const auto zerop_dtype = zerop->dtype(); + + if (scale_dtype != loco::DataType::FLOAT32) + return false; + + // Invariant from onnx DequantizeLinear operator + if (input_dtype != zerop_dtype) + return false; + + return true; + } + +public: + luci::CircleCustomOut *custom_out = nullptr; + luci::CircleCustom *dequantize = nullptr; + luci::CircleConst *input = nullptr; + luci::CircleConst *scale = nullptr; + luci::CircleConst *zerop = nullptr; +}; + +class QuantizeOnnxDequantizeLinear final +{ +public: + QuantizeOnnxDequantizeLinear(const OnnxDequantizeLinearPattern &p) : _p(p) {} + +public: + void apply(void) + { + // The final const's dtype is the same with input_dtype by default + auto const_dtype = _p.input->dtype(); + if (const_dtype == loco::DataType::U8) + { + // Onnx does not support int4/uint4 as of writing. We assume uint8 + // tensor is quantized in int4/uint4 if values are within [0,15] + if (value_range(_p.input, {0, 15})) + { + if (value_range(_p.zerop, {8, 8})) + { + const_dtype = loco::DataType::S4; + } + else if (value_range(_p.zerop, {0, 15})) + { + const_dtype = loco::DataType::U4; + } + } + } + + luci::CircleConst *quant_const = nullptr; + switch (const_dtype) + { + case loco::DataType::S4: + quant_const = gen_s4_quant(); + break; + case loco::DataType::U4: + quant_const = gen_u4_quant(); + break; + case loco::DataType::U8: + quant_const = gen_u8_quant(); + break; + case loco::DataType::S16: + quant_const = gen_s16_quant(); + break; + default: + throw std::runtime_error("Unsupported quantized dtype"); + } + + assert(quant_const); // FIX_ME_UNLESS + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p.dequantize), luci::get_origin(_p.input), luci::get_origin(_p.scale), + luci::get_origin(_p.zerop)}; + + luci::add_origin(quant_const, luci::composite_origin(origin_vec)); + + replace(_p.custom_out).with(quant_const); + } + +private: + luci::CircleConst *gen_s4_quant(void) + { + assert(_p.input->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::S4); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create S4 CircleConst + // NOTE S4 is saved as S8 in luci::CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const uint8_t u8_val = _p.input->at(i); + assert(u8_val <= 15); // FIX_CALLER_UNLESS + quantized_node->at(i) = static_cast(u8_val) - 8; + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + + luci::CircleConst *gen_u4_quant(void) + { + assert(_p.input->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::U4); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create U4 CircleConst + // NOTE U4 is saved as U8 in luci::CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const uint8_t u8_val = _p.input->at(i); + assert(u8_val <= 15); // FIX_CALLER_UNLESS + quantized_node->at(i) = u8_val; + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + + luci::CircleConst *gen_u8_quant(void) + { + assert(_p.input->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::U8); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create U8 CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const uint8_t u8_val = _p.input->at(i); + quantized_node->at(i) = u8_val; + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + + luci::CircleConst *gen_s16_quant(void) + { + assert(_p.input->dtype() == loco::DataType::S16); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::S16); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::S16); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create S16 CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const int16_t s16_val = _p.input->at(i); + quantized_node->at(i) = s16_val; + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + +private: + const OnnxDequantizeLinearPattern &_p; +}; + +} // namespace + +namespace luci +{ + +/** + * + * Quantize pattern + * + * [Before] + * + * [CircleConst(quantized)] + * | + * [CircleCustom(OnnxDequantizeLinear)] + * | + * [CircleNode] + * + * [After] + * + * [CircleConst(quantized)] + * | + * [CircleNode] + */ +bool QuantizeOnnxDequantizeLinearPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto circle_custom_out = dynamic_cast(node)) + { + OnnxDequantizeLinearPattern p(circle_custom_out); + if (p.matched()) + { + QuantizeOnnxDequantizeLinear quantize(p); + quantize.apply(); + changed = true; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.h b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.h new file mode 100644 index 000000000..17436672b --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_ONNX_DEQUANTIZE_LINEAR_PASS_H__ +#define __LUCI_QUANTIZE_ONNX_DEQUANTIZE_LINEAR_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to quantize ONNXDequantizeLinear operator + * + */ +struct QuantizeOnnxDequantizeLinearPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::QuantizeOnnxDequantizeLinear"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZE_ONNX_DEQUANTIZE_LINEAR_PASS_H__ diff --git a/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.test.cpp b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.test.cpp new file mode 100644 index 000000000..bd409c9c0 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.test.cpp @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeOnnxDequantizeLinearPass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +template +class QuantizeOnnxDequantizeLinearTest : public luci::ConstantFoldingAddTestGraph, + public ::testing::Test +{ +public: + QuantizeOnnxDequantizeLinearTest() : luci::ConstantFoldingAddTestGraph({2, 2, 2}, DT) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + _dequantize = _g.nodes()->template create(3, 1); + _dequantize_out = _g.nodes()->template create(); + _input = _g.nodes()->template create(); + _scale = _g.nodes()->template create(); + _zerop = _g.nodes()->template create(); + + _dequantize->dtype(loco::DataType::FLOAT32); + _dequantize_out->dtype(loco::DataType::FLOAT32); + _input->dtype(DT); + _scale->dtype(loco::DataType::FLOAT32); + _zerop->dtype(DT); + + _input->shape({2, 2, 2}); + _scale->shape({2}); + _zerop->shape({2}); + + _input->size
(8); + + _scale->size(2); + _scale->at(0) = 5.0; + _scale->at(1) = 10.0; + + _zerop->size
(2); + + // custom option + auto flex_buffers = std::make_unique(); + size_t map_start = flex_buffers->StartMap(); + flex_buffers->Int("axis", 1); + flex_buffers->EndMap(map_start); + flex_buffers->Finish(); + + _dequantize->inputs(0, _input); + _dequantize->inputs(1, _scale); + _dequantize->inputs(2, _zerop); + _dequantize->custom_code("ONNXDequantizeLinear"); + _dequantize->custom_options(flex_buffers->GetBuffer()); + + _dequantize_out->input(_dequantize); + _dequantize_out->index(0); + + _input->name("input"); + _dequantize->name("dequantize"); + _dequantize_out->name("dequantize_out"); + + return _dequantize_out; + } + + void createNotQuantizablePattern() { _input->dtype(loco::DataType::FLOAT32); } + +protected: + luci::CircleCustom *_dequantize = nullptr; + luci::CircleCustomOut *_dequantize_out = nullptr; + luci::CircleConst *_input = nullptr; + luci::CircleConst *_scale = nullptr; + luci::CircleConst *_zerop = nullptr; +}; + +class S4QuantizeOnnxDequantizeLinearTest + : public QuantizeOnnxDequantizeLinearTest +{ + virtual void SetUp() override + { + init(); + + // Input range [0, 15] + for (uint32_t i = 0; i < _input->size(); i++) + { + _input->at(i) = 1; + } + + // Zerop = 8 + for (uint32_t i = 0; i < _zerop->size(); i++) + { + _zerop->at(i) = 8; + } + } +}; + +class U4QuantizeOnnxDequantizeLinearTest + : public QuantizeOnnxDequantizeLinearTest +{ + virtual void SetUp() override + { + init(); + + // Input range [0, 15] + for (uint32_t i = 0; i < _input->size(); i++) + { + _input->at(i) = 1; + } + + // Zerop = [0, 15] + for (uint32_t i = 0; i < _zerop->size(); i++) + { + _zerop->at(i) = 1; + } + } +}; + +class U8QuantizeOnnxDequantizeLinearTest + : public QuantizeOnnxDequantizeLinearTest +{ + virtual void SetUp() override + { + init(); + + // Input range [0, 255] + for (uint32_t i = 0; i < _input->size(); i++) + { + _input->at(i) = 255; + } + + // Zerop = [0, 255] + for (uint32_t i = 0; i < _zerop->size(); i++) + { + _zerop->at(i) = 128; + } + } +}; + +class S16QuantizeOnnxDequantizeLinearTest + : public QuantizeOnnxDequantizeLinearTest +{ + virtual void SetUp() override + { + init(); + + for (uint32_t i = 0; i < _input->size(); i++) + { + _input->at(i) = 1024; + } + + for (uint32_t i = 0; i < _zerop->size(); i++) + { + _zerop->at(i) = 0; + } + } +}; + +} // namespace + +TEST_F(S4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic) +{ + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(loco::DataType::S4, folded_const->dtype()); +} + +TEST_F(S4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic_NEG) +{ + createNotQuantizablePattern(); + + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(U4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic) +{ + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(loco::DataType::U4, folded_const->dtype()); +} + +TEST_F(U4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic_NEG) +{ + createNotQuantizablePattern(); + + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(U8QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic) +{ + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(loco::DataType::U8, folded_const->dtype()); +} + +TEST_F(U8QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic_NEG) +{ + createNotQuantizablePattern(); + + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(S16QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic) +{ + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(loco::DataType::S16, folded_const->dtype()); +} + +TEST_F(S16QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic_NEG) +{ + createNotQuantizablePattern(); + + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} diff --git a/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp new file mode 100644 index 000000000..face706b2 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizeOnnxFakeQuantModelPass.h" +#include "QuantizeOnnxQDQPass.h" +#include "QuantizeOnnxDequantizeLinearPass.h" +#include "QuantizeWithPredecessorPass.h" +#include "InsertQuantizeOpOnDTypeMismatch.h" +#include "QuantizeActivation.h" +#include "QuantizationUtils.h" + +#include +#include +#include + +namespace luci +{ + +/** + * How QuantizeOnnxFakeQuantModel works? + * + * 1. Activation is quantized as below. + * + * Before + * + * [node(fp32)] -> [OnnxQuantizeLinear] -> [OnnxDequantizeLinear] + * + * After + * + * [node(q)] + * + * + * 2. Weight(constant) are quantized as below. + * + * Before + * + * [Const(q w/o qparam)] -> [OnnxDequantizeLinear] + * + * After + * + * [Const(q)] + * + * 3. Quantize constant activations + * + * 4. Quantize with predecessors' qparams + * + * 5. Update qparams of special operators + * + * 6. Insert Quantize Op if an Op's input dtype and output dtype mismatch + */ +bool QuantizeOnnxFakeQuantModelPass::run(loco::Graph *g) +{ + LOGGER(l); + INFO(l) << "QuantizeOnnxFakeQuantModelPass Start" << std::endl; + + // Quantize Onnx QuantizeLinear-DequantizeLinear pattern + { + QuantizeOnnxQDQPass pass; + pass.run(g); + } + + // Quantize Onnx const-DequantizeLinear pattern + { + QuantizeOnnxDequantizeLinearPass pass; + pass.run(g); + } + + // Quantize const input activation + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast(node); + + QuantizeConstInputActivation qcia(_ctx->default_activation_dtype); + circle_node->accept(&qcia); + } + + // Quantize nodes using their predecessors' qparams + { + QuantizeWithPredecessorPass pass; + pass.run(g); + } + + // Update qparam of output of special Ops + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast(node); + + if (is_quantized(circle_node)) + { + QuantizeSpecialActivation qsa(circle_node->dtype()); + circle_node->accept(&qsa); + } + } + + // Insert QuantizeOp if input/output dtype does not match + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast(node); + + InsertQuantizeOpOnDTypeMismatch iqoodm; + circle_node->accept(&iqoodm); + } + + // Update output dtype + auto graph_outputs = g->outputs(); + for (auto node : loco::output_nodes(g)) + { + auto circle_node = loco::must_cast(node); + auto from = loco::must_cast(circle_node->from()); + circle_node->dtype(from->dtype()); + + auto graph_output = graph_outputs->at(circle_node->index()); + graph_output->dtype(circle_node->dtype()); + } + + INFO(l) << "QuantizeOnnxFakeQuantModelPass End" << std::endl; + return false; // one time run +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.test.cpp b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.test.cpp new file mode 100644 index 000000000..953f20655 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.test.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizeOnnxFakeQuantModelPass.h" +#include "PassTestGraphs.h" + +#include + +namespace +{ + +using namespace luci::test; + +class S16OnnxFakeQuantGraphlet +{ +public: + S16OnnxFakeQuantGraphlet() = default; + + void init(loco::Graph *g) + { + _quantize = g->nodes()->create(3, 1); + _quantize_out = g->nodes()->create(); + _dequantize = g->nodes()->create(3, 1); + _dequantize_out = g->nodes()->create(); + _scale = g->nodes()->create(); + _zerop = g->nodes()->create(); + + _quantize->dtype(loco::DataType::S16); + _quantize_out->dtype(loco::DataType::S16); + _dequantize->dtype(loco::DataType::FLOAT32); + _dequantize_out->dtype(loco::DataType::FLOAT32); + _scale->dtype(loco::DataType::FLOAT32); + _zerop->dtype(loco::DataType::S16); + + _scale->shape({1}); + _zerop->shape({1}); + + _scale->size(1); + _scale->at(0) = 5.0; + + _zerop->size(1); + _zerop->at(0) = 0; + + _quantize->custom_code("ONNXQuantizeLinear"); + _quantize_out->index(0); + + _dequantize->custom_code("ONNXDequantizeLinear"); + _dequantize_out->index(0); + + _scale->name("scale"); + _zerop->name("zerop"); + _quantize->name("quantize"); + _quantize_out->name("quantize_out"); + _dequantize->name("dequantize"); + _dequantize_out->name("dequantize_out"); + } + +protected: + luci::CircleCustom *_quantize = nullptr; + luci::CircleCustomOut *_quantize_out = nullptr; + luci::CircleCustom *_dequantize = nullptr; + luci::CircleCustomOut *_dequantize_out = nullptr; + luci::CircleConst *_scale = nullptr; + luci::CircleConst *_zerop = nullptr; +}; + +class S16QuantizeOnnxFakeQuantModelTestGraph : public TestIOGraph, public S16OnnxFakeQuantGraphlet +{ +public: + void init(void) + { + TestIOGraph::init({2, 2, 2}, {2, 2, 2}); + S16OnnxFakeQuantGraphlet::init(g()); + + _quantize->inputs(0, input()); + _quantize->inputs(1, _scale); + _quantize->inputs(2, _zerop); + _quantize_out->input(_quantize); + _dequantize->inputs(0, _quantize_out); + _dequantize->inputs(1, _scale); + _dequantize->inputs(2, _zerop); + _dequantize_out->input(_dequantize); + + output()->from(_dequantize_out); + } +}; + +} // namespace + +TEST(QuantizeOnnxFakeQuantModelTest, s16_quantize_onnx_qdq) +{ + S16QuantizeOnnxFakeQuantModelTestGraph g; + + auto ctx = std::make_unique(); + { + ctx->default_activation_dtype = loco::DataType::S16; + } + + luci::QuantizeOnnxFakeQuantModelPass pass(std::move(ctx)); + + g.init(); + + // Always return false + EXPECT_FALSE(pass.run(g.g())); + + EXPECT_EQ(loco::DataType::S16, g.input()->dtype()); +} diff --git a/compiler/luci/pass/src/QuantizeOnnxQDQPass.cpp b/compiler/luci/pass/src/QuantizeOnnxQDQPass.cpp new file mode 100644 index 000000000..efe8b59fa --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxQDQPass.cpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeOnnxQDQPass.h" +#include "QuantizationUtils.h" + +#include +#include +#include + +#include +#include + +namespace +{ + +using namespace luci; + +struct OnnxQDQPattern final +{ +public: + OnnxQDQPattern(luci::CircleCustomOut *candidate) { dq_out = candidate; } + +public: + bool matched() + { + if (not dq_out) + return false; + + dq = loco::must_cast(dq_out->input()); + if (not is_onnx_dequantize_linear(dq)) + return false; + + q_out = dynamic_cast(dq->inputs(0)); + if (not q_out) + return false; + + dq_scale = dynamic_cast(dq->inputs(1)); + if (not dq_scale) + return false; + + dq_zerop = dynamic_cast(dq->inputs(2)); + if (not dq_zerop) + return false; + + q = loco::must_cast(q_out->input()); + if (not is_onnx_quantize_linear(q)) + return false; + + input = loco::must_cast(q->inputs(0)); + if (input->dtype() != loco::DataType::FLOAT32) + return false; + + q_scale = dynamic_cast(q->inputs(1)); + if (not q_scale) + return false; + + q_zerop = dynamic_cast(q->inputs(2)); + if (not q_zerop) + return false; + + const auto q_dtype = q->dtype(); + const auto q_scale_dtype = q_scale->dtype(); + const auto q_zerop_dtype = q_zerop->dtype(); + const auto dq_scale_dtype = dq_scale->dtype(); + const auto dq_zerop_dtype = dq_zerop->dtype(); + + if (q_scale_dtype != loco::DataType::FLOAT32) + return false; + + if (dq_scale_dtype != loco::DataType::FLOAT32) + return false; + + // Invariant from onnx Quantize operator + if (q_dtype != q_zerop_dtype) + return false; + + // Invariant from onnx Dequantize operator + if (q_dtype != dq_zerop_dtype) + return false; + + // Check length of scale, zp = 1 + if (q_scale->size() != 1) + return false; + + if (dq_scale->size() != 1) + return false; + + auto q_zerop_size = 0; + auto dq_zerop_size = 0; + switch (q_zerop_dtype) + { + case loco::DataType::S16: + q_zerop_size = q_zerop->size(); + dq_zerop_size = dq_zerop->size(); + break; + case loco::DataType::U8: + q_zerop_size = q_zerop->size(); + dq_zerop_size = dq_zerop->size(); + break; + default: + throw std::runtime_error("Unsupported zerop dtype in " + q_zerop->name()); + } + + if (q_zerop_size != 1) + return false; + + if (dq_zerop_size != 1) + return false; + + return true; + } + +public: + luci::CircleCustomOut *dq_out = nullptr; + luci::CircleCustom *dq = nullptr; + luci::CircleConst *dq_scale = nullptr; + luci::CircleConst *dq_zerop = nullptr; + luci::CircleCustomOut *q_out = nullptr; + luci::CircleCustom *q = nullptr; + luci::CircleConst *q_scale = nullptr; + luci::CircleConst *q_zerop = nullptr; + luci::CircleNode *input = nullptr; +}; + +class QuantizeOnnxQDQ final +{ +public: + QuantizeOnnxQDQ(const OnnxQDQPattern &p) : _p(p) {} + +public: + void apply(void) + { + const auto quantized_dtype = _p.q->dtype(); + + // Get scale + assert(_p.q_scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.q_scale->size() == 1); // FIX_CALLER_UNLESS + const float q_scale = _p.q_scale->at(0); + + assert(_p.dq_scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.dq_scale->size() == 1); // FIX_CALLER_UNLESS + const float dq_scale = _p.dq_scale->at(0); + + if (q_scale != dq_scale) + throw std::runtime_error("Invalid scale value in " + _p.dq_scale->name()); + + // Get zerop + int64_t q_zerop = 0; + int64_t dq_zerop = 0; + switch (quantized_dtype) + { + case loco::DataType::S16: + assert(_p.q_zerop->size() == 1); // FIX_CALLER_UNLESS + assert(_p.dq_zerop->size() == 1); // FIX_CALLER_UNLESS + q_zerop = _p.q_zerop->at(0); + dq_zerop = _p.dq_zerop->at(0); + break; + case loco::DataType::U8: + assert(_p.q_zerop->size() == 1); // FIX_CALLER_UNLESS + assert(_p.dq_zerop->size() == 1); // FIX_CALLER_UNLESS + q_zerop = _p.q_zerop->at(0); + dq_zerop = _p.dq_zerop->at(0); + break; + default: + throw std::runtime_error("Unsupported zerop dtype in " + _p.q_zerop->name()); + } + + if (q_zerop != dq_zerop) + throw std::runtime_error("Invalid zerop value in " + _p.dq_zerop->name()); + + auto qparam = std::make_unique(); + { + qparam->scale.push_back(q_scale); + qparam->zerop.push_back(q_zerop); + qparam->quantized_dimension = 0; + } + + if (auto const_input = dynamic_cast(_p.input)) + { + assert(const_input->dtype() == loco::DataType::FLOAT32); // FIX_ME_UNLESS + const auto num_elem = const_input->size(); + + auto new_const = luci::clone(const_input); + new_const->name(new_const->name() + "_quant"); + add_origin(new_const, luci::get_origin(const_input)); + + new_const->dtype(quantized_dtype); + + // Quantize const + switch (quantized_dtype) + { + case loco::DataType::S16: + { + new_const->size(num_elem); + + const int64_t max_val = std::numeric_limits::max(); + const int64_t min_val = -max_val; + for (uint32_t i = 0; i < num_elem; i++) + { + const float fp_val = const_input->at(i); + const int64_t q_val = std::round(fp_val / q_scale) + q_zerop; + new_const->at(i) = std::min(max_val, std::max(min_val, q_val)); + } + break; + } + default: + throw std::runtime_error("Unsupported quantized_dtype"); + } + new_const->quantparam(std::move(qparam)); + + replace(_p.dq_out).with(new_const); + } + else + { + // clang-format off + // NOTE We overwrite dtype and qparam to _p.input + // This can be problematic if a single tensor has + // multiple different qparams. Let's fix later. + _p.input->dtype(quantized_dtype); + _p.input->quantparam(std::move(qparam)); + + replace(_p.dq_out).with(_p.input); + // clang-format on + } + } + +private: + const OnnxQDQPattern &_p; +}; + +} // namespace + +namespace luci +{ + +/** + * + * Quantize pattern + * + * [Before] + * + * [CircleNode(fp32)] + * | + * [CircleCustom(OnnxQuantizeLinear)] + * | + * [CircleCustom(OnnxDequantizeLinear)] + * | + * [CircleNode] + * + * [After] + * + * [CircleNode(quantized)] + * | + * [CircleNode] + */ +bool QuantizeOnnxQDQPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto circle_custom_out = dynamic_cast(node)) + { + OnnxQDQPattern p(circle_custom_out); + if (p.matched()) + { + QuantizeOnnxQDQ quantize(p); + quantize.apply(); + changed = true; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeOnnxQDQPass.h b/compiler/luci/pass/src/QuantizeOnnxQDQPass.h new file mode 100644 index 000000000..ac61d9ee8 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxQDQPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_ONNX_QDQ_PASS_H__ +#define __LUCI_QUANTIZE_ONNX_QDQ_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to quantize ONNX QuantizeLinear-DequantizeLinear operator + * + */ +struct QuantizeOnnxQDQPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::QuantizeOnnxQDQ"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZE_ONNX_QDQ_PASS_H__ diff --git a/compiler/luci/pass/src/QuantizeOnnxQDQPass.test.cpp b/compiler/luci/pass/src/QuantizeOnnxQDQPass.test.cpp new file mode 100644 index 000000000..6053bdf96 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxQDQPass.test.cpp @@ -0,0 +1,322 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeOnnxQDQPass.h" +#include "PassTestGraphs.h" + +#include + +namespace +{ + +using namespace luci::test; + +class U8OnnxQDQGraphlet +{ +public: + U8OnnxQDQGraphlet() = default; + + void init(loco::Graph *g) + { + _quantize = g->nodes()->create(3, 1); + _quantize_out = g->nodes()->create(); + _dequantize = g->nodes()->create(3, 1); + _dequantize_out = g->nodes()->create(); + _scale = g->nodes()->create(); + _zerop = g->nodes()->create(); + + _quantize->dtype(loco::DataType::U8); + _quantize_out->dtype(loco::DataType::U8); + _dequantize->dtype(loco::DataType::FLOAT32); + _dequantize_out->dtype(loco::DataType::FLOAT32); + _scale->dtype(loco::DataType::FLOAT32); + _zerop->dtype(loco::DataType::U8); + + _scale->shape({1}); + _zerop->shape({1}); + + _scale->size(1); + _scale->at(0) = 5.0; + + _zerop->size(1); + _zerop->at(0) = 0; + + _quantize->custom_code("ONNXQuantizeLinear"); + _quantize_out->index(0); + + _dequantize->custom_code("ONNXDequantizeLinear"); + _dequantize_out->index(0); + + _scale->name("scale"); + _zerop->name("zerop"); + _quantize->name("quantize"); + _quantize_out->name("quantize_out"); + _dequantize->name("dequantize"); + _dequantize_out->name("dequantize_out"); + } + +protected: + luci::CircleCustom *_quantize = nullptr; + luci::CircleCustomOut *_quantize_out = nullptr; + luci::CircleCustom *_dequantize = nullptr; + luci::CircleCustomOut *_dequantize_out = nullptr; + luci::CircleConst *_scale = nullptr; + luci::CircleConst *_zerop = nullptr; +}; + +class U8QuantizeOnnxQDQTestGraph : public TestIOGraph, public U8OnnxQDQGraphlet +{ +public: + void init(void) + { + TestIOGraph::init({2, 2, 2}, {2, 2, 2}); + U8OnnxQDQGraphlet::init(g()); + + _quantize->inputs(0, input()); + _quantize->inputs(1, _scale); + _quantize->inputs(2, _zerop); + _quantize_out->input(_quantize); + _dequantize->inputs(0, _quantize_out); + _dequantize->inputs(1, _scale); + _dequantize->inputs(2, _zerop); + _dequantize_out->input(_dequantize); + + output()->from(_dequantize_out); + } +}; + +class S16OnnxQDQGraphlet +{ +public: + S16OnnxQDQGraphlet() = default; + + void init(loco::Graph *g) + { + _quantize = g->nodes()->create(3, 1); + _quantize_out = g->nodes()->create(); + _dequantize = g->nodes()->create(3, 1); + _dequantize_out = g->nodes()->create(); + _scale = g->nodes()->create(); + _zerop = g->nodes()->create(); + + _quantize->dtype(loco::DataType::S16); + _quantize_out->dtype(loco::DataType::S16); + _dequantize->dtype(loco::DataType::FLOAT32); + _dequantize_out->dtype(loco::DataType::FLOAT32); + _scale->dtype(loco::DataType::FLOAT32); + _zerop->dtype(loco::DataType::S16); + + _scale->shape({1}); + _zerop->shape({1}); + + _scale->size(1); + _scale->at(0) = 5.0; + + _zerop->size(1); + _zerop->at(0) = 0; + + _quantize->custom_code("ONNXQuantizeLinear"); + _quantize_out->index(0); + + _dequantize->custom_code("ONNXDequantizeLinear"); + _dequantize_out->index(0); + + _scale->name("scale"); + _zerop->name("zerop"); + _quantize->name("quantize"); + _quantize_out->name("quantize_out"); + _dequantize->name("dequantize"); + _dequantize_out->name("dequantize_out"); + } + +protected: + luci::CircleCustom *_quantize = nullptr; + luci::CircleCustomOut *_quantize_out = nullptr; + luci::CircleCustom *_dequantize = nullptr; + luci::CircleCustomOut *_dequantize_out = nullptr; + luci::CircleConst *_scale = nullptr; + luci::CircleConst *_zerop = nullptr; +}; + +class S16QuantizeOnnxQDQTestGraph : public TestIOGraph, public S16OnnxQDQGraphlet +{ +public: + void init(void) + { + TestIOGraph::init({2, 2, 2}, {2, 2, 2}); + S16OnnxQDQGraphlet::init(g()); + + _quantize->inputs(0, input()); + _quantize->inputs(1, _scale); + _quantize->inputs(2, _zerop); + _quantize_out->input(_quantize); + _dequantize->inputs(0, _quantize_out); + _dequantize->inputs(1, _scale); + _dequantize->inputs(2, _zerop); + _dequantize_out->input(_dequantize); + + output()->from(_dequantize_out); + } +}; + +class S16ConstOnnxQDQTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test +{ +public: + S16ConstOnnxQDQTest() : luci::ConstantFoldingAddTestGraph({2, 2, 2}, loco::DataType::S16) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + _quantize = _g.nodes()->create(3, 1); + _quantize_out = _g.nodes()->create(); + _dequantize = _g.nodes()->create(3, 1); + _dequantize_out = _g.nodes()->create(); + _scale = _g.nodes()->create(); + _zerop = _g.nodes()->create(); + _input = _g.nodes()->create(); + + _quantize->dtype(loco::DataType::S16); + _quantize_out->dtype(loco::DataType::S16); + _dequantize->dtype(loco::DataType::FLOAT32); + _dequantize_out->dtype(loco::DataType::FLOAT32); + _scale->dtype(loco::DataType::FLOAT32); + _zerop->dtype(loco::DataType::S16); + _input->dtype(loco::DataType::FLOAT32); + + _scale->shape({1}); + _zerop->shape({1}); + _input->shape({2, 2, 2}); + + _scale->size(1); + _scale->at(0) = 5.0; + + _zerop->size(1); + _zerop->at(0) = 0; + + _input->size(8); + for (uint32_t i = 0; i < 8; i++) + _input->at(i) = i; + + _quantize->custom_code("ONNXQuantizeLinear"); + _quantize_out->index(0); + + _dequantize->custom_code("ONNXDequantizeLinear"); + _dequantize_out->index(0); + + _scale->name("scale"); + _zerop->name("zerop"); + _quantize->name("quantize"); + _quantize_out->name("quantize_out"); + _dequantize->name("dequantize"); + _dequantize_out->name("dequantize_out"); + + _quantize->inputs(0, _input); + _quantize->inputs(1, _scale); + _quantize->inputs(2, _zerop); + _quantize_out->input(_quantize); + _dequantize->inputs(0, _quantize_out); + _dequantize->inputs(1, _scale); + _dequantize->inputs(2, _zerop); + _dequantize_out->input(_dequantize); + + return _dequantize_out; + } + +protected: + luci::CircleCustom *_quantize = nullptr; + luci::CircleCustomOut *_quantize_out = nullptr; + luci::CircleCustom *_dequantize = nullptr; + luci::CircleCustomOut *_dequantize_out = nullptr; + luci::CircleConst *_scale = nullptr; + luci::CircleConst *_zerop = nullptr; + luci::CircleConst *_input = nullptr; +}; + +} // namespace + +TEST(QuantizeOnnxQDQTest, s16_quantize_onnx_qdq) +{ + S16QuantizeOnnxQDQTestGraph g; + + luci::QuantizeOnnxQDQPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(QuantizeOnnxQDQTest, s16_quantize_onnx_qdq_NEG) +{ + S16QuantizeOnnxQDQTestGraph g; + + luci::QuantizeOnnxQDQPass pass; + + g.init(); + + g.input()->dtype(loco::DataType::S16); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(QuantizeOnnxQDQTest, u8_quantize_onnx_qdq) +{ + U8QuantizeOnnxQDQTestGraph g; + + luci::QuantizeOnnxQDQPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(QuantizeOnnxQDQTest, u8_quantize_onnx_qdq_NEG) +{ + U8QuantizeOnnxQDQTestGraph g; + + luci::QuantizeOnnxQDQPass pass; + + g.init(); + + g.input()->dtype(loco::DataType::U8); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST_F(S16ConstOnnxQDQTest, s16_const_qdq) +{ + luci::QuantizeOnnxQDQPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::S16, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(0, folded_const->at(0)); + EXPECT_EQ(0, folded_const->at(1)); + EXPECT_EQ(0, folded_const->at(2)); + EXPECT_EQ(1, folded_const->at(3)); + EXPECT_EQ(1, folded_const->at(4)); + EXPECT_EQ(1, folded_const->at(5)); + EXPECT_EQ(1, folded_const->at(6)); + EXPECT_EQ(1, folded_const->at(7)); +} diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp index 59329c19e..17a887cfa 100644 --- a/compiler/luci/pass/src/QuantizeWeights.cpp +++ b/compiler/luci/pass/src/QuantizeWeights.cpp @@ -407,8 +407,6 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights) { sym_wquant_per_channel(weights, scaling_factor, channel_dim_index); } - quantparam->min.clear(); - quantparam->max.clear(); quantparam->quantized_dimension = channel_dim_index; } // Find min/max per layer-wise @@ -449,8 +447,6 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights) auto min = quantparam->min[0]; auto scaling_factor = quantparam->scale[0]; asym_wquant_per_layer(weights, min, scaling_factor); - quantparam->min.clear(); - quantparam->max.clear(); } } void QuantizeWeights::visit(luci::CircleConv2D *node) diff --git a/compiler/luci/pass/src/QuantizeWeightsOnly.cpp b/compiler/luci/pass/src/QuantizeWeightsOnly.cpp index e69a7b6a8..edaf13e59 100644 --- a/compiler/luci/pass/src/QuantizeWeightsOnly.cpp +++ b/compiler/luci/pass/src/QuantizeWeightsOnly.cpp @@ -68,9 +68,10 @@ void sym_wquant_per_channel(CircleConst *node, std::vector &min, std::vec std::vector &nudged_max, int32_t &channel_dim_index) { assert(node->dtype() == loco::DataType::FLOAT32); - assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16); - const int32_t kMaxScale = (out_type == loco::DataType::S8) ? std::numeric_limits::max() - : std::numeric_limits::max(); + assert(out_type == loco::DataType::S4 || out_type == loco::DataType::S8 || + out_type == loco::DataType::S16); + + const int32_t kMaxScale = max_for_sym_quant(out_type); const int32_t kMinScale = -kMaxScale; uint32_t size = node->size(); @@ -163,7 +164,12 @@ void QuantizeWeightsOnly::quantize_weights(luci::CircleConst *weights) std::vector scaling_factor(min.size()); std::vector zp(min.size()); - if (output_type == loco::DataType::S8) + if (output_type == loco::DataType::S4) + { + sym_wquant_per_channel(weights, min, max, scaling_factor, nudged_min, + nudged_max, channel_dim_index); + } + else if (output_type == loco::DataType::S8) { sym_wquant_per_channel(weights, min, max, scaling_factor, nudged_min, nudged_max, channel_dim_index); @@ -205,6 +211,20 @@ void QuantizeWeightsOnly::visit(luci::CircleConv2D *node) } } +void QuantizeWeightsOnly::visit(luci::CircleFullyConnected *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeightsOnly visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast(node->weights()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->weights(new_weights); + quantize_weights(new_weights); + } +} + void QuantizeWeightsOnly::visit(luci::CircleDepthwiseConv2D *node) { LOGGER(l); diff --git a/compiler/luci/pass/src/QuantizeWeightsOnly.h b/compiler/luci/pass/src/QuantizeWeightsOnly.h index ff6ad3261..8d1421f4b 100644 --- a/compiler/luci/pass/src/QuantizeWeightsOnly.h +++ b/compiler/luci/pass/src/QuantizeWeightsOnly.h @@ -43,6 +43,7 @@ private: void visit(luci::CircleConv2D *node); void visit(luci::CircleDepthwiseConv2D *node); + void visit(luci::CircleFullyConnected *node); void visit(luci::CircleNode *); }; diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index 4f4edaf36..bdb50d67a 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -78,7 +78,7 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType auto qtype = luci::activation_qtype(node); if (use_predefined_values(qtype)) { - quantize->quantparam(luci::make_predefined_qparam(qtype, out_type)); + quantize->quantparam(luci::make_predefined_qparam(qtype, out_type, node->quantparam())); return quantize; } @@ -232,16 +232,25 @@ private: } // INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE -#define INSERT_QUANTIZE_TO_BINARY_OP(NODE, INPUT_NAME1, INPUT_NAME2) \ - void visit(NODE *node) \ - { \ - if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \ - node->INPUT_NAME1(input1_quant); \ - \ - if (auto input2_quant = create_in_quantize(node->INPUT_NAME2(), node)) \ - node->INPUT_NAME2(input2_quant); \ - \ - insert_out_quantize(node); \ +#define INSERT_QUANTIZE_TO_BINARY_OP(NODE, INPUT_NAME1, INPUT_NAME2) \ + void visit(NODE *node) \ + { \ + if (node->INPUT_NAME1() == node->INPUT_NAME2()) \ + { \ + if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \ + { \ + node->INPUT_NAME1(input1_quant); \ + node->INPUT_NAME2(input1_quant); \ + } \ + return; \ + } \ + if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \ + node->INPUT_NAME1(input1_quant); \ + \ + if (auto input2_quant = create_in_quantize(node->INPUT_NAME2(), node)) \ + node->INPUT_NAME2(input2_quant); \ + \ + insert_out_quantize(node); \ } // Default behavior (NYI) @@ -584,7 +593,7 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) for (auto node : loco::all_nodes(g)) { auto circle_node = loco::must_cast(node); - QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node)); + QuantizeActivation qa(quantize_dtype(circle_node)); circle_node->accept(&qa); } @@ -637,7 +646,7 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) if (circle_node->quantparam() == nullptr) continue; - QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node)); + QuantizeSpecialActivation qsa(quantize_dtype(circle_node)); circle_node->accept(&qsa); } @@ -700,15 +709,18 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) phase_runner.run(phase); } - // Remove min/max values - for (auto node : loco::active_nodes(loco::output_nodes(g))) + if (not _ctx->save_min_max) { - auto circle_node = loco::must_cast(node); - if (auto qparam = circle_node->quantparam()) + // Remove min/max values + for (auto node : loco::all_nodes(g)) { - warn_accuracy_with_range(circle_node); - qparam->min.clear(); - qparam->max.clear(); + auto circle_node = loco::must_cast(node); + if (auto qparam = circle_node->quantparam()) + { + warn_accuracy_with_range(circle_node); + qparam->min.clear(); + qparam->max.clear(); + } } } diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp index 49c2d4652..9e4716533 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp @@ -51,6 +51,51 @@ public: luci::CircleConst *input_2 = nullptr; }; +class SimpleConcatGraphWithMinMaxValues +{ +public: + SimpleConcatGraphWithMinMaxValues(loco::DataType quant_type) + { + concat_node = g.nodes()->create(2); + input_1 = g.nodes()->create(); + input_2 = g.nodes()->create(); + + concat_node->dtype(quant_type); + concat_node->fusedActivationFunction(luci::FusedActFunc::NONE); + auto concat_node_quant_params = std::make_unique(); + concat_node_quant_params->min = {-1}; + concat_node_quant_params->max = {1}; + concat_node->quantparam(std::move(concat_node_quant_params)); + + input_1->dtype(quant_type); + auto input1_node_quant_params = std::make_unique(); + input1_node_quant_params->min = {-1}; + input1_node_quant_params->max = {1}; + input_1->quantparam(std::move(input1_node_quant_params)); + + input_2->dtype(quant_type); + auto input2_node_quant_params = std::make_unique(); + input2_node_quant_params->min = {-1}; + input2_node_quant_params->max = {1}; + input_2->quantparam(std::move(input2_node_quant_params)); + + concat_node->values(0, input_1); + concat_node->values(1, input_2); + } + + ~SimpleConcatGraphWithMinMaxValues() + { + concat_node->values(0, nullptr); + concat_node->values(1, nullptr); + } + +public: + loco::Graph g; + luci::CircleConcatenation *concat_node = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; +}; + TEST(QuantizeWithMinMaxPassTest, name) { auto ctx = std::make_unique(); @@ -105,3 +150,30 @@ TEST(QuantizeWithMinMaxPassTest, inactive_input) EXPECT_NO_THROW(qwmm.run(&g.g)); } + +// Test saving min max +TEST(QuantizeWithMinMaxPassTest, save_min_max_test) +{ + SimpleConcatGraphWithMinMaxValues g(loco::DataType::FLOAT32); + + auto ctx = std::make_unique(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::LayerWise; + ctx->save_min_max = true; + } + + luci::QuantizeWithMinMaxPass qwmm(std::move(ctx)); + + qwmm.run(&g.g); + + EXPECT_NE(0, g.input_1->quantparam()->min.size()); + EXPECT_NE(0, g.input_1->quantparam()->max.size()); + + EXPECT_NE(0, g.input_2->quantparam()->min.size()); + EXPECT_NE(0, g.input_2->quantparam()->max.size()); + + EXPECT_NE(0, g.concat_node->quantparam()->min.size()); + EXPECT_NE(0, g.concat_node->quantparam()->max.size()); +} diff --git a/compiler/luci/pass/src/QuantizeWithPredecessorPass.cpp b/compiler/luci/pass/src/QuantizeWithPredecessorPass.cpp new file mode 100644 index 000000000..352653b3e --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWithPredecessorPass.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeWithPredecessorPass.h" + +#include "QuantizationUtils.h" + +#include +#include +#include + +#include + +namespace +{ + +// Quantize dst node using src node's qparam +// Return true if dst node is quantized with src +// Return false otherwise +bool quantize_with_same_qparam(luci::CircleNode *src, luci::CircleNode *dst) +{ + // src node is not quantized. Skip this case. + auto src_qparam = src->quantparam(); + if (not src_qparam) + return false; + + auto dst_qparam = dst->quantparam(); + // dst node is already quantized. Skip this case. + if (dst_qparam) + return false; + + luci::copy_quantparam(src, dst); + + dst->dtype(src->dtype()); + + return true; +} + +// Visitor to quantize nodes using predecessors qparams +struct QuantizeWithPredecessor final : public luci::CircleNodeMutableVisitor +{ + QuantizeWithPredecessor() = default; + + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleReshape *node) + { + auto input_node = loco::must_cast(node->tensor()); + return quantize_with_same_qparam(input_node, node); + } + + bool visit(luci::CircleTranspose *node) + { + auto input_node = loco::must_cast(node->a()); + return quantize_with_same_qparam(input_node, node); + } + + bool visit(luci::CircleStridedSlice *node) + { + auto input_node = loco::must_cast(node->input()); + return quantize_with_same_qparam(input_node, node); + } + + bool visit(luci::CircleSqueeze *node) + { + auto input_node = loco::must_cast(node->input()); + return quantize_with_same_qparam(input_node, node); + } + + bool visit(luci::CircleGather *node) + { + auto input_node = loco::must_cast(node->params()); + return quantize_with_same_qparam(input_node, node); + } + + bool visit(luci::CircleMul *node) + { + // Skip if node is already quantized + if (luci::is_quantized(node)) + return false; + + auto x = loco::must_cast(node->x()); + auto y = loco::must_cast(node->y()); + + // Only support square for now + if (x != y) + return false; + + // Only support S16 for now + if (x->dtype() != loco::DataType::S16) + return false; + + const auto input_qparam = x->quantparam(); + if (not input_qparam) + return false; + + if (input_qparam->scale.size() != 1) + return false; + + const auto input_scale = input_qparam->scale.at(0); + + const auto s16_max = std::numeric_limits::max(); + // How to determine a new scale of x^2? + // x's scale would have been determined by its max or min + // + // Max value of x^2 = (s * s16_max)^2 + // Min value of x^2 = 0 + // New scale = (s * s16_max)^2 / s16_max = s^2 * s16_max + // + // NOTE s16_max = -s16_min (symmetric quantization) + const auto output_scale = input_scale * input_scale * s16_max; + + auto new_qparam = std::make_unique(); + { + new_qparam->scale.push_back(output_scale); + new_qparam->zerop.push_back(0); + } + + node->quantparam(std::move(new_qparam)); + node->dtype(x->dtype()); + + return true; + } + + bool visit(luci::CircleNeg *node) + { + auto input_node = loco::must_cast(node->x()); + // Only support S16 for now + if (input_node->dtype() != loco::DataType::S16) + return false; + + return quantize_with_same_qparam(input_node, node); + } + + bool visit(luci::CircleConcatenation *node) + { + const auto num_inputs = node->numValues(); + + for (uint32_t i = 0; i < num_inputs; i++) + { + auto input = loco::must_cast(node->values(i)); + // Only support S16 for now + if (input->dtype() != loco::DataType::S16) + return false; + + if (input->quantparam() == nullptr) + return false; + + if (input->quantparam()->scale.size() != 1) + return false; + } + + luci::CircleNode *max_scale_node = nullptr; + float max_scale = 0.0; + for (uint32_t i = 0; i < num_inputs; i++) + { + auto input = loco::must_cast(node->values(i)); + auto qparam = input->quantparam(); + auto scale = qparam->scale.at(0); + if (max_scale < scale) + { + max_scale = scale; + max_scale_node = input; + } + } + + if (max_scale_node == nullptr) + { + throw std::runtime_error{"Invalid max_scale_node"}; + } + + return quantize_with_same_qparam(max_scale_node, node); + } +}; + +} // namespace + +namespace luci +{ + +bool QuantizeWithPredecessorPass::run(loco::Graph *g) +{ + bool changed = false; + + LOGGER(l); + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast(node); + INFO(l) << "QuantizeWithPredecessorPass visit node: " << circle_node->name() << std::endl; + + QuantizeWithPredecessor qwp; + if (circle_node->accept(&qwp)) + { + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeWithPredecessorPass.h b/compiler/luci/pass/src/QuantizeWithPredecessorPass.h new file mode 100644 index 000000000..e182c7921 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWithPredecessorPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_WITH_PREDECESSOR_PASS_H__ +#define __LUCI_QUANTIZE_WITH_PREDECESSOR_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to quantize nodes using their predecessors' qparam + * + */ +struct QuantizeWithPredecessorPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::QuantizeWithPredecessor"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZE_WITH_PREDECESSOR_PASS_H__ diff --git a/compiler/luci/pass/src/QuantizeWithPredecessorPass.test.cpp b/compiler/luci/pass/src/QuantizeWithPredecessorPass.test.cpp new file mode 100644 index 000000000..20d4f908a --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWithPredecessorPass.test.cpp @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeWithPredecessorPass.h" + +#include + +#include + +namespace +{ + +void addQuantParam(luci::CircleNode *node, const std::vector &scale, + const std::vector &zp) +{ + assert(node->quantparam() == nullptr); + + auto quantparam = std::make_unique(); + quantparam->scale = scale; + quantparam->zerop = zp; + node->quantparam(std::move(quantparam)); +} + +/** + * Simple graph for test + * + * BEFORE + * + * [Conv] (int16) + * | + * [Reshape] (fp32) + * + * AFTER + * + * [Conv] (int16) + * | + * [Reshape] (int16) + * + */ +class ConvReshapeGraph +{ +public: + ConvReshapeGraph() + { + input = g.nodes()->create(); + conv = g.nodes()->create(); + reshape = g.nodes()->create(); + output = g.nodes()->create(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + conv->dtype(loco::DataType::S16); + reshape->dtype(loco::DataType::FLOAT32); + + addQuantParam(conv, {1.0}, {0}); + + conv->input(input); + reshape->tensor(conv); + output->from(reshape); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleConv2D *conv = nullptr; + luci::CircleReshape *reshape = nullptr; + luci::CircleOutput *output = nullptr; +}; + +/** + * Simple graph for test + * + * BEFORE + * + * [Conv] (int16) + * | + * [Squeeze] (fp32) + * + * AFTER + * + * [Conv] (int16) + * | + * [Squeeze] (int16) + * + */ +class ConvSqueezeGraph +{ +public: + ConvSqueezeGraph() + { + input = g.nodes()->create(); + conv = g.nodes()->create(); + squeeze = g.nodes()->create(); + output = g.nodes()->create(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + conv->dtype(loco::DataType::S16); + squeeze->dtype(loco::DataType::FLOAT32); + + addQuantParam(conv, {1.0}, {0}); + + squeeze->squeeze_dims({0}); + + conv->input(input); + squeeze->input(conv); + output->from(squeeze); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleConv2D *conv = nullptr; + luci::CircleSqueeze *squeeze = nullptr; + luci::CircleOutput *output = nullptr; +}; + +/** + * Simple graph for test + * + * BEFORE + * + * [Conv] (int16) + * | + * [Mul] (fp32) + * + * AFTER + * + * [Conv] (int16) + * | + * [Mul] (int16) + * + */ +class ConvMulGraph +{ +public: + ConvMulGraph() + { + input = g.nodes()->create(); + conv = g.nodes()->create(); + mul = g.nodes()->create(); + output = g.nodes()->create(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + conv->dtype(loco::DataType::S16); + mul->dtype(loco::DataType::FLOAT32); + + addQuantParam(conv, {1.0}, {0}); + + conv->input(input); + mul->x(conv); + mul->y(conv); + output->from(mul); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleConv2D *conv = nullptr; + luci::CircleMul *mul = nullptr; + luci::CircleOutput *output = nullptr; +}; + +} // namespace + +TEST(QuantizeWithPredecessor, reshape) +{ + ConvReshapeGraph g; + + luci::QuantizeWithPredecessorPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_NE(nullptr, g.reshape->quantparam()); + EXPECT_FLOAT_EQ(1.0, g.reshape->quantparam()->scale[0]); + EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]); +} + +TEST(QuantizeWithPredecessor, reshape_NEG) +{ + ConvReshapeGraph g; + g.conv->quantparam(nullptr); + + luci::QuantizeWithPredecessorPass pass; + EXPECT_FALSE(pass.run(&g.g)); +} + +TEST(QuantizeWithPredecessor, squeeze) +{ + ConvSqueezeGraph g; + + luci::QuantizeWithPredecessorPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_NE(nullptr, g.squeeze->quantparam()); + EXPECT_FLOAT_EQ(1.0, g.squeeze->quantparam()->scale[0]); + EXPECT_EQ(0, g.squeeze->quantparam()->zerop[0]); +} + +TEST(QuantizeWithPredecessor, squeeze_NEG) +{ + ConvSqueezeGraph g; + g.conv->quantparam(nullptr); + + luci::QuantizeWithPredecessorPass pass; + EXPECT_FALSE(pass.run(&g.g)); +} + +TEST(QuantizeWithPredecessor, mul) +{ + ConvMulGraph g; + + luci::QuantizeWithPredecessorPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_NE(nullptr, g.mul->quantparam()); + EXPECT_FLOAT_EQ(32767, g.mul->quantparam()->scale[0]); + EXPECT_EQ(0, g.mul->quantparam()->zerop[0]); +} + +TEST(QuantizeWithPredecessor, mul_NEG) +{ + ConvMulGraph g; + g.conv->quantparam(nullptr); + + luci::QuantizeWithPredecessorPass pass; + EXPECT_FALSE(pass.run(&g.g)); +} diff --git a/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp b/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp index e50dda9e0..bec911914 100644 --- a/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp +++ b/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp @@ -122,9 +122,15 @@ bool RemoveDuplicateConstPass::remove_duplicate_const() case loco::DataType::S8: is_equal = is_equal_consts(reference_const, cur_const); break; + case loco::DataType::S4: + is_equal = is_equal_consts(reference_const, cur_const); + break; case loco::DataType::U8: is_equal = is_equal_consts(reference_const, cur_const); break; + case loco::DataType::U4: + is_equal = is_equal_consts(reference_const, cur_const); + break; default: continue; } @@ -211,9 +217,15 @@ bool RemoveDuplicateConstPass::run(loco::Graph *g) case loco::DataType::S8: add_to_map(const_node); break; + case loco::DataType::S4: + add_to_map(const_node); + break; case loco::DataType::U8: add_to_map(const_node); break; + case loco::DataType::U4: + add_to_map(const_node); + break; default: continue; } diff --git a/compiler/luci/pass/src/RemoveGatherGuardPass.cpp b/compiler/luci/pass/src/RemoveGatherGuardPass.cpp new file mode 100644 index 000000000..920595c54 --- /dev/null +++ b/compiler/luci/pass/src/RemoveGatherGuardPass.cpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveGatherGuardPass.h" + +#include + +#include + +namespace +{ + +/* + * BEFORE + * + * [CircleNode] [CircleNode] + * | | + * | [CircleAdd] + * | | + * | [CircleFloorMod] + * | / + * [CircleGather] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] [CircleNode] + * | | \ + * | | [CircleAdd] + * | / | + * | / [CircleFloorMod] + * | / + * [CircleGather] + * | + * [CircleNode] + */ + +bool is_single_value_equal(const loco::Node *node, int32_t value) +{ + assert(node); + + auto const cnode = dynamic_cast(node); + if (cnode == nullptr) + return false; + if (not(cnode->rank() == 0 || (cnode->rank() == 1 && cnode->dim(0).value() == 1))) + return false; + + if (cnode->dtype() == loco::DataType::S32) + return cnode->at(0) == value; + else if (cnode->dtype() == loco::DataType::S64) + return cnode->at(0) == static_cast(value); + + return false; +} + +bool remove_guards(luci::CircleGather *gather) +{ + assert(gather); + // check if sequence is Add+FloorMod + auto floormod = dynamic_cast(gather->indices()); + if (floormod == nullptr) + return false; + auto add = dynamic_cast(floormod->x()); + if (add == nullptr) + return false; + + if (add->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + // check if gather axis is 0 for now + // TODO support other axis + if (gather->axis() != 0) + return false; + // check if RHS of Add and FloorMod is Const and is scalar/single element and + // the value is same as gather.params.dim(0) + luci::CircleNode *params = loco::must_cast(gather->params()); + if (params->shape_status() != luci::ShapeStatus::VALID || params->rank() == 0) + return false; + // safe range check + if (params->dim(gather->axis()).value() >= INT_MAX) + return false; + int32_t params_axis_dim = static_cast(params->dim(gather->axis()).value()); + if (not is_single_value_equal(add->y(), params_axis_dim)) + return false; + if (not is_single_value_equal(floormod->y(), params_axis_dim)) + return false; + + // disconnect Add+FloorMod + gather->indices(add->x()); + + return true; +} + +} // namespace + +namespace luci +{ + +bool RemoveGatherGuardPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto gather = dynamic_cast(node)) + { + if (remove_guards(gather)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveGatherGuardPass.test.cpp b/compiler/luci/pass/src/RemoveGatherGuardPass.test.cpp new file mode 100644 index 000000000..6de51dd8e --- /dev/null +++ b/compiler/luci/pass/src/RemoveGatherGuardPass.test.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveGatherGuardPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class GatherGuardGraphlet +{ +public: + GatherGuardGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 params_s, const ShapeU32 indices_s, + const ShapeU32 output_s) + { + std::vector params_shape{params_s}; + + _add_y = g->nodes()->create(); + _add_y->rank(0); + _add_y->shape_status(luci::ShapeStatus::VALID); + _add_y->dtype(loco::DataType::S32); + _add_y->size(1); + _add_y->at(0) = params_shape[0]; + + _add = g->nodes()->create(); + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + _add->dtype(loco::DataType::S32); + _add->shape(indices_s); + _add->shape_status(luci::ShapeStatus::VALID); + + _fm_y = g->nodes()->create(); + _fm_y->rank(0); + _fm_y->shape_status(luci::ShapeStatus::VALID); + _fm_y->dtype(loco::DataType::S32); + _fm_y->size(1); + _fm_y->at(0) = params_shape[0]; + + _fm = g->nodes()->create(); + _fm->dtype(loco::DataType::S32); + _fm->shape(indices_s); + _fm->shape_status(luci::ShapeStatus::VALID); + + _gather = g->nodes()->create(); + _gather->axis(0); + _gather->dtype(loco::DataType::FLOAT32); + _gather->shape(output_s); + _gather->shape_status(luci::ShapeStatus::VALID); + } + +protected: + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_add_y = nullptr; + luci::CircleFloorMod *_fm = nullptr; + luci::CircleConst *_fm_y = nullptr; + luci::CircleGather *_gather = nullptr; +}; + +class GatherGuardGraph : public TestIsGraphlet<2>, public TestOGraphlet, public GatherGuardGraphlet +{ +public: + GatherGuardGraph() = default; + +public: + void init(const ShapeU32 params_s, const ShapeU32 indices_s, const ShapeU32 output_s) + { + TestIsGraphlet<2>::init(g(), {params_s, indices_s}); + TestOGraphlet::init(g(), output_s); + GatherGuardGraphlet::init(g(), params_s, indices_s, output_s); + + // connect graph + _add->x(input(1)); + _add->y(_add_y); + _fm->x(_add); + _fm->y(_fm_y); + _gather->params(input(0)); + _gather->indices(_fm); + output()->from(_gather); + } +}; + +class GatherGuardGraphTest : public ::testing::Test, public GatherGuardGraph +{ +protected: + luci::RemoveGatherGuardPass _pass; + + ShapeU32 _input_s0 = {10, 3}; + ShapeU32 _input_s1 = {5, 4}; + ShapeU32 _output_s = {5, 4, 3}; +}; + +} // namespace + +TEST_F(GatherGuardGraphTest, name) +{ + auto const name = _pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(GatherGuardGraphTest, removed) +{ + // test to check pass is working as expected + + init(_input_s0, _input_s1, _output_s); + + auto *indices_before = loco::must_cast(_gather->indices()); + EXPECT_NE(input(1), indices_before); + + EXPECT_TRUE(_pass.run(g())); + + auto *indices_after = loco::must_cast(_gather->indices()); + EXPECT_EQ(input(1), indices_after); +} + +TEST_F(GatherGuardGraphTest, axis_value_NEG) +{ + // test if fails when gather->params->dim(0) != add/floormod rhs value + + init(_input_s0, _input_s1, _output_s); + + _add_y->at(0) = 11; + EXPECT_FALSE(_pass.run(g())); + _add_y->at(0) = 10; + + _fm_y->at(0) = 11; + EXPECT_FALSE(_pass.run(g())); +} + +TEST_F(GatherGuardGraphTest, add_act_not_none_NEG) +{ + // test if fails when add activation function is not none + + init(_input_s0, _input_s1, _output_s); + + _add->fusedActivationFunction(luci::FusedActFunc::RELU); + + EXPECT_FALSE(_pass.run(g())); +} diff --git a/compiler/luci/pass/src/RemoveQDQForMixedPrecisionOpPass.cpp b/compiler/luci/pass/src/RemoveQDQForMixedPrecisionOpPass.cpp new file mode 100644 index 000000000..1f796cd53 --- /dev/null +++ b/compiler/luci/pass/src/RemoveQDQForMixedPrecisionOpPass.cpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveQDQForMixedPrecisionOpPass.h" + +#include + +/** + * Remove Quantize-Dequantize pattern for backends with mixed-precision operator. + * + * BEFORE + * [CircleNode_1] + * | + * [CircleQuantize, dtype_1, scale_1, zero_point_1] + * | + * [CircleDequantize] + * | + * [CircleQuantize, dtype_2, scale_2, zero_point_2] + * | + * [CircleDequantize] + * | + * [CircleNode_2] + * + * AFTER + * + * [CircleNode_1] + * | + * [CircleQuantize, dtype_2, scale_2, zero_point_2] + * | + * [CircleDequantize] + * | + * [CircleNode_2] + * + */ + +namespace +{ + +bool remove_qdq_for_mpo(luci::CircleDequantize *node) +{ + auto prev = dynamic_cast(node->input()); + if (not prev) + return false; + + auto prev_prev = dynamic_cast(prev->input()); + if (not prev_prev) + return false; + + auto prev_prev_prev = dynamic_cast(prev_prev->input()); + if (not prev_prev_prev) + return false; + + auto input = loco::must_cast(prev_prev_prev->input()); + + const static std::set supported_ops{luci::CircleOpcode::FULLY_CONNECTED, + luci::CircleOpcode::BATCH_MATMUL}; + + if (supported_ops.find(input->opcode()) == supported_ops.end()) + return false; + + prev->input(input); + + return true; +} + +} // namespace + +namespace luci +{ + +bool RemoveQDQForMixedPrecisionOpPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + if (auto dq = dynamic_cast(node)) + { + if (remove_qdq_for_mpo(dq)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveQDQForMixedPrecisionOpPass.test.cpp b/compiler/luci/pass/src/RemoveQDQForMixedPrecisionOpPass.test.cpp new file mode 100644 index 000000000..69e18da3c --- /dev/null +++ b/compiler/luci/pass/src/RemoveQDQForMixedPrecisionOpPass.test.cpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveQDQForMixedPrecisionOpPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class QuantDequantGraphlet +{ +public: + QuantDequantGraphlet() = default; + +public: + void init(loco::Graph *g) + { + _fc = g->nodes()->create(); + _fc->name("fc"); + + _qu = g->nodes()->create(); + _qu->name("qu"); + + _de = g->nodes()->create(); + _de->name("de"); + + _qu_2 = g->nodes()->create(); + _qu_2->name("qu"); + + _de_2 = g->nodes()->create(); + _de_2->name("de"); + } + +public: + luci::CircleFullyConnected *fc(void) { return _fc; } + luci::CircleQuantize *qu(void) { return _qu; } + luci::CircleQuantize *qu_2(void) { return _qu_2; } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleQuantize *_qu = nullptr; + luci::CircleDequantize *_de = nullptr; + luci::CircleQuantize *_qu_2 = nullptr; + luci::CircleDequantize *_de_2 = nullptr; +}; + +class QuantDequantGraph : public TestIOGraph, public QuantDequantGraphlet +{ +public: + QuantDequantGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1}, {1}); + QuantDequantGraphlet::init(g()); + + _fc->input(input()); + _qu->input(_fc); + _de->input(_qu); + _qu_2->input(_de); + _de_2->input(_qu_2); + + output()->from(_de_2); + } +}; + +} // namespace + +TEST(RemoveQDQForMixedPrecisionOpPass, remove_qdq_FC) +{ + QuantDequantGraph g; + luci::RemoveQDQForMixedPrecisionOpPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); + + EXPECT_EQ(g.fc(), g.qu_2()->input()); +} + +TEST(RemoveQDQForMixedPrecisionOpPass, remove_qdq_wrong_op_NEG) +{ + QuantDequantGraph g; + luci::RemoveQDQForMixedPrecisionOpPass pass; + + g.init(); + + g.qu()->input(g.input()); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/RemoveUnnecessaryAddPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryAddPass.cpp new file mode 100644 index 000000000..93887ccdd --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryAddPass.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryAddPass.h" + +#include "helpers/NodeFiller.h" + +#include + +#include // std::numeric_limits + +namespace +{ + +bool remove_no_effect_add(luci::CircleNode *node) +{ + auto target_node = dynamic_cast(node); + if (target_node == nullptr || target_node->dtype() != loco::DataType::FLOAT32) + return false; + + // NOTE for general activation function A: Act(A + 0) != A + if (target_node->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleConst *const_operand = nullptr; + luci::CircleNode *nonconst_operand = nullptr; + if (not luci::fill(&const_operand, &nonconst_operand).with_commutative_args_of(target_node)) + return false; + + if (dynamic_cast(nonconst_operand) != nullptr) + { + // NOTE this is degenerated '(const1 + const2)' case + return false; + } + + // check const_operand is zero + + // NOTE we assume graph is valid, so no need to check shape. + // just check that const operand is zero + auto const size = const_operand->size(); + for (uint32_t index = 0; index < size; index++) + { + auto const value = const_operand->at(index); + if (std::abs(value) > std::numeric_limits::min()) + { + // at least one value is not zero + return false; + } + } + + replace(target_node).with(nonconst_operand); + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * + * [CircleNode] + * | + * | [CircleConst(=0)] + * | / + * | / + * [CircleAdd] (no activation) + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * | + * | + * [CircleNode] + * + **/ +bool RemoveUnnecessaryAddPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast(node); + if (remove_no_effect_add(circle_node)) + { + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveUnnecessaryAddPass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryAddPass.test.cpp new file mode 100644 index 000000000..c2b264022 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryAddPass.test.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveUnnecessaryAddPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class AddGraphlet +{ +public: + AddGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 input_shape, bool fill_with_zeros, bool activation) + { + // zero Create. + _zero = g->nodes()->create(); + _zero->rank(1); + _zero->dim(0).set(input_shape.size()); + _zero->shape_status(luci::ShapeStatus::VALID); + _zero->dtype(loco::DataType::FLOAT32); + _zero->size(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) + _zero->at(i) = fill_with_zeros ? 0 : 1; + _zero->name("begin"); + + // Add Create. + _add = g->nodes()->create(); + _add->y(_zero); + if (activation) + { + _add->fusedActivationFunction(luci::FusedActFunc::RELU); + } + else + { + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + } + _add->dtype(loco::DataType::FLOAT32); + _add->shape(input_shape); + _add->name("add"); + } + +protected: + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_zero = nullptr; +}; + +class AddGraph : public TestIOGraph, public AddGraphlet +{ +public: + AddGraph() = default; + +public: + void init(const ShapeU32 shape, bool fill_with_zeros, bool activation) + { + TestIOGraph::init(shape, shape); + AddGraphlet::init(g(), shape, fill_with_zeros, activation); + + _add->x(input()); + output()->from(_add); + } +}; + +} // namespace + +TEST(RemoveUnnecessaryAddPass, name_test) +{ + luci::RemoveUnnecessaryAddPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveUnnecessaryAddPass, simple_test) +{ + luci::RemoveUnnecessaryAddPass pass; + + AddGraph g; + g.init({1, 14, 21, 32}, true, false); + + ASSERT_TRUE(pass.run(g.g())); + + // check Add is removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto add = dynamic_cast(node)) + count++; + } + ASSERT_EQ(0, count); +} + +TEST(RemoveUnnecessaryAddPass, not_removed_NEG) +{ + luci::RemoveUnnecessaryAddPass pass; + AddGraph g; + g.init({1, 14, 21, 32}, false, false); + + ASSERT_FALSE(pass.run(g.g())); + + // check Add is not removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto add = dynamic_cast(node)) + count++; + } + ASSERT_EQ(1, count); +} + +TEST(RemoveUnnecessaryAddPass, activation_blocks_removal_NEG) +{ + luci::RemoveUnnecessaryAddPass pass; + AddGraph g; + g.init({1, 14, 21, 32}, true, true); + + ASSERT_FALSE(pass.run(g.g())); + + // check Add is not removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto add = dynamic_cast(node)) + count++; + } + ASSERT_EQ(1, count); +} diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp new file mode 100644 index 000000000..f9e0bf99a --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -0,0 +1,498 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryTransposeNetPass.h" + +#include +#include + +#include + +namespace +{ + +class TaggedShapeAnalyzer final +{ +public: + /** + * @brief check 'Transpose-Reshape-Transpose' can be replaced by one 'Reshape'. + * + * @example + * Let's explain how analyzer check Transpose-Reshape-Transpose pattern with an exact example. + * + * Let's assume under pattern is given : + * + * Input(1, 7, 7, 448) + * | + * Transpose(perm=(0, 3, 1, 2)) + * | + * Resahape(shape=(1, 448, 49)) + * | + * Transpose(perm=(0, 2, 1)) + * | + * Output(1, 49, 448) + * + * It simulates how each dimension of the tensor's shape are transformed/moved + * using a member variable named '_shape'. + * 'tags' in _shape record the initial order of each dimension. + * + * TIMELINE | _shape states : + * | + * init_shape_with_tag | - value : (1) (7) (7) (448) + * | - tags : (-) (0) (1) (2) + * | + * analyze_transpose | - value : (1) (448) (7) (7) + * | - tags : (-) (2) (0) (1) + * | + * analyze_reshape | - value : (1) (448) (49) + * | - tags : (-) (2) (0, 1) + * | + * anaylze_transpose | - value : (1) (49) (448) + * | - tags : (-) (0, 1) (2) + * + * After all simulation done, if tags are in same order as initial _shape, + * Transpose has no effect in final shape, which they can be removed as + * unnecessary Ops. + */ + template + bool can_remove_transposes(const luci::CircleTranspose *f_tr, const luci::CircleReshape *m_rs, + const luci::CircleTranspose *b_tr); + +private: + void init_shape_with_tag(const luci::CircleNode *); + + template void analyze_transpose(const luci::CircleTranspose *); + + template bool analyze_reshape(const luci::CircleReshape *); + + bool verify_tag() const; + + struct Dim final + { + int32_t value; + std::vector tags; + }; + + const uint8_t START_TAG = 0; + + using Shape = std::vector; + Shape _shape; + + int32_t flatsize(const Shape &shape) const; + bool inference_incomplete_shape(const Shape &src, Shape &dst); +}; + +int32_t TaggedShapeAnalyzer::flatsize(const Shape &shape) const +{ + int32_t size = 1; + for (const auto &dim : shape) + { + if (dim.value >= 1) + size *= dim.value; + } + return size; +} + +/** + * @brief if 'dst' has -1 valued dim, replace -1 with inferenced value + * + * @return ture, if successfully replace -1 value + * false, otherwise + */ +bool TaggedShapeAnalyzer::inference_incomplete_shape(const Shape &src, Shape &dst) +{ + std::vector incomplete_indexes; + for (size_t i = 0; i < dst.size(); i++) + { + if (dst[i].value == -1) + incomplete_indexes.push_back(i); + } + + if (incomplete_indexes.size() == 0) + return true; + else if (incomplete_indexes.size() == 1) + { + if (flatsize(src) % flatsize(dst) == 0) + dst[incomplete_indexes[0]].value = flatsize(src) / flatsize(dst); + else + return false; + } + else // incomplete_indexes.size() >= 2 + return false; + + return true; +} + +/** + * @brief initalize _shape with input tensor named in_tensor + * + * @note 'tags' are attached to non-1 valued dimension. + */ +void TaggedShapeAnalyzer::init_shape_with_tag(const luci::CircleNode *in_tensor) +{ + _shape.clear(); + uint8_t tag = START_TAG; + + for (uint32_t i = 0; i < in_tensor->rank(); i++) + { + TaggedShapeAnalyzer::Dim dim; + { + dim.value = in_tensor->dim(i).value(); + if (dim.value != 1) + dim.tags.push_back(tag++); + } + _shape.push_back(dim); + } +} + +/** + * @brief update _shape based on 'Transpose' permutation value + * + * @example Let's assume Transpose(perm=0, 3, 1, 2) is given to [before] _shape. + * + * This function reordered the Dims' order based on permutaiton value. + * + * [before] _shape : + * - value : (1) (7) (7) (448) + * - tags : (-) (0) (1) (2) + * + * [after] _shape : + * - value : (1) (448) (7) (7) + * - tags : (-) (2) (0) (1) + */ +template +void TaggedShapeAnalyzer::analyze_transpose(const luci::CircleTranspose *transpose_node) +{ + const luci::CircleConst *perm_node = loco::must_cast(transpose_node->perm()); + assert(perm_node->dtype() == PermType); + + TaggedShapeAnalyzer::Shape new_shape; + const auto size = perm_node->size(); + for (uint32_t i = 0; i < size; i++) + { + auto perm_idx = perm_node->at(i); + new_shape.push_back(_shape.at(perm_idx)); + } + _shape = new_shape; +} + +/** + * @brief update _shape based on 'Reshape' shape value + * + * @return False, if it determined that removing transposes is impossible + * + * @example Let's assume Reshape(shape=1, 448, 49) is given to [before] _shape. + * + * [before] _shape : + * - value : (1) (448) (7) (7) + * - tags : (-) (2) (0) (1) + * + * [after] _shape : + * - value : (1) (448) (49) + * - tags : (-) (2) (0, 1) + */ +template +bool TaggedShapeAnalyzer::analyze_reshape(const luci::CircleReshape *reshape_node) +{ + const luci::CircleConst *shape_node = loco::must_cast(reshape_node->shape()); + assert(shape_node->dtype() == ReshapeType); + + // At least one element must be in reshape's output-tensor. + if (shape_node->size() <= 0) + return false; + + // Create new_shape based on reshape_node/shape + Shape new_shape; + for (uint32_t i = 0; i < shape_node->size(); i++) + { + TaggedShapeAnalyzer::Dim dim; + dim.value = shape_node->at(i); + + new_shape.push_back(dim); + } + + // inference new_shape dim with -1 value + if (inference_incomplete_shape(_shape, new_shape) == false) + return false; + + // indexing for _shape [old_shape_start_idx, old_shape_end_idx) + uint32_t old_shape_start_idx = 0; + uint32_t old_shape_end_idx = 1; + auto old_shape_product = _shape[old_shape_start_idx].value; + + auto expand_range = [&]() -> bool { + if (old_shape_end_idx >= _shape.size()) + return false; + + old_shape_product *= _shape[old_shape_end_idx].value; + old_shape_end_idx++; + return true; + }; + + auto move_to_next_range = [&]() -> bool { + if (old_shape_end_idx >= _shape.size()) + return false; + + old_shape_start_idx = old_shape_end_idx; + old_shape_end_idx++; + old_shape_product = _shape[old_shape_start_idx].value; + return true; + }; + + // Add tags from '_shape' to the 'new_shape' + uint32_t new_shape_idx = 0; + while (new_shape_idx < new_shape.size()) + { + Dim &target_dim = new_shape[new_shape_idx]; + + // Ignore dim == 1 + if (target_dim.value == 1) + { + new_shape_idx++; + continue; + } + + while (old_shape_product < target_dim.value) + { + if (expand_range() == false) + break; + } + + if (old_shape_product != target_dim.value) + return false; + + assert(old_shape_product == target_dim.value); + for (uint32_t idx = old_shape_start_idx; idx < old_shape_end_idx; idx++) + { + const auto &old_tags = _shape[idx].tags; + target_dim.tags.insert(target_dim.tags.end(), old_tags.begin(), old_tags.end()); + } + + new_shape_idx++; + move_to_next_range(); + } + _shape = new_shape; + return true; +} + +bool TaggedShapeAnalyzer::verify_tag() const +{ + // check whether tags in _shape are incremental + uint8_t tag = START_TAG; + for (const auto &dim : _shape) + { + for (const auto &t : dim.tags) + { + if (t == tag) + tag++; + else + return false; + } + } + return true; +} + +// For implementation details, please refer the comment with declaration. +template +bool TaggedShapeAnalyzer::can_remove_transposes(const luci::CircleTranspose *f_tr, + const luci::CircleReshape *m_rs, + const luci::CircleTranspose *b_tr) +{ + assert(loco::must_cast(f_tr->perm())->dtype() == DType); + assert(loco::must_cast(m_rs->shape())->dtype() == DType); + assert(loco::must_cast(b_tr->perm())->dtype() == DType); + + const luci::CircleNode *in_tensor = loco::must_cast(f_tr->a()); + + init_shape_with_tag(in_tensor); + + analyze_transpose(f_tr); + + if (not analyze_reshape(m_rs)) + return false; + + analyze_transpose(b_tr); + + if (not verify_tag()) + return false; + + return true; +} + +/** + * @brief create CircleReshape node that reshapes 'front_transpose input tensor shape' into + * 'back_transposes output tensor shape' + */ +template +luci::CircleReshape *create_reshape_node(loco::Graph *graph, + const luci::CircleTranspose *front_transpose, + const luci::CircleReshape *mid_reshape, + const luci::CircleTranspose *back_transpose) +{ + std::string composed_name = + front_transpose->name() + ";" + mid_reshape->name() + ";" + back_transpose->name(); + + std::vector> src_origin{luci::get_origin(front_transpose), + luci::get_origin(mid_reshape), + luci::get_origin(back_transpose)}; + auto const composed_origin = luci::composite_origin(src_origin); + + auto shape_node = graph->nodes()->create(); + { + shape_node->dtype(ShapeType); + shape_node->rank(1); + shape_node->dim(0).set(back_transpose->rank()); + + shape_node->size(back_transpose->rank()); + for (uint32_t i = 0; i < back_transpose->rank(); i++) + { + shape_node->at(i) = back_transpose->dim(i).value(); + } + shape_node->shape_status(luci::ShapeStatus::VALID); + shape_node->name(composed_name + "/shape"); + luci::add_origin(shape_node, composed_origin); + } + + auto reshape_node = graph->nodes()->create(); + { + reshape_node->name(composed_name); + reshape_node->tensor(front_transpose->a()); + reshape_node->shape(shape_node); + luci::add_origin(reshape_node, composed_origin); + } + return reshape_node; +} + +bool remove_unnecessary_transpose(luci::CircleTranspose *node) +{ + // find 'front_transpose - mid_reshape - back_transpose' pattern + const auto back_transpose = node; + const auto mid_reshape = dynamic_cast(back_transpose->a()); + { + if (mid_reshape == nullptr) + return false; + } + const auto front_transpose = dynamic_cast(mid_reshape->tensor()); + { + if (not front_transpose) + return false; + } + + // check perm and shape are CircleConst node and its' datatype is S32 + const auto back_perm = dynamic_cast(back_transpose->perm()); + { + if (back_perm == nullptr) + return false; + + if (back_perm->dtype() != loco::DataType::S32) + return false; + } + const auto shape = dynamic_cast(mid_reshape->shape()); + { + if (shape == nullptr) + return false; + + if (shape->dtype() != loco::DataType::S32) + return false; + } + const auto front_perm = dynamic_cast(front_transpose->perm()); + { + if (front_perm == nullptr) + return false; + + if (front_perm->dtype() != loco::DataType::S32) + return false; + } + + // for now, handle only rank reduction equal (not expansion) cases + const auto output_rank = back_transpose->rank(); + const auto input_rank = front_transpose->rank(); + if (input_rank < output_rank) + return false; + + // analyze pattern to check this pass is applicable + TaggedShapeAnalyzer analyzer; + if (not analyzer.can_remove_transposes(front_transpose, mid_reshape, + back_transpose)) + return false; + + // repalce with new_node + luci::CircleReshape *new_node = create_reshape_node( + node->graph(), front_transpose, mid_reshape, back_transpose); + + replace(node).with(new_node); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * + * Current pass only targets below cases: + * - in.rank() >= out.rank() + * - 'Reshape' used to reduce N dimension into one (e.g. A x B x C => A x BC) + * + * + * [CircleNode] [CircleConst] + * (in) (perm) + * \ / + * [CircleTranspose] [CircleConst] + * \ (shape) + * \ / + * [CircleReshape] [CircleConst] + * \ (perm) + * \ / + * [CircleTranspose] + * \ + * \ + * [CircleNode] + * (out) + * + * AFTER + * + * [CircleNode] [CircleConst] + * (in) (shape) + * \ / + * [CircleReshape] + * (new) + * \ + * [CircleNode] + * (out) + * + */ + +bool RemoveUnnecessaryTransposeNetPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto transpose_node = dynamic_cast(node)) + { + if (remove_unnecessary_transpose(transpose_node)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.test.cpp new file mode 100644 index 000000000..5d0a965cc --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.test.cpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryTransposeNetPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class TransposeReshapeTransposeGraph : public TestIOGraph +{ + +public: + // create input-transpose-reshape-transpose-output graph + void init_whole_graph(ShapeU32 in_shape, ShapeU32 front_perm, ShapeI32 mid_shape, + ShapeU32 back_perm, ShapeU32 out_shape) + { + TestIOGraph::init(in_shape, out_shape); + + _front_perm = g()->nodes()->create(); + { + _front_perm->name("front_transpose/perm"); + init_circle_const(_front_perm, front_perm); + } + + _front_transpose = g()->nodes()->create(); + { + _front_transpose->a(input()); + _front_transpose->name("front_transpose"); + _front_transpose->perm(_front_perm); + } + + _mid_shape = g()->nodes()->create(); + { + _mid_shape->name("mid_reshpae/shape"); + init_circle_const(_mid_shape, mid_shape); + } + + _mid_reshape = g()->nodes()->create(); + { + _mid_reshape->name("mid_reshape"); + _mid_reshape->tensor(_front_transpose); + _mid_reshape->shape(_mid_shape); + } + + _back_perm = g()->nodes()->create(); + { + _back_perm->name("back_transpose/perm"); + init_circle_const(_back_perm, back_perm); + } + + _back_transpose = g()->nodes()->create(); + { + _back_transpose->name("back_transpose"); + _back_transpose->a(_mid_reshape); + _back_transpose->perm(_back_perm); + } + + output()->from(_back_transpose); + } + + // create input-transpose-transpose-output graph + void init_transpose_only(ShapeU32 in_shape, ShapeU32 front_perm, ShapeU32 back_perm, + ShapeU32 out_shape) + { + TestIOGraph::init(in_shape, out_shape); + + _front_perm = g()->nodes()->create(); + { + _front_perm->name("front_transpose/perm"); + init_circle_const(_front_perm, front_perm); + } + + _front_transpose = g()->nodes()->create(); + { + _front_transpose->a(input()); + _front_transpose->name("front_transpose"); + _front_transpose->perm(_front_perm); + } + + _back_perm = g()->nodes()->create(); + { + _back_perm->name("back_transpose/perm"); + init_circle_const(_back_perm, back_perm); + } + + _back_transpose = g()->nodes()->create(); + { + _back_transpose->name("back_transpose"); + _back_transpose->a(_front_transpose); + _back_transpose->perm(_back_perm); + } + + output()->from(_back_transpose); + } + +private: + void init_circle_const(luci::CircleConst *const_node, ShapeU32 shape) + { + const_node->dtype(loco::DataType::S32); + const_node->size(shape.size()); + uint32_t i = 0; + for (auto v : shape) + { + const_node->at(i++) = v; + } + } + + void init_circle_const(luci::CircleConst *const_node, ShapeI32 shape) + { + const_node->dtype(loco::DataType::S32); + const_node->size(shape.size()); + uint32_t i = 0; + for (auto v : shape) + { + const_node->at(i++) = v; + } + } + + luci::CircleTranspose *_front_transpose = nullptr; + luci::CircleConst *_front_perm = nullptr; + + luci::CircleReshape *_mid_reshape = nullptr; + luci::CircleConst *_mid_shape = nullptr; + + luci::CircleTranspose *_back_transpose = nullptr; + luci::CircleConst *_back_perm = nullptr; +}; + +bool is_transpose_removed(loco::Graph *g) +{ + bool transpose_exist = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (dynamic_cast(node)) + { + transpose_exist = true; + break; + } + } + return not transpose_exist; +} + +} // namespace + +TEST(RemoveUnnecessaryTransposeNetPass, rank_reduction_pattern1) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 14, 14, 192) + * | + * (1, 192, 14, 14) + * | + * (1, 192, 196) + * | + * (1, 196, 192) + */ + g.init_whole_graph(/*in*/ {1, 14, 14, 192}, /*perm*/ {0, 3, 1, 2}, /*reshape*/ {1, 192, 196}, + /*perm*/ {0, 2, 1}, /*out*/ {1, 196, 192}); + + EXPECT_TRUE(pass.run(g.g())); + EXPECT_TRUE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, rank_reduction_pattern2) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 100, 10, 12) + * | + * (1, 10, 12, 100) + * | + * (120, 100) + * | + * (100, 120) + */ + g.init_whole_graph(/*in*/ {1, 100, 10, 12}, /*perm*/ {0, 2, 3, 1}, /*reshape*/ {120, 100}, + /*perm*/ {1, 0}, + /*out*/ {100, 120}); + + EXPECT_TRUE(pass.run(g.g())); + EXPECT_TRUE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, identity_pattern) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 2, 3) + * | + * (1, 2, 3) + * | + * (1, 2, 3) + * | + * (1, 2, 3) + */ + g.init_whole_graph(/*in*/ {1, 2, 3}, /*perm*/ {0, 1, 2}, /*reshape*/ {1, 2, 3}, + /*perm*/ {0, 1, 2}, + /*out*/ {1, 2, 3}); + + EXPECT_TRUE(pass.run(g.g())); + EXPECT_TRUE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, incomplete_reshape_pattern) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 7, 7, 448) + * | + * (1, 448, 7, 7) + * | + * (1, 448, -1) + * | + * (1, 49, 448) + */ + g.init_whole_graph(/*in*/ {1, 7, 7, 448}, /*perm*/ {0, 3, 1, 2}, /*reshape*/ {1, 448, -1}, + /*perm*/ {0, 2, 1}, /*out*/ {1, 49, 448}); + + EXPECT_TRUE(pass.run(g.g())); + EXPECT_TRUE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, basic_pattern1_NEG) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 2, 4, 6) + * | + * (1, 2, 6, 4) + * | + * (1, 12, 4) + * | + * (1, 4, 12) + */ + g.init_whole_graph(/*in*/ {1, 2, 4, 6}, /*perm*/ {0, 1, 3, 2}, /*reshape*/ {1, 12, 4}, + /*perm*/ {0, 2, 1}, + /*out*/ {1, 4, 12}); + + EXPECT_FALSE(pass.run(g.g())); + EXPECT_FALSE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, basic_pattern2_NEG) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (15, 10, 10) + * | + * (10, 10, 15) + * | + * (1, 1, 1500) + * | + * (1500, 1, 1) + */ + g.init_whole_graph(/*in*/ {15, 10, 10}, /*perm*/ {1, 2, 0}, /*reshape*/ {1, 1, 1500}, + /*perm*/ {2, 0, 1}, + /*out*/ {1500, 1, 1}); + + EXPECT_FALSE(pass.run(g.g())); + EXPECT_FALSE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, basic_pattern3_NEG) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 2, 3, 4) + * | + * perm (0, 3, 1, 2) + * | + * (1, 4, 2, 3) + * | + * perm (0, 2, 3, 1) + * | + * (1, 2, 3, 4) + */ + g.init_transpose_only(/*in*/ {1, 2, 3, 4}, /*perm*/ {0, 3, 1, 2}, /*perm*/ {0, 2, 3, 1}, + /*out*/ {1, 2, 3, 4}); + + EXPECT_FALSE(pass.run(g.g())); + EXPECT_FALSE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, incomplete_reshape_pattern1_NEG) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 10, 10, 14) + * | + * (1, 14, 10, 10) + * | + * (1, 14, 2, -1) #(1, 14, 2, 50) + * | + * (1, 14, 50, 2) + */ + g.init_whole_graph(/*in*/ {1, 10, 10, 14}, /*perm*/ {0, 3, 1, 2}, /*reshape*/ {1, 14, 2, -1}, + /*perm*/ {0, 1, 3, 2}, /*out*/ {1, 14, 50, 2}); + + EXPECT_FALSE(pass.run(g.g())); + EXPECT_FALSE(is_transpose_removed(g.g())); +} + +TEST(RemoveUnnecessaryTransposeNetPass, incomplete_reshape_pattern2_NEG) +{ + TransposeReshapeTransposeGraph g; + luci::RemoveUnnecessaryTransposeNetPass pass; + + /** + * (1, 10, 10, 14) + * | + * (1, 14, 10, 10) + * | + * (1, 14, -1, -1) # unexpected shape + * | + * (1, 14, 50, 2) + */ + g.init_whole_graph(/*in*/ {1, 10, 10, 14}, /*perm*/ {0, 3, 1, 2}, /*reshape*/ {1, 14, -1, -1}, + /*perm*/ {0, 1, 3, 2}, /*out*/ {1, 14, 50, 2}); + + EXPECT_FALSE(pass.run(g.g())); + EXPECT_FALSE(is_transpose_removed(g.g())); +} diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp index 07457c1e8..a1ff82f83 100644 --- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp @@ -138,6 +138,11 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) if (dynamic_cast(fc->weights())) return false; // NonConst + // NOTE For const inputs, it is possible to block this conversion, + // because we can make transposed FC rather than matmul to improve + // accuracy of quantized models by sacrificing latency. + // See https://github.com/Samsung/ONE/discussions/11941 for more details. + if ((ty = dynamic_cast(fc->weights()))) // is y a transpose? { adj_y = false; diff --git a/compiler/luci/pass/src/ReplaceWithFCGeluFCPass.cpp b/compiler/luci/pass/src/ReplaceWithFCGeluFCPass.cpp new file mode 100644 index 000000000..a9288a37a --- /dev/null +++ b/compiler/luci/pass/src/ReplaceWithFCGeluFCPass.cpp @@ -0,0 +1,316 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ReplaceWithFCGeluFCPass.h" +#include "helpers/NodeFiller.h" + +#include +#include +#include +#include + +#include +#include + +namespace +{ + +// Float comparison +bool same(float a, float b) { return fabs(a - b) < 1e-5; } + +luci::CircleConst *multiply_const(luci::CircleConst *node, float multiplier) +{ + auto cloned = luci::clone(node); + + assert(node->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(cloned->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + + for (uint32_t i = 0; i < cloned->size(); i++) + { + cloned->at(i) *= multiplier; + } + + luci::add_origin(cloned, luci::get_origin(node)); + + return cloned; +} + +/** + * Below diagram shows the target pattern. + * - The pattern will be converted to FC (front) -> Gelu -> FC (back). + * - FC (front) has the same weights with fc1 + * - FC (back)'s weights is twice of fc3's weights + * + * +---- [In] + * | | + * | V + * | fc2 (w = w of fc1 * sqrt(0.5). bias as well) -> const folded + * | | + * fc1 V + * | erf + * | | + * | V + * | add_one (1.0) + * | | + * | V + * +---> mul + * | + * V + * fc3 + * | + * V + * [Out] + * + */ +class FCGeluFCPattern final +{ +public: + FCGeluFCPattern(luci::CircleFullyConnected *cand) + { + assert(cand); + _fc3 = cand; + } + +public: + bool matched(); + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleFullyConnected *_fc1 = nullptr; + luci::CircleFullyConnected *_fc2 = nullptr; + luci::CircleFullyConnected *_fc3 = nullptr; + luci::CircleCustom *_erf = nullptr; + luci::CircleCustomOut *_erf_out = nullptr; + luci::CircleAdd *_add_one = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_const_one = nullptr; + luci::CircleConst *_fc1_w = nullptr; + luci::CircleConst *_fc2_w = nullptr; + luci::CircleConst *_fc3_w = nullptr; + luci::CircleConst *_fc1_b = nullptr; + luci::CircleConst *_fc2_b = nullptr; + luci::CircleConst *_fc3_b = nullptr; +}; + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +bool FCGeluFCPattern::matched() +{ + // check pattern + _fc3_w = dynamic_cast(_fc3->weights()); + CHECK_OR_FALSE(_fc3_w != nullptr); + + _mul = dynamic_cast(_fc3->input()); + CHECK_OR_FALSE(_mul != nullptr); + + CHECK_OR_FALSE(luci::fill(&_fc1, &_add_one).with_commutative_args_of(_mul)); + + _fc1_w = dynamic_cast(_fc1->weights()); + CHECK_OR_FALSE(_fc1_w != nullptr); + + CHECK_OR_FALSE(_fc1->weights_format() == luci::CircleFullyConnected::WeightsFormat::DEFAULT); + + _ifm = loco::must_cast(_fc1->input()); + + CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one)); + + _erf = dynamic_cast(_erf_out->input()); + CHECK_OR_FALSE(_erf != nullptr); + + // Check erf + CHECK_OR_FALSE(_erf->custom_code() == "Erf"); + CHECK_OR_FALSE(_erf->numInputs() == 1); + CHECK_OR_FALSE(_erf->numOutputs() == 1); + + _fc2 = dynamic_cast(_erf->inputs(0)); + CHECK_OR_FALSE(_fc2 != nullptr); + _fc2_w = dynamic_cast(_fc2->weights()); + CHECK_OR_FALSE(_fc2_w != nullptr); + + CHECK_OR_FALSE(_fc2->weights_format() == luci::CircleFullyConnected::WeightsFormat::DEFAULT); + CHECK_OR_FALSE(_ifm == _fc2->input()); + + // Check Activation to be NONE + CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_fc1->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_fc2->fusedActivationFunction() == luci::FusedActFunc::NONE); + // fc3 can have activation + + // Check dtype + CHECK_OR_FALSE(_fc1->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_fc2->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_fc3->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_erf->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_erf_out->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_add_one->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_mul->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_fc1_w->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_fc2_w->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_fc3_w->dtype() == loco::DataType::FLOAT32); + + // Check _const_one condition + CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_one->size() == 1); + CHECK_OR_FALSE(_const_one->at(0) == 1); + + // Check fc2_w = fc1_w * sqrt(0.5) + CHECK_OR_FALSE(_fc1_w->size() == + _fc2_w->size()); + for (uint32_t i = 0; i < _fc1_w->size(); i++) + { + const auto fc1_val = _fc1_w->at(i); + const auto fc2_val = _fc2_w->at(i); + CHECK_OR_FALSE(::same(fc1_val * sqrtf(0.5f), fc2_val)); + } + + // Start to check bias + _fc1_b = dynamic_cast(_fc1->bias()); + _fc2_b = dynamic_cast(_fc2->bias()); + _fc3_b = dynamic_cast(_fc3->bias()); + + // Check there is no non-constant bias + if (_fc1_b == nullptr) + CHECK_OR_FALSE(dynamic_cast(_fc1->bias()) != nullptr); + + if (_fc2_b == nullptr) + CHECK_OR_FALSE(dynamic_cast(_fc2->bias()) != nullptr); + + if (_fc3_b == nullptr) + CHECK_OR_FALSE(dynamic_cast(_fc3->bias()) != nullptr); + + // Check both fc1 and fc2 have biases, or both have no bias + CHECK_OR_FALSE((_fc1_b == nullptr and _fc2_b == nullptr) or + (_fc1_b != nullptr and _fc2_b != nullptr)); + + // Check values of fc1 and fc2 bias (if bias exists) + if (_fc1_b != nullptr and _fc2_b != nullptr) + { + CHECK_OR_FALSE(_fc1_b->size() == + _fc2_b->size()); + for (uint32_t i = 0; i < _fc1_b->size(); i++) + { + const auto fc1_val = _fc1_b->at(i); + const auto fc2_val = _fc2_b->at(i); + + // fc2_b = fc1_b * sqrt(0.5) + CHECK_OR_FALSE(::same(fc1_val * sqrtf(0.5f), fc2_val)); + } + } + + return true; +} + +#undef CHECK_OR_FALSE + +class ReplaceWithFCGeluFC final +{ +public: + ReplaceWithFCGeluFC(const FCGeluFCPattern *p) : _p(p) {} + +public: + void apply(void); + +private: + // Create FC -> Gelu -> FC pattern and set front/back + void create_fc_gelu_fc(luci::CircleFullyConnected *&front, luci::CircleFullyConnected *&back); + +private: + const FCGeluFCPattern *_p; +}; + +void ReplaceWithFCGeluFC::create_fc_gelu_fc(luci::CircleFullyConnected *&front, + luci::CircleFullyConnected *&back) +{ + auto graph = _p->_fc1->graph(); + assert(graph); + + front = loco::must_cast(luci::clone_node(_p->_fc1, graph)); + front->weights(_p->_fc1->weights()); + front->bias(_p->_fc1->bias()); + luci::add_origin(front, luci::get_origin(_p->_fc1)); + + auto gelu = graph->nodes()->create(); + gelu->features(front); + // TODO Support approximate = True pattern + gelu->approximate(false); + gelu->name(_p->_erf->name() + "_gelu"); + std::vector> origin_vec{ + luci::get_origin(_p->_fc2), luci::get_origin(_p->_erf), luci::get_origin(_p->_add_one), + luci::get_origin(_p->_mul)}; + luci::add_origin(gelu, luci::composite_origin(origin_vec)); + + back = loco::must_cast(luci::clone_node(_p->_fc3, graph)); + back->input(gelu); + back->weights(multiply_const(_p->_fc3_w, 2.0f /* multiplier */)); + back->bias(_p->_fc3->bias()); + luci::add_origin(back, luci::get_origin(_p->_fc3)); +} + +void ReplaceWithFCGeluFC::apply() +{ + luci::CircleFullyConnected *front = nullptr; + luci::CircleFullyConnected *back = nullptr; + create_fc_gelu_fc(front, back); + + assert(front); // FIX_ME_UNLESS + assert(back); // FIX_ME_UNLESS + + front->input(_p->_ifm); + + replace(_p->_fc3).with(back); +} + +bool replace_fc_gelu_fc(luci::CircleFullyConnected *fc) +{ + assert(fc); + + FCGeluFCPattern pattern(fc); + if (pattern.matched()) + { + ReplaceWithFCGeluFC replace(&pattern); + replace.apply(); + return true; + } + + return false; +} + +} // namespace + +namespace luci +{ + +bool ReplaceWithFCGeluFCPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast(node); + if (not fc) + continue; + + if (replace_fc_gelu_fc(fc)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ReplaceWithFCGeluFCPass.test.cpp b/compiler/luci/pass/src/ReplaceWithFCGeluFCPass.test.cpp new file mode 100644 index 000000000..4b82bfd0e --- /dev/null +++ b/compiler/luci/pass/src/ReplaceWithFCGeluFCPass.test.cpp @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ReplaceWithFCGeluFCPass.h" + +#include + +#include + +#include +#include + +namespace +{ + +using namespace luci::test; + +class FCGeluFCGraphlet +{ +public: + FCGeluFCGraphlet() = default; + + virtual ~FCGeluFCGraphlet() = default; + + void init(loco::Graph *g) + { + _fc1 = g->nodes()->create(); + _fc2 = g->nodes()->create(); + _fc3 = g->nodes()->create(); + _erf = g->nodes()->create(1, 1); + _erf_out = g->nodes()->create(); + _add_one = g->nodes()->create(); + _mul = g->nodes()->create(); + _const_one = g->nodes()->create(); + _fc1_w = g->nodes()->create(); + _fc2_w = g->nodes()->create(); + _fc3_w = g->nodes()->create(); + auto no_bias = g->nodes()->create(); + + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + _add_one->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc1->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc2->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc3->fusedActivationFunction(luci::FusedActFunc::NONE); + + _fc1->name("fc1"); + _fc2->name("fc2"); + _fc3->name("fc3"); + _erf->name("erf"); + _erf_out->name("erf_out"); + _add_one->name("add_one"); + _mul->name("mul"); + _const_one->name("const_one"); + _fc1_w->name("fc1_w"); + _fc2_w->name("fc2_w"); + _fc3_w->name("fc3_w"); + + _erf->custom_code("Erf"); + + _const_one->dtype(loco::DataType::FLOAT32); + _const_one->size(1); + _const_one->shape({1}); + _const_one->at(0) = 1.0; + _const_one->shape_status(luci::ShapeStatus::VALID); + + _fc1_w->dtype(loco::DataType::FLOAT32); + _fc1_w->size(16); + _fc1_w->shape({4, 4}); + for (uint32_t i = 0; i < 16; i++) + _fc1_w->at(i) = 1.0; + _fc1_w->shape_status(luci::ShapeStatus::VALID); + + _fc2_w->dtype(loco::DataType::FLOAT32); + _fc2_w->size(16); + _fc2_w->shape({4, 4}); + for (uint32_t i = 0; i < 16; i++) + _fc2_w->at(i) = sqrtf(0.5); + _fc2_w->shape_status(luci::ShapeStatus::VALID); + + _fc3_w->dtype(loco::DataType::FLOAT32); + _fc3_w->size(16); + _fc3_w->shape({4, 4}); + for (uint32_t i = 0; i < 16; i++) + _fc3_w->at(i) = 1.0; + _fc3_w->shape_status(luci::ShapeStatus::VALID); + + _fc1->dtype(loco::DataType::FLOAT32); + _fc2->dtype(loco::DataType::FLOAT32); + _fc3->dtype(loco::DataType::FLOAT32); + _erf->dtype(loco::DataType::FLOAT32); + _erf_out->dtype(loco::DataType::FLOAT32); + _add_one->dtype(loco::DataType::FLOAT32); + _mul->dtype(loco::DataType::FLOAT32); + + // Connect nodes + _fc1->weights(_fc1_w); + _fc1->bias(no_bias); + _fc2->weights(_fc2_w); + _fc2->bias(no_bias); + _erf->inputs(0, _fc2); + _erf_out->input(_erf); + _add_one->x(_erf_out); + _add_one->y(_const_one); + _mul->x(_fc1); + _mul->y(_add_one); + _fc3->input(_mul); + _fc3->weights(_fc3_w); + _fc3->bias(no_bias); + } + +protected: + luci::CircleFullyConnected *_fc1 = nullptr; + luci::CircleFullyConnected *_fc2 = nullptr; + luci::CircleFullyConnected *_fc3 = nullptr; + luci::CircleCustom *_erf = nullptr; + luci::CircleCustomOut *_erf_out = nullptr; + luci::CircleAdd *_add_one = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_const_one = nullptr; + luci::CircleConst *_fc1_w = nullptr; + luci::CircleConst *_fc2_w = nullptr; + luci::CircleConst *_fc3_w = nullptr; +}; + +class FCGeluFCGraphletWithBias : public FCGeluFCGraphlet +{ +public: + FCGeluFCGraphletWithBias() = default; + + virtual ~FCGeluFCGraphletWithBias() = default; + + void init(loco::Graph *g) + { + // Create graphlet without bias + FCGeluFCGraphlet::init(g); + + // Set bias + _fc1_b = g->nodes()->create(); + _fc2_b = g->nodes()->create(); + _fc3_b = g->nodes()->create(); + + _fc1_b->name("fc1_b"); + _fc2_b->name("fc2_b"); + _fc3_b->name("fc3_b"); + + _fc1_b->dtype(loco::DataType::FLOAT32); + _fc1_b->size(4); + _fc1_b->shape({4}); + for (uint32_t i = 0; i < 4; i++) + _fc1_b->at(i) = 1.0; + _fc1_b->shape_status(luci::ShapeStatus::VALID); + + _fc2_b->dtype(loco::DataType::FLOAT32); + _fc2_b->size(4); + _fc2_b->shape({4}); + for (uint32_t i = 0; i < 4; i++) + _fc2_b->at(i) = sqrtf(0.5); + _fc2_b->shape_status(luci::ShapeStatus::VALID); + + _fc3_b->dtype(loco::DataType::FLOAT32); + _fc3_b->size(4); + _fc3_b->shape({4}); + for (uint32_t i = 0; i < 4; i++) + _fc3_b->at(i) = 1.0; + _fc3_b->shape_status(luci::ShapeStatus::VALID); + + // Connect nodes + _fc1->bias(_fc1_b); + _fc2->bias(_fc2_b); + _fc3->bias(_fc3_b); + } + +protected: + luci::CircleConst *_fc1_b = nullptr; + luci::CircleConst *_fc2_b = nullptr; + luci::CircleConst *_fc3_b = nullptr; +}; + +class ReplaceWithFCGeluFCTestGraph : public TestIOGraph, public FCGeluFCGraphlet +{ +public: + ReplaceWithFCGeluFCTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 4, 4}, {1, 4, 4}); + FCGeluFCGraphlet::init(g()); + + _fc1->input(input()); + _fc2->input(input()); + + output()->from(_fc3); + } +}; + +class ReplaceWithFCGeluFCTestNegGraph : public TestIOGraph, public FCGeluFCGraphlet +{ +public: + ReplaceWithFCGeluFCTestNegGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 4, 4}, {1, 4, 4}); + FCGeluFCGraphlet::init(g()); + _fc1->input(input()); + _fc2->input(_fc1); + + output()->from(_fc3); + } +}; + +class ReplaceWithFCGeluFCWithBiasTestGraph : public TestIOGraph, public FCGeluFCGraphletWithBias +{ +public: + ReplaceWithFCGeluFCWithBiasTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 4, 4}, {1, 4, 4}); + FCGeluFCGraphletWithBias::init(g()); + + _fc1->input(input()); + _fc2->input(input()); + + output()->from(_fc3); + } +}; + +class ReplaceWithFCGeluFCWithBiasNegTestGraph : public TestIOGraph, public FCGeluFCGraphletWithBias +{ +public: + ReplaceWithFCGeluFCWithBiasNegTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 4, 4}, {1, 4, 4}); + FCGeluFCGraphletWithBias::init(g()); + + _fc1->input(input()); + _fc2->input(_fc1); + + output()->from(_fc3); + } +}; + +} // namespace + +TEST(ReplaceWithFCGeluFCPassTest, basic) +{ + ReplaceWithFCGeluFCTestGraph g; + luci::ReplaceWithFCGeluFCPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(ReplaceWithFCGeluFCPassTest, wrong_pattern_NEG) +{ + ReplaceWithFCGeluFCTestNegGraph g; + luci::ReplaceWithFCGeluFCPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(ReplaceWithFCGeluFCPassTest, with_bias) +{ + ReplaceWithFCGeluFCWithBiasTestGraph g; + luci::ReplaceWithFCGeluFCPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(ReplaceWithFCGeluFCPassTest, with_bias_wrong_pattern_NEG) +{ + ReplaceWithFCGeluFCWithBiasNegTestGraph g; + luci::ReplaceWithFCGeluFCPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/RequantizePass.cpp b/compiler/luci/pass/src/RequantizePass.cpp index 77c55324a..75e7e6a6b 100644 --- a/compiler/luci/pass/src/RequantizePass.cpp +++ b/compiler/luci/pass/src/RequantizePass.cpp @@ -145,6 +145,26 @@ bool RequantizePass::run(loco::Graph *g) return false; } + // Fix wrong quantized_dimension + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast(node); + + auto qparam = circle_node->quantparam(); + if (not qparam) + continue; + + if (circle_node->rank() != 1) + continue; + + if (qparam->quantized_dimension == 0) + continue; + + // For rank 1 node, quantized_dimension should be 0 + qparam->quantized_dimension = 0; + WARN(l) << "Wrong quantized_dimension is fixed (" << circle_node->name() << ")" << std::endl; + } + // Update output dtype auto graph_outputs = g->outputs(); for (auto node : loco::output_nodes(g)) diff --git a/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp b/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp index 9f7e2f17d..730543c42 100644 --- a/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp @@ -25,10 +25,42 @@ namespace { +/// @brief Returns the number of BroadcastTo among node's inputs +int32_t num_of_broadcast_to(const luci::CircleNode *node) +{ + int32_t bt_cnt = 0; + for (uint32_t idx = 0; idx < node->arity(); idx++) + { + auto input = loco::must_cast(node->arg(idx)); + switch (input->opcode()) + { + case luci::CircleOpcode::CIRCLECUSTOMOUT: + { + auto inputOut = loco::must_cast(input); + auto custom = loco::must_cast(inputOut->input()); + if (custom->custom_code() == "BroadcastTo") + ++bt_cnt; + break; + } + case luci::CircleOpcode::BROADCAST_TO: + { + ++bt_cnt; + break; + } + default: + break; + } + } + return bt_cnt; +} + /// @brief Returns the index of BroadcastTo node among cop's inputs. // NOTE This function assumes there is only one BroadcastTo node among its inputs. int32_t get_broadcastTo_index_among_inputs_of(luci::CircleCustom *cop) { + if (num_of_broadcast_to(cop) != 1) + return -1; + for (uint32_t idx = 0; idx < cop->numInputs(); idx++) { auto input = dynamic_cast(cop->inputs(idx)); @@ -38,12 +70,38 @@ int32_t get_broadcastTo_index_among_inputs_of(luci::CircleCustom *cop) if (broadcastTo->custom_code() == "BroadcastTo") return idx; } + else + { + auto broadcastTo = dynamic_cast(cop->inputs(idx)); + if (broadcastTo) + return idx; + } } return -1; } -/** BEFORE +// NOTE Broadcasting of input `Const` is skipped cause `Add` will do the broadcasting. +// TODO Implement broadcasting to the `Const` input. +/** + * [Pattern1] + * BEFORE + * [CircleConst] + * | + * [CircleNode] [BroadcastTo(Builtin)] + * \ / + * \ / + * \ / + * [AddV2(CircleCustom)] + * AFTER + * + * [CircleConst] [CircleNode] + * \ / + * \ / + * [CircleAdd] + * + * [Pattern2] + * BEFORE * [CircleConst] * | * [CircleNode] [BroadcastTo(CircleCustom)] @@ -65,19 +123,41 @@ bool resolve_with_BroadcastTo(luci::CircleCustom *addv2) if (broadcastTo_idx == -1) return false; - auto input = loco::must_cast(addv2->inputs(broadcastTo_idx)); - auto broadcastTo = loco::must_cast(input->input()); - auto name = addv2->name(); assert(name.length() > 0); + auto input = loco::must_cast(addv2->inputs(broadcastTo_idx)); + luci::CircleNode *bc = nullptr; + loco::Node *bc_input = nullptr; + + switch (input->opcode()) + { + case luci::CircleOpcode::BROADCAST_TO: + { + auto broadcastTo = loco::must_cast(addv2->inputs(broadcastTo_idx)); + bc = broadcastTo; + bc_input = broadcastTo->input(); + break; + } + case luci::CircleOpcode::CIRCLECUSTOMOUT: + { + auto inputOut = + loco::must_cast(addv2->inputs(broadcastTo_idx)); + auto braodcastTo = loco::must_cast(inputOut->input()); + bc = braodcastTo; + bc_input = braodcastTo->inputs(0); + break; + } + default: + return false; + } + auto add = addv2->graph()->nodes()->create(); add->fusedActivationFunction(luci::FusedActFunc::NONE); add->x(addv2->inputs(1 - broadcastTo_idx)); - add->y(broadcastTo->inputs(0)); + add->y(bc_input); add->name(name + "/Add"); - luci::add_origin( - add, luci::composite_origin({luci::get_origin(broadcastTo), luci::get_origin(addv2)})); + luci::add_origin(add, luci::composite_origin({luci::get_origin(bc), luci::get_origin(addv2)})); auto customOut = loco::succs(addv2); assert(customOut.size() == 1); @@ -103,7 +183,9 @@ bool resolve_custom_op(luci::CircleCustom *addv2) auto input = loco::must_cast(addv2->inputs(i)); switch (input->dtype()) { + case loco::DataType::U4: case loco::DataType::U8: + case loco::DataType::S4: case loco::DataType::S8: case loco::DataType::S16: case loco::DataType::S32: diff --git a/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp index 31c245b0e..2e591469c 100644 --- a/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp @@ -17,10 +17,213 @@ #include "luci/Pass/ResolveCustomOpAddPass.h" #include +#include -TEST(ResolveCustomOpAddPassTest, name) +#include + +using namespace luci::test; + +namespace +{ + +/** + * Test graph with Custom(AddV2) to resolve + * + * [Pattern 1] + * [Input] [BroadcastTo] + * \ / + * [Custom(AddV2)] + * | + * [CustomOut] + * | + * [Output] + * + * [Pattern 2] + * [Input] [Custom(BroadcastTo)] + * \ / + * [Custom(AddV2)] + * | + * [CustomOut] + * | + * [Output] + */ +class BroadcastToAddGraphlet { - luci::ResolveCustomOpAddPass pass; - auto const name = pass.name(); +public: + BroadcastToAddGraphlet() = default; + +public: + void init(loco::Graph *g) + { + _addV2 = g->nodes()->create(2, 1); + _addV2->custom_code("AddV2"); + _addV2->shape({2, 3}); + _addV2->dtype(loco::DataType::FLOAT32); + _addV2->name("addV2"); + + // Const as BroadcastTo input + _broadcastTo_input = g->nodes()->create(); + _broadcastTo_input->dtype(loco::DataType::FLOAT32); + _broadcastTo_input->shape({1, 3}); + _broadcastTo_input->size(3); + _broadcastTo_input->at(0) = 1; + _broadcastTo_input->at(1) = 2; + _broadcastTo_input->at(2) = 3; + + // Const as BroadcastTo shape + auto broadcastTo_shape = g->nodes()->create(); + broadcastTo_shape->dtype(loco::DataType::S32); + broadcastTo_shape->shape({2}); + broadcastTo_shape->size(2); + broadcastTo_shape->at(0) = 2; + broadcastTo_shape->at(1) = 3; + + _custom_broadcastTo = g->nodes()->create(2, 1); + _custom_broadcastTo->custom_code("BroadcastTo"); + _custom_broadcastTo->dtype(loco::DataType::FLOAT32); + _custom_broadcastTo->shape({2, 3}); + _custom_broadcastTo->name("BroadcastTo"); + + _custom_broadcastTo->inputs(0, _broadcastTo_input); + _custom_broadcastTo->inputs(1, broadcastTo_shape); + + _custom_broadcastTo_out = g->nodes()->create(); + _custom_broadcastTo->custom_code("BroadcastTo"); + _custom_broadcastTo_out->shape({2, 3}); + _custom_broadcastTo_out->dtype(loco::DataType::FLOAT32); + _custom_broadcastTo_out->index(0); + _custom_broadcastTo_out->input(_custom_broadcastTo); + + _builtin_broadcastTo = g->nodes()->create(); + _builtin_broadcastTo->dtype(loco::DataType::FLOAT32); + _builtin_broadcastTo->name("BroadcastTo"); + + _builtin_broadcastTo->input(_broadcastTo_input); + _builtin_broadcastTo->shape(broadcastTo_shape); + + _addV2_out = g->nodes()->create(); + _addV2_out->shape({2, 3}); + _addV2_out->dtype(loco::DataType::FLOAT32); + _addV2_out->index(0); + _addV2_out->input(_addV2); + } + +public: + luci::CircleCustom *addV2() { return _addV2; } + luci::CircleBroadcastTo *builtin_broadcastTo() { return _builtin_broadcastTo; } + +protected: + luci::CircleCustom *_addV2 = nullptr; + luci::CircleCustomOut *_addV2_out = nullptr; + luci::CircleCustom *_custom_broadcastTo = nullptr; + luci::CircleBroadcastTo *_builtin_broadcastTo = nullptr; + luci::CircleCustomOut *_custom_broadcastTo_out = nullptr; + luci::CircleConst *_broadcastTo_input = nullptr; +}; + +class BroadcastToAddV2Graph : public TestIGraphlet, + public TestOsGraphlet<1>, + public BroadcastToAddGraphlet +{ +public: + BroadcastToAddV2Graph() = default; + + void init(const bool &isCustomBroadcastTo) + { + TestIGraphlet::init(g(), {2, 3}); + TestOsGraphlet<1>::init(g(), {{2, 3}}); + BroadcastToAddGraphlet::init(g()); + + // connect Input and Output to test graph + _addV2->inputs(0, input()); + + if (isCustomBroadcastTo) + _addV2->inputs(1, _custom_broadcastTo_out); + else + _addV2->inputs(1, _builtin_broadcastTo); + + _addV2_out->input(_addV2); + output(0)->from(_addV2_out); + } +}; + +class ResolveCustomOpAddPassTest : public ::testing::Test +{ +public: + BroadcastToAddV2Graph _g; + luci::ResolveCustomOpAddPass _pass; +}; + +} // namespace + +TEST_F(ResolveCustomOpAddPassTest, name) +{ + auto const name = _pass.name(); ASSERT_NE(nullptr, name); } + +TEST_F(ResolveCustomOpAddPassTest, simple_test_CustomBroadcastTo) +{ + _g.init(true); + + // check if Custom(AddV2) exists before the pass + auto addV2_out = dynamic_cast(_g.output(0)->from()); + EXPECT_NE(nullptr, addV2_out); + auto addV2 = dynamic_cast(addV2_out->input()); + EXPECT_NE(nullptr, addV2); + EXPECT_TRUE("AddV2" == addV2->custom_code()); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(true, ret); + + // check if Custom(AddV2) is converted to Add + auto add = dynamic_cast(_g.output(0)->from()); + EXPECT_NE(nullptr, add); + + // check if Custom(BroadcastTo) is removed + auto input_y = dynamic_cast(add->y()); + EXPECT_NE(nullptr, input_y); +} + +TEST_F(ResolveCustomOpAddPassTest, simple_test_BuiltinBroadcastTo) +{ + _g.init(false); + + // check if Custom(AddV2) exists before the pass + auto addV2_out = dynamic_cast(_g.output(0)->from()); + EXPECT_NE(nullptr, addV2_out); + auto addV2 = dynamic_cast(addV2_out->input()); + EXPECT_NE(nullptr, addV2); + EXPECT_TRUE("AddV2" == addV2->custom_code()); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(true, ret); + + // check if Custom(AddV2) is converted to Add + auto add = dynamic_cast(_g.output(0)->from()); + EXPECT_NE(nullptr, add); + + // check if BroadcastTo is removed + auto input_y = dynamic_cast(add->y()); + EXPECT_NE(nullptr, input_y); +} + +TEST_F(ResolveCustomOpAddPassTest, wrong_custom_code_NEG) +{ + _g.init(false); + + _g.addV2()->custom_code("UNSUPORTED_CUSTOM_CODE"); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(false, ret); +} + +TEST_F(ResolveCustomOpAddPassTest, wrong_input_type_NEG) +{ + _g.init(false); + + _g.builtin_broadcastTo()->dtype(loco::DataType::BOOL); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(false, ret); +} diff --git a/compiler/luci/pass/src/ResolveFormerCustomOpPass.cpp b/compiler/luci/pass/src/ResolveFormerCustomOpPass.cpp new file mode 100644 index 000000000..75e74bb34 --- /dev/null +++ b/compiler/luci/pass/src/ResolveFormerCustomOpPass.cpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ResolveFormerCustomOpPass.h" + +#include +#include +#include + +#include + +namespace +{ + +bool resolve_with_BroadcastTo(luci::CircleCustom *node) +{ + // check if the number of inputs is 2. + if (node->numInputs() != 2) + return false; + + auto input = loco::must_cast(node->inputs(0)); + + // check if shape are support data types + auto shape = loco::must_cast(node->inputs(1)); + if (shape->dtype() != loco::DataType::S32 && shape->dtype() != loco::DataType::S64) + return false; + + auto customOut = loco::succs(node); + assert(customOut.size() == 1); + + // check if the data type of output is same with the one of the input feature map. + auto output = loco::must_cast(*customOut.begin()); + if (input->dtype() != output->dtype()) + return false; + + auto name = node->name(); + assert(name.length() > 0); + + auto broadcastTo = node->graph()->nodes()->create(); + broadcastTo->input(input); + broadcastTo->shape(shape); + broadcastTo->name(name); + luci::add_origin(broadcastTo, luci::get_origin(node)); + + replace(*customOut.begin()).with(broadcastTo); + + return true; +} + +bool resolve_custom_op(luci::CircleCustom *node) +{ + const std::string custom_code = node->custom_code(); + + if (custom_code == "BroadcastTo") + { + return resolve_with_BroadcastTo(node); + } + // TODO add more custom codes + + return false; +} + +} // namespace + +namespace luci +{ + +bool ResolveFormerCustomOpPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto cop = dynamic_cast(node); + if (not cop) + continue; + + if (resolve_custom_op(cop)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ResolveFormerCustomOpPass.test.cpp b/compiler/luci/pass/src/ResolveFormerCustomOpPass.test.cpp new file mode 100644 index 000000000..dc8cc7ead --- /dev/null +++ b/compiler/luci/pass/src/ResolveFormerCustomOpPass.test.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ResolveFormerCustomOpPass.h" + +#include + +#include +#include + +using namespace luci::test; + +namespace +{ + +/** + * graph having Custom operator BroadcastTo + * + * [Const(Input)] [Const(Shape)] + * \ / + * [Custom(BroadcastTo)] + * | + * [CustomOut] + * | + * [Output] + */ +class CustomBroadcastToGraphlet +{ +public: + CustomBroadcastToGraphlet() = default; + +public: + void init(loco::Graph *g) + { + // CircleCustom(BroadcastTo) + _broadcastTo = g->nodes()->create(2, 1); + _broadcastTo->custom_code("BroadcastTo"); + _broadcastTo->dtype(loco::DataType::FLOAT32); + _broadcastTo->shape({2, 3}); + _broadcastTo->name("BroadcastTo"); + + // CircleConst(BroadcastTo-input) + _input = g->nodes()->create(); + _input->dtype(loco::DataType::FLOAT32); + _input->shape({1, 3}); + _input->size(3); + _input->at(0) = 1; + _input->at(1) = 2; + _input->at(2) = 3; + + // CircleConst(BroadcastTo-shape) + _shape = g->nodes()->create(); + _shape->dtype(loco::DataType::S32); + _shape->shape({2}); + _shape->size(2); + _shape->at(0) = 2; + _shape->at(1) = 3; + + _broadcastTo->inputs(0, _input); + _broadcastTo->inputs(1, _shape); + + // CircleCustomOut + _broadcastTo_out = g->nodes()->create(); + _broadcastTo_out->shape({2, 3}); + _broadcastTo_out->dtype(loco::DataType::FLOAT32); + _broadcastTo_out->index(0); + _broadcastTo_out->input(_broadcastTo); + } + +public: + luci::CircleCustom *broadcastTo() { return _broadcastTo; } + luci::CircleConst *shape() { return _shape; } + luci::CircleCustomOut *broadcastTo_out() { return _broadcastTo_out; } + +protected: + luci::CircleCustom *_broadcastTo = nullptr; + luci::CircleCustomOut *_broadcastTo_out = nullptr; + luci::CircleConst *_input = nullptr; + luci::CircleConst *_shape = nullptr; +}; + +class BroadcastToGraph : public TestIGraphlet, + public TestOsGraphlet<1>, + public CustomBroadcastToGraphlet +{ +public: + BroadcastToGraph() = default; + + void init(void) + { + TestOsGraphlet<1>::init(g(), {{1, 2, 3, 1, 2, 3}}); + CustomBroadcastToGraphlet::init(g()); + + output(0)->from(_broadcastTo_out); + } +}; + +class FormerCustomOpGraphTest : public ::testing::Test +{ +public: + BroadcastToGraph _g; + luci::ResolveFormerCustomOpPass _pass; +}; + +} // namespace + +TEST(ResolveFormerCustomOpPassTest, name) +{ + luci::ResolveFormerCustomOpPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FormerCustomOpGraphTest, simple_test_BroadcastTo) +{ + _g.init(); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(true, ret); + + auto broadcastTo = dynamic_cast(_g.output(0)->from()); + EXPECT_NE(nullptr, broadcastTo); + + auto input = dynamic_cast(broadcastTo->input()); + EXPECT_NE(nullptr, input); + EXPECT_EQ(1, input->at(0)); + EXPECT_EQ(2, input->at(1)); + EXPECT_EQ(3, input->at(2)); + + auto shape = dynamic_cast(broadcastTo->shape()); + EXPECT_NE(nullptr, shape); + EXPECT_EQ(true, (shape->dtype() == loco::DataType::S32)); + EXPECT_EQ(2, shape->at(0)); + EXPECT_EQ(3, shape->at(1)); +} + +TEST_F(FormerCustomOpGraphTest, wrong_op_NEG) +{ + _g.init(); + + _g.broadcastTo()->custom_code("UNSUPORTED_CUSTOM_CODE"); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(false, ret); +} + +TEST_F(FormerCustomOpGraphTest, wrong_shape_type_NEG) +{ + // the data type of shape should be S32 or S64. + _g.init(); + + _g.shape()->dtype(loco::DataType::FLOAT32); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(false, ret); +} + +TEST_F(FormerCustomOpGraphTest, unequal_input_output_type_NEG) +{ + _g.init(); + + _g.broadcastTo_out()->dtype(loco::DataType::S32); + + auto ret = _pass.run(_g.g()); + EXPECT_EQ(false, ret); +} diff --git a/compiler/luci/pass/src/SubstitutePadV2ToPadPass.cpp b/compiler/luci/pass/src/SubstitutePadV2ToPadPass.cpp index 549ed22ec..d756d919b 100644 --- a/compiler/luci/pass/src/SubstitutePadV2ToPadPass.cpp +++ b/compiler/luci/pass/src/SubstitutePadV2ToPadPass.cpp @@ -215,9 +215,9 @@ bool positive_or_zero(loco::Node *ifm) // Since Relu.output[i] >= 0 if (dynamic_cast(ifm)) return true; - if (auto conv = dynamic_cast(ifm)) + if (auto node = dynamic_cast *>(ifm)) { - if (conv->fusedActivationFunction() == luci::FusedActFunc::RELU) + if (node->fusedActivationFunction() == luci::FusedActFunc::RELU) return true; // Add more FusedActFunc } diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp index df7266df9..9bc764f92 100644 --- a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp @@ -119,7 +119,7 @@ bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze) if (squeeze->shape_status() != luci::ShapeStatus::VALID) return false; - auto squeeze_dims = squeeze->squeeze_dims(); + auto &squeeze_dims = squeeze->squeeze_dims(); if (not is_valid_input(input, squeeze_dims)) throw std::runtime_error("Invalid values in squeeze_dims: " + squeeze->name()); diff --git a/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp index 9d1dfc1e3..d196c5997 100644 --- a/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp +++ b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp @@ -17,7 +17,6 @@ #include "luci/Pass/TransformMinMaxToRelu6Pass.h" #include "helpers/NodeFiller.h" -#include "helpers/TypeMapper.h" #include #include diff --git a/compiler/luci/pass/src/TransformMinReluToRelu6Pass.cpp b/compiler/luci/pass/src/TransformMinReluToRelu6Pass.cpp index cccc0134c..f5ec3df20 100644 --- a/compiler/luci/pass/src/TransformMinReluToRelu6Pass.cpp +++ b/compiler/luci/pass/src/TransformMinReluToRelu6Pass.cpp @@ -17,7 +17,6 @@ #include "luci/Pass/TransformMinReluToRelu6Pass.h" #include "helpers/NodeFiller.h" -#include "helpers/TypeMapper.h" #include #include diff --git a/compiler/luci/pass/src/TransformSqrtDivToRsqrtMulPass.cpp b/compiler/luci/pass/src/TransformSqrtDivToRsqrtMulPass.cpp new file mode 100644 index 000000000..c21319d5d --- /dev/null +++ b/compiler/luci/pass/src/TransformSqrtDivToRsqrtMulPass.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/TransformSqrtDivToRsqrtMulPass.h" + +#include "helpers/NodeFiller.h" + +#include +#include + +namespace +{ + +/** + * BEFORE + * [CircleNode] [CircleNode] + * | | + * | [CircleSqrt] + * | | + * [CircleDiv] + * | + * [CircleNode] + * + * AFTER + * [CircleNode] [CircleNode] + * | | + * | [CircleRsqrt] [CircleSqrt] + * | | | + * [CircleMul] [CircleDiv] + * | + * [CircleNode] + * + */ + +bool transform_sqrtdiv_to_rsqrtmul(luci::CircleDiv *div) +{ + assert(div != nullptr); + + // skip if x is const, for FuseRsqrtPass + auto *const_node = dynamic_cast(div->x()); + if (const_node != nullptr) + return false; + + auto *sqrt = dynamic_cast(div->y()); + if (sqrt == nullptr) + return false; + + auto *graph = div->graph(); + + auto *rsqrt = graph->nodes()->create(); + rsqrt->x(sqrt->x()); + rsqrt->name(sqrt->name() + "_rsqrt"); + luci::add_origin(rsqrt, luci::get_origin(sqrt)); + + auto *mul = graph->nodes()->create(); + mul->x(div->x()); + mul->y(rsqrt); + mul->fusedActivationFunction(div->fusedActivationFunction()); + mul->name(div->name() + "_mul"); + luci::add_origin(mul, luci::get_origin(div)); + + replace(div).with(mul); + + return true; +} + +} // namespace + +namespace luci +{ + +bool TransformSqrtDivToRsqrtMulPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto div = dynamic_cast(node)) + { + if (transform_sqrtdiv_to_rsqrtmul(div)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/TransformSqrtDivToRsqrtMulPass.test.cpp b/compiler/luci/pass/src/TransformSqrtDivToRsqrtMulPass.test.cpp new file mode 100644 index 000000000..4f3bc2c3c --- /dev/null +++ b/compiler/luci/pass/src/TransformSqrtDivToRsqrtMulPass.test.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/TransformSqrtDivToRsqrtMulPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class SqrtDivGraphlet +{ +public: + SqrtDivGraphlet() = default; + +public: + void init(loco::Graph *g) + { + _div = g->nodes()->create(); + _div->name("div"); + + _sqrt = g->nodes()->create(); + _sqrt->name("sqrt"); + } + +protected: + luci::CircleDiv *_div = nullptr; + luci::CircleSqrt *_sqrt = nullptr; +}; + +class SqrtDivGraph : public TestIOGraph, public SqrtDivGraphlet +{ +public: + SqrtDivGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1, 2, 3}, {1, 2, 3}); + SqrtDivGraphlet::init(g()); + + _div->x(input()); + _div->y(_sqrt); + + _sqrt->x(input()); + + output()->from(_div); + } +}; + +// For negative test: Div input order does not match +class SqrtDivOrderGraph : public TestIOGraph, public SqrtDivGraphlet +{ +public: + SqrtDivOrderGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1, 2, 3}, {1, 2, 3}); + SqrtDivGraphlet::init(g()); + + _div->x(_sqrt); + _div->y(input()); + + _sqrt->x(input()); + + output()->from(_div); + } +}; + +// For negative test: Div input x is Const +class SqrtDivConstGraph : public TestIOGraph, public SqrtDivGraphlet +{ +public: + SqrtDivConstGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1, 2, 3}, {1, 2, 3}); + SqrtDivGraphlet::init(g()); + + _const = g()->nodes()->create(); + _const->name("const"); + + _div->x(_const); + _div->y(_sqrt); + + _sqrt->x(input()); + + output()->from(_div); + } + +protected: + luci::CircleConst *_const = nullptr; +}; + +class TransformSqrtDivToRsqrtMulPassTest : public ::testing::Test +{ +protected: + luci::TransformSqrtDivToRsqrtMulPass _pass; +}; + +} // namespace + +TEST_F(TransformSqrtDivToRsqrtMulPassTest, name) +{ + auto const name = _pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(TransformSqrtDivToRsqrtMulPassTest, simple_run) +{ + SqrtDivGraph graph; + graph.init(); + + EXPECT_TRUE(_pass.run(graph.g())); + + // success pass will transform Div to Mul + auto mul_node = dynamic_cast(graph.output()->from()); + ASSERT_NE(nullptr, mul_node); +} + +TEST_F(TransformSqrtDivToRsqrtMulPassTest, div_input_order_NEG) +{ + SqrtDivOrderGraph graph; + graph.init(); + + EXPECT_FALSE(_pass.run(graph.g())); +} + +TEST_F(TransformSqrtDivToRsqrtMulPassTest, div_input_const_NEG) +{ + SqrtDivConstGraph graph; + graph.init(); + + EXPECT_FALSE(_pass.run(graph.g())); +} diff --git a/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp index b73efafa5..0b736e7db 100644 --- a/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp +++ b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp @@ -16,9 +16,6 @@ #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h" -#include "helpers/NodeFiller.h" -#include "helpers/TypeMapper.h" - #include #include diff --git a/compiler/luci/pass/src/XpSepActFromTransposeConvPass.cpp b/compiler/luci/pass/src/XpSepActFromTransposeConvPass.cpp new file mode 100644 index 000000000..74f5535fd --- /dev/null +++ b/compiler/luci/pass/src/XpSepActFromTransposeConvPass.cpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/XpSepActFromTransposeConvPass.h" + +#include +#include +#include + +namespace luci +{ + +/** + * XpSepActFromTransposeConvPass + * - Experimental Separate Activation From TransposeConv + * - This pass exist temporary to separate activation function from + * - TransposeConv to support backends that don't support this. + * - This pass will be removed when all backends support fused activation. + * + * BEFORE + * [Node] + * | + * [TransposeConv] (w/ Act) + * | + * [Node] + * + * AFTER + * + * [Node] + * | + * [TransposeConv] + * | + * [ReLU/ReLU6/...] + * | + * [Node] + * + */ + +namespace +{ + +bool separate_activation_fuction(luci::CircleTransposeConv *trconv) +{ + // cannot separate for quantized state: support F32 for now + // TODO revise this to better way + if (trconv->dtype() != loco::DataType::FLOAT32) + return false; + + auto fused_act = trconv->fusedActivationFunction(); + if (fused_act == luci::FusedActFunc::NONE) + return false; + if (fused_act == luci::FusedActFunc::UNDEFINED) + throw std::runtime_error("XpSepActFromTransposeConvPass Activation is undefined"); + + // NOTE features() is call after replace().with(); + // calling loco::replace(trconv).with(actnode) will also update actnode + // itself which will make totally wrong result with actnode input being + // itself. this happends as TransposeConv is re-used, not replaced with + // a new one. + + auto name = trconv->name(); + luci::CircleNode *actnode = nullptr; + switch (fused_act) + { + case luci::FusedActFunc::RELU: + { + auto af = trconv->graph()->nodes()->create(); + loco::replace(trconv).with(af); + af->features(trconv); + af->name(name + "/Relu"); + actnode = af; + } + break; + case luci::FusedActFunc::RELU6: + { + auto af = trconv->graph()->nodes()->create(); + loco::replace(trconv).with(af); + af->features(trconv); + af->name(name + "/Relu6"); + actnode = af; + } + break; + // TODO support more + default: + return false; + } + assert(actnode != nullptr); + actnode->dtype(trconv->dtype()); + luci::add_origin(actnode, luci::get_origin(trconv)); + + trconv->fusedActivationFunction(luci::FusedActFunc::NONE); + + return true; +} + +} // namespace + +bool XpSepActFromTransposeConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto trconv = dynamic_cast(node); + if (trconv != nullptr) + { + if (separate_activation_fuction(trconv)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/XpSepActFromTransposeConvPass.test.cpp b/compiler/luci/pass/src/XpSepActFromTransposeConvPass.test.cpp new file mode 100644 index 000000000..73824ba8d --- /dev/null +++ b/compiler/luci/pass/src/XpSepActFromTransposeConvPass.test.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/XpSepActFromTransposeConvPass.h" + +#include + +#include +#include "test/TestFirstNode.h" + +#include + +namespace +{ + +using namespace luci::test; + +class TrConvGraphlet +{ +public: + TrConvGraphlet() = default; + +public: + void init(loco::Graph *g, ShapeU32 wshape) + { + const uint32_t elements_num = num_elements(wshape); + + // trconv inputSizes + auto wshape_size = static_cast(wshape.size()); + _inpsize = g->nodes()->create(); + _inpsize->dtype(loco::DataType::S32); + _inpsize->shape({wshape_size}); + _inpsize->size(wshape_size); + auto wsp = wshape.begin(); + for (uint32_t idx = 0; idx < wshape_size; idx++) + { + _inpsize->at(idx) = int32_t(*wsp++); + } + _inpsize->name("inpsize"); + + // trconv filter + _filter = g->nodes()->create(); + _filter->dtype(loco::DataType::FLOAT32); + _filter->shape(wshape); + _filter->size(elements_num); + for (uint32_t idx = 0; idx < elements_num; idx++) + { + _filter->at(idx) = float(idx); + } + _filter->name("filter"); + + // trconv + _tc = g->nodes()->create(); + _tc->dtype(loco::DataType::FLOAT32); + _tc->name("trconv"); + } + +protected: + luci::CircleTransposeConv *_tc = nullptr; + luci::CircleConst *_filter = nullptr; + luci::CircleConst *_inpsize = nullptr; +}; + +class TrConvGraph : public TestIGraphlet, public TestOGraphlet, public TrConvGraphlet +{ +public: + TrConvGraph() = default; + + void init(const ShapeU32 shape) + { + TestIGraphlet::init(g(), shape); + TestOGraphlet::init(g(), shape); + TrConvGraphlet::init(g(), shape); + + // connect graph + _tc->inputSizes(_inpsize); + _tc->filter(_filter); + _tc->outBackprop(input()); + + output()->from(_tc); + } +}; + +} // namespace + +TEST(XpSepActFromTransposeConvPassTest, name) +{ + luci::XpSepActFromTransposeConvPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(XpSepActFromTransposeConvPassTest, separation_ok) +{ + TrConvGraph g; + + g.init({1, 4, 4, 3}); + + auto tc_node = luci::test::first_node(g.g()); + ASSERT_NE(tc_node, nullptr); + tc_node->fusedActivationFunction(luci::FusedActFunc::RELU); + + luci::XpSepActFromTransposeConvPass pass; + EXPECT_EQ(pass.run(g.g()), true); + + auto la_node = dynamic_cast(g.output()->from()); + ASSERT_NE(la_node, nullptr); + auto la_tc_node = dynamic_cast(la_node->features()); + ASSERT_NE(la_tc_node, nullptr); + ASSERT_EQ(la_tc_node->fusedActivationFunction(), luci::FusedActFunc::NONE); +} + +TEST(XpSepActFromTransposeConvPassTest, none_act_NEG) +{ + TrConvGraph g; + + g.init({1, 4, 4, 3}); + + auto tc_node = luci::test::first_node(g.g()); + ASSERT_NE(tc_node, nullptr); + tc_node->fusedActivationFunction(luci::FusedActFunc::NONE); + + luci::XpSepActFromTransposeConvPass pass; + EXPECT_NE(pass.run(g.g()), true); +} + +TEST(XpSepActFromTransposeConvPassTest, invalid_act_NEG) +{ + TrConvGraph g; + + g.init({1, 4, 4, 3}); + + auto tc_node = luci::test::first_node(g.g()); + ASSERT_NE(tc_node, nullptr); + // leave activation as undefined + + luci::XpSepActFromTransposeConvPass pass; + EXPECT_ANY_THROW(pass.run(g.g())); +} + +TEST(XpSepActFromTransposeConvPassTest, invalid_dtype_NEG) +{ + TrConvGraph g; + + g.init({1, 4, 4, 3}); + + auto tc_node = luci::test::first_node(g.g()); + ASSERT_NE(tc_node, nullptr); + tc_node->dtype(loco::DataType::S16); + + luci::XpSepActFromTransposeConvPass pass; + EXPECT_NE(pass.run(g.g()), true); +} diff --git a/compiler/luci/pass/src/helpers/Compute.cpp b/compiler/luci/pass/src/helpers/Compute.cpp new file mode 100644 index 000000000..6c9e85547 --- /dev/null +++ b/compiler/luci/pass/src/helpers/Compute.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Compute.h" + +namespace luci +{ + +bool to_compute(const Padding padding, compute::PaddingType &padding_type) +{ + switch (padding) + { + case Padding::SAME: + padding_type = compute::PaddingType::kSame; + break; + + case Padding::VALID: + padding_type = compute::PaddingType::kValid; + break; + + default: + return false; + } + return true; +} + +bool to_compute(const FusedActFunc act, compute::FusedActFunc &act_func) +{ + switch (act) + { + case FusedActFunc::NONE: + act_func = compute::FusedActFunc::NONE; + break; + + case FusedActFunc::TANH: + act_func = compute::FusedActFunc::TANH; + break; + + case FusedActFunc::RELU: + act_func = compute::FusedActFunc::RELU; + break; + + case FusedActFunc::RELU_N1_TO_1: + act_func = compute::FusedActFunc::RELU_N1_TO_1; + break; + + case FusedActFunc::RELU6: + act_func = compute::FusedActFunc::RELU6; + break; + + default: + return false; + } + return true; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/Compute.h b/compiler/luci/pass/src/helpers/Compute.h new file mode 100644 index 000000000..9034adf0f --- /dev/null +++ b/compiler/luci/pass/src/helpers/Compute.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_HELPERS_COMPUTE_H__ +#define __LUCI_PASS_HELPERS_COMPUTE_H__ + +#include + +#include + +namespace luci +{ + +// Convert luci::XX to luci::compute::XX +// Return true if conversion is valid. +// Return false otherwise (undefined behavior). +bool to_compute(const Padding padding, compute::PaddingType &padding_type); +bool to_compute(const FusedActFunc act, compute::FusedActFunc &act_func); + +} // namespace luci + +#endif // __LUCI_PASS_HELPERS_COMPUTE_H__ diff --git a/compiler/luci/pass/src/helpers/ExpressionCache.cpp b/compiler/luci/pass/src/helpers/ExpressionCache.cpp new file mode 100644 index 000000000..b51f1bb91 --- /dev/null +++ b/compiler/luci/pass/src/helpers/ExpressionCache.cpp @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ExpressionCache.h" + +namespace +{ + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + +// Check common (non-op-specific) attributes of lhs and rhs +bool same_common_attributes(const luci::CircleNode *lhs, const luci::CircleNode *rhs) +{ + // Opcode + if (lhs->opcode() != rhs->opcode()) + return false; + + // Shape + if (lhs->rank() != rhs->rank()) + return false; + + for (uint32_t i = 0; i < lhs->rank(); i++) + { + if (lhs->dim(i).known() != rhs->dim(i).known()) + return false; + + if (lhs->dim(i).value() != rhs->dim(i).value()) + return false; + } + + // Data type + if (lhs->dtype() != rhs->dtype()) + return false; + + // Op version + if (lhs->op_version() != rhs->op_version()) + return false; + + // QuantParam + const auto lhs_qparam = lhs->quantparam(); + const auto rhs_qparam = rhs->quantparam(); + + if (lhs_qparam == nullptr and rhs_qparam != nullptr) + return false; + + if (lhs_qparam != nullptr and rhs_qparam == nullptr) + return false; + + if (lhs_qparam) + { + assert(rhs_qparam); // FIX_ME_UNLESS + + if (lhs_qparam->scale != rhs_qparam->scale) + return false; + + if (lhs_qparam->zerop != rhs_qparam->zerop) + return false; + + if (lhs_qparam->min != rhs_qparam->min) + return false; + + if (lhs_qparam->max != rhs_qparam->max) + return false; + } + + return true; +} + +// Return true if two constants are the same +bool same_const(const luci::CircleConst *x, const luci::CircleConst *y) +{ + assert(x != nullptr); // FIX_CALLER_UNLESS + assert(y != nullptr); // FIX_CALLER_UNLESS + + RETURN_FALSE_UNLESS(same_common_attributes(x, y)); + + switch (x->dtype()) + { + case loco::DataType::S32: + { + const auto size_x = x->size(); + const auto size_y = y->size(); + RETURN_FALSE_UNLESS(size_x == size_y); + + for (uint32_t i = 0; i < size_x; i++) + { + RETURN_FALSE_UNLESS(x->at(i) == y->at(i)); + } + return true; + } + // TODO Support more dtypes + default: + // Simply return false + return false; + } + + return true; +} + +// Return true if x and y are semantically equal +bool same_attributes(const luci::CircleTranspose *x, luci::CircleTranspose *y) +{ + assert(x != nullptr); // FIX_CALLER_UNLESS + assert(y != nullptr); // FIX_CALLER_UNLESS + + assert(same_common_attributes(x, y)); // FIX_CALLER_UNLESS + + const auto perm_x = dynamic_cast(x->perm()); + const auto perm_y = dynamic_cast(y->perm()); + + RETURN_FALSE_UNLESS(perm_x); + RETURN_FALSE_UNLESS(perm_y); + + // Check perm_x and perm_y are the same + RETURN_FALSE_UNLESS(same_const(perm_x, perm_y)); + + return true; +} + +// Use a similar approach as boost's hash_combine +template inline void hash_combine(std::size_t &s, const T v) +{ + std::hash h; + s ^= h(v) + 0x9e3779b9 + (s << 6) + (s >> 2); +} + +template <> inline void hash_combine(std::size_t &s, luci::CircleNode *node) +{ + // Shape + hash_combine(s, node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + hash_combine(s, node->dim(i).value()); + + // Data type + hash_combine(s, static_cast(node->dtype())); + + // Op version + hash_combine(s, node->op_version()); + + // Op code + hash_combine(s, node->opcode()); + + // QuantParam + // Let's skip quantparam to reduce burden of hash function +} + +} // namespace + +namespace luci +{ +namespace pass +{ + +Expression Expression::build(luci::CircleNode *node) +{ + if (node == nullptr) + throw std::invalid_argument("node"); + + Expression key; + { + switch (node->opcode()) + { + case luci::CircleOpcode::QUANTIZE: + case luci::CircleOpcode::TRANSPOSE: + key.inputs.emplace_back(node->arg(0)); + break; + // TODO Add more Ops + default: + // NYI. Return invalid expression + key.op = nullptr; + return key; + } + + key.op = node; + } + + return key; +} + +bool operator==(const Expression &x, const Expression &y) +{ + if (x.inputs != y.inputs) + return false; + + // Check general (non-op-specific) attributes + if (not same_common_attributes(x.op, y.op)) + return false; + + assert(x.op->opcode() == y.op->opcode()); // FIX_ME_UNLESS + + // Check op-specific attributes + switch (x.op->opcode()) + { + case luci::CircleOpcode::QUANTIZE: + { + // This Op has no op-specific attribute. + // same_common_attributes is enough. + return true; + } + case luci::CircleOpcode::TRANSPOSE: + { + const auto trans_x = loco::must_cast(x.op); + const auto trans_y = loco::must_cast(y.op); + + return same_attributes(trans_x, trans_y); + } + // TODO Implement more operators + default: + // NYI: Unsupported operators + return false; + } + + return true; +} + +std::size_t Expression::Hash::call(const Expression &k) const noexcept +{ + std::size_t res = 0; + for (const auto input : k.inputs) + hash_combine(res, input); + + hash_combine(res, k.op); + + return res; +} + +} // namespace pass +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/ExpressionCache.h b/compiler/luci/pass/src/helpers/ExpressionCache.h new file mode 100644 index 000000000..bf9460859 --- /dev/null +++ b/compiler/luci/pass/src/helpers/ExpressionCache.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_HELPERS_EXPRESSION_CACHE_H__ +#define __LUCI_PASS_HELPERS_EXPRESSION_CACHE_H__ + +#include + +#include +#include + +namespace luci +{ +namespace pass +{ + +// Expression is defined as a circle node (operator) and its input feature maps +struct Expression final +{ +private: + // Prevent default ctor + Expression() = default; + +public: + // Input feature maps + std::vector inputs; + luci::CircleNode *op = nullptr; + + // Hash function for Expression (used for std::unordered_map) + struct Hash final + { + std::size_t call(const Expression &k) const noexcept; + std::size_t operator()(const Expression &k) const noexcept { return call(k); } + }; + + // Build expression from a circle node + // Returned Expression.op == nullptr if Expression is invalid + static Expression build(luci::CircleNode *node); +}; + +// Return true if two expressions are the same +// This is a core logic for common subexpression elimination +bool operator==(const Expression &x, const Expression &y); + +// Cache for Expression object +class ExpressionCache final +{ +public: + using Key = Expression; + using Value = luci::CircleNode *; + +private: + std::unordered_map _content; + +public: + // Return value for the corresponding key + // Return nullptr if there is no item with the key + Value get(const Key &k) const + { + auto item = _content.find(k); + if (item == _content.end()) + return nullptr; + + return item->second; + } + + // Save circle node for the corresponding key + void put(const Key &k, const Value v) { _content[k] = v; } +}; + +} // namespace pass +} // namespace luci + +#endif // __LUCI_PASS_HELPERS_EXPRESSION_CACHE_H__ diff --git a/compiler/luci/pass/src/helpers/ExpressionCache.test.cpp b/compiler/luci/pass/src/helpers/ExpressionCache.test.cpp new file mode 100644 index 000000000..308324b29 --- /dev/null +++ b/compiler/luci/pass/src/helpers/ExpressionCache.test.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "ExpressionCache.h" + +using namespace luci::pass; + +TEST(ExpressionCacheTest, simple_test) +{ + luci::CircleInput in; + luci::CircleQuantize node; + node.input(&in); + + auto expr = Expression::build(&node); + + ExpressionCache cache; + + cache.put(expr, &node); + + EXPECT_NE(nullptr, cache.get(expr)); +} + +TEST(ExpressionCacheTest, null_expr_NEG) { EXPECT_ANY_THROW(Expression::build(nullptr)); } + +TEST(ExpressionCacheTest, invalid_expr_NEG) +{ + luci::CircleInput in; + + auto expr = Expression::build(&in); + + // Input is a virtual Op, thus return invalid expr + EXPECT_EQ(nullptr, expr.op); +} diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp index ac07f9ec9..37d8e18e9 100644 --- a/compiler/luci/pass/src/helpers/LayerInfoMap.cpp +++ b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp @@ -145,7 +145,7 @@ LayerInfoMap layer_info_map(loco::Graph *g, std::vector &layers_info) for (auto &&info : layers_info) { - auto name = info.name; + auto &name = info.name; bool found = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { diff --git a/compiler/luci/pass/src/helpers/Shape.cpp b/compiler/luci/pass/src/helpers/Shape.cpp new file mode 100644 index 000000000..76b718835 --- /dev/null +++ b/compiler/luci/pass/src/helpers/Shape.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Shape.h" + +namespace luci +{ + +bool is_same_shape(const luci::CircleNode *node, const loco::TensorShape &shape) +{ + if (node == nullptr) + return false; + + if (node->shape_status() != luci::ShapeStatus::VALID) + return false; + + if (node->rank() != shape.rank()) + return false; + + for (uint32_t i = 0; i < node->rank(); ++i) + { + if (node->dim(i).known() != shape.dim(i).known()) + return false; + + if (node->dim(i).value() != shape.dim(i).value()) + return false; + } + + return true; +} + +bool is_same_shape(const luci::CircleNode *node, const std::initializer_list shape) +{ + if (node == nullptr) + return false; + + if (node->rank() != shape.size()) + return false; + + uint32_t i = 0; + for (auto it = shape.begin(); it != shape.end(); ++it, ++i) + { + if (node->dim(i).value() != *it) + return false; + } + return true; +} + +bool has_dynamic_shape(const loco::Node *node) +{ + const auto circle_node = loco::must_cast(node); + for (uint32_t i = 0; i < circle_node->rank(); ++i) + if (!circle_node->dim(i).known()) + return true; + return false; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/Shape.h b/compiler/luci/pass/src/helpers/Shape.h new file mode 100644 index 000000000..69ea50e0a --- /dev/null +++ b/compiler/luci/pass/src/helpers/Shape.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_HELPERS_SHAPE_H__ +#define __LUCI_PASS_HELPERS_SHAPE_H__ + +#include + +namespace luci +{ + +bool is_same_shape(const luci::CircleNode *node, const loco::TensorShape &shape); +bool is_same_shape(const luci::CircleNode *node, const std::initializer_list shape); + +bool has_dynamic_shape(const loco::Node *node); + +} // namespace luci + +#endif // __LUCI_PASS_HELPERS_SHAPE_H__ diff --git a/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp b/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp index 72b7d60ff..c15df2a6c 100644 --- a/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp +++ b/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp @@ -38,7 +38,9 @@ uint64_t GetFlattenedIndex(const std::vector &indices, const std::vector= 0; i--) { - index += indices[i] * sub_elements; + assert(indices[i] >= 0); + assert(sub_elements >= 0); + index += static_cast(indices[i]) * static_cast(sub_elements); sub_elements *= shape[i]; } return index; diff --git a/compiler/luci/pass/src/helpers/Strings.cpp b/compiler/luci/pass/src/helpers/Strings.cpp index 2628726c1..cb7f8a12b 100644 --- a/compiler/luci/pass/src/helpers/Strings.cpp +++ b/compiler/luci/pass/src/helpers/Strings.cpp @@ -46,6 +46,8 @@ std::string to_lower_case(std::string s) loco::DataType str_to_dtype(const std::string &str) { + if (to_lower_case(str).compare("uint4") == 0) + return loco::DataType::U4; if (to_lower_case(str).compare("uint8") == 0) return loco::DataType::U8; if (to_lower_case(str).compare("uint16") == 0) @@ -55,6 +57,8 @@ loco::DataType str_to_dtype(const std::string &str) if (to_lower_case(str).compare("uint64") == 0) return loco::DataType::U64; + if (to_lower_case(str).compare("int4") == 0) + return loco::DataType::S4; if (to_lower_case(str).compare("int8") == 0) return loco::DataType::S8; if (to_lower_case(str).compare("int16") == 0) diff --git a/compiler/luci/pass/src/helpers/Strings.test.cpp b/compiler/luci/pass/src/helpers/Strings.test.cpp index 6d854ad4f..a8423aa9a 100644 --- a/compiler/luci/pass/src/helpers/Strings.test.cpp +++ b/compiler/luci/pass/src/helpers/Strings.test.cpp @@ -22,11 +22,13 @@ TEST(StringsTest, str_to_dtype) { + ASSERT_EQ(loco::DataType::U4, luci::str_to_dtype("uint4")); ASSERT_EQ(loco::DataType::U8, luci::str_to_dtype("uint8")); ASSERT_EQ(loco::DataType::U16, luci::str_to_dtype("uint16")); ASSERT_EQ(loco::DataType::U32, luci::str_to_dtype("uint32")); ASSERT_EQ(loco::DataType::U64, luci::str_to_dtype("uint64")); + ASSERT_EQ(loco::DataType::S4, luci::str_to_dtype("int4")); ASSERT_EQ(loco::DataType::S8, luci::str_to_dtype("int8")); ASSERT_EQ(loco::DataType::S16, luci::str_to_dtype("int16")); ASSERT_EQ(loco::DataType::S32, luci::str_to_dtype("int32")); diff --git a/compiler/luci/requires.cmake b/compiler/luci/requires.cmake index a71d4482c..7fd58df1b 100644 --- a/compiler/luci/requires.cmake +++ b/compiler/luci/requires.cmake @@ -4,7 +4,7 @@ require("loco") require("locop") require("logo") require("logo-core") -require("mio-circle06") +require("mio-circle08") require("luci-compute") require("oops") require("hermes") diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index 92c5fb04c..e8c7266dd 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -119,6 +119,7 @@ public: // loco::TensorShape visit(const luci::CircleReduceMin *node) final; // loco::TensorShape visit(const luci::CircleReduceProd *node) final; // loco::TensorShape visit(const luci::CircleRelu *node) final; + // loco::TensorShape visit(const luci::CircleRelu0To1 *node) final; // loco::TensorShape visit(const luci::CircleRelu6 *node) final; // loco::TensorShape visit(const luci::CircleReluN1To1 *node) final; // loco::TensorShape visit(const luci::CircleReshape *node) final; diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h index 4f4ab0f34..e725722a9 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h @@ -118,6 +118,7 @@ public: // loco::DataType visit(const luci::CircleReduceMin *node) final; // loco::DataType visit(const luci::CircleReduceProd *node) final; // loco::DataType visit(const luci::CircleRelu *node) final; + // loco::DataType visit(const luci::CircleRelu0To1 *node) final; // loco::DataType visit(const luci::CircleRelu6 *node) final; // loco::DataType visit(const luci::CircleReluN1To1 *node) final; // loco::DataType visit(const luci::CircleReshape *node) final; diff --git a/compiler/luci/service/src/ChangeOutputs.cpp b/compiler/luci/service/src/ChangeOutputs.cpp index 1f8000061..65175530c 100644 --- a/compiler/luci/service/src/ChangeOutputs.cpp +++ b/compiler/luci/service/src/ChangeOutputs.cpp @@ -72,7 +72,7 @@ void change_outputs(loco::Graph *graph, const std::vector &new_outp auto output = luci::output_node(graph, out); // output is CircleOutput assert(output != nullptr); - auto node_name = new_outputs.at(out); + auto &node_name = new_outputs.at(out); auto node = named_nodes[node_name]; assert(node != nullptr); diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h index e0b4dbc41..e2f61e1eb 100644 --- a/compiler/luci/service/src/CircleCloneNode.h +++ b/compiler/luci/service/src/CircleCloneNode.h @@ -52,12 +52,14 @@ public: luci::CircleNode *visit(const luci::CircleAveragePool2D *) final; luci::CircleNode *visit(const luci::CircleBatchMatMul *) final; luci::CircleNode *visit(const luci::CircleBatchToSpaceND *) final; + luci::CircleNode *visit(const luci::CircleBroadcastTo *) final; luci::CircleNode *visit(const luci::CircleCast *) final; luci::CircleNode *visit(const luci::CircleCeil *) final; luci::CircleNode *visit(const luci::CircleConcatenation *) final; luci::CircleNode *visit(const luci::CircleConst *) final; luci::CircleNode *visit(const luci::CircleConv2D *) final; luci::CircleNode *visit(const luci::CircleCos *) final; + luci::CircleNode *visit(const luci::CircleCumSum *) final; luci::CircleNode *visit(const luci::CircleCustom *) final; luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; } @@ -171,6 +173,7 @@ public: luci::CircleNode *visit(const luci::CircleReduceMin *) final; luci::CircleNode *visit(const luci::CircleReduceProd *) final; luci::CircleNode *visit(const luci::CircleRelu *) final; + luci::CircleNode *visit(const luci::CircleRelu0To1 *) final; luci::CircleNode *visit(const luci::CircleRelu6 *) final; luci::CircleNode *visit(const luci::CircleReluN1To1 *) final; luci::CircleNode *visit(const luci::CircleReshape *) final; @@ -255,6 +258,7 @@ public: luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final; luci::CircleNode *visit(const luci::CircleBCQGather *) final; luci::CircleNode *visit(const luci::CircleInstanceNorm *) final; + luci::CircleNode *visit(const luci::CircleGRU *) final; // NOTE CircleInput and CircleOutput are not handled here as these need // link with graph I/O diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index d56886c97..e8febc58f 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -492,6 +492,36 @@ loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape, return loco::NodeShape{output_shape}; } +loco::NodeShape infer_broadcast_to(const luci::CircleBroadcastTo *node) +{ + const loco::DataType S32 = loco::DataType::S32; + + loco::TensorShape shape_by_input; + { + LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); + + // Only support node's shape() is CircleConst with S32 + auto const_shape_node = dynamic_cast(node->shape()); + if (const_shape_node != nullptr) + { + LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); + + shape_by_input.rank(const_shape_node->size()); + for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + { + shape_by_input.dim(axis) = const_shape_node->at(axis); + } + } + else + { + // We use shape from the node itself + shape_by_input = own_shape(node); + } + } + + return loco::NodeShape{shape_by_input}; +} + loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node) { // TODO Support when CircleConcatenation has 0 input @@ -514,6 +544,8 @@ loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node) for (uint32_t i = 1; i < node->numValues(); ++i) { auto input_shape = luci::shape_get(node->values(i)).as(); + if (input_shape.rank() != output_shape.rank()) + INTERNAL_EXN_V("Input has incompatible shape", node->name()); for (uint32_t j = 0; j < output_shape.rank(); ++j) { @@ -1575,7 +1607,9 @@ loco::NodeShape infer_transpose(const luci::CircleTranspose *node) loco::NodeShape infer_transpose_conv(const luci::CircleTransposeConv *node) { // TransposeConv's output shape is written in its 'inputSizes' argument - auto input_sizes_const = loco::must_cast(node->inputSizes()); + auto input_sizes_const = dynamic_cast(node->inputSizes()); + if (not input_sizes_const) + return use_own(node); // TODO support non-const type LUCI_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype") LUCI_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4, @@ -1710,6 +1744,28 @@ loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node) return loco::NodeShape{output_shape}; } +loco::NodeShape infer_circle_gru(const luci::CircleGRU *node) +{ + loco::TensorShape output_shape; + + const auto input_shape = luci::shape_get(node->input()).as(); + const auto state_shape = luci::shape_get(node->state()).as(); + + auto rank = input_shape.rank(); + assert(rank > 1); + output_shape.rank(rank); + for (uint32_t i = 0; i < rank - 1; i++) + { + output_shape.dim(i) = input_shape.dim(i); + } + output_shape.dim(rank - 1) = state_shape.dim(1); + + if (not node->returnSequences()) + output_shape.dim(0) = 1; + + return loco::NodeShape{output_shape}; +} + // Virtual loco::NodeShape infer_input(const luci::CircleInput *node) { @@ -1966,7 +2022,7 @@ loco::NodeShape infer_while_out(const luci::CircleWhileOut *node) auto cond_graph_inputs = cond_graph->inputs(); auto cond_graph_input = cond_graph_inputs->at(cond_in->index()); - auto cond_graph_input_shape = *cond_graph_input->shape(); + const auto &cond_graph_input_shape = *cond_graph_input->shape(); auto this_shape = own_shape(node); if (!(this_shape == cond_graph_input_shape)) @@ -2015,6 +2071,11 @@ public: return infer_batch_to_space_nd(node); } + loco::NodeShape visit(const luci::CircleBroadcastTo *node) final + { + return infer_broadcast_to(node); + } + loco::NodeShape visit(const luci::CircleCast *node) final { return use_x(node); } loco::NodeShape visit(const luci::CircleCeil *node) final { return use_x(node); } @@ -2030,6 +2091,8 @@ public: loco::NodeShape visit(const luci::CircleCos *node) final { return use_x(node); } + loco::NodeShape visit(const luci::CircleCumSum *node) final { return use_input(node); } + loco::NodeShape visit(const luci::CircleCustom *node) final { return use_own(node); } loco::NodeShape visit(const luci::CircleDensify *node) final { return use_input(node); } @@ -2252,6 +2315,13 @@ public: return loco::NodeShape{input_shape}; } + loco::NodeShape visit(const luci::CircleRelu0To1 *node) final + { + auto input_shape = luci::shape_get(node->features()).as(); + + return loco::NodeShape{input_shape}; + } + loco::NodeShape visit(const luci::CircleRelu6 *node) final { auto input_shape = luci::shape_get(node->features()).as(); @@ -2439,6 +2509,8 @@ public: return loco::NodeShape{input_shape}; } + loco::NodeShape visit(const luci::CircleGRU *node) final { return infer_circle_gru(node); } + // Virtual loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); } diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp index bd3feb977..78dde1004 100644 --- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp @@ -69,6 +69,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitorinput()); } + loco::DataType visit(const luci::CircleBroadcastTo *node) final + { + return luci::dtype_get(node->input()); + } + loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); } loco::DataType visit(const luci::CircleCeil *node) final { return luci::dtype_get(node->x()); } @@ -93,6 +98,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitorx()); } + loco::DataType visit(const luci::CircleCumSum *node) final + { + return luci::dtype_get(node->input()); + } + loco::DataType visit(const luci::CircleCustom *node) final { if (node->custom_code() == "BatchMatMulV2") @@ -186,6 +196,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitorfeatures()); } + loco::DataType visit(const luci::CircleGRU *node) final { return luci::dtype_get(node->input()); } + loco::DataType visit(const luci::CircleIf *node) final { // Type of If is not used. Just use input 0 @@ -370,6 +382,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitorfeatures()); } + loco::DataType visit(const luci::CircleRelu0To1 *node) final + { + return luci::dtype_get(node->features()); + } + loco::DataType visit(const luci::CircleRelu6 *node) final { return luci::dtype_get(node->features()); diff --git a/compiler/luci/service/src/Nodes/CircleBroadcastTo.cpp b/compiler/luci/service/src/Nodes/CircleBroadcastTo.cpp new file mode 100644 index 000000000..ca0702510 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBroadcastTo.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNodeLet::visit(const luci::CircleBroadcastTo *) +{ + return _graph->nodes()->create(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleBroadcastTo.test.cpp b/compiler/luci/service/src/Nodes/CircleBroadcastTo.test.cpp new file mode 100644 index 000000000..07ff7defe --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleBroadcastTo.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include + +TEST(CloneNodeTest, clone_BroadcastTo) +{ + auto g = loco::make_graph(); + auto node_broadcastTo = g->nodes()->create(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_broadcastTo, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_broadcastTo = dynamic_cast(cloned); + ASSERT_NE(nullptr, cloned_broadcastTo); +} diff --git a/compiler/luci/service/src/Nodes/CircleConst.cpp b/compiler/luci/service/src/Nodes/CircleConst.cpp index 017dcc8ad..c8f43c444 100644 --- a/compiler/luci/service/src/Nodes/CircleConst.cpp +++ b/compiler/luci/service/src/Nodes/CircleConst.cpp @@ -59,10 +59,18 @@ luci::CircleConst *clone_circleconst(const luci::CircleConst *node, loco::Graph copy_values(node, cloned); break; + case loco::DataType::U4: + copy_values(node, cloned); + break; + case loco::DataType::U8: copy_values(node, cloned); break; + case loco::DataType::S4: + copy_values(node, cloned); + break; + case loco::DataType::S8: copy_values(node, cloned); break; diff --git a/compiler/luci/service/src/Nodes/CircleConst.test.cpp b/compiler/luci/service/src/Nodes/CircleConst.test.cpp index 5d94798f4..3d9acd983 100644 --- a/compiler/luci/service/src/Nodes/CircleConst.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleConst.test.cpp @@ -98,6 +98,20 @@ TEST(CircleConstTest, clone) ASSERT_NE(nullptr, const_cloned->sparsityparam()); } +TEST(CircleConstTest, clone_U4) +{ + auto g = loco::make_graph(); + + // prepare source CircleConst + auto circle_const = new_empty_const(g.get()); + + // make a clone + auto const_cloned = luci::clone(circle_const); + + // check attributes + ASSERT_EQ(loco::DataType::U4, const_cloned->dtype()); +} + TEST(CircleConstTest, clone_U8) { auto g = loco::make_graph(); @@ -112,6 +126,20 @@ TEST(CircleConstTest, clone_U8) ASSERT_EQ(loco::DataType::U8, const_cloned->dtype()); } +TEST(CircleConstTest, clone_S4) +{ + auto g = loco::make_graph(); + + // prepare source CircleConst + auto circle_const = new_empty_const(g.get()); + + // make a clone + auto const_cloned = luci::clone(circle_const); + + // check attributes + ASSERT_EQ(loco::DataType::S4, const_cloned->dtype()); +} + TEST(CircleConstTest, clone_S8) { auto g = loco::make_graph(); diff --git a/compiler/luci/service/src/Nodes/CircleCumSum.cpp b/compiler/luci/service/src/Nodes/CircleCumSum.cpp new file mode 100644 index 000000000..5f8e6f0a0 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCumSum.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNodeLet::visit(const luci::CircleCumSum *node) +{ + auto cloned = _graph->nodes()->create(); + assert(cloned); + cloned->exclusive(node->exclusive()); + cloned->reverse(node->reverse()); + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleCumSum.test.cpp b/compiler/luci/service/src/Nodes/CircleCumSum.test.cpp new file mode 100644 index 000000000..66601e84d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleCumSum.test.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include + +TEST(CloneNodeTest, clone_CumSum) +{ + auto g = loco::make_graph(); + auto node_cumsum = g->nodes()->create(); + node_cumsum->exclusive(false); + node_cumsum->reverse(false); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_cumsum, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_cumsum = dynamic_cast(cloned); + ASSERT_NE(nullptr, cloned_cumsum); + ASSERT_EQ(node_cumsum->exclusive(), cloned_cumsum->exclusive()); + ASSERT_EQ(node_cumsum->reverse(), cloned_cumsum->reverse()); +} diff --git a/compiler/luci/service/src/Nodes/CircleGRU.cpp b/compiler/luci/service/src/Nodes/CircleGRU.cpp new file mode 100644 index 000000000..f39e4aaae --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGRU.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleGRU *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->returnSequences(node->returnSequences()); + cloned->timeMajor(node->timeMajor()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleGRU.test.cpp b/compiler/luci/service/src/Nodes/CircleGRU.test.cpp new file mode 100644 index 000000000..ae684d938 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGRU.test.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include +#include +#include + +#include + +#include + +TEST(ShapeRuleTest, simple_circle_gru) +{ + luci::CircleInput input; + luci::CircleConst hidden_hidden; + luci::CircleConst hidden_hidden_bias; + luci::CircleConst hidden_input; + luci::CircleConst hidden_input_bias; + luci::CircleConst state; + luci::CircleGRU circle_gru; + + input.shape({10, 1, 4}); + input.shape_status(luci::ShapeStatus::VALID); + + hidden_hidden.shape({7, 32}); + hidden_hidden.shape_status(luci::ShapeStatus::VALID); + + hidden_hidden_bias.shape({7}); + hidden_hidden_bias.shape_status(luci::ShapeStatus::VALID); + + hidden_input.shape({7, 4}); + hidden_input.shape_status(luci::ShapeStatus::VALID); + + hidden_input_bias.shape({7}); + hidden_input_bias.shape_status(luci::ShapeStatus::VALID); + + state.shape({1, 32}); + state.shape_status(luci::ShapeStatus::VALID); + + circle_gru.input(&input); + circle_gru.hidden_hidden(&hidden_hidden); + circle_gru.hidden_hidden_bias(&hidden_hidden_bias); + circle_gru.hidden_input(&hidden_input); + circle_gru.hidden_input_bias(&hidden_input_bias); + circle_gru.state(&state); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&circle_gru, shape)); + ASSERT_EQ(3, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(1, shape.dim(1).value()); + ASSERT_EQ(32, shape.dim(2).value()); +} + +TEST(DataTypeRuleTest, simple_circle_gru) +{ + luci::CircleInput input; + luci::CircleConst hidden_hidden; + luci::CircleConst hidden_hidden_bias; + luci::CircleConst hidden_input; + luci::CircleConst hidden_input_bias; + luci::CircleConst state; + luci::CircleGRU circle_gru; + + input.dtype(loco::DataType::FLOAT32); + hidden_hidden.dtype(loco::DataType::FLOAT32); + hidden_hidden_bias.dtype(loco::DataType::FLOAT32); + hidden_input.dtype(loco::DataType::FLOAT32); + hidden_input_bias.dtype(loco::DataType::FLOAT32); + state.dtype(loco::DataType::FLOAT32); + + circle_gru.input(&input); + circle_gru.hidden_hidden(&hidden_hidden); + circle_gru.hidden_hidden_bias(&hidden_hidden_bias); + circle_gru.hidden_input(&hidden_input); + circle_gru.hidden_input_bias(&hidden_input_bias); + circle_gru.state(&state); + + loco::DataType dtype; + luci::tinf::Rule type_inf_rule; + + ASSERT_TRUE(type_inf_rule.infer(&circle_gru, dtype)); + ASSERT_EQ(loco::DataType::FLOAT32, dtype); +} + +TEST(CloneNodeTest, clone_circel_gru) +{ + auto g = loco::make_graph(); + auto node_circle_gru = g->nodes()->create(); + node_circle_gru->fusedActivationFunction(luci::FusedActFunc::NONE); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_circle_gru, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_circle_gru = dynamic_cast(cloned); + ASSERT_NE(nullptr, cloned_circle_gru); +} diff --git a/compiler/luci/service/src/Nodes/CircleRelu0To1.cpp b/compiler/luci/service/src/Nodes/CircleRelu0To1.cpp new file mode 100644 index 000000000..bdc320176 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRelu0To1.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNodeLet::visit(const luci::CircleRelu0To1 *) +{ + return _graph->nodes()->create(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRelu0To1.test.cpp b/compiler/luci/service/src/Nodes/CircleRelu0To1.test.cpp new file mode 100644 index 000000000..9bda5846d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRelu0To1.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include + +#include + +TEST(CloneNodeTest, clone_Relu0To1) +{ + auto g = loco::make_graph(); + auto node_relu6 = g->nodes()->create(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_relu6, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_relu0to1 = dynamic_cast(cloned); + ASSERT_NE(nullptr, cloned_relu0to1); +} diff --git a/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp b/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp index bdd27739a..5a22da319 100644 --- a/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp +++ b/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp @@ -124,7 +124,7 @@ inline int64_t StartForAxis(const StridedSliceParams ¶ms, const loco::Tensor int64_t start = start_indices[axis]; // begin_mask override - if (begin_mask & (1 << axis)) + if (begin_mask & (1LL << axis)) { if (strides[axis] > 0) { @@ -180,7 +180,7 @@ inline int64_t StopForAxis(const StridedSliceParams ¶ms, const loco::TensorS } // Begin with the specified index - const bool shrink_axis = shrink_axis_mask & (1 << axis); + const bool shrink_axis = shrink_axis_mask & (1LL << axis); int64_t stop = stop_indices[axis]; // When shrinking an axis, the end position does not matter (and can be @@ -193,7 +193,7 @@ inline int64_t StopForAxis(const StridedSliceParams ¶ms, const loco::TensorS } // end_mask override - if (end_mask & (1 << axis)) + if (end_mask & (1LL << axis)) { if (strides[axis] > 0) { @@ -249,8 +249,8 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext *op_context) int64_t num_add_axis = 0; for (int64_t i = 0; i < begin_count; ++i) { - if (!((1 << i) & op_context->params.ellipsis_mask) && - ((1 << i) & op_context->params.new_axis_mask)) + if (!((1LL << i) & op_context->params.ellipsis_mask) && + ((1LL << i) & op_context->params.new_axis_mask)) { num_add_axis++; } @@ -268,7 +268,7 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext *op_context) int64_t ellipsis_start_idx = effective_dims, expanded_ellipsis = 0; for (int64_t i = 0; i < effective_dims;) { - if ((1 << i) & op_context->params.ellipsis_mask) + if ((1LL << i) & op_context->params.ellipsis_mask) { ellipsis_start_idx = i; int64_t ellipsis_end_idx = @@ -279,14 +279,14 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext *op_context) // Set bit for effective_ellipsis_mask. for (; i < ellipsis_end_idx; ++i) { - effective_ellipsis_mask |= (1 << i); + effective_ellipsis_mask |= (1LL << i); } continue; } - if ((1 << (i - expanded_ellipsis)) & op_context->params.new_axis_mask) + if ((1LL << (i - expanded_ellipsis)) & op_context->params.new_axis_mask) { - effective_new_axis_mask |= (1 << i); + effective_new_axis_mask |= (1LL << i); } ++i; } @@ -298,17 +298,17 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext *op_context) for (int64_t i = 0; i < effective_dims; ++i) { - if ((1 << i) & effective_ellipsis_mask) + if ((1LL << i) & effective_ellipsis_mask) { // If ellipsis_mask, set the begin_mask and end_mask at that index. added_ellipsis = std::max(int64_t(0), i - ellipsis_start_idx); assert(i < 16); - op_params.begin_mask |= (1 << i); - op_params.end_mask |= (1 << i); + op_params.begin_mask |= (1LL << i); + op_params.end_mask |= (1LL << i); op_params.strides[i] = 1; op_context->effective_input_shape.dim(i) = input_shape.dim(i - added_axises); } - else if ((1 << i) & effective_new_axis_mask) + else if ((1LL << i) & effective_new_axis_mask) { // If new_axis_mask is set, it is equivalent to adding a new dim of 1 to // input tensor. Store added shape to effective_input_shape. @@ -324,8 +324,8 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext *op_context) op_params.stop_indices[i] = 0; op_params.strides[i] = 1; assert(i < 16); - op_params.begin_mask |= (1 << i); - op_params.end_mask |= (1 << i); + op_params.begin_mask |= (1LL << i); + op_params.end_mask |= (1LL << i); op_context->effective_input_shape.dim(i) = input_shape.dim(i - added_axises); } else @@ -334,20 +334,20 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext *op_context) op_params.start_indices[i] = op_context->begin->at(orig_idx); op_params.stop_indices[i] = op_context->end->at(orig_idx); op_params.strides[i] = op_context->strides->at(orig_idx); - if (op_context->params.begin_mask & (1 << orig_idx)) + if (op_context->params.begin_mask & (1LL << orig_idx)) { assert(i < 16); - op_params.begin_mask |= (1 << i); + op_params.begin_mask |= (1LL << i); } - if (op_context->params.end_mask & (1 << orig_idx)) + if (op_context->params.end_mask & (1LL << orig_idx)) { assert(i < 16); - op_params.end_mask |= (1 << i); + op_params.end_mask |= (1LL << i); } - if (op_context->params.shrink_axis_mask & (1 << orig_idx)) + if (op_context->params.shrink_axis_mask & (1LL << orig_idx)) { assert(i < 16); - op_params.shrink_axis_mask |= (1 << i); + op_params.shrink_axis_mask |= (1LL << i); } op_context->effective_input_shape.dim(i) = input_shape.dim(i - added_axises); } @@ -398,7 +398,7 @@ loco::TensorShape infer_output_shape(const CircleStridedSlice *node) StridedSliceContext op_context(node); auto op_params = BuildStridedSliceParams(&op_context); - auto effective_input_shape = op_context.effective_input_shape; + auto &effective_input_shape = op_context.effective_input_shape; std::vector output_shape_vector; for (int32_t idx = effective_input_shape.rank() - 1; idx >= 0; --idx) diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst index a88661db3..a9ded24b6 100644 --- a/compiler/luci/tests/test.lst +++ b/compiler/luci/tests/test.lst @@ -27,6 +27,7 @@ addread(BatchMatMul_000) addread(BatchMatMulV2_000) addread(BatchMatMulV2_001) addread(BatchToSpaceND_000) +addread(BroadcastTo_001) addread(Cast_000) addread(Cast_001) addread(Ceil_000) @@ -39,6 +40,7 @@ addread(Conv2D_003) addread(Conv2D_U8_000) addread(Conv2D_U8_001) addread(Cos_000) +addread(CumSum_000) addread(Densify_000) addread(DepthToSpace_000) addread(DepthwiseConv2D_000) @@ -134,6 +136,7 @@ addread(ReduceProd_001) addread(ReduceProd_002) addread(ReduceProd_003) addread(ReLU_000) +addread(ReLU0To1_000) addread(ReLU6_000) addread(ReLUN1To1_000) addread(Reshape_000) @@ -257,6 +260,7 @@ addwrite(BatchMatMul_000) addwrite(BatchMatMulV2_000) addwrite(BatchMatMulV2_001) addwrite(BatchToSpaceND_000) +addwrite(BroadcastTo_001) addwrite(Cast_000) addwrite(Cast_001) addwrite(Ceil_000) @@ -269,6 +273,7 @@ addwrite(Conv2D_003) addwrite(Conv2D_U8_000) addwrite(Conv2D_U8_001) addwrite(Cos_000) +addwrite(CumSum_000) addwrite(Densify_000) addwrite(DepthToSpace_000) addwrite(DepthwiseConv2D_000) @@ -363,6 +368,7 @@ addwrite(ReduceProd_001) addwrite(ReduceProd_002) addwrite(ReduceProd_003) addwrite(ReLU_000) +addwrite(ReLU0To1_000) addwrite(ReLU6_000) addwrite(ReLUN1To1_000) addwrite(Reshape_000) diff --git a/compiler/minmax-embedder-value-test/CMakeLists.txt b/compiler/minmax-embedder-value-test/CMakeLists.txt new file mode 100644 index 000000000..f97034f1a --- /dev/null +++ b/compiler/minmax-embedder-value-test/CMakeLists.txt @@ -0,0 +1,53 @@ +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +# Disable test if minmax-embedder does not exist +if (NOT TARGET minmax_embedder) + message(STATUS "minmax-embedder-test is disabled as minmax-embedder was not built.") + return() +endif(NOT TARGET minmax_embedder) + +add_subdirectory(gen) + +# Add tests +unset(MINMAX_EMBEDDER_VALUE_TESTS) + +macro(addTest NAME) + list(APPEND MINMAX_EMBEDDER_VALUE_TESTS ${NAME}) +endmacro(addTest) + +include("test.lst") +include("test.local.lst" OPTIONAL) + +unset(TEST_DEPS) + +# Generate test.config +set(TEST_CONFIG "${CMAKE_CURRENT_BINARY_DIR}/test.config") + +add_custom_command( + OUTPUT ${TEST_CONFIG} + COMMAND ${CMAKE_COMMAND} -E remove -f ${TEST_CONFIG} + COMMAND ${CMAKE_COMMAND} -E echo 'ARTIFACTS_PATH=\"$\"' >> ${TEST_CONFIG} + COMMAND ${CMAKE_COMMAND} -E echo 'MINMAX_DATA_GEN=\"$\"' >> ${TEST_CONFIG} + COMMAND ${CMAKE_COMMAND} -E echo 'MINMAX_EMBEDDER=\"$\"' >> ${TEST_CONFIG} + COMMAND ${CMAKE_COMMAND} -E echo 'CIRCLEDUMP=\"$\"' >> ${TEST_CONFIG} + DEPENDS testDataGenerator + DEPENDS minmax_data_gen + DEPENDS minmax_embedder_driver + DEPENDS circledump + COMMENT "Generate test configuration" +) + +list(APPEND TEST_DEPS "${TEST_CONFIG}") + +# This enforces CMake to generate all the dependencies during "build" phase +add_custom_target(minmax_embedder_value_test_deps ALL DEPENDS ${TEST_DEPS}) + +# Run tests +add_test( + NAME minmax_embedder_value_test + COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/testall.sh" + "${TEST_CONFIG}" + ${MINMAX_EMBEDDER_VALUE_TESTS} +) diff --git a/compiler/minmax-embedder-value-test/README.md b/compiler/minmax-embedder-value-test/README.md new file mode 100644 index 000000000..dcf5b1cbd --- /dev/null +++ b/compiler/minmax-embedder-value-test/README.md @@ -0,0 +1,30 @@ +## minmax-embedder-value-test + +minmax-embedder-value-test aims to test minmax-embedder tool. + +It generates minmax data (encoded min and max from run idx, op/input index). + +Then, it checks whether it is correctly embedded into circle. + +minmax-embedder is supposed to be executed in a device. + +Thus, test is also implemented so that it can be run on a device (especially +on Tizen device. For example, It does not use Python. + +### minmax-data-gen + +`minmax-data-gen` generates minmax-data for test. + +#### Usage + +``` +Usage: ./minmax-data-gen [-h] [--num_inputs NUM_INPUTS] [--num_ops NUM_OPS] minmax + +[Positional argument] +minmax path to generated minmax data + +[Optional argument] +-h, --help Show help message and exit +--num_inputs number of input layers (default:1) +--num_ops number of operators (default:1) +``` diff --git a/compiler/minmax-embedder-value-test/gen/CMakeLists.txt b/compiler/minmax-embedder-value-test/gen/CMakeLists.txt new file mode 100644 index 000000000..79e70a5c8 --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/CMakeLists.txt @@ -0,0 +1,16 @@ +# Build minmax-data-gen +nnas_find_package(HDF5 COMPONENTS STATIC QUIET) + +if(NOT HDF5_FOUND) + message(STATUS "Build minmax-datagen: FAILED (missing HDF5)") + return() +endif(NOT HDF5_FOUND) + +file(GLOB_RECURSE SOURCES "src/*.cpp") + +add_executable(minmax_data_gen "${SOURCES}") +target_include_directories(minmax_data_gen PUBLIC ${HDF5_INCLUDE_DIRS}) + +target_link_libraries(minmax_data_gen ${HDF5_CXX_LIBRARIES}) +target_link_libraries(minmax_data_gen safemain) +target_link_libraries(minmax_data_gen arser) diff --git a/compiler/minmax-embedder-value-test/gen/src/Cast.h b/compiler/minmax-embedder-value-test/gen/src/Cast.h new file mode 100644 index 000000000..8885cc94b --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/src/Cast.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MINMAX_EMBEDDER_TEST_CAST_H__ +#define __MINMAX_EMBEDDER_TEST_CAST_H__ + +#include +#include + +namespace minmax_embedder_test +{ +uint32_t to_u32(uint64_t v) +{ + if (v > UINT32_MAX) + throw std::overflow_error("to_u32 gets a value bigger than uint32 max."); + return static_cast(v); +} + +} // end of namespace minmax_embedder_test + +#endif // __MINMAX_EMBEDDER_TEST_CAST_H__ diff --git a/compiler/minmax-embedder-value-test/gen/src/DataGen.cpp b/compiler/minmax-embedder-value-test/gen/src/DataGen.cpp new file mode 100644 index 000000000..c1cef8933 --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/src/DataGen.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "DataGen.h" + +namespace minmax_embedder_test +{ +MinMax DataGen::operator()(uint32_t run, uint32_t obj) const +{ + MinMax r; + r.min() = (run * 10'000) + (obj * 10); + r.max() = (run * 10'000) + (obj * 10) + 7; + return r; +} +} // namespace minmax_embedder_test diff --git a/compiler/minmax-embedder-value-test/gen/src/DataGen.h b/compiler/minmax-embedder-value-test/gen/src/DataGen.h new file mode 100644 index 000000000..731ae9131 --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/src/DataGen.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MINMAX_EMBEDDER_TEST_DATA_GEN_H__ +#define __MINMAX_EMBEDDER_TEST_DATA_GEN_H__ + +#include + +namespace minmax_embedder_test +{ +class MinMax +{ +public: + float &min() { return v[0]; } + float &max() { return v[1]; } + float *data() { return &v[0]; } + +private: + float v[2]; +}; + +/** + * generates (min,max) for obj (= model input or op output) index at run + */ +class DataGen +{ +public: + MinMax operator()(uint32_t run, uint32_t obj) const; +}; +} // end of namespace minmax_embedder_test + +#endif // __MINMAX_EMBEDDER_TEST_DATA_GEN_H__ diff --git a/compiler/minmax-embedder-value-test/gen/src/Driver.cpp b/compiler/minmax-embedder-value-test/gen/src/Driver.cpp new file mode 100644 index 000000000..479a568df --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/src/Driver.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "H5Writer.h" + +#include + +using namespace minmax_embedder_test; + +int entry(const int argc, char **argv) +{ + arser::Arser arser("Generate min/max data to test minmax-embedder"); + arser.add_argument("--num_inputs") + .type(arser::DataType::INT32) + .default_value(1) + .help("number of input layers (default:1)"); + arser.add_argument("--num_ops") + .type(arser::DataType::INT32) + .default_value(1) + .help("number of operators (default:1)"); + arser.add_argument("minmax").help("path to generated minmax data"); + + try + { + arser.parse(argc, argv); + } + catch (const std::runtime_error &err) + { + std::cout << err.what() << std::endl; + std::cout << arser; + return 255; + } + + auto num_inputs = arser.get("--num_inputs"); + auto num_ops = arser.get("--num_ops"); + auto data_output_path = arser.get("minmax"); + + ModelSpec mspec; + { + if (num_inputs <= 0 || num_ops <= 0) + { + std::cout << "num_inputs and num_ops must be positive integers." << std::endl; + return 255; + } + mspec.n_inputs = num_inputs; + mspec.n_ops = num_ops; + } + H5Writer writer(mspec, data_output_path); + writer.dump(); + + return EXIT_SUCCESS; +} diff --git a/compiler/minmax-embedder-value-test/gen/src/H5Writer.cpp b/compiler/minmax-embedder-value-test/gen/src/H5Writer.cpp new file mode 100644 index 000000000..bf83ef1f8 --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/src/H5Writer.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "H5Writer.h" +#include "Cast.h" + +#include +#include +#include +#include + +namespace minmax_embedder_test +{ +/* + * ensure grp_name exists in parent + */ +H5::Group ensureGroup(H5::Group parent, const std::string &child) +{ + H5::Exception::dontPrint(); + try + { + return parent.openGroup(child.c_str()); + } + catch (H5::Exception &e) + { + return parent.createGroup(child.c_str()); + } +} + +static const char *h5_value_grpname = "value"; + +H5Writer::H5Writer(const ModelSpec &md_spec, const std::string &filepath) + : _md_spec{md_spec}, _filepath{filepath} +{ +} + +void H5Writer::dump() +{ + // NOTE: H5Writer + H5::H5File h5file{_filepath, H5F_ACC_CREAT | H5F_ACC_RDWR}; + auto root_grp = h5file.openGroup("/"); + ensureGroup(root_grp, h5_value_grpname); + auto val_grp = h5file.openGroup(h5_value_grpname); + // NOTE: Writer + uint32_t num_run = to_u32(val_grp.getNumObjs()); + auto run_grp = val_grp.createGroup(("run_") + std::to_string(num_run)); + // Assumption: single subgraph + auto model_grp = ensureGroup(run_grp, std::string("model_") + "0"); + hsize_t dims[] = {2}; + H5::DataSpace dspace(1, dims); // rank=1, dim(0)=2, {min, max} + DataGen gen; + // dump input minmax + for (uint32_t i = 0; i < _md_spec.n_inputs; ++i) + { + const auto subg_idx = 0; + auto subg_grp = ensureGroup(model_grp, std::string("subg_") + std::to_string(subg_idx).c_str()); + auto input_dset = subg_grp.createDataSet(std::string("input_") + std::to_string(i), + H5::PredType::IEEE_F32BE, dspace); + auto minmax = gen(num_run, i); + input_dset.write(gen(num_run, i).data(), H5::PredType::NATIVE_FLOAT); + } + // dump op minmax + for (uint32_t op = 0; op < _md_spec.n_ops; ++op) + { + const auto subg_idx = 0; + auto subg_grp = ensureGroup(model_grp, std::string("subg_") + std::to_string(subg_idx).c_str()); + auto op_dset = subg_grp.createDataSet(std::string("op_") + std::to_string(op), + H5::PredType::IEEE_F32BE, dspace); + op_dset.write(gen(num_run, op).data(), H5::PredType::NATIVE_FLOAT); + } +} +} // end of namespace minmax_embedder_test diff --git a/compiler/minmax-embedder-value-test/gen/src/H5Writer.h b/compiler/minmax-embedder-value-test/gen/src/H5Writer.h new file mode 100644 index 000000000..bef679a3c --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/src/H5Writer.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MINMAX_EMBEDDER_TEST_H5WRITER_H__ +#define __MINMAX_EMBEDDER_TEST_H5WRITER_H__ + +#include "ModelSpec.h" +#include "DataGen.h" + +#include + +namespace minmax_embedder_test +{ +// It must be same to onert/core/src/dumper/h5/MinMaxDumper.h +// +// GROUP / +// GROUP value +// └── GROUP run_{idx} +// └── GROUP model_{idx} +// └── GROUP subg_{idx} +// ├── DATASET op_{idx} +// │ DATATYPE Float32 +// │ DATASPACE (2) +// │ DATA { min, max } +// └── DATASET input_{idx} +// DATATYPE Float32 +// DATASPACE (2) +// DATA { min, max } +class H5Writer +{ +public: + H5Writer(const ModelSpec &md_spec, const std::string &filepath); + void dump(); + +private: + ModelSpec _md_spec; + std::string _filepath; +}; +} // namespace minmax_embedder_test + +#endif // __MINMAX_EMBEDDER_TEST_H5WRITER_H__ diff --git a/compiler/minmax-embedder-value-test/gen/src/ModelSpec.h b/compiler/minmax-embedder-value-test/gen/src/ModelSpec.h new file mode 100644 index 000000000..10cd17a7b --- /dev/null +++ b/compiler/minmax-embedder-value-test/gen/src/ModelSpec.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MINMAX_EMBEDDER_TEST_MODEL_SPEC_H__ +#define __MINMAX_EMBEDDER_TEST_MODEL_SPEC_H__ + +#include + +namespace minmax_embedder_test +{ +struct ModelSpec +{ + /** number of model inputs */ + uint32_t n_inputs; + /** number of operators*/ + uint32_t n_ops; +}; +} // end of namespace minmax_embedder_test + +#endif // __MINMAX_EMBEDDER_TEST_MODEL_SPEC_H__ diff --git a/compiler/minmax-embedder-value-test/requires.cmake b/compiler/minmax-embedder-value-test/requires.cmake new file mode 100644 index 000000000..87c107e2b --- /dev/null +++ b/compiler/minmax-embedder-value-test/requires.cmake @@ -0,0 +1,2 @@ +require("common-artifacts") +require("minmax-embedder") diff --git a/compiler/minmax-embedder-value-test/test.lst b/compiler/minmax-embedder-value-test/test.lst new file mode 100644 index 000000000..dcdf94253 --- /dev/null +++ b/compiler/minmax-embedder-value-test/test.lst @@ -0,0 +1,3 @@ +addTest(Abs_000) # input: 1, op: 1, output: 1 +addTest(Add_000) # input: 2, op: 1, output: 1 +addTest(Net_Conv_Relu6_000) # input: 1, op: 3, output: 1 diff --git a/compiler/minmax-embedder-value-test/testall.sh b/compiler/minmax-embedder-value-test/testall.sh new file mode 100755 index 000000000..80a529fb5 --- /dev/null +++ b/compiler/minmax-embedder-value-test/testall.sh @@ -0,0 +1,161 @@ +#!/bin/bash + +# This script tests the parallel behavior of minmax-embedder +# +# HOW TO USE +# +# ./testall.sh ... +# +# test.config must contains the following variables: +# - ARTIFACTS_PATH: path to test models +# - MINMAX_DATA_GEN: path to minmax_data_gen +# - MINMAX_EMBEDDER: path to minmax_embedder_driver +# - CIRCLEDUMP: path to circledump + +# is the name of model under ARTIFACTS_PATH + +CONFIG_PATH="$1"; shift; source "${CONFIG_PATH}" +WORK_DIR=$(dirname "${CONFIG_PATH}") # For temporary and report outputs + +echo "-- Found ARTIFACTS_PATH: ${ARTIFACTS_PATH}" +echo "-- Found MINMAX_DATA_GEN: ${MINMAX_DATA_GEN}" +echo "-- Found MINMAX_EMBEDDER: ${MINMAX_EMBEDDER}" +echo "-- Found CIRCLEDUMP: ${CIRCLEDUMP}" +echo "-- Found CONFIG_PATH: ${CONFIG_PATH}" + +TESTED=() +PASSED=() +FAILED=() + +pushd "${WORK_DIR}" +for TESTCASE in "$@"; do + TESTED+=("${TESTCASE}") + + TESTCASE_FILE="${ARTIFACTS_PATH}/${TESTCASE}" + + PASSED_TAG="${WORK_DIR}/${TESTCASE}.passed" + rm -f "${PASSED_TAG}" + + cat > "${WORK_DIR}/${TESTCASE}.log" <( + exec 2>&1 + set -ex + + # Get model input tensor names + INPUT_NAMES=( $("${CIRCLEDUMP}" "${TESTCASE_FILE}.circle" | grep -aP '^I T' | grep -oE '[^ ]+$') ) + if [[ $? -ne 0 ]]; then + echo "FAILED TO GET MODEL INPUT TENSOR INDEX" + continue + fi + + # Get op output tensor names + OP_OUT_NAMES=( $("${CIRCLEDUMP}" "${TESTCASE_FILE}.circle" | grep -aP ' O T\(\d+:' | grep -oE '[^ ]+$') ) + if [[ $? -ne 0 ]]; then + echo "FAILED TO GET OP OUTPUT TENSOR INDEX" + continue + fi + + # Run minmax-embedder-data-gen + RUNS=2 + for (( RUN=1; RUN<=RUNS; RUN++ )); do + "${MINMAX_DATA_GEN}" --num_inputs ${#INPUT_NAMES[@]} --num_ops ${#OP_OUT_NAMES[@]} "${TESTCASE}.minmax" + if [[ $? -ne 0 ]]; then + echo "FAILED TO GENERATE MINMAX DATA" + continue + fi + done + + # Run minmax-embedder + "${MINMAX_EMBEDDER}" \ + --min_percentile 0 --max_percentile 100 \ + -o "${TESTCASE}.circle+minmax" \ + "${TESTCASE_FILE}.circle" \ + "${TESTCASE}.minmax" + if [[ $? -ne 0 ]]; then + echo "FAILED TO EMBED MINMAX INTO CIRCLE" + continue + fi + + rm -f "${TESTCASE}.minmax" + + # Read min/max from circle+minmax + MD_MIN=() + MD_MAX=() + for NAME in "${INPUT_NAMES[@]}"; do + MD_MIN+=( $("${CIRCLEDUMP}" "${TESTCASE}.circle+minmax" | grep -aP "^T.*${NAME}$" -A 1 | tail -1 | grep -oP '(?<=min\()[+-]?[\d]+') ) + if [[ $? -ne 0 ]]; then + echo "FAILED TO PARSE MODEL INPUT MIN FROM CIRCLE" + continue + fi + MD_MAX+=( $("${CIRCLEDUMP}" "${TESTCASE}.circle+minmax" | grep -aP "^T.*${NAME}$" -A 1 | tail -1 | grep -oP '(?<=max\()[+-]?[\d]+') ) + if [[ $? -ne 0 ]]; then + echo "FAILED TO PARSE MODEL INPUT MAX FROM CIRCLE" + continue + fi + done + + OP_MIN=() + OP_MAX=() + for NAME in "${OP_OUT_NAMES[@]}"; do + OP_MIN+=( $("${CIRCLEDUMP}" "${TESTCASE}.circle+minmax" | grep -aP "^T.*${NAME}$" -A 1 | tail -1 | grep -oP '(?<=min\()[+-]?[\d]+') ) + if [[ $? -ne 0 ]]; then + echo "FAILED TO PARSE OP MIN FROM CIRCLE" + continue + fi + OP_MAX+=( $("${CIRCLEDUMP}" "${TESTCASE}.circle+minmax" | grep -aP "^T.*${NAME}$" -A 1 | tail -1 | grep -oP '(?<=max\()[+-]?[\d]+') ) + if [[ $? -ne 0 ]]; then + echo "FAILED TO PARSE OP MAX FROM CIRCLE" + continue + fi + done + + # check model input + for i in "${!MD_MIN[@]}"; do + # Be sure it is synced with minmax-embedder-data-gen + EXPECTED_MIN=$((i*10)) + EXPECTED_MAX=$(((RUNS-1)*10000+i*10+7)) + if [[ "${MD_MIN[i]}" != "$EXPECTED_MIN" ]]; then + echo "Min at model input $i does not equal." + continue + fi + if [[ "${MD_MAX[i]}" != "$EXPECTED_MAX" ]]; then + echo "Max at model input $i does not equal." + continue + fi + done + + # check op output + for i in "${!OP_MIN[@]}"; do + # Be sure it is synced with minmax-embedder-data-gen + EXPECTED_MIN=$((i*10)) + EXPECTED_MAX=$(((RUNS-1)*10000+i*10+7)) + if [[ "${OP_MIN[i]}" != "$EXPECTED_MIN" ]]; then + echo "Min at op $i does not equal." + continue + fi + if [[ "${OP_MAX[i]}" != "$EXPECTED_MAX" ]]; then + echo "Max at op $i does not equal." + continue + fi + done + touch "${PASSED_TAG}" + ) + + if [[ -f "${PASSED_TAG}" ]]; then + PASSED+=("$TESTCASE") + else + FAILED+=("$TESTCASE") + fi +done +popd + +if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then + echo "FAILED" + for TEST in "${FAILED[@]}" + do + echo "- ${TEST}" + done + exit 255 +fi + +echo "PASSED" +exit 0 diff --git a/compiler/minmax-embedder/CMakeLists.txt b/compiler/minmax-embedder/CMakeLists.txt new file mode 100644 index 000000000..3fd3bc719 --- /dev/null +++ b/compiler/minmax-embedder/CMakeLists.txt @@ -0,0 +1,56 @@ +nnas_find_package(HDF5 COMPONENTS STATIC QUIET) + +if(NOT HDF5_FOUND) + message(STATUS "Build minmax_embedder: FAILED (missing HDF5)") + return() +endif(NOT HDF5_FOUND) + +file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE DRIVER "driver/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${DRIVER}) +list(REMOVE_ITEM SOURCES ${TESTS}) + +# +# Library +# +add_library(minmax_embedder "${SOURCES}") +target_include_directories(minmax_embedder PUBLIC ${HDF5_INCLUDE_DIRS}) +target_include_directories(minmax_embedder PRIVATE include) + +target_link_libraries(minmax_embedder ${HDF5_CXX_LIBRARIES}) +target_link_libraries(minmax_embedder loco) +target_link_libraries(minmax_embedder luci_import) +target_link_libraries(minmax_embedder luci_service) +target_link_libraries(minmax_embedder luci_pass) +target_link_libraries(minmax_embedder luci_export) +target_link_libraries(minmax_embedder luci_env) + +install(TARGETS minmax_embedder DESTINATION lib) +install(DIRECTORY include/ DESTINATION include + FILES_MATCHING PATTERN "*.h") +# +# GTest +# +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(minmax_embedder_test ${TESTS}) +target_include_directories(minmax_embedder_test PRIVATE include) +target_link_libraries(minmax_embedder_test minmax_embedder) + +# +# Driver +# +add_executable(minmax_embedder_driver "${DRIVER}") +target_link_libraries(minmax_embedder_driver minmax_embedder) +target_link_libraries(minmax_embedder_driver arser) +target_link_libraries(minmax_embedder_driver safemain) +target_link_libraries(minmax_embedder_driver vconone) +target_include_directories(minmax_embedder_driver PRIVATE include) +set_target_properties(minmax_embedder_driver PROPERTIES OUTPUT_NAME minmax_embedder) + +install(TARGETS minmax_embedder_driver DESTINATION bin) diff --git a/compiler/minmax-embedder/README.md b/compiler/minmax-embedder/README.md new file mode 100644 index 000000000..ddd787869 --- /dev/null +++ b/compiler/minmax-embedder/README.md @@ -0,0 +1,19 @@ +# minmax-embedder + +_minmax-embedder_ embeds minmax to circle. + +### Usage +``` +Usage: ./minmax_embedder [-h] [--version] [--min_percentile MIN_PERCENTILE] [--max_percentile MAX_PERCENTILE] [-o O] circle minmax + +[Positional argument] +circle Path to input circle model +minmax Path to minmax data in hdf5 + +[Optional argument] +-h, --help Show help message and exit +--version Show version information and exit +--min_percentile Set min percentile (default: 1) +--max_percentile Set max percentile (default: 99) +-o Path to output circle model +``` diff --git a/compiler/minmax-embedder/driver/Driver.cpp b/compiler/minmax-embedder/driver/Driver.cpp new file mode 100644 index 000000000..a3e03de60 --- /dev/null +++ b/compiler/minmax-embedder/driver/Driver.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minmax-embedder/Embedder.h" + +#include +#include + +#include + +using namespace minmax_embedder; + +void print_version(void) +{ + std::cout << "minmax-embedder version " << vconone::get_string() << std::endl; + std::cout << vconone::get_copyright() << std::endl; +} + +int entry(const int argc, char **argv) +{ + arser::Arser arser("minmax-embedder embeds given minmax into circle"); + arser::Helper::add_version(arser, print_version); + // named args + arser.add_argument("--min_percentile") + .type(arser::DataType::FLOAT) + .default_value(1.f) + .help("Set min percentile (default: 1)"); + arser.add_argument("--max_percentile") + .type(arser::DataType::FLOAT) + .default_value(99.f) + .help("Set max percentile (default: 99)"); + arser.add_argument("-o").default_value("out.circle").help("Path to output circle model"); + // positional args: minmax(h5), input(circle) + arser.add_argument("circle").help("Path to input circle model"); + arser.add_argument("minmax").help("Path to minmax data in hdf5"); + try + { + arser.parse(argc, argv); + } + catch (const std::runtime_error &err) + { + std::cout << err.what() << std::endl; + std::cout << arser; + return EXIT_FAILURE; + } + + std::string minmax_path = arser.get("minmax"); + std::string circle_path = arser.get("circle"); + std::string output_path = arser.get("-o"); + float min_percentile = arser.get("--min_percentile"); + float max_percentile = arser.get("--max_percentile"); + + EmbedderOptions opt{min_percentile, max_percentile}; + try + { + Embedder().embed(output_path, circle_path, minmax_path, opt); + } + catch (const std::runtime_error &err) + { + std::cout << err.what() << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} diff --git a/compiler/minmax-embedder/include/minmax-embedder/Embedder.h b/compiler/minmax-embedder/include/minmax-embedder/Embedder.h new file mode 100644 index 000000000..316148abe --- /dev/null +++ b/compiler/minmax-embedder/include/minmax-embedder/Embedder.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MINMAX_EMBEDDER_EMBEDDER_H__ +#define __MINMAX_EMBEDDER_EMBEDDER_H__ + +#include + +namespace minmax_embedder +{ + +struct EmbedderOptions +{ + float min_percentile = 0.0f; // dummy initial value to make SE tool happy + float max_percentile = 0.0f; // dummy initial value To make SE tool happy +}; + +class Embedder +{ +public: + void embed(const std::string &output_path, const std::string &input_path, + const std::string &minmax_path, const EmbedderOptions &); +}; +} // namespace minmax_embedder + +#endif // __MINMAX_EMBEDDER_EMBEDDER_H__ diff --git a/compiler/minmax-embedder/requires.cmake b/compiler/minmax-embedder/requires.cmake new file mode 100644 index 000000000..186f96646 --- /dev/null +++ b/compiler/minmax-embedder/requires.cmake @@ -0,0 +1,5 @@ +require("arser") +require("loco") +require("luci") +require("safemain") +require("vconone") diff --git a/compiler/minmax-embedder/src/Embedder.cpp b/compiler/minmax-embedder/src/Embedder.cpp new file mode 100644 index 000000000..46734ff05 --- /dev/null +++ b/compiler/minmax-embedder/src/Embedder.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minmax-embedder/Embedder.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "h5/Reader.h" + +#include +#include // for std::floor +#include +#include + +namespace +{ + +/* NOTE: getNthPercentile is copied from compiler/record-minmax/include/RecordFunction.h */ +/** + * @brief getNthPercentile calculates the n-th percentile of input vector (0.0 <= n <= 100.0) + * linear interpolation is used when the desired percentile lies between two data points + */ +float getNthPercentile(std::vector &vector, float percentile) +{ + if (percentile < 0 || percentile > 100) + throw std::runtime_error("Percentile must be ranged from 0 to 100"); + + if (vector.empty()) + throw std::runtime_error("Percentile must take a non-empty vector as an argument"); + + if (vector.size() == 1) + return vector[0]; + + std::vector copy; + copy.assign(vector.begin(), vector.end()); + std::sort(copy.begin(), copy.end()); + + if (percentile == 0.0) + return copy.front(); + + if (percentile == 100.0) + return copy.back(); + + int index = static_cast(std::floor((copy.size() - 1) * percentile / 100.0)); + + float percent_i = static_cast(index) / static_cast(copy.size() - 1); + float fraction = + (percentile / 100.0 - percent_i) / ((index + 1.0) / (copy.size() - 1.0) - percent_i); + float res = copy[index] + fraction * (copy[index + 1] - copy[index]); + return res; +} + +} // namespace + +namespace minmax_embedder +{ + +void Embedder::embed(const std::string &output_path, const std::string &input_path, + const std::string &minmax_path, const EmbedderOptions &opt) +{ + // Load model from the file + luci::ImporterEx importerex; + auto module = importerex.importVerifyModule(input_path); + if (module.get() == nullptr) + throw std::runtime_error{"Input circle is invalid"}; + + h5::Reader mmr{minmax_path}; + + for (size_t idx = 0; idx < module->size(); ++idx) + { + auto graph = module->graph(idx); + + /* read subgraph inputs */ + const auto input_nodes = loco::input_nodes(graph); + const auto n_inputs = input_nodes.size(); + for (size_t input_idx = 0; input_idx < n_inputs; ++input_idx) + { + const auto *circle_input = loco::must_cast(input_nodes[input_idx]); + if (circle_input->index() != input_idx) + throw std::runtime_error("Input order in minmax recording does not match to circle"); + + auto minmax = mmr.read_input(0, idx, input_idx); + auto min = getNthPercentile(minmax.min_vector, opt.min_percentile); + auto max = getNthPercentile(minmax.max_vector, opt.max_percentile); + auto quantparam = std::make_unique(); + quantparam->min.push_back(min); + quantparam->max.push_back(max); + const auto *circle_node = loco::must_cast(input_nodes[input_idx]); + auto mutable_node = const_cast(circle_node); + mutable_node->quantparam(std::move(quantparam)); + } + + /* read op outputs */ + uint32_t n_nodes = graph->nodes()->size(); + for (uint32_t i = 0; i < n_nodes; ++i) + { + auto node = loco::must_cast(graph->nodes()->at(i)); + if (not luci::has_node_id(node)) // Skip non-op nodes (e.g. input/const/output) + continue; + auto op_idx = luci::get_node_id(node); + auto minmax = mmr.read(0, idx, op_idx); + auto min = getNthPercentile(minmax.min_vector, opt.min_percentile); + auto max = getNthPercentile(minmax.max_vector, opt.max_percentile); + auto quantparam = std::make_unique(); + quantparam->min.push_back(min); + quantparam->max.push_back(max); + auto mutable_node = const_cast(node); + mutable_node->quantparam(std::move(quantparam)); + } + + if (!luci::validate(graph)) + throw std::runtime_error{"Circle after embedding minmax is invalid"}; + } + + // Export to output Circle file + luci::CircleExporter exporter; + + luci::CircleFileExpContract contract(module.get(), output_path); + + if (!exporter.invoke(&contract)) + throw std::runtime_error{"Failed to export circle"}; +} + +} // namespace minmax_embedder diff --git a/compiler/minmax-embedder/src/Embedder.test.cpp b/compiler/minmax-embedder/src/Embedder.test.cpp new file mode 100644 index 000000000..c747e2ce8 --- /dev/null +++ b/compiler/minmax-embedder/src/Embedder.test.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minmax-embedder/Embedder.h" + +#include + +using namespace minmax_embedder; + +namespace +{ +struct MinMaxEmbedderTest : public ::testing::Test +{ + EmbedderOptions _opt{0, 100}; +}; + +} // namespace + +TEST_F(MinMaxEmbedderTest, invalid_input_NEG) +{ + Embedder embedder; + EXPECT_THROW(embedder.embed("", "not_existing", "", _opt), std::runtime_error); +} diff --git a/compiler/minmax-embedder/src/h5/Reader.cpp b/compiler/minmax-embedder/src/h5/Reader.cpp new file mode 100644 index 000000000..b0bb8c393 --- /dev/null +++ b/compiler/minmax-embedder/src/h5/Reader.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Reader.h" + +#include +#include + +namespace +{ +bool exists(hid_t id, const char *path) { return H5Lexists(id, path, H5P_DEFAULT) > 0; } +} // namespace + +namespace minmax_embedder +{ +namespace h5 +{ +static const char *h5_value_grpname = "value"; + +Reader::Reader(const std::string &filepath) : _file(filepath, H5F_ACC_RDONLY) +{ + _val_grp = _file.openGroup(h5_value_grpname); +} + +// TODO: Handle multiple output +MinMaxVectors Reader::read(int model_idx, int subg_idx, int op_idx) const +{ + MinMaxVectors mmv; + float minmax[2]; + auto num_run = _val_grp.getNumObjs(); + for (uint32_t r = 0; r < num_run; ++r) + { + // check whether minmax exists + char path[128]; // 128 is enough to print "/value/run_%d/model_%d/subg_%d/op_%d" + null + snprintf(path, 128, "/value/run_%d/model_%d/subg_%d/op_%d", r, model_idx, subg_idx, op_idx); + if (!exists(_file.getId(), path)) + continue; + auto run_grp = _val_grp.openGroup(std::string("run_") + std::to_string(r)); + auto model_grp = run_grp.openGroup(std::string("model_") + std::to_string(model_idx)); + auto subg_grp = model_grp.openGroup(std::string("subg_") + std::to_string(subg_idx)); + auto op_dset = subg_grp.openDataSet(std::string("op_") + std::to_string(op_idx)); + H5::DataType dtype = op_dset.getDataType(); + if (not(dtype == H5::PredType::IEEE_F32BE || dtype == H5::PredType::IEEE_F32LE)) + throw std::runtime_error{"dtype of min, max in h5 is not float."}; + op_dset.read(minmax, H5::PredType::NATIVE_FLOAT); + mmv.min_vector.emplace_back(minmax[0]); + mmv.max_vector.emplace_back(minmax[1]); + } + return mmv; +} + +MinMaxVectors Reader::read_input(int model_idx, int subg_idx, int input_idx) const +{ + MinMaxVectors mmv; + float minmax[2]; + auto num_run = _val_grp.getNumObjs(); + for (uint32_t r = 0; r < num_run; ++r) + { + // check whether minmax exists + char path[128]; // 128 is enough to print "/value/run_%d/model_%d/subg_%d/input_%d" + null + snprintf(path, 128, "/value/run_%d/model_%d/subg_%d/input_%d", r, model_idx, subg_idx, + input_idx); + if (!exists(_file.getId(), path)) + continue; + auto run_grp = _val_grp.openGroup(std::string("run_") + std::to_string(r)); + auto model_grp = run_grp.openGroup(std::string("model_") + std::to_string(model_idx)); + auto subg_grp = model_grp.openGroup(std::string("subg_") + std::to_string(subg_idx)); + auto op_dset = subg_grp.openDataSet(std::string("input_") + std::to_string(input_idx)); + + H5::DataType dtype = op_dset.getDataType(); + if (not(dtype == H5::PredType::IEEE_F32BE || dtype == H5::PredType::IEEE_F32LE)) + throw std::runtime_error{"dtype of min, max in h5 is not float."}; + op_dset.read(minmax, H5::PredType::NATIVE_FLOAT); + mmv.min_vector.emplace_back(minmax[0]); + mmv.max_vector.emplace_back(minmax[1]); + } + return mmv; +} + +} // namespace h5 +} // namespace minmax_embedder diff --git a/compiler/minmax-embedder/src/h5/Reader.h b/compiler/minmax-embedder/src/h5/Reader.h new file mode 100644 index 000000000..600a38719 --- /dev/null +++ b/compiler/minmax-embedder/src/h5/Reader.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MINMAX_EMBEDDER_H5_READER_H__ +#define __MINMAX_EMBEDDER_H5_READER_H__ + +#include +#include +#include +#include + +namespace minmax_embedder +{ +namespace h5 +{ +// The hierachy of single model minmax h5 file +// +// GROUP / +// GROUP value +// └── GROUP run_{idx} +// └── GROUP model_{idx} +// └── GROUP subg_{idx} +// ├── DATASET op_{idx} +// │ DATATYPE Float32 +// │ DATASPACE (2) +// │ DATA { min, max } +// └── DATASET input_{idx} +// DATATYPE Float32 +// DATASPACE (2) +// DATA { min, max } +// GROUP name (optional, for debug) +// └── GROUP model_{idx} +// └── GROUP subg_{idx} +// ├── ATTRIBUTE op_{idx} +// │ DATATYPE String +// │ DATA { "op/name"} +// └── ATTRIBUTE input_{idx} +// DATATYPE String +// DATA { "input/name"} +struct MinMaxVectors +{ + std::vector min_vector; + std::vector max_vector; +}; + +class Reader +{ +public: + Reader(const std::string &filepath); + /** + * @brief Returns minmax recording for op {model_idx, subg_idx, op_idx} + * + * @return MinMaxVectors + */ + MinMaxVectors read(int model_idx, int subg_idx, int op_idx) const; + /** + * @brief Returns minmax recording for input {model_idx, subg_idx, input_idx} + * + * @return MinMaxVectors + */ + MinMaxVectors read_input(int model_idx, int subg_idx, int input_idx) const; + +private: + H5::H5File _file; + H5::Group _val_grp; +}; + +} // namespace h5 +} // namespace minmax_embedder + +#endif // __MINMAX_EMBEDDER_H5_READER_H__ diff --git a/compiler/mio-circle/CMakeLists.txt b/compiler/mio-circle/CMakeLists.txt index d24717343..d2a037eb4 100644 --- a/compiler/mio-circle/CMakeLists.txt +++ b/compiler/mio-circle/CMakeLists.txt @@ -1,7 +1,7 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) - message(STATUS "mio-circle skip: FlatBuffers 2.0 NOT FOUND") + message(STATUS "mio-circle skip: FlatBuffers 23.5.26 NOT FOUND") return() endif(NOT FlatBuffers_FOUND) diff --git a/compiler/mio-circle/exclude.me b/compiler/mio-circle/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-circle04/CMakeLists.txt b/compiler/mio-circle04/CMakeLists.txt index 8ee6da44c..cf93f2bf3 100644 --- a/compiler/mio-circle04/CMakeLists.txt +++ b/compiler/mio-circle04/CMakeLists.txt @@ -1,7 +1,7 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) - message(STATUS "mio-circle04 skip: FlatBuffers 2.0 NOT FOUND") + message(STATUS "mio-circle04 skip: FlatBuffers 23.5.26 NOT FOUND") return() endif(NOT FlatBuffers_FOUND) diff --git a/compiler/mio-circle04/exclude.me b/compiler/mio-circle04/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-circle05/CMakeLists.txt b/compiler/mio-circle05/CMakeLists.txt index dfd359eaa..a1211e1a6 100644 --- a/compiler/mio-circle05/CMakeLists.txt +++ b/compiler/mio-circle05/CMakeLists.txt @@ -1,7 +1,7 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) - message(STATUS "mio-circle05 skip: FlatBuffers 2.0 NOT FOUND") + message(STATUS "mio-circle05 skip: FlatBuffers 23.5.26 NOT FOUND") return() endif(NOT FlatBuffers_FOUND) diff --git a/compiler/mio-circle05/exclude.me b/compiler/mio-circle05/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-circle06/CMakeLists.txt b/compiler/mio-circle06/CMakeLists.txt index 2ccd8059c..d4f6a4d17 100644 --- a/compiler/mio-circle06/CMakeLists.txt +++ b/compiler/mio-circle06/CMakeLists.txt @@ -1,7 +1,7 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) - message(STATUS "mio-circle06 skip: FlatBuffers 2.0 NOT FOUND") + message(STATUS "mio-circle06 skip: FlatBuffers 23.5.26 NOT FOUND") return() endif(NOT FlatBuffers_FOUND) diff --git a/compiler/mio-circle06/exclude.me b/compiler/mio-circle06/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-circle07/CMakeLists.txt b/compiler/mio-circle07/CMakeLists.txt new file mode 100644 index 000000000..aee9ef620 --- /dev/null +++ b/compiler/mio-circle07/CMakeLists.txt @@ -0,0 +1,52 @@ +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) + +if(NOT FlatBuffers_FOUND) + message(STATUS "mio-circle07 skip: FlatBuffers 23.5.26 NOT FOUND") + return() +endif(NOT FlatBuffers_FOUND) + +message(STATUS "Build mio-circle07: TRUE") + +# TODO Find a better way +# TODO use nnpackage +# set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/nnpackage/schema/circle_schema.fbs") +set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.7/circle_schema.fbs") + +# NOTE Copy circle_schema.fbs as schema.fbs to generate "schema_generated.fbs" instead of "circle_schema_generated.fbs" +add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs" + COMMAND ${CMAKE_COMMAND} -E copy "${SCHEMA_FILE}" schema.fbs + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + DEPENDS "${SCHEMA_FILE}" +) + +FlatBuffers_Target(mio_circle07 + OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen/mio/circle" + INCLUDE_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen" + SCHEMA_DIR "${CMAKE_CURRENT_BINARY_DIR}" + SCHEMA_FILES "schema.fbs" +) + +# This example shows how to use "mio-circle07" library +add_executable(mio_circle07_example example.cpp) +target_link_libraries(mio_circle07_example mio_circle07) + +file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${TESTS}) + +add_library(mio_circle07_helper STATIC ${SOURCES}) +set_target_properties(mio_circle07_helper PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(mio_circle07_helper PRIVATE src) +target_include_directories(mio_circle07_helper PUBLIC include) +target_link_libraries(mio_circle07_helper mio_circle07) + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(mio_circle07_helper_test ${TESTS}) +target_include_directories(mio_circle07_helper_test PRIVATE src) +target_link_libraries(mio_circle07_helper_test mio_circle07) +target_link_libraries(mio_circle07_helper_test mio_circle07_helper) diff --git a/compiler/mio-circle07/README.md b/compiler/mio-circle07/README.md new file mode 100644 index 000000000..2152ad7f9 --- /dev/null +++ b/compiler/mio-circle07/README.md @@ -0,0 +1,3 @@ +# mio-circle07 + +Let's make it easy to read and write Circle models. diff --git a/compiler/mio-circle07/example.cpp b/compiler/mio-circle07/example.cpp new file mode 100644 index 000000000..f524fb3cc --- /dev/null +++ b/compiler/mio-circle07/example.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// This example shows how to include and use "mio-circle07" +// +#include + +#include +#include +#include + +int main(int argc, char **argv) +{ + std::ifstream ifs(argv[1], std::ios_base::binary); + std::vector buf(std::istreambuf_iterator{ifs}, std::istreambuf_iterator{}); + + flatbuffers::Verifier verifier{reinterpret_cast(buf.data()), buf.size()}; + + if (!circle::VerifyModelBuffer(verifier)) + { + std::cout << "Fail" << std::endl; + return 255; + } + + std::cout << "Pass" << std::endl; + return 0; +} diff --git a/compiler/mio-circle07/exclude.me b/compiler/mio-circle07/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-circle07/include/mio_circle/Helper.h b/compiler/mio-circle07/include/mio_circle/Helper.h new file mode 100644 index 000000000..d44a6c1dc --- /dev/null +++ b/compiler/mio-circle07/include/mio_circle/Helper.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MIO_CIRCLE07_HELPER_H__ +#define __MIO_CIRCLE07_HELPER_H__ + +#include + +#include + +namespace mio +{ +namespace circle +{ + +::circle::BuiltinOperator builtin_code_neutral(const ::circle::OperatorCode *opcode); +bool is_valid(const ::circle::OperatorCode *opcode); +bool is_custom(const ::circle::OperatorCode *opcode); +std::string opcode_name(const ::circle::OperatorCode *opcode); +const char *tensor_type(const ::circle::Tensor *tensor); +const char *tensor_name(const ::circle::Tensor *tensor); + +template std::vector as_index_vector(const flatbuffers::Vector *flat_array) +{ + if (flat_array == nullptr) + { + throw std::runtime_error("flat array is nullptr"); + } + + std::vector ret(flat_array->size()); + for (uint32_t i = 0; i < flat_array->size(); i++) + { + ret[i] = flat_array->Get(i); + } + return ret; +} + +} // namespace circle +} // namespace mio + +#endif // __MIO_CIRCLE07_HELPER_H__ diff --git a/compiler/mio-circle07/include/mio_circle/Reader.h b/compiler/mio-circle07/include/mio_circle/Reader.h new file mode 100644 index 000000000..441157929 --- /dev/null +++ b/compiler/mio-circle07/include/mio_circle/Reader.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MIO_CIRCLE07_READER_H__ +#define __MIO_CIRCLE07_READER_H__ + +#include + +#include +#include +#include + +// NOTE Reader class originated from circledump and for circle-tensordump +// where this class has more work to be done for stability +// as the tools are for developers not customores. + +namespace mio +{ +namespace circle +{ + +/** + * @brief Loads Circle file and provides helpers to access attributes + */ +class Reader +{ +private: + using CircleSubGraphs_t = flatbuffers::Vector>; + using CircleBuffers_t = flatbuffers::Vector>; + using CircleTensors_t = flatbuffers::Vector>; + using CircleOperators_t = flatbuffers::Vector>; + using CircleMetadata_t = flatbuffers::Vector>; + using CircleSignatureDef_t = flatbuffers::Vector>; + +public: + Reader(const ::circle::Model *model); + + Reader() = delete; + +public: + uint32_t version() const { return _version; } + + const std::vector &opcodes() { return _op_codes; } + const CircleBuffers_t *buffers() { return _buffers; } + const CircleTensors_t *tensors() { return _tensors; } + const CircleOperators_t *operators() { return _operators; } + const std::vector &inputs() const { return _inputs; } + const std::vector &outputs() const { return _outputs; } + const CircleMetadata_t *metadata() const { return _metadata; } + const CircleSignatureDef_t *signature_defs() const { return _signature_defs; } + + uint32_t num_subgraph() const { return _subgraphs->size(); } + + size_t buffer_info(uint32_t buf_idx, const uint8_t **buff_data); + ::circle::BuiltinOperator builtin_code(const ::circle::Operator *op) const; + std::string opcode_name(const ::circle::Operator *op) const; + std::vector outputs(const ::circle::Operator *op) const; + std::string tensor_name(const ::circle::Tensor *tensor) const; + std::string tensor_dtype(const ::circle::Tensor *tensor) const; + +public: + bool select_subgraph(uint32_t subgraph); + const std::string &subgraph_name(void) const { return _subgraph_name; } + uint32_t subgraph_index(void) const { return _subgraph_index; } + +private: + uint32_t _version; + + const CircleSubGraphs_t *_subgraphs{nullptr}; + const CircleBuffers_t *_buffers{nullptr}; + const CircleTensors_t *_tensors{nullptr}; + const CircleOperators_t *_operators{nullptr}; + const CircleMetadata_t *_metadata{nullptr}; + const CircleSignatureDef_t *_signature_defs{nullptr}; + + uint32_t _subgraph_index = 0; + std::string _subgraph_name; + std::vector _op_codes; + std::vector _inputs; + std::vector _outputs; +}; + +} // namespace circle +} // namespace mio + +#endif // __MIO_CIRCLE07_READER_H__ diff --git a/compiler/mio-circle07/src/Helper.cpp b/compiler/mio-circle07/src/Helper.cpp new file mode 100644 index 000000000..bbfa2041a --- /dev/null +++ b/compiler/mio-circle07/src/Helper.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Helper.h" + +#include +#include + +namespace mio +{ +namespace circle +{ + +/** + * This will provide v3/v3a/v3b format neutral BuiltinOperator + * NOTE circle has minus value opcode (252~254 as uint8_t) + * we cannot use std::max() like tflite as deprecated_builtin_code can be + * minus and builtin_code being 0 for v0.3 files. + */ +::circle::BuiltinOperator builtin_code_neutral(const ::circle::OperatorCode *opcode) +{ + assert(opcode != nullptr); + if (opcode->deprecated_builtin_code() == 127) + { + assert(opcode->builtin_code() >= 127); + return opcode->builtin_code(); + } + // There was no 255(-1) value in v0.3 + assert(opcode->deprecated_builtin_code() != -1); + return static_cast<::circle::BuiltinOperator>(opcode->deprecated_builtin_code()); +} + +bool is_valid(const ::circle::OperatorCode *opcode) +{ + // Valid Range : BuiltinOperator_MIN <= deprecated_builtin_code <= 127 + const int8_t deprecated_builtin_code = opcode->deprecated_builtin_code(); + if (deprecated_builtin_code < ::circle::BuiltinOperator_MIN) + return false; + // There was no 255(-1) value in v0.3 + if (deprecated_builtin_code == -1) + return false; + + const ::circle::BuiltinOperator builtin_code = opcode->builtin_code(); + if (!(::circle::BuiltinOperator_MIN <= builtin_code && + builtin_code <= ::circle::BuiltinOperator_MAX)) + return false; + + return true; +} + +bool is_custom(const ::circle::OperatorCode *opcode) +{ + ::circle::BuiltinOperator code = builtin_code_neutral(opcode); + return (code == ::circle::BuiltinOperator_CUSTOM); +} + +std::string opcode_name(const ::circle::OperatorCode *opcode) +{ + assert(opcode); + + if (!is_valid(opcode)) + { + std::ostringstream oss; + oss << "(invalid)"; + return oss.str(); + } + + if (is_custom(opcode)) + { + if (!opcode->custom_code()) + return "(invalid custom)"; + + std::string custom_op = "CUSTOM("; + custom_op += opcode->custom_code()->c_str(); + custom_op += ")"; + return custom_op; + } + + ::circle::BuiltinOperator code = builtin_code_neutral(opcode); + return ::circle::EnumNameBuiltinOperator(code); +} + +const char *tensor_type(const ::circle::Tensor *tensor) +{ + return ::circle::EnumNameTensorType(tensor->type()); +} + +const char *tensor_name(const ::circle::Tensor *tensor) +{ + if (tensor->name() == nullptr || std::string(tensor->name()->c_str()).empty()) + return "(noname)"; + + return tensor->name()->c_str(); +} + +} // namespace circle +} // namespace mio diff --git a/compiler/mio-circle07/src/Helper.test.cpp b/compiler/mio-circle07/src/Helper.test.cpp new file mode 100644 index 000000000..687fb03d0 --- /dev/null +++ b/compiler/mio-circle07/src/Helper.test.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Helper.h" + +#include +#include + +#include + +class mio_circle07_helper_test : public ::testing::Test +{ +protected: + void initialization_finish(void) + { + _fbb.Finish(circle::CreateModelDirect(_fbb, 0, &_opcodes_vec)); + } + +protected: + void add_operator_code(int8_t deprecated_builtin_code, const char *custom_code, + circle::BuiltinOperator builtin_code) + { + _opcodes_vec.push_back(circle::CreateOperatorCodeDirect( + _fbb, deprecated_builtin_code, custom_code, 1 /* version */, builtin_code)); + } + + const circle::OperatorCode *get_operator_code(uint8_t idx) + { + return circle::GetModel(_fbb.GetBufferPointer())->operator_codes()->Get(idx); + } + +private: + flatbuffers::FlatBufferBuilder _fbb; + std::vector> _opcodes_vec; +}; + +TEST_F(mio_circle07_helper_test, v07) +{ + // BuiltinOperator_ADD = 0 + // BuiltinOperator_CONV_2D = 3 + add_operator_code(3, "", circle::BuiltinOperator_ADD); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CONV_2D); + ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_custom_old) +{ + // BuiltinOperator_ADD = 0 + // BuiltinOperator_CUSTOM = 32 + add_operator_code(32, "custom", circle::BuiltinOperator_ADD); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CUSTOM); + ASSERT_TRUE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_NEG) +{ + // BuiltinOperator_ADD = 0 + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "", circle::BuiltinOperator_ADD); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_under127) +{ + // BuiltinOperator_CONV_2D = 3 + add_operator_code(3, "", circle::BuiltinOperator_CONV_2D); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CONV_2D); + ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_under127_NEG) +{ + // BuiltinOperator_CONV_2D = 3 + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "", circle::BuiltinOperator_CONV_2D); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_custom) +{ + // BuiltinOperator_CUSTOM = 32 + add_operator_code(32, "custom", circle::BuiltinOperator_CUSTOM); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CUSTOM); + ASSERT_TRUE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_custom_NEG) +{ + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "custom", circle::BuiltinOperator_CUSTOM); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_over127) +{ + // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127 + // BuiltinOperator_CUMSUM = 128 + add_operator_code(127, "", circle::BuiltinOperator_CUMSUM); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CUMSUM); + ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle07_helper_test, v07_over127_NEG) +{ + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "", circle::BuiltinOperator_CUMSUM); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} diff --git a/compiler/mio-circle07/src/Reader.cpp b/compiler/mio-circle07/src/Reader.cpp new file mode 100644 index 000000000..114eaf622 --- /dev/null +++ b/compiler/mio-circle07/src/Reader.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Reader.h" +#include "mio_circle/Helper.h" + +#include +#include + +namespace mio +{ +namespace circle +{ + +Reader::Reader(const ::circle::Model *model) +{ + if (model == nullptr) + { + throw std::runtime_error("Invalid model"); + } + + _version = model->version(); + _subgraphs = model->subgraphs(); + _buffers = model->buffers(); + _metadata = model->metadata(); + _signature_defs = model->signature_defs(); + + auto opcodes = model->operator_codes(); + for (const ::circle::OperatorCode *opcode : *opcodes) + { + _op_codes.push_back(opcode); + } +} + +size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data) +{ + if (buff_data != nullptr) + { + *buff_data = nullptr; + } + + if (buf_idx == 0) + return 0; + + if (auto *buffer = (*_buffers)[buf_idx]) + { + if (auto *array = buffer->data()) + { + if (size_t size = array->size()) + { + if (buff_data != nullptr) + { + *buff_data = reinterpret_cast(array->data()); + } + return size; + } + } + } + + return 0; +} + +::circle::BuiltinOperator Reader::builtin_code(const ::circle::Operator *op) const +{ + uint32_t index = op->opcode_index(); + assert(index < _op_codes.size()); + const ::circle::OperatorCode *opcode = _op_codes.at(index); + + return mio::circle::builtin_code_neutral(opcode); +} + +std::string Reader::opcode_name(const ::circle::Operator *op) const +{ + uint32_t index = op->opcode_index(); + assert(index < _op_codes.size()); + const ::circle::OperatorCode *opcode = _op_codes.at(index); + + if (!mio::circle::is_valid(opcode)) + { + std::ostringstream oss; + oss << "(invalid: " << index << ")"; + return oss.str(); + } + + return mio::circle::opcode_name(opcode); +} + +std::vector Reader::outputs(const ::circle::Operator *op) const +{ + return as_index_vector(op->outputs()); +} + +std::string Reader::tensor_name(const ::circle::Tensor *tensor) const +{ + return mio::circle::tensor_name(tensor); +} + +std::string Reader::tensor_dtype(const ::circle::Tensor *tensor) const +{ + return mio::circle::tensor_type(tensor); +} + +bool Reader::select_subgraph(uint32_t sgindex) +{ + _subgraph_index = sgindex; + _tensors = nullptr; + _operators = nullptr; + + _inputs.clear(); + _outputs.clear(); + + if (_subgraphs->size() <= sgindex) + { + assert(false); + return false; + } + + const ::circle::SubGraph *subgraph = (*_subgraphs)[sgindex]; + + auto name = subgraph->name(); + _subgraph_name = name ? name->c_str() : "(noname)"; + + _tensors = subgraph->tensors(); + _operators = subgraph->operators(); + + _inputs = as_index_vector(subgraph->inputs()); + _outputs = as_index_vector(subgraph->outputs()); + + return true; +} + +} // namespace circle +} // namespace mio diff --git a/compiler/mio-circle07/src/Reader.test.cpp b/compiler/mio-circle07/src/Reader.test.cpp new file mode 100644 index 000000000..132da013b --- /dev/null +++ b/compiler/mio-circle07/src/Reader.test.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Reader.h" + +#include +#include + +class mio_circle07_reader_test : public ::testing::Test +{ +protected: + void initialization_emty(void) + { + _model = circle::CreateModelDirect(_fbb, 0, &_opcodes_vec); + circle::FinishModelBuffer(_fbb, _model); + } + + const circle::Model *circleModel(void) + { + auto ptr = _fbb.GetBufferPointer(); + return circle::GetModel(ptr); + } + +private: + flatbuffers::FlatBufferBuilder _fbb; + flatbuffers::Offset _model; + std::vector> _opcodes_vec; +}; + +TEST_F(mio_circle07_reader_test, null_Model_NEG) +{ + EXPECT_THROW(mio::circle::Reader reader(nullptr), std::runtime_error); +} + +TEST_F(mio_circle07_reader_test, empty_Model) +{ + initialization_emty(); + + const circle::Model *model = circleModel(); + EXPECT_NE(nullptr, model); + + mio::circle::Reader reader(model); + + SUCCEED(); +} + +// TODO add more tests diff --git a/compiler/mio-circle08/CMakeLists.txt b/compiler/mio-circle08/CMakeLists.txt new file mode 100644 index 000000000..03e449d6e --- /dev/null +++ b/compiler/mio-circle08/CMakeLists.txt @@ -0,0 +1,52 @@ +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) + +if(NOT FlatBuffers_FOUND) + message(STATUS "mio-circle08 skip: FlatBuffers 23.5.26 NOT FOUND") + return() +endif(NOT FlatBuffers_FOUND) + +message(STATUS "Build mio-circle08: TRUE") + +# TODO Find a better way +# TODO use nnpackage +# set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/nnpackage/schema/circle_schema.fbs") +set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.8/circle_schema.fbs") + +# NOTE Copy circle_schema.fbs as schema.fbs to generate "schema_generated.fbs" instead of "circle_schema_generated.fbs" +add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs" + COMMAND ${CMAKE_COMMAND} -E copy "${SCHEMA_FILE}" schema.fbs + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + DEPENDS "${SCHEMA_FILE}" +) + +FlatBuffers_Target(mio_circle08 + OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen/mio/circle" + INCLUDE_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen" + SCHEMA_DIR "${CMAKE_CURRENT_BINARY_DIR}" + SCHEMA_FILES "schema.fbs" +) + +# This example shows how to use "mio-circle08" library +add_executable(mio_circle08_example example.cpp) +target_link_libraries(mio_circle08_example mio_circle08) + +file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${TESTS}) + +add_library(mio_circle08_helper STATIC ${SOURCES}) +set_target_properties(mio_circle08_helper PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(mio_circle08_helper PRIVATE src) +target_include_directories(mio_circle08_helper PUBLIC include) +target_link_libraries(mio_circle08_helper mio_circle08) + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(mio_circle08_helper_test ${TESTS}) +target_include_directories(mio_circle08_helper_test PRIVATE src) +target_link_libraries(mio_circle08_helper_test mio_circle08) +target_link_libraries(mio_circle08_helper_test mio_circle08_helper) diff --git a/compiler/mio-circle08/README.md b/compiler/mio-circle08/README.md new file mode 100644 index 000000000..a12fce860 --- /dev/null +++ b/compiler/mio-circle08/README.md @@ -0,0 +1,3 @@ +# mio-circle08 + +Let's make it easy to read and write Circle models. diff --git a/compiler/mio-circle08/example.cpp b/compiler/mio-circle08/example.cpp new file mode 100644 index 000000000..99fa86626 --- /dev/null +++ b/compiler/mio-circle08/example.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// This example shows how to include and use "mio-circle08" +// +#include + +#include +#include +#include + +int main(int argc, char **argv) +{ + std::ifstream ifs(argv[1], std::ios_base::binary); + std::vector buf(std::istreambuf_iterator{ifs}, std::istreambuf_iterator{}); + + flatbuffers::Verifier verifier{reinterpret_cast(buf.data()), buf.size()}; + + if (!circle::VerifyModelBuffer(verifier)) + { + std::cout << "Fail" << std::endl; + return 255; + } + + std::cout << "Pass" << std::endl; + return 0; +} diff --git a/compiler/mio-circle08/include/mio_circle/Helper.h b/compiler/mio-circle08/include/mio_circle/Helper.h new file mode 100644 index 000000000..ce6bb01e7 --- /dev/null +++ b/compiler/mio-circle08/include/mio_circle/Helper.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MIO_CIRCLE08_HELPER_H__ +#define __MIO_CIRCLE08_HELPER_H__ + +#include + +#include + +namespace mio +{ +namespace circle +{ + +::circle::BuiltinOperator builtin_code_neutral(const ::circle::OperatorCode *opcode); +bool is_valid(const ::circle::OperatorCode *opcode); +bool is_custom(const ::circle::OperatorCode *opcode); +std::string opcode_name(const ::circle::OperatorCode *opcode); +const char *tensor_type(const ::circle::Tensor *tensor); +const char *tensor_name(const ::circle::Tensor *tensor); + +template std::vector as_index_vector(const flatbuffers::Vector *flat_array) +{ + if (flat_array == nullptr) + { + throw std::runtime_error("flat array is nullptr"); + } + + std::vector ret(flat_array->size()); + for (uint32_t i = 0; i < flat_array->size(); i++) + { + ret[i] = flat_array->Get(i); + } + return ret; +} + +} // namespace circle +} // namespace mio + +#endif // __MIO_CIRCLE08_HELPER_H__ diff --git a/compiler/mio-circle08/include/mio_circle/Reader.h b/compiler/mio-circle08/include/mio_circle/Reader.h new file mode 100644 index 000000000..723668f26 --- /dev/null +++ b/compiler/mio-circle08/include/mio_circle/Reader.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MIO_CIRCLE08_READER_H__ +#define __MIO_CIRCLE08_READER_H__ + +#include + +#include +#include +#include + +// NOTE Reader class originated from circledump and for circle-tensordump +// where this class has more work to be done for stability +// as the tools are for developers not customores. + +namespace mio +{ +namespace circle +{ + +/** + * @brief Loads Circle file and provides helpers to access attributes + */ +class Reader +{ +private: + using CircleSubGraphs_t = flatbuffers::Vector>; + using CircleBuffers_t = flatbuffers::Vector>; + using CircleTensors_t = flatbuffers::Vector>; + using CircleOperators_t = flatbuffers::Vector>; + using CircleMetadata_t = flatbuffers::Vector>; + using CircleSignatureDef_t = flatbuffers::Vector>; + +public: + Reader(const ::circle::Model *model); + + Reader() = delete; + +public: + uint32_t version() const { return _version; } + + const std::vector &opcodes() { return _op_codes; } + const CircleBuffers_t *buffers() { return _buffers; } + const CircleTensors_t *tensors() { return _tensors; } + const CircleOperators_t *operators() { return _operators; } + const std::vector &inputs() const { return _inputs; } + const std::vector &outputs() const { return _outputs; } + const CircleMetadata_t *metadata() const { return _metadata; } + const CircleSignatureDef_t *signature_defs() const { return _signature_defs; } + + uint32_t num_subgraph() const { return _subgraphs->size(); } + + size_t buffer_info(uint32_t buf_idx, const uint8_t **buff_data); + ::circle::BuiltinOperator builtin_code(const ::circle::Operator *op) const; + std::string opcode_name(const ::circle::Operator *op) const; + std::vector outputs(const ::circle::Operator *op) const; + std::string tensor_name(const ::circle::Tensor *tensor) const; + std::string tensor_dtype(const ::circle::Tensor *tensor) const; + +public: + bool select_subgraph(uint32_t subgraph); + const std::string &subgraph_name(void) const { return _subgraph_name; } + uint32_t subgraph_index(void) const { return _subgraph_index; } + +private: + uint32_t _version; + + const CircleSubGraphs_t *_subgraphs{nullptr}; + const CircleBuffers_t *_buffers{nullptr}; + const CircleTensors_t *_tensors{nullptr}; + const CircleOperators_t *_operators{nullptr}; + const CircleMetadata_t *_metadata{nullptr}; + const CircleSignatureDef_t *_signature_defs{nullptr}; + + uint32_t _subgraph_index = 0; + std::string _subgraph_name; + std::vector _op_codes; + std::vector _inputs; + std::vector _outputs; +}; + +} // namespace circle +} // namespace mio + +#endif // __MIO_CIRCLE08_READER_H__ diff --git a/compiler/mio-circle08/src/Helper.cpp b/compiler/mio-circle08/src/Helper.cpp new file mode 100644 index 000000000..a7bbd23ea --- /dev/null +++ b/compiler/mio-circle08/src/Helper.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Helper.h" + +#include +#include + +namespace mio +{ +namespace circle +{ + +/** + * This will provide v3/v3a/v3b format neutral BuiltinOperator + * NOTE circle has minus value opcode (252~254 as uint8_t) + * we cannot use std::max() like tflite as deprecated_builtin_code can be + * minus and builtin_code being 0 for v0.3 files. + */ +::circle::BuiltinOperator builtin_code_neutral(const ::circle::OperatorCode *opcode) +{ + assert(opcode != nullptr); + if (opcode->deprecated_builtin_code() == 127) + { + assert(opcode->builtin_code() >= 127); + return opcode->builtin_code(); + } + // There was no 255(-1) value in v0.3 + assert(opcode->deprecated_builtin_code() != -1); + return static_cast<::circle::BuiltinOperator>(opcode->deprecated_builtin_code()); +} + +bool is_valid(const ::circle::OperatorCode *opcode) +{ + // Valid Range : BuiltinOperator_MIN <= deprecated_builtin_code <= 127 + const int8_t deprecated_builtin_code = opcode->deprecated_builtin_code(); + if (deprecated_builtin_code < ::circle::BuiltinOperator_MIN) + return false; + // There was no 255(-1) value in v0.3 + if (deprecated_builtin_code == -1) + return false; + + const ::circle::BuiltinOperator builtin_code = opcode->builtin_code(); + if (!(::circle::BuiltinOperator_MIN <= builtin_code && + builtin_code <= ::circle::BuiltinOperator_MAX)) + return false; + + return true; +} + +bool is_custom(const ::circle::OperatorCode *opcode) +{ + ::circle::BuiltinOperator code = builtin_code_neutral(opcode); + return (code == ::circle::BuiltinOperator_CUSTOM); +} + +std::string opcode_name(const ::circle::OperatorCode *opcode) +{ + assert(opcode); + + if (!is_valid(opcode)) + { + std::ostringstream oss; + oss << "(invalid)"; + return oss.str(); + } + + if (is_custom(opcode)) + { + if (!opcode->custom_code()) + return "(invalid custom)"; + + std::string custom_op = "CUSTOM("; + custom_op += opcode->custom_code()->c_str(); + custom_op += ")"; + return custom_op; + } + + ::circle::BuiltinOperator code = builtin_code_neutral(opcode); + return ::circle::EnumNameBuiltinOperator(code); +} + +const char *tensor_type(const ::circle::Tensor *tensor) +{ + return ::circle::EnumNameTensorType(tensor->type()); +} + +const char *tensor_name(const ::circle::Tensor *tensor) +{ + if (tensor->name() == nullptr || std::string(tensor->name()->c_str()).empty()) + return "(noname)"; + + return tensor->name()->c_str(); +} + +} // namespace circle +} // namespace mio diff --git a/compiler/mio-circle08/src/Helper.test.cpp b/compiler/mio-circle08/src/Helper.test.cpp new file mode 100644 index 000000000..57e78c03b --- /dev/null +++ b/compiler/mio-circle08/src/Helper.test.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Helper.h" + +#include +#include + +#include + +class mio_circle08_helper_test : public ::testing::Test +{ +protected: + void initialization_finish(void) + { + _fbb.Finish(circle::CreateModelDirect(_fbb, 0, &_opcodes_vec)); + } + +protected: + void add_operator_code(int8_t deprecated_builtin_code, const char *custom_code, + circle::BuiltinOperator builtin_code) + { + _opcodes_vec.push_back(circle::CreateOperatorCodeDirect( + _fbb, deprecated_builtin_code, custom_code, 1 /* version */, builtin_code)); + } + + const circle::OperatorCode *get_operator_code(uint8_t idx) + { + return circle::GetModel(_fbb.GetBufferPointer())->operator_codes()->Get(idx); + } + +private: + flatbuffers::FlatBufferBuilder _fbb; + std::vector> _opcodes_vec; +}; + +TEST_F(mio_circle08_helper_test, v08) +{ + // BuiltinOperator_ADD = 0 + // BuiltinOperator_CONV_2D = 3 + add_operator_code(3, "", circle::BuiltinOperator_ADD); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CONV_2D); + ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_custom_old) +{ + // BuiltinOperator_ADD = 0 + // BuiltinOperator_CUSTOM = 32 + add_operator_code(32, "custom", circle::BuiltinOperator_ADD); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CUSTOM); + ASSERT_TRUE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_NEG) +{ + // BuiltinOperator_ADD = 0 + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "", circle::BuiltinOperator_ADD); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_under127) +{ + // BuiltinOperator_CONV_2D = 3 + add_operator_code(3, "", circle::BuiltinOperator_CONV_2D); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CONV_2D); + ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_under127_NEG) +{ + // BuiltinOperator_CONV_2D = 3 + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "", circle::BuiltinOperator_CONV_2D); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_custom) +{ + // BuiltinOperator_CUSTOM = 32 + add_operator_code(32, "custom", circle::BuiltinOperator_CUSTOM); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CUSTOM); + ASSERT_TRUE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_custom_NEG) +{ + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "custom", circle::BuiltinOperator_CUSTOM); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_over127) +{ + // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127 + // BuiltinOperator_CUMSUM = 128 + add_operator_code(127, "", circle::BuiltinOperator_CUMSUM); + initialization_finish(); + + ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0))); + ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)), + circle::BuiltinOperator_CUMSUM); + ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0))); +} + +TEST_F(mio_circle08_helper_test, v08_over127_NEG) +{ + // BuiltinOperator_CUMSUM = 128 + // deprecated_builtin_code cannot be negative value + add_operator_code(128, "", circle::BuiltinOperator_CUMSUM); + initialization_finish(); + + ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0))); +} diff --git a/compiler/mio-circle08/src/Reader.cpp b/compiler/mio-circle08/src/Reader.cpp new file mode 100644 index 000000000..e4df6d04d --- /dev/null +++ b/compiler/mio-circle08/src/Reader.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Reader.h" +#include "mio_circle/Helper.h" + +#include +#include + +namespace mio +{ +namespace circle +{ + +Reader::Reader(const ::circle::Model *model) +{ + if (model == nullptr) + { + throw std::runtime_error("Invalid model"); + } + + _version = model->version(); + _subgraphs = model->subgraphs(); + _buffers = model->buffers(); + _metadata = model->metadata(); + _signature_defs = model->signature_defs(); + + auto opcodes = model->operator_codes(); + for (const ::circle::OperatorCode *opcode : *opcodes) + { + _op_codes.push_back(opcode); + } +} + +size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data) +{ + if (buff_data != nullptr) + { + *buff_data = nullptr; + } + + if (buf_idx == 0) + return 0; + + if (auto *buffer = (*_buffers)[buf_idx]) + { + if (auto *array = buffer->data()) + { + if (size_t size = array->size()) + { + if (buff_data != nullptr) + { + *buff_data = reinterpret_cast(array->data()); + } + return size; + } + } + } + + return 0; +} + +::circle::BuiltinOperator Reader::builtin_code(const ::circle::Operator *op) const +{ + uint32_t index = op->opcode_index(); + assert(index < _op_codes.size()); + const ::circle::OperatorCode *opcode = _op_codes.at(index); + + return mio::circle::builtin_code_neutral(opcode); +} + +std::string Reader::opcode_name(const ::circle::Operator *op) const +{ + uint32_t index = op->opcode_index(); + assert(index < _op_codes.size()); + const ::circle::OperatorCode *opcode = _op_codes.at(index); + + if (!mio::circle::is_valid(opcode)) + { + std::ostringstream oss; + oss << "(invalid: " << index << ")"; + return oss.str(); + } + + return mio::circle::opcode_name(opcode); +} + +std::vector Reader::outputs(const ::circle::Operator *op) const +{ + return as_index_vector(op->outputs()); +} + +std::string Reader::tensor_name(const ::circle::Tensor *tensor) const +{ + return mio::circle::tensor_name(tensor); +} + +std::string Reader::tensor_dtype(const ::circle::Tensor *tensor) const +{ + return mio::circle::tensor_type(tensor); +} + +bool Reader::select_subgraph(uint32_t sgindex) +{ + _subgraph_index = sgindex; + _tensors = nullptr; + _operators = nullptr; + + _inputs.clear(); + _outputs.clear(); + + if (_subgraphs->size() <= sgindex) + { + assert(false); + return false; + } + + const ::circle::SubGraph *subgraph = (*_subgraphs)[sgindex]; + + auto name = subgraph->name(); + _subgraph_name = name ? name->c_str() : "(noname)"; + + _tensors = subgraph->tensors(); + _operators = subgraph->operators(); + + _inputs = as_index_vector(subgraph->inputs()); + _outputs = as_index_vector(subgraph->outputs()); + + return true; +} + +} // namespace circle +} // namespace mio diff --git a/compiler/mio-circle08/src/Reader.test.cpp b/compiler/mio-circle08/src/Reader.test.cpp new file mode 100644 index 000000000..7ead63e9d --- /dev/null +++ b/compiler/mio-circle08/src/Reader.test.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mio_circle/Reader.h" + +#include +#include + +class mio_circle08_reader_test : public ::testing::Test +{ +protected: + void initialization_emty(void) + { + _model = circle::CreateModelDirect(_fbb, 0, &_opcodes_vec); + circle::FinishModelBuffer(_fbb, _model); + } + + const circle::Model *circleModel(void) + { + auto ptr = _fbb.GetBufferPointer(); + return circle::GetModel(ptr); + } + +private: + flatbuffers::FlatBufferBuilder _fbb; + flatbuffers::Offset _model; + std::vector> _opcodes_vec; +}; + +TEST_F(mio_circle08_reader_test, null_Model_NEG) +{ + EXPECT_THROW(mio::circle::Reader reader(nullptr), std::runtime_error); +} + +TEST_F(mio_circle08_reader_test, empty_Model) +{ + initialization_emty(); + + const circle::Model *model = circleModel(); + EXPECT_NE(nullptr, model); + + mio::circle::Reader reader(model); + + SUCCEED(); +} + +// TODO add more tests diff --git a/compiler/mio-tf/exclude.me b/compiler/mio-tf/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-tflite/CMakeLists.txt b/compiler/mio-tflite/CMakeLists.txt index 90187b037..a5f9d44d0 100644 --- a/compiler/mio-tflite/CMakeLists.txt +++ b/compiler/mio-tflite/CMakeLists.txt @@ -1,4 +1,4 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) message(STATUS "Build mio-tflite: FAILED (missing Flatbuffers)") diff --git a/compiler/mio-tflite/exclude.me b/compiler/mio-tflite/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-tflite2121/CMakeLists.txt b/compiler/mio-tflite2121/CMakeLists.txt index 1ca8e7581..371118be8 100644 --- a/compiler/mio-tflite2121/CMakeLists.txt +++ b/compiler/mio-tflite2121/CMakeLists.txt @@ -1,7 +1,7 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) - message(STATUS "Build mio-tflite2121: FAILED (missing Flatbuffers 2.0)") + message(STATUS "Build mio-tflite2121: FAILED (missing Flatbuffers 23.5.26)") return() endif(NOT FlatBuffers_FOUND) diff --git a/compiler/mio-tflite260/CMakeLists.txt b/compiler/mio-tflite260/CMakeLists.txt index f2cfeafcc..d34d6cca2 100644 --- a/compiler/mio-tflite260/CMakeLists.txt +++ b/compiler/mio-tflite260/CMakeLists.txt @@ -1,7 +1,7 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) - message(STATUS "Build mio-tflite260: FAILED (missing Flatbuffers 2.0)") + message(STATUS "Build mio-tflite260: FAILED (missing Flatbuffers 23.5.26)") return() endif(NOT FlatBuffers_FOUND) diff --git a/compiler/mio-tflite260/exclude.me b/compiler/mio-tflite260/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mio-tflite280/CMakeLists.txt b/compiler/mio-tflite280/CMakeLists.txt index edf75f479..7c5d328e6 100644 --- a/compiler/mio-tflite280/CMakeLists.txt +++ b/compiler/mio-tflite280/CMakeLists.txt @@ -1,7 +1,7 @@ -nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(FlatBuffers EXACT 23.5.26 QUIET) if(NOT FlatBuffers_FOUND) - message(STATUS "Build mio-tflite280: FAILED (missing Flatbuffers 2.0)") + message(STATUS "Build mio-tflite280: FAILED (missing Flatbuffers 23.5.26)") return() endif(NOT FlatBuffers_FOUND) diff --git a/compiler/mio-tflite280/exclude.me b/compiler/mio-tflite280/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mir-interpreter/exclude.me b/compiler/mir-interpreter/exclude.me new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/mir-interpreter/src/ops/Common.h b/compiler/mir-interpreter/src/ops/Common.h index 43336216e..f7edf44ab 100644 --- a/compiler/mir-interpreter/src/ops/Common.h +++ b/compiler/mir-interpreter/src/ops/Common.h @@ -27,7 +27,7 @@ namespace mir_interpreter { template