}
};
+/*! \brief Attributes for FixedPointMultiply operator */
+struct FixedPointMultiplyAttrs : public tvm::AttrsNode<FixedPointMultiplyAttrs> {
+ int32_t multiplier;
+ int32_t shift;
+
+ TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") {
+ TVM_ATTR_FIELD(multiplier)
+ .describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)");
+ TVM_ATTR_FIELD(shift).describe(
+ "Shift of a fixed floating point number described as multiplier*2^(shift)");
+ }
+};
+
/*! \brief Attributes for LayoutTransform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
TVM_DLL const Op& large_uint_imm();
/*!
+ * \brief Execute a multiplication between two Q-numbers x and y
+ * followed by a right shift s
+ * The default rounding rule is to the nearest value, rounding half up
+ * (i.e., round(x.1) = x and round (x.5) = x+1)
+ */
+TVM_DLL const Op& q_multiply_shift();
+
+/*!
* \brief See pesudo code
*
* Handle address_of(Load *op) {
*/
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
+/*!
+ * \brief Execute a multiplication between two Q-numbers x and y
+ * followed by a right shift s. The mathematical expression is:
+ *
+ * out = round(x*y*2^-s)
+ *
+ * Please note that the two Q-numbers x and y are supposed to have
+ * the same number of fractional bits q.
+ *
+ * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
+ *
+ * The rounding rule is to the nearest value, rounding half up
+ * (i.e., round(x.1) = x and round (x.5) = x+1)
+ * \param x first Q-number
+ * \param y second Q-number
+ * \param q number of fractional bits in x and y. Needs to be > 0
+ * \param s integer right shift
+ * \return The constructed expression.
+ */
+TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s);
+
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
register_injective_schedule("clip")
+# fixed point multiply
+@register_compute("fixed_point_multiply")
+def fixed_point_multiply_compute(attrs, inputs, output_type):
+ assert len(inputs) == 1
+ return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)]
+
+register_injective_schedule("fixed_point_multiply")
+
# full
@script
def _full_shape_func(shape):
"""
return _make.clip(a, a_min, a_max)
+def fixed_point_multiply(data, multiplier, shift):
+ """Fixed point multiplication between data and a fixed point
+ constant expressed as multiplier * 2^(-shift), where multiplier
+ is a Q-number with 31 fractional bits
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input tensor.
+ multiplier : int
+ The integer multiplier of the fixed point constant.
+ a_max : float
+ The integer shift of the fixed point constant.
+
+ Returns
+ -------
+ result : relay.Expr
+ The output of the fixed point multiplication
+ """
+ return _make.fixed_point_multiply(data, multiplier, shift)
+
def concatenate(data, axis):
"""Concatenate the input tensors along the given axis.
from .op import isnan, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
+from .op import q_multiply_shift
from . import ir_builder
from . import transform
"""
return call_intrin(x.dtype, "tir.popcount", x)
+def q_multiply_shift(x, y, q, s):
+ """Execute a multiplication between two Q-numbers x and y
+ followed by a right shift s. The mathematical expression is:
+
+ out = round(x*y*2^-s)
+
+ More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
+ The rounding rule is to the nearest value, rounding half up
+ (i.e., round(x.1) = x and round (x.5) = x+1)
+
+ Parameters
+ ----------
+ x : PrimExpr
+ First Q-number
+ y : PrimExpr
+ Second Q-number
+ q : PrimExpr
+ Number of fractional bits in x and y. Needs to be > 0
+ s : PrimExpr
+ Integer shift
+
+ Returns
+ -------
+ y : PrimExpr
+ The result.
+ """
+ return call_intrin('int32', "tir.q_multiply_shift", x, y, q, s)
+
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
.set_attrs_type<ClipAttrs>()
.set_support_level(3);
+// relay.fixed_point_multiply
+TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply")
+ .set_body_typed([](Expr a, int32_t multiplier, int32_t shift) {
+ auto attrs = make_object<FixedPointMultiplyAttrs>();
+ attrs->multiplier = multiplier;
+ attrs->shift = shift;
+ static const Op& op = Op::Get("fixed_point_multiply");
+ return Call(op, {a}, Attrs(attrs), {});
+ });
+
+RELAY_REGISTER_OP("fixed_point_multiply")
+ .describe(R"code(fixed point multiplication)code" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_type_rel("Identity", IdentityRel)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attrs_type<FixedPointMultiplyAttrs>()
+ .set_support_level(10);
+
RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
+ int32_t fixed_point_multiplier, shift;
+ std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
+
+ const bool is_upward_rounding = (param->rounding == "UPWARD");
+
+ // When using upward rounding (i.e., x.5 rounded to x+1), leverage
+ // the FixedPointMultiply operator
scaled_int32_t =
- FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
+ (is_upward_rounding
+ ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
+ : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
}
+
} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
namespace relay {
namespace qnn {
-/*
- * \brief Convert FP32 representation into fixed point representation.
- * \param double_multplier The input FP32 number.
- * \return The pair of multiplier and shift for fixed point representation.
- * \note Converts a floating point number so that it can be represented by
- * integers. The representation is
- * float_number = (significand) * 2^(exponent)
- *
- * The significand is a number between 0.5 and 1. This is represented by
- * an integer number. For example, if it is int32, then the decimal point
- * exists between bit 31 and 30 from LSB (or between first and second bit
- * from the left).
- *
- * Some examples are
- * 0.25 = (0.5) * 2^(-1)
- * 0.125 = (0.5) * 2^(-2)
- *
- * Credit to TFLite reference implementation.
- */
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
int32_t significand, exponent;
if (double_multiplier == 0.) {
return std::make_pair(significand, exponent);
}
-Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
- const std::string& rounding) {
+Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
+ const Array<IndexExpr>& input_shape) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
Expr round_scalar;
- if (rounding == "UPWARD") {
- round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
- } else if (rounding == "TONEAREST") {
- auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
- auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
- auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
- auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
- auto zero_t = Zeros(input_shape, hp_dtype);
- round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
- } else {
- LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
- }
+ auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
+ auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
+ auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
+ auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
+
+ auto zero_t = Zeros(input_shape, hp_dtype);
+ round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
}
}
+/*
+ * \brief Convert FP32 representation into fixed point representation.
+ * \param double_multplier The input FP32 number.
+ * \return The pair of multiplier and shift for fixed point representation.
+ * \note Converts a floating point number so that it can be represented by
+ * integers. The representation is
+ * float_number = (significand) * 2^(exponent)
+ *
+ * The significand is a number between 0.5 and 1. This is represented by
+ * an integer number. For example, if it is int32, then the decimal point
+ * exists between bit 31 and 30 from LSB (or between first and second bit
+ * from the left).
+ *
+ * Some examples are
+ * 0.25 = (0.5) * 2^(-1)
+ * 0.125 = (0.5) * 2^(-2)
+ *
+ * Credit to TFLite reference implementation.
+ */
+std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier);
+
Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
/*
* \brief Fixed point multiplication between integer tensor with floating point
- scalar.
+ * scalar. This implementation rounds to the nearest value when it is midway
+ * between two representable values.
* \param tensor The quantized input tensor of dtype int64.
* \param multiplier The scalar multiplier.
* \param input_shape Shape of the input tensor.
- * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
- is midway between" "two representable values.
- * \return The sequence of Relay ops for fixed point multiplication.
+ * \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding.
* \note Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 scalar can be
* 2) Round the result.
* 3) Right shift the result
*/
-Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
- const std::string& rounding);
+Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
+ const Array<IndexExpr>& input_shape);
/*
* \brief Fixed point multiplication between integer tensor with floating point
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
- data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding);
+ if (cfg->rounding == "UPWARD") {
+ int32_t fixed_point_multiplier, shift;
+ std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor);
+ data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
+ } else {
+ data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
+ }
+
return Cast(data, dtype);
}
}
return QRealizeIntExpr(data, dom_scale, n->dtype);
} else {
data = Cast(data, DataType::Int(64));
- data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
- ref_call->type_as<TensorTypeNode>()->shape, cfg->rounding);
+ if (cfg->rounding == "UPWARD") {
+ int32_t fixed_point_multiplier, shift;
+ std::tie(fixed_point_multiplier, shift) =
+ qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm);
+ data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
+ } else {
+ data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm,
+ ref_call->type_as<TensorTypeNode>()->shape);
+ }
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); }
+inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) {
+ static const Op& op = Op::Get("fixed_point_multiply");
+ auto attrs = make_object<FixedPointMultiplyAttrs>();
+ attrs->multiplier = multiplier;
+ attrs->shift = shift;
+ return Call(op, {x}, Attrs(attrs), {});
+}
+
inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return Call(op, {lhs, rhs}, Attrs(), {});
*rv = isinf(call->args[0]);
});
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift")
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ using tir::make_const;
+
+ PrimExpr e = args[0];
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ CHECK(call != nullptr);
+
+ PrimExpr x = call->args[0];
+ PrimExpr y = call->args[1];
+ PrimExpr q = call->args[2];
+ PrimExpr s = call->args[3];
+
+ // Only int32 types are supported (any number of lanes is allowed)
+ CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
+ CHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);
+
+ DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
+ DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
+
+ // 1) Calculating the integer multiplier and integer shift
+ PrimExpr zero = make_const(s.dtype(), 0);
+ PrimExpr left_shift = tir::Select(s > zero, s, zero);
+ PrimExpr right_shift = tir::Select(s > zero, zero, -s);
+
+ // 2) Cast and Multiply the integer multiplier
+ PrimExpr one = make_const(hp_dtype, 1);
+ x = cast(hp_dtype, x);
+ y = cast(hp_dtype, y);
+ x = tir::Select(left_shift != zero, x << left_shift, x);
+
+ // 3) Perform the multiplication in higher precision.
+ x = x * y;
+
+ // 4) Find the rounding scalar
+ PrimExpr total_right_shift = right_shift + q;
+ PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
+ x = x + pos_rounding_value;
+
+ // 5) Simply right shift the result to get the final output.
+ x = x >> total_right_shift;
+
+ // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
+ *rv = cast(lp_dtype, x);
+ });
+
} // namespace intrin
} // namespace codegen
} // namespace tvm
.set_num_inputs(3)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
+TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift)
+ .set_num_inputs(3)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+ .set_attr<TVectorizable>("TVectorizable", true);
+
TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));
{make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)});
}
+// Q-multiplication
+PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) {
+ return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(),
+ {x, y, q, s});
+}
+
// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*)
if (lhs.dtype() == rhs.dtype()) return;
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
- IntrinInjecter(arith::Analyzer* analyzer, std::string target) : IRMutatorWithAnalyzer(analyzer) {
+ IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
+ : IRMutatorWithAnalyzer(analyzer) {
patterns_.push_back("tvm.intrin.rule." + target + ".");
+
+ bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
+ if (is_llvm_aarch64) {
+ patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+ }
+
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
if (target == "stackvm") {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) << "LowerIntrin: Require the target attribute";
arith::Analyzer analyzer;
- n->body = IntrinInjecter(&analyzer, target.value()->id->name)(std::move(n->body));
+ auto mtriple = target.value()->GetAttr<runtime::String>("mtriple", "");
+ n->body =
+ IntrinInjecter(&analyzer, target.value()->id->name, mtriple.value())(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
ref_res = np.clip(data, 1., 4.)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
+def test_fixed_point_multiply():
+ # Test 23 * 1/16
+ # [m,s] = [0.5, -3] = frexp(1/16)
+ # M = 0.5*2^31 = 1073741824
+ # so M = 1073741824 and s = -3
+
+ a = relay.var("a", relay.TensorType((10, 4), "int32"))
+ y = relay.fixed_point_multiply(a, 1073741824, -3)
+ yy = run_infer_type(y)
+ assert yy.checked_type == relay.TensorType((10, 4), "int32")
+
+ data = 23*np.ones((10, 4)).astype('int32')
+ intrp = create_executor()
+ op_res = intrp.evaluate(y, { a: relay.const(data) })
+ ref_res = np.ones((10, 4)).astype('int32')
+ np.testing.assert_allclose(op_res.asnumpy(), ref_res, atol=1)
def test_reinterpret():
a = relay.var("a", relay.TensorType((1000, 4), "float32"))
test_isinf()
test_unravel_index()
test_sparse_to_dense()
+ test_fixed_point_multiply()
C = te.compute((batches, M, N),
lambda b, x, y:
C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
- name="C", tag='injective')
+ name="C")
# --- Produce the conv output
out_shape = (batches, OH, OW, OC)
return out
# Schedules
-def schedule_conv2d_gemm(cfg, s, out):
+def schedule_conv2d_gemm(cfg, s, out, final_out):
"""Create schedule for tensors"""
C = out.op.input_tensors[0]
C_interleaved = C.op.input_tensors[0]
s[C_interleaved].tensorize(yi, gem_v_dotprod)
# Output transform
- N, OH, OW, OC = out.shape
- s[C].split(C.op.axis[1], OW)
- s[C].compute_at(s[out], out.op.axis[3])
+ if out != final_out:
+ n, h, w, c = out.op.axis
+ _, inner = s[out].split(c, 4)
+ s[C].compute_at(s[out], inner)
+ s[out].vectorize(inner)
+
return s
def schedule_conv2d_NHWC_quantized(cfg, outs):
"""Create schedule for tensors"""
s = te.create_schedule([x.op for x in outs])
+ # Vectorize the output and then inline all the rest
+ out = outs[0]
+ n, h, w, c = out.op.axis
+ outer, inner = s[out].split(c, 4)
+ s[out].vectorize(inner)
def _callback(op):
"""Traverse operators from computation graph"""
if op.name == "conv2d_gemm_output":
- schedule_conv2d_gemm(cfg, s, op.output(0))
+ conv_out = op.output(0)
+ schedule_conv2d_gemm(cfg, s, conv_out, out)
+ if out != conv_out:
+ s[conv_out].compute_at(s[out], inner)
+ else:
+ C = conv_out.op.input_tensors[0]
+ s[C].compute_at(s[out], inner)
+
traverse_inline(s, outs[0].op, _callback)
return s
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
x = outs[0]
+
if list(s[x].op.axis):
# do not vectorize for broadcast
- (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
+ (io, ii) = s[x].split(list(s[x].op.axis)[-1], 4)
s[x].vectorize(ii)
tvm.te.schedule.AutoInlineInjective(s)
return te.decl_tensor_intrin(
C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
default_buffer_params=buffer_params)
+
+def _q_multiply_shift_arm(op):
+ """
+ Implementation of q_multiply_shift_arm through arm intrinsics
+ sqrdmulh and srshl when q == 31.
+
+ Please note that this is introducing a small round-up error for
+ some corner cases. This is because we are rounding twice instead
+ than only once. I.e.:
+
+ * original q_multiply_shift: round(x*y*2^-s)
+ * arm q_multiply_shift: round(round(x*y)*2^-s)
+ """
+ x = op.args[0]
+ y = op.args[1]
+ q = op.args[2]
+ s = op.args[3]
+
+ # Don't use this intrinsic if we don't have a int32x4 vector
+ # or if we are not multiplying q31 numbers
+ if x.dtype != "int32x4" or q.value != 31:
+ return op
+
+ # Case 1, shift is negative
+ sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype,
+ 'llvm.aarch64.neon.sqrdmulh',
+ tvm.tir.const(2, 'uint32'),
+ x,
+ y)
+
+ fixup = (sqrdmulh & (-s)) >> 31
+ fixed_up_x = (sqrdmulh + fixup)
+ out_1 = tvm.tir.call_llvm_intrin(op.dtype,
+ 'llvm.aarch64.neon.srshl',
+ tvm.tir.const(2, 'uint32'),
+ sqrdmulh,
+ s)
+
+ # Case 2, shift is positive
+ x = x * (1 << (s))
+ out_2 = tvm.tir.call_llvm_intrin(op.dtype,
+ 'llvm.aarch64.neon.sqrdmulh',
+ tvm.tir.const(2, 'uint32'),
+ x,
+ y)
+
+ # Select depending on the shift
+ return tvm.tir.Select(s < 0, out_1, out_2)
+
+tvm.target.intrin.register_intrin_rule("llvm.aarch64",
+ "q_multiply_shift",
+ _q_multiply_shift_arm, override=True)
return tvm.te.max(tvm.te.min(value, const_max), const_min)
return te.compute(x.shape, _compute)
+@tvm.te.tag_scope(tag=tag.ELEMWISE)
+def fixed_point_multiply(x, multiplier, shift):
+ """Fixed point multiplication between data and a fixed point
+ constant expressed as multiplier * 2^(-shift), where multiplier
+ is a Q-number with 31 fractional bits
+
+ Parameters
+ ----------
+ x : tvm.te.Tensor or Expr
+ Input argument.
+ multiplier : int
+ Multiplier of a fixed floating point number described as multiplier*2^(-shift).
+ shift : int
+ Shift of a fixed floating point number described as multiplier*2^(-shift).
+
+ Returns
+ -------
+ y : tvm.te.Tensor
+ The result.
+ """
+ def _compute(*indices):
+ value = x(*indices)
+ return tvm.tir.q_multiply_shift(value,
+ tvm.tir.const(multiplier, 'int32'),
+ tvm.tir.const(31, 'int32'),
+ tvm.tir.const(shift, 'int32'))
+ return te.compute(x.shape, _compute)
def cast(x, dtype):
"""Cast input to specified data type.