From: Giuseppe Rossini Date: Fri, 17 Jul 2020 16:14:49 +0000 (+0100) Subject: Fixed point multiplication improvements for AArch64 (#5980) X-Git-Tag: upstream/0.7.0~386 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ccacb1ec13597b0dd9b5f3ffcc599ac3b3957ae4;p=platform%2Fupstream%2Ftvm.git Fixed point multiplication improvements for AArch64 (#5980) * Fixed point multiplication improvements for AArch64 Change-Id: Ib3c10348d4c0eac11fa92b39cc6e792560e9eba4 * Fix python linting errors Change-Id: I4cf5ac18aa24b39374b83805dcc8e1663e173909 * Fix doxygen errors Change-Id: Ie3c861f8ead3f1ea5b30d5e9d7d94e222299d407 * Fix arm_cpu injective tests Change-Id: I6ad9da61b61e6bd737627f26fba59767418c07cd * Fix python linting errors - 2 Change-Id: Ic864a235aa5da5786393cbf6146dd815c121df5e * Fix arm_cpu injective tests - 2 Change-Id: If9ca1cc3d947b1656c836c7f88de90470d92f979 * Redesign: introduce a qmuls (q-multiply and shift) general intrinsic Change-Id: I1966fef9aee32eab50e4b984bbe81018488c8c02 * Fix python linting errors - 3 Change-Id: Ib87a19a8ee2d532954a7db1eb5793666e7aef366 * Addressing review comments Change-Id: Ie82e75204e5a421d17660f381f3e31fc325cd26c * Fixing test failures Change-Id: I74cc675764cf8d260fe68a41e770b1ec7e84729a * Renaming qmuls to q_multiply_shift Change-Id: I5a8ed60ba855208040304fcdf6e1ea28061f06ad --- diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index eb73427..a03e15a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -306,6 +306,19 @@ struct ClipAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for FixedPointMultiply operator */ +struct FixedPointMultiplyAttrs : public tvm::AttrsNode { + int32_t multiplier; + int32_t shift; + + TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") { + TVM_ATTR_FIELD(multiplier) + .describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)"); + TVM_ATTR_FIELD(shift).describe( + "Shift of a fixed floating point number described as multiplier*2^(shift)"); + } +}; + /*! \brief Attributes for LayoutTransform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string src_layout; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 464ce6c..bea5313 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -93,6 +93,14 @@ TVM_DLL const Op& shift_right(); TVM_DLL const Op& large_uint_imm(); /*! + * \brief Execute a multiplication between two Q-numbers x and y + * followed by a right shift s + * The default rounding rule is to the nearest value, rounding half up + * (i.e., round(x.1) = x and round (x.5) = x+1) + */ +TVM_DLL const Op& q_multiply_shift(); + +/*! * \brief See pesudo code * * Handle address_of(Load *op) { diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 31ce13c..68ca266 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -552,6 +552,27 @@ TVM_DLL PrimExpr trunc(PrimExpr x); */ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); +/*! + * \brief Execute a multiplication between two Q-numbers x and y + * followed by a right shift s. The mathematical expression is: + * + * out = round(x*y*2^-s) + * + * Please note that the two Q-numbers x and y are supposed to have + * the same number of fractional bits q. + * + * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) + * + * The rounding rule is to the nearest value, rounding half up + * (i.e., round(x.1) = x and round (x.5) = x+1) + * \param x first Q-number + * \param y second Q-number + * \param q number of fractional bits in x and y. Needs to be > 0 + * \param s integer right shift + * \return The constructed expression. + */ +TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index d4911d9..feeec1f 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -131,6 +131,14 @@ def clip_compute(attrs, inputs, output_type): register_injective_schedule("clip") +# fixed point multiply +@register_compute("fixed_point_multiply") +def fixed_point_multiply_compute(attrs, inputs, output_type): + assert len(inputs) == 1 + return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)] + +register_injective_schedule("fixed_point_multiply") + # full @script def _full_shape_func(shape): diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index a02e08d..c002c8b 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1034,6 +1034,27 @@ def clip(a, a_min, a_max): """ return _make.clip(a, a_min, a_max) +def fixed_point_multiply(data, multiplier, shift): + """Fixed point multiplication between data and a fixed point + constant expressed as multiplier * 2^(-shift), where multiplier + is a Q-number with 31 fractional bits + + Parameters + ---------- + data : relay.Expr + The input tensor. + multiplier : int + The integer multiplier of the fixed point constant. + a_max : float + The integer shift of the fixed point constant. + + Returns + ------- + result : relay.Expr + The output of the fixed point multiplication + """ + return _make.fixed_point_multiply(data, multiplier, shift) + def concatenate(data, axis): """Concatenate the input tensors along the given axis. diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 9dbdc07..1aac55f 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,6 +45,7 @@ from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum +from .op import q_multiply_shift from . import ir_builder from . import transform diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index cbbd59f..1078376 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -965,6 +965,34 @@ def popcount(x): """ return call_intrin(x.dtype, "tir.popcount", x) +def q_multiply_shift(x, y, q, s): + """Execute a multiplication between two Q-numbers x and y + followed by a right shift s. The mathematical expression is: + + out = round(x*y*2^-s) + + More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) + The rounding rule is to the nearest value, rounding half up + (i.e., round(x.1) = x and round (x.5) = x+1) + + Parameters + ---------- + x : PrimExpr + First Q-number + y : PrimExpr + Second Q-number + q : PrimExpr + Number of fractional bits in x and y. Needs to be > 0 + s : PrimExpr + Integer shift + + Returns + ------- + y : PrimExpr + The result. + """ + return call_intrin('int32', "tir.q_multiply_shift", x, y, q, s) + def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 958b8b5..fc61661 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -290,6 +290,29 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma .set_attrs_type() .set_support_level(3); +// relay.fixed_point_multiply +TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs); + +TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply") + .set_body_typed([](Expr a, int32_t multiplier, int32_t shift) { + auto attrs = make_object(); + attrs->multiplier = multiplier; + attrs->shift = shift; + static const Op& op = Op::Get("fixed_point_multiply"); + return Call(op, {a}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("fixed_point_multiply") + .describe(R"code(fixed point multiplication)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kElemWise) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attrs_type() + .set_support_level(10); + RELAY_REGISTER_UNARY_OP("floor") .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index bdeaf05..222d910 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -153,9 +153,19 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, static_cast(input_scale_float) / static_cast(output_scale_float); // Skip if input and output scales are same. if (!IsEqualScalar(input_scale, output_scale)) { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); + + const bool is_upward_rounding = (param->rounding == "UPWARD"); + + // When using upward rounding (i.e., x.5 rounded to x+1), leverage + // the FixedPointMultiply operator scaled_int32_t = - FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding); + (is_upward_rounding + ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift) + : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape)); } + } else { // This is per-channel (per=axis) quantization. std::vector double_multipliers; diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 4daa5c9..113038e 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -30,25 +30,6 @@ namespace tvm { namespace relay { namespace qnn { -/* - * \brief Convert FP32 representation into fixed point representation. - * \param double_multplier The input FP32 number. - * \return The pair of multiplier and shift for fixed point representation. - * \note Converts a floating point number so that it can be represented by - * integers. The representation is - * float_number = (significand) * 2^(exponent) - * - * The significand is a number between 0.5 and 1. This is represented by - * an integer number. For example, if it is int32, then the decimal point - * exists between bit 31 and 30 from LSB (or between first and second bit - * from the left). - * - * Some examples are - * 0.25 = (0.5) * 2^(-1) - * 0.125 = (0.5) * 2^(-2) - * - * Credit to TFLite reference implementation. - */ std::pair GetFixedPointMultiplierShift(double double_multiplier) { int32_t significand, exponent; if (double_multiplier == 0.) { @@ -75,8 +56,8 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie return std::make_pair(significand, exponent); } -Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, - const std::string& rounding) { +Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, + const Array& input_shape) { // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); @@ -109,19 +90,15 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); Expr round_scalar; - if (rounding == "UPWARD") { - round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); - } else if (rounding == "TONEAREST") { - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); - auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); - } else { - LOG(FATAL) << "Rounding mode " << rounding << " not supported."; - } + auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + + auto zero_t = Zeros(input_shape, hp_dtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + // Add the rounding scalar. tensor = Add(tensor, round_scalar); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 736b736..72eb2a4 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -70,6 +70,27 @@ static inline int32_t GetQmax(const DataType& dtype) { } } +/* + * \brief Convert FP32 representation into fixed point representation. + * \param double_multplier The input FP32 number. + * \return The pair of multiplier and shift for fixed point representation. + * \note Converts a floating point number so that it can be represented by + * integers. The representation is + * float_number = (significand) * 2^(exponent) + * + * The significand is a number between 0.5 and 1. This is represented by + * an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit + * from the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + * + * Credit to TFLite reference implementation. + */ +std::pair GetFixedPointMultiplierShift(double double_multiplier); + Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, @@ -94,13 +115,12 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { /* * \brief Fixed point multiplication between integer tensor with floating point - scalar. + * scalar. This implementation rounds to the nearest value when it is midway + * between two representable values. * \param tensor The quantized input tensor of dtype int64. * \param multiplier The scalar multiplier. * \param input_shape Shape of the input tensor. - * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value - is midway between" "two representable values. - * \return The sequence of Relay ops for fixed point multiplication. + * \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding. * \note Original compuation is scale_fp32 * quantized_tensor. To convert into * integer computation, the multiplication with fp32 scalar can be @@ -114,8 +134,8 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { * 2) Round the result. * 3) Right shift the result */ -Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, - const std::string& rounding); +Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, + const Array& input_shape); /* * \brief Fixed point multiplication between integer tensor with floating point diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index ddf945a..ace2c24 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -113,7 +113,14 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding); + if (cfg->rounding == "UPWARD") { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor); + data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); + } else { + data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape); + } + return Cast(data, dtype); } } @@ -164,8 +171,15 @@ Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const Ob return QRealizeIntExpr(data, dom_scale, n->dtype); } else { data = Cast(data, DataType::Int(64)); - data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, - ref_call->type_as()->shape, cfg->rounding); + if (cfg->rounding == "UPWARD") { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = + qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm); + data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); + } else { + data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm, + ref_call->type_as()->shape); + } data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); return QRealizeIntExpr(data, dom_scale, n->dtype); } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index adbd1bd..b3e3681 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -495,6 +495,14 @@ inline Expr Round(Expr x) { inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); } +inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { + static const Op& op = Op::Get("fixed_point_multiply"); + auto attrs = make_object(); + attrs->multiplier = multiplier; + attrs->shift = shift; + return Call(op, {x}, Attrs(attrs), {}); +} + inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 31fadf1..fa0ee38 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -115,6 +115,52 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") *rv = isinf(call->args[0]); }); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + using tir::make_const; + + PrimExpr e = args[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + + PrimExpr x = call->args[0]; + PrimExpr y = call->args[1]; + PrimExpr q = call->args[2]; + PrimExpr s = call->args[3]; + + // Only int32 types are supported (any number of lanes is allowed) + CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + CHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); + + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + + // 1) Calculating the integer multiplier and integer shift + PrimExpr zero = make_const(s.dtype(), 0); + PrimExpr left_shift = tir::Select(s > zero, s, zero); + PrimExpr right_shift = tir::Select(s > zero, zero, -s); + + // 2) Cast and Multiply the integer multiplier + PrimExpr one = make_const(hp_dtype, 1); + x = cast(hp_dtype, x); + y = cast(hp_dtype, y); + x = tir::Select(left_shift != zero, x << left_shift, x); + + // 3) Perform the multiplication in higher precision. + x = x * y; + + // 4) Find the rounding scalar + PrimExpr total_right_shift = right_shift + q; + PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); + x = x + pos_rounding_value; + + // 5) Simply right shift the result to get the final output. + x = x >> total_right_shift; + + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + *rv = cast(lp_dtype, x); + }); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index d23662c..3afb881 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -89,6 +89,11 @@ TIR_DEFINE_BUILTIN_FUNC(if_then_else) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift) + .set_num_inputs(3) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TVectorizable", true); + TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index a0ba8d6..75a483c 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -90,6 +90,12 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}); } +// Q-multiplication +PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) { + return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(), + {x, y, q, s}); +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 5372ef8..1c529d8 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -40,8 +40,15 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - IntrinInjecter(arith::Analyzer* analyzer, std::string target) : IRMutatorWithAnalyzer(analyzer) { + IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") + : IRMutatorWithAnalyzer(analyzer) { patterns_.push_back("tvm.intrin.rule." + target + "."); + + bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); + if (is_llvm_aarch64) { + patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64."); + } + patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); if (target == "stackvm") { @@ -287,7 +294,9 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - n->body = IntrinInjecter(&analyzer, target.value()->id->name)(std::move(n->body)); + auto mtriple = target.value()->GetAttr("mtriple", ""); + n->body = + IntrinInjecter(&analyzer, target.value()->id->name, mtriple.value())(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 0445c98..76f10d6 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -84,6 +84,22 @@ def test_clip(): ref_res = np.clip(data, 1., 4.) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) +def test_fixed_point_multiply(): + # Test 23 * 1/16 + # [m,s] = [0.5, -3] = frexp(1/16) + # M = 0.5*2^31 = 1073741824 + # so M = 1073741824 and s = -3 + + a = relay.var("a", relay.TensorType((10, 4), "int32")) + y = relay.fixed_point_multiply(a, 1073741824, -3) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((10, 4), "int32") + + data = 23*np.ones((10, 4)).astype('int32') + intrp = create_executor() + op_res = intrp.evaluate(y, { a: relay.const(data) }) + ref_res = np.ones((10, 4)).astype('int32') + np.testing.assert_allclose(op_res.asnumpy(), ref_res, atol=1) def test_reinterpret(): a = relay.var("a", relay.TensorType((1000, 4), "float32")) @@ -1079,3 +1095,4 @@ if __name__ == "__main__": test_isinf() test_unravel_index() test_sparse_to_dense() + test_fixed_point_multiply() diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py index 63d96bb..e97de56 100644 --- a/topi/python/topi/arm_cpu/conv2d_gemm.py +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -119,7 +119,7 @@ def compute_conv2d_gemm_without_weight_transform(cfg, C = te.compute((batches, M, N), lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)], - name="C", tag='injective') + name="C") # --- Produce the conv output out_shape = (batches, OH, OW, OC) @@ -129,7 +129,7 @@ def compute_conv2d_gemm_without_weight_transform(cfg, return out # Schedules -def schedule_conv2d_gemm(cfg, s, out): +def schedule_conv2d_gemm(cfg, s, out, final_out): """Create schedule for tensors""" C = out.op.input_tensors[0] C_interleaved = C.op.input_tensors[0] @@ -172,8 +172,11 @@ def schedule_conv2d_gemm(cfg, s, out): s[C_interleaved].tensorize(yi, gem_v_dotprod) # Output transform - N, OH, OW, OC = out.shape - s[C].split(C.op.axis[1], OW) - s[C].compute_at(s[out], out.op.axis[3]) + if out != final_out: + n, h, w, c = out.op.axis + _, inner = s[out].split(c, 4) + s[C].compute_at(s[out], inner) + s[out].vectorize(inner) + return s diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py index 5a895c0..89a37fa 100644 --- a/topi/python/topi/arm_cpu/conv2d_int8.py +++ b/topi/python/topi/arm_cpu/conv2d_int8.py @@ -137,11 +137,23 @@ def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, paddi def schedule_conv2d_NHWC_quantized(cfg, outs): """Create schedule for tensors""" s = te.create_schedule([x.op for x in outs]) + # Vectorize the output and then inline all the rest + out = outs[0] + n, h, w, c = out.op.axis + outer, inner = s[out].split(c, 4) + s[out].vectorize(inner) def _callback(op): """Traverse operators from computation graph""" if op.name == "conv2d_gemm_output": - schedule_conv2d_gemm(cfg, s, op.output(0)) + conv_out = op.output(0) + schedule_conv2d_gemm(cfg, s, conv_out, out) + if out != conv_out: + s[conv_out].compute_at(s[out], inner) + else: + C = conv_out.op.input_tensors[0] + s[C].compute_at(s[out], inner) + traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 9665200..3e3c73d 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -62,9 +62,10 @@ def schedule_injective(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) x = outs[0] + if list(s[x].op.axis): # do not vectorize for broadcast - (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8) + (io, ii) = s[x].split(list(s[x].op.axis)[-1], 4) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index dfa2f05..270bfbe 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -451,3 +451,55 @@ def dot_int8_int8_int32(int32_lanes, dtype='uint'): return te.decl_tensor_intrin( C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, default_buffer_params=buffer_params) + +def _q_multiply_shift_arm(op): + """ + Implementation of q_multiply_shift_arm through arm intrinsics + sqrdmulh and srshl when q == 31. + + Please note that this is introducing a small round-up error for + some corner cases. This is because we are rounding twice instead + than only once. I.e.: + + * original q_multiply_shift: round(x*y*2^-s) + * arm q_multiply_shift: round(round(x*y)*2^-s) + """ + x = op.args[0] + y = op.args[1] + q = op.args[2] + s = op.args[3] + + # Don't use this intrinsic if we don't have a int32x4 vector + # or if we are not multiplying q31 numbers + if x.dtype != "int32x4" or q.value != 31: + return op + + # Case 1, shift is negative + sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.sqrdmulh', + tvm.tir.const(2, 'uint32'), + x, + y) + + fixup = (sqrdmulh & (-s)) >> 31 + fixed_up_x = (sqrdmulh + fixup) + out_1 = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.srshl', + tvm.tir.const(2, 'uint32'), + sqrdmulh, + s) + + # Case 2, shift is positive + x = x * (1 << (s)) + out_2 = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.sqrdmulh', + tvm.tir.const(2, 'uint32'), + x, + y) + + # Select depending on the shift + return tvm.tir.Select(s < 0, out_1, out_2) + +tvm.target.intrin.register_intrin_rule("llvm.aarch64", + "q_multiply_shift", + _q_multiply_shift_arm, override=True) diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index b4228a4..046b103 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -612,6 +612,33 @@ def clip(x, a_min, a_max): return tvm.te.max(tvm.te.min(value, const_max), const_min) return te.compute(x.shape, _compute) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def fixed_point_multiply(x, multiplier, shift): + """Fixed point multiplication between data and a fixed point + constant expressed as multiplier * 2^(-shift), where multiplier + is a Q-number with 31 fractional bits + + Parameters + ---------- + x : tvm.te.Tensor or Expr + Input argument. + multiplier : int + Multiplier of a fixed floating point number described as multiplier*2^(-shift). + shift : int + Shift of a fixed floating point number described as multiplier*2^(-shift). + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + def _compute(*indices): + value = x(*indices) + return tvm.tir.q_multiply_shift(value, + tvm.tir.const(multiplier, 'int32'), + tvm.tir.const(31, 'int32'), + tvm.tir.const(shift, 'int32')) + return te.compute(x.shape, _compute) def cast(x, dtype): """Cast input to specified data type.