From: Tianqi Chen Date: Fri, 27 Sep 2019 14:45:25 +0000 (-0700) Subject: [ARITH] Use explicit div mode in python. (#4014) X-Git-Tag: upstream/0.7.0~1858 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2ded2d8caf3279c2fb9dbe16c275f17f8a61d6f2;p=platform%2Fupstream%2Ftvm.git [ARITH] Use explicit div mode in python. (#4014) --- diff --git a/docs/api/python/tvm.rst b/docs/api/python/tvm.rst index a04502b..b517195 100644 --- a/docs/api/python/tvm.rst +++ b/docs/api/python/tvm.rst @@ -35,6 +35,13 @@ The user facing API for computation declaration. tvm.thread_axis tvm.comm_reducer tvm.sum + tvm.div + tvm.indexdiv + tvm.indexmod + tvm.truncdiv + tvm.truncmod + tvm.floordiv + tvm.floormod tvm.min tvm.max tvm.tag_scope @@ -53,6 +60,13 @@ The user facing API for computation declaration. .. autofunction:: tvm.thread_axis .. autofunction:: tvm.comm_reducer .. autofunction:: tvm.sum +.. autofunction:: tvm.div +.. autofunction:: tvm.indexdiv +.. autofunction:: tvm.indexmod +.. autofunction:: tvm.truncdiv +.. autofunction:: tvm.truncmod +.. autofunction:: tvm.floordiv +.. autofunction:: tvm.floormod .. autofunction:: tvm.min .. autofunction:: tvm.max .. autofunction:: tvm.tag_scope diff --git a/python/tvm/api.py b/python/tvm/api.py index b54d364..e7523bd 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"): reducer.__doc__ = doc_str.format(name) return reducer +def div(a, b): + """Compute a / b as in C/C++ semantics. + + Parameters + ---------- + a : Expr + The left hand operand, known to be non-negative. + + b : Expr + The right hand operand, known to be non-negative. + + Returns + ------- + res : Expr + The result expression. + Note + ---- + When operands are integers, returns truncdiv(a, b). + """ + return _make._OpDiv(a, b) + + +def indexdiv(a, b): + """Compute floor(a / b) where a and b are non-negative. + + Parameters + ---------- + a : Expr + The left hand operand, known to be non-negative. + + b : Expr + The right hand operand, known to be non-negative. + + Returns + ------- + res : Expr + The result expression. + + Note + ---- + Use this function to split non-negative indices. + This function may take advantage of operands' + non-negativeness. + """ + return _make._OpIndexDiv(a, b) + + +def indexmod(a, b): + """Compute the remainder of indexdiv. a and b are non-negative. + + Parameters + ---------- + a : Expr + The left hand operand, known to be non-negative. + + b : Expr + The right hand operand, known to be non-negative. + + Returns + ------- + res : Expr + The result expression. + + Note + ---- + Use this function to split non-negative indices. + This function may take advantage of operands' + non-negativeness. + """ + return _make._OpIndexMod(a, b) + def truncdiv(a, b): """Compute the truncdiv of two expressions. diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index c6530dd..aceab6d 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -101,8 +101,11 @@ def convolution_inference( assert isinstance(stride, list) and len(stride) == 2 batch, _, input_height, input_width = data.shape output_channels, _, kernel_height, kernel_width = kernel.shape - output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1 - output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1 + idxdiv = _api.indexdiv + output_height = idxdiv( + input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1 + output_width = idxdiv( + input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1 return _api.extern( (batch, output_channels, output_height, output_width), @@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform( batch, _, input_height, input_width = data.shape output_channels, _, _, _ = transformed_kernel.shape kernel_height, kernel_width = (3, 3) - output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1 - output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1 + idxdiv = _api.indexdiv + output_height = idxdiv(input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1 + output_width = idxdiv(input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1 return _api.extern( (batch, output_channels, output_height, output_width), diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 0d70ac9..a8bd651 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node. # pylint: disable=missing-docstring from __future__ import absolute_import as _abs from ._ffi.node import NodeBase, NodeGeneric, register_node +from ._ffi.runtime_ctypes import TVMType, TypeCode from . import make as _make from . import generic as _generic from . import _api_internal +def div_ambiguity_error(): + return RuntimeError( + "TVM supports multiple types of integer divisions, " + + "please call div, indexdiv/indexmod, floordiv/floormod " + + " or truncdiv/truncmod directly to avoid ambiguity in the code.") + +def _dtype_is_int(value): + if isinstance(value, int): + return True + return (isinstance(value, ExprOp) and + TVMType(value.dtype).type_code == TypeCode.INT) + + class ExprOp(object): def __add__(self, other): return _generic.add(self, other) @@ -58,24 +72,35 @@ class ExprOp(object): return _generic.multiply(other, self) def __div__(self, other): + # if _dtype_is_int(self) and _dtype_is_int(other): + # raise div_ambiguity_error() return _generic.divide(self, other) def __rdiv__(self, other): + # if _dtype_is_int(self) and _dtype_is_int(other): + # raise div_ambiguity_error() return _generic.divide(other, self) def __truediv__(self, other): - return self.__div__(other) + # if _dtype_is_int(self) and _dtype_is_int(other): + # raise div_ambiguity_error() + return _generic.divide(self, other) def __rtruediv__(self, other): - return self.__rdiv__(other) + # if _dtype_is_int(self) and _dtype_is_int(other): + # raise div_ambiguity_error() + return _generic.divide(other, self) def __floordiv__(self, other): - return self.__div__(other) + # return _generic.floordiv(self, other) + return _generic.divide(self, other) def __rfloordiv__(self, other): - return self.__rdiv__(other) + # return _generic.floordiv(other, self) + return _generic.divide(other, self) def __mod__(self, other): + # raise div_ambiguity_error() return _make._OpMod(self, other) def __neg__(self): diff --git a/python/tvm/generic.py b/python/tvm/generic.py index 7f5c3d6..b7bea7f 100644 --- a/python/tvm/generic.py +++ b/python/tvm/generic.py @@ -25,6 +25,7 @@ from . import make as _make #Operator precedence used when overloading. __op_priority__ = 0 + def add(lhs, rhs): """Generic add operator. @@ -78,7 +79,6 @@ def multiply(lhs, rhs): """ return _make._OpMul(lhs, rhs) - def divide(lhs, rhs): """Generic divide operator. @@ -96,6 +96,23 @@ def divide(lhs, rhs): """ return _make._OpDiv(lhs, rhs) +def floordiv(lhs, rhs): + """Generic floordiv operator. + + Parameters + ---------- + lhs : object + The left operand. + rhs : object + The right operand. + + Returns + ------- + op : tvm.Expr + The result Expr of divide operaton. + """ + return _make._OpFloorDiv(lhs, rhs) + def cast(src, dtype): """Generic cast operator. diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 171a2f8..44db999 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -31,6 +31,7 @@ from . import util from .preprocessor import determine_variable_usage from ..api import all as _all from ..api import any as _any + from ..container import Array from ..tensor import Tensor, Operation from .. import _api_internal as _tvm_internal @@ -78,6 +79,18 @@ class Symbol(Enum): ThreadBind = 10 +def _floordiv(x, y): + if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp): + return _api.floordiv(x, y) + return operator.floordiv(x, y) + + +def _floormod(x, y): + if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp): + return _api.floormod(x, y) + return operator.mod(x, y) + + class HybridParser(ast.NodeVisitor): """Python AST visitor pass which finally lowers it to HalideIR""" @@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor): ast.Sub : operator.sub, ast.Mult : operator.mul, ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv, - ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv, - ast.Mod : operator.mod, + ast.FloorDiv: _floordiv, + ast.Mod : _floormod, ast.BitOr : operator.or_, ast.BitAnd : operator.and_, ast.BitXor : operator.xor, diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 5dddfc6..d1c0f09 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) @script def _arange_shape_func(start, stop, step): out = output_tensor((1,), "int64") - out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0]))) + out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0]))) return out @_reg.register_shape_func("arange", True) @@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim): assert len(newshape) - i > 2, "Not enough dims in new shape for -4" if newshape[i+1] == -1: assert newshape[i+2] != -1, "Split dims cannot both be -1." - out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2]) + out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2]) out[dst_idx+1] = int64(newshape[i+2]) else: out[dst_idx] = int64(newshape[i+1]) if newshape[i+2] == -1: - out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1]) + out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1]) else: out[dst_idx+1] = int64(newshape[i+2]) assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ @@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim): new_size = int64(1) for i in const_range(out.shape[0]): new_size *= out[i] - out[infer_idx] = old_size / new_size + out[infer_idx] = old_size // new_size return out @_reg.register_shape_func("reshape", False) diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index d5fc1fa..b8ee144 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); REGISTER_MAKE_BINARY_OP(_OpDiv, div); REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); +REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv); +REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index e98b366..54616ad 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } + void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) if (op->type.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } + +void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*) + if (op->type.is_int()) + PrintBinaryExpr(op, "//", os, this); + else + PrintBinaryExpr(op, "/", os, this); +} + void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } + +void CodeGenHybrid::VisitExpr_(const FloorMod *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "%", os, this); +} void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index e297556..498838f 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -100,6 +100,8 @@ class CodeGenHybrid : void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) @@ -161,12 +163,12 @@ class CodeGenHybrid : std::string GetUniqueName(std::string prefix); /*! \brief The output code string builder. */ std::stringstream stream; - /*! + /*! * \brief Get or allocate the ID for the given variable. * \param v The given variable. */ std::string GetVarID(const Variable *v); - /*! + /*! * \brief Get or allocate the ID for the given tensor. * \param func The tensor to allocate a name. * \param value_index The value index of the given tensor. diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index e502736..46a0737 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) { } Expr floordiv(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) { } Expr floormod(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 813037a..bbc3c57 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { if (op == nullptr) return ret; int shift; const DataType& dtype = op->type; - if (dtype.is_float()) { - return floor(Div::make(op->a, op->b)); - } CHECK(dtype.is_int() || !dtype.is_uint()); if (is_const_power_of_two_integer(op->b, &shift)) { diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 7f444d2..28a4388 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -33,9 +33,11 @@ def test_mul_sum_simplify(): x * 13 + z * 4 + y * 4 +6) ck.verify(x * 3 - 4 * x + 1, 1 - x) ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2) + tdiv = tvm.truncdiv + tmod = tvm.truncmod # trucdiv - ck.verify((x + y + x + y * 3) / 2, y * 2 + x) - ck.verify((x + y + x + y * 3) % 2, 0) + ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x) + ck.verify(tmod(x + y + x + y * 3, 2), 0) # floordiv fld = tvm.floordiv @@ -51,28 +53,31 @@ def test_split_index_simplify(): x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # trucdiv + tdiv = tvm.truncdiv + tmod = tvm.truncmod + # split div const - ck.verify((x/3) *3 + x % 3, x) - ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x) - ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4) - ck.verify((x % 2) / 8, 0) - ck.verify((x % 2) / 7, 0) - ck.verify(((x % 16) / 2) * 2 / 6, (x % 16) / 6) + ck.verify(tdiv(x, 3) *3 + tmod(x, 3), x) + ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x) + ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4)) + ck.verify(tdiv(tmod(x, 2), 8), 0) + ck.verify(tdiv(tmod(x, 2), 7), 0) + ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 6), tdiv(tmod(x, 16), 6)) # split mod const - ck.verify((x * 8) % 16, (x % 2) * 8) - ck.verify((x * 8) % 2, 0) + ck.verify(tmod((x * 8), 16), tmod(x, 2) * 8) + ck.verify(tmod(x * 8, 2), 0) # simplify then fold ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000)) - ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y) + ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y) # complex fold - ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y) + ck.verify(tdiv(z * 9 + y, 2) * 2 + tmod(z * 9 + y, 2), z * 9 + y) ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True) - ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y) + ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y) # floordiv fld = tvm.floordiv @@ -85,23 +90,24 @@ def test_split_index_simplify(): ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6)) # cannot simplify mixed case, unless we canonicalize into one mode. - ck.verify((x/6) * 2 + fld(x,3) % 2, (x/6) * 2 + fld(x,3) % 2) + ck.verify(tdiv(x,6) * 2 + tmod(fld(x,3), 2), tdiv(x,6) * 2 + tmod(fld(x,3), 2)) def test_div_simplify(): ck = CanonicalChecker() x = tvm.var("x") + tdiv = tvm.truncdiv # truc div - ck.verify((16+48*x)/16, x*3 + 1) + ck.verify(tdiv(16+48*x,16), x*3 + 1) # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0 # (17+48*x)/16 != 1+3*x - ck.verify((17+48*x)/16, (x * 48 + 17) / 16) + ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16)) # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10)) - ck.verify((17+48*x)/16, x * 3 + 1) + ck.verify(tdiv(17 + 48 * x, 16), x * 3 + 1) # Trying expressions that are not simplifiable for any values of the variables - ck.verify((17+47*x)/16, (x * 47 + 17) / 16) + ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16)) # floordiv fld = tvm.floordiv @@ -124,8 +130,10 @@ def test_canonical_mixed(): ck = CanonicalChecker() x = tvm.var("x") z = tvm.const(3, "int32") - ck.verify(x / (z*z) - x / (z*z), 0) - ck.verify(x / (z+z) - x / (z+z), 0) + tdiv = tvm.truncdiv + tmod = tvm.truncmod + ck.verify(tdiv(x, (z*z)) - tdiv(x, (z*z)), 0) + ck.verify(tdiv(x, (z+z)) - tdiv(x, (z+z)), 0) ck.verify(x - 2 < 3, x < 5) ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0) ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0) @@ -207,42 +215,44 @@ def test_reduce_simplify(): tvm.sum(k + j, [k, j])) ck.verify(tvm.sum(A[3], []), A[3]) # The rule below is not typical, removed for now - ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k)) + ck.verify(tvm.sum(tvm.div(k, 10), k), tvm.sum(tvm.const(0, "int32"), k)) def test_simplify_if_then_else(): ck = CanonicalChecker() x = tvm.var("x") y = tvm.var("y") + tdiv = tvm.truncdiv + tmod = tvm.truncmod # simplification that takes condition into account. res = tvm.if_then_else((x * 4 + y) >= 466036, - tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), - (((((x*4) + y) - 466036) % 24528) -24512) % 16, + tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528), + tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16), x), y) res2 = tvm.if_then_else((x * 4) >= 466036 - y, - tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), - (((((x*4) + y) - 466036) % 24528) -24512) % 16, + tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528), + tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16), x), y) expected = tvm.if_then_else( tvm.expr.LE(466036, (x * 4 + y)), - tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)), - (((x*4) + y) - 4) % 16, + tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)), + tmod(((x*4) + y) - 4, 16), x), y) ck.verify(res, expected) ck.verify(res2, expected) # can only simplify if condition - res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3) - expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3) + res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3)) + expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3)) ck.verify(res, ck.analyzer.canonical_simplify(expected)) res = tvm.expr.Select(x >= 10, - tvm.if_then_else(x / 3 > 2, x, 0), 0) + tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0) expected = tvm.expr.Select(x >= 10, x, 0) ck.verify(res, ck.analyzer.canonical_simplify(expected)) res = tvm.expr.Select(x >= 10, - tvm.if_then_else(x / 3 < 2, x, 0), 0) + tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0) ck.verify(res, 0) @@ -250,20 +260,20 @@ def test_complex_cases(): ck = CanonicalChecker() x = tvm.var("x") y = tvm.var("y") - res2 = (((((((((((x*128) + y) % 1296)/36)*2) + 1)/2)*36) + - ((((((x*128) + y) % 36)*2) + 1)/2)) - - (((x*128) + y) % 1296)) + 1) + tdiv = tvm.truncdiv + tmod = tvm.truncmod + res2 = (tdiv(tdiv(tmod(x*128 + y, 1296),36)*2 + 1,2)*36 + + tdiv(tmod((x*128) + y, 36)*2 + 1,2) + - tmod((x*128) + y, 1296) + 1) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5)) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127)) ck.verify(res2, 1) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True) - res3 = ((((((((((x*1024) + y)/65536) + ((((x*1024) + y) % 65536)/256)) - + ((((x*1024) + y) % 256)/16)) + (((x*1024) + y) % 16)) - (y/256)) - - ((y % 256)/16)) - (y % 16)) - (x*4)) - ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4)) - - + res3 = (tdiv(x*1024 + y,65536) + tdiv(tmod(x*1024 + y, 65536),256) + + tdiv(tmod(x*1024 + y, 256),16) + tmod(x*1024 + y, 16) - tdiv(y,256) - + tdiv(tmod(y, 256),16) - tmod(y, 16) - (x*4)) + ck.verify(res3, tdiv((x*1024) + y, 256) - tdiv(y,256) - (x*4)) if __name__ == "__main__": diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index 9b41bdc..4944861 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -38,12 +38,13 @@ def test_dtype_bound(): def test_cast_bound(): analyzer = tvm.arith.Analyzer() x = tvm.var("x", dtype="int8") - bd = analyzer.const_int_bound((x % 3).astype("uint32")) + tmod = tvm.truncmod + bd = analyzer.const_int_bound(tmod(x, 3).astype("uint32")) assert bd.min_value == 0 assert bd.max_value == 2 bd = analyzer.const_int_bound( - (x % 3).astype("float32").astype("int32")) + tmod(x, 3).astype("float32").astype("int32")) assert bd.min_value == -2 assert bd.max_value == 2 @@ -98,47 +99,50 @@ def test_mul_bound(): assert bd.max_value == bd.POS_INF -def test_div_bound(): +def test_truncdiv_bound(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") + tdiv = tvm.truncdiv analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(x / y) + bd = analyzer.const_int_bound(tdiv(x, y)) assert bd.min_value == -2 analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True) - bd = analyzer.const_int_bound(x / y) + bd = analyzer.const_int_bound(tdiv(x, y)) assert bd.min_value == -4 assert bd.max_value == 9 analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True) - bd = analyzer.const_int_bound(x / y) + bd = analyzer.const_int_bound(tdiv(x, y)) assert bd.min_value == bd.NEG_INF assert bd.max_value == bd.POS_INF -def test_mod_bound(): +def test_truncmod_bound(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") + tmod = tvm.truncmod + analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(x % y) + bd = analyzer.const_int_bound(tmod(x, y)) assert bd.min_value == -9 assert bd.max_value == 4 analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(x % y) + bd = analyzer.const_int_bound(tmod(x, y)) assert bd.min_value == -9 assert bd.max_value == 9 analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(x % y) + bd = analyzer.const_int_bound(tmod(x, y)) assert bd.min_value == 0 assert bd.max_value == 9 @@ -253,9 +257,12 @@ def test_shift_and_bound(): def test_mix_index_bound(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") + tdiv = tvm.truncdiv + tmod = tvm.truncmod + analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1)) analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1)) - bd = analyzer.const_int_bound((x % 8) + (x / 8) * 8) + bd = analyzer.const_int_bound(tmod(x, 8) + tdiv(x, 8) * 8) assert bd.min_value == 0 assert bd.max_value == 24 - 1 @@ -263,7 +270,7 @@ def test_mix_index_bound(): assert bd.min_value == 0 assert bd.max_value == 24 * 3 - 1 - bd = analyzer.const_int_bound((x % 7) + (x / 7) * 7) + bd = analyzer.const_int_bound(tmod(x, 7) + tdiv(x, 7) * 7) assert bd.min_value == 0 assert bd.max_value == (23 // 7) * 7 + 6 @@ -273,8 +280,8 @@ if __name__ == "__main__": test_cast_bound() test_add_sub_bound() test_mul_bound() - test_div_bound() - test_mod_bound() + test_truncdiv_bound() + test_truncmod_bound() test_floordiv_bound() test_floormod_bound() test_min_max_bound() diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index e41532d..235c935 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -35,9 +35,11 @@ def test_deduce(): d_s = tvm.arith.IntervalSet(-3, -1) zero = tvm.const(0, "int32") + tdiv = tvm.truncdiv + e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((d - c) /(b*-1) + (-1)) + ans0 = (tdiv(d - c, b*-1) + (-1)) assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs @@ -46,7 +48,7 @@ def test_deduce(): e0 = d*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((d-c)/d - 1) + ans0 = (tdiv(d-c,d) - 1) assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs @@ -56,7 +58,7 @@ def test_deduce(): e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - ans1 = (((c - b) + -1)/4 -1) + ans1 = (tdiv((c - b) + -1,4) -1) assert_expr_equal(res1.max_value, ans1) @@ -79,7 +81,7 @@ def test_deduce(): e3 = (-b)+a*c-d res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - ans3 = 2/c+1 + ans3 = tdiv(2,c)+1 assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index ebf50e1..17cc6f1 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -60,13 +60,14 @@ def test_add_sub(): def test_mul_div(): ck = IntSetChecker() x, y = tvm.var("x"), tvm.var("y") + tdiv = tvm.truncdiv ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20)) ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2)) - ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y)) - ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5)) + ck.verify(tdiv(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y))) + ck.verify(tdiv(x, 2), {x : tvm.arith.IntervalSet(1, 10)}, (0, 5)) fld = tvm.floordiv ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y))) @@ -76,9 +77,10 @@ def test_mul_div(): def test_mod(): ck = IntSetChecker() x, y = tvm.var("x"), tvm.var("y") + tmod = tvm.truncmod ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) - ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) - ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9)) + ck.verify(tmod(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) + ck.verify(tmod(x, 10), {x : tvm.arith.IntervalSet(1, 10)}, (0, 9)) flm = tvm.floormod ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 80fff4e..1ce7197 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -54,7 +54,8 @@ def test_div_shift(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") # not sure if x is non-negative - m = analyzer.modular_set((x * 4 + 2) / 2) + tdiv = tvm.truncdiv + m = analyzer.modular_set(tdiv(x * 4 + 2, 2)) assert m.coeff == 1 assert m.base == 0 # right shift always round down so it is fine @@ -67,7 +68,7 @@ def test_div_shift(): assert m.base == 1 # x is non-negative analyzer.update(x, tvm.arith.ConstIntBound(0, 100)) - m = analyzer.modular_set((x * 4 + 2) / 2) + m = analyzer.modular_set(tdiv(x * 4 + 2, 2)) assert m.coeff == 2 assert m.base == 1 @@ -92,6 +93,7 @@ def test_mix_index(): a = tvm.var("a") b = tvm.var("b") analyzer = tvm.arith.Analyzer() + tdiv = tvm.truncdiv m = analyzer.modular_set(a * 4 + b * 6 + 7) assert m.coeff == 2 assert m.base == 1 @@ -100,11 +102,11 @@ def test_mix_index(): assert m.coeff == 4 assert m.base == 3 - m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3)) + m = analyzer.modular_set(tdiv(a * 4 + 1, b * 8 + 3)) assert m.coeff == 1 assert m.base == 0 - m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4)) + m = analyzer.modular_set((a * 4 + 1) * tdiv(b * 8, 4)) assert m.coeff == 2 assert m.base == 0 @@ -121,11 +123,13 @@ def test_constraint_scope(): a = tvm.var("a") b = tvm.var("b") analyzer = tvm.arith.Analyzer() - with analyzer.constraint_scope(b % 4 == 2): + tmod = tvm.truncmod + + with analyzer.constraint_scope(tmod(b, 4) == 2): m = analyzer.modular_set(b + 1) assert m.coeff == 4 assert m.base == 3 - with analyzer.constraint_scope(a % 2 == 1): + with analyzer.constraint_scope(tmod(a, 2) == 1): m = analyzer.modular_set(b + a * 2) assert m.coeff == 4 assert m.base == 0 @@ -140,15 +144,16 @@ def test_constraint_scope(): def test_intersect(): a = tvm.var("a") analyzer = tvm.arith.Analyzer() - with analyzer.constraint_scope(a % 4 == 1): - with analyzer.constraint_scope(a % 3 == 1): + tmod = tvm.truncmod + with analyzer.constraint_scope(tmod(a, 4) == 1): + with analyzer.constraint_scope(tmod(a, 3) == 1): m = analyzer.modular_set(a) assert m.coeff == 12 assert m.base == 1 - with analyzer.constraint_scope(a % 3 == 2): - with analyzer.constraint_scope(a % 5 == 3): - with analyzer.constraint_scope(a % 7 == 2): + with analyzer.constraint_scope(tmod(a, 3) == 2): + with analyzer.constraint_scope(tmod(a, 5) == 3): + with analyzer.constraint_scope(tmod(a, 7) == 2): m = analyzer.modular_set(a) assert m.coeff == 105 assert m.base == 23 diff --git a/tests/python/unittest/test_autotvm_flop_calculator.py b/tests/python/unittest/test_autotvm_flop_calculator.py index 2507732..54ade9a 100644 --- a/tests/python/unittest/test_autotvm_flop_calculator.py +++ b/tests/python/unittest/test_autotvm_flop_calculator.py @@ -60,11 +60,14 @@ def test_pack_gemm(): k = tvm.reduce_axis((0, L)) bn = 4 + fld = tvm.floordiv + flm = tvm.floormod + A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j]) B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j]) C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj: tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k])) - C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn]) + C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)]) s = tvm.create_schedule([C.op]) assert compute_flop(s) == 2 * N * L * M @@ -119,9 +122,11 @@ def test_average_pool(): OH = (H - KH) + 1 OW = (W - KW) + 1 + C = tvm.compute( (N, CO, OH, OW), - lambda n, co, h, w: tvm.sum(D[n][co][h + kh][w + kw].astype(acc_dtype) / (KW * KH), axis=[kh, kw])) + lambda n, co, h, w: tvm.sum( + tvm.div(D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw])) s = tvm.create_schedule([C.op]) diff --git a/tests/python/unittest/test_build_lower.py b/tests/python/unittest/test_build_lower.py index 082b85f..090120c 100644 --- a/tests/python/unittest/test_build_lower.py +++ b/tests/python/unittest/test_build_lower.py @@ -35,7 +35,7 @@ def test_lower_rfactor(): def test_dependent_output_shape(): n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x') A = tvm.placeholder((n, m)) - B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B') + B = tvm.compute((m, n//x), lambda i, j: A[i,j] , name='B') s = tvm.create_schedule(B.op) mod = tvm.build(s, [A, B, x]) diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 06f1801..f4401d0 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -409,7 +409,7 @@ def test_llvm_div(): """Check that the semantics of div and mod is the same as in C/C++""" def check_div(start, end, divisor, dtype): T = tvm.compute((end - start,), - lambda i: tvm.expr.Cast(dtype, (start + i)) / tvm.const(divisor, dtype)) + lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype))) s = tvm.create_schedule([T.op]) f = tvm.build(s, [T], "llvm") a = tvm.nd.empty((end - start,), dtype) @@ -418,8 +418,9 @@ def test_llvm_div(): tvm.testing.assert_allclose(a.asnumpy(), ref) def check_mod(start, end, divisor, dtype): + tmod = tvm.truncmod T = tvm.compute((end - start,), - lambda i: tvm.expr.Cast(dtype, (start + i)) % tvm.const(divisor, dtype)) + lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype))) s = tvm.create_schedule([T.op]) f = tvm.build(s, [T], "llvm") a = tvm.nd.empty((end - start,), dtype) @@ -443,7 +444,7 @@ def test_llvm_div(): def test_llvm_fp_math(): def check_llvm_reciprocal(n): A = tvm.placeholder((n,), name='A') - B = tvm.compute((n,), lambda i: 1.0/(1e+37*A[i]), name='B') + B = tvm.compute((n,), lambda i: tvm.div(1.0,(1e+37*A[i])), name='B') s = tvm.create_schedule(B.op) f = tvm.build(s, [A, B], "llvm") diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index 322c37e..ef58174 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -41,8 +41,9 @@ def test_if(): ib = tvm.ir_builder.create() n = tvm.var("n") A = ib.pointer("float32", name="A") + tmod = tvm.truncmod with ib.for_range(0, n, name="i") as i: - with ib.if_scope((i % 2) == 0): + with ib.if_scope(tmod(i, 2) == 0): A[i] = A[i] + 1 with ib.else_scope(): A[0] = A[i] + 2 @@ -108,13 +109,14 @@ def test_gpu(): dtype = "float32" A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') + fld = tvm.floordiv def test_device_ir(A, B, C): n = A.shape[0] max_threads = 32 ib = tvm.ir_builder.create() bx = tvm.thread_axis("blockIdx.x") tx = tvm.thread_axis("threadIdx.x") - ib.scope_attr(bx, "thread_extent", (n+max_threads-1) // max_threads) + ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads)) ib.scope_attr(tx, "thread_extent", max_threads) idx = bx.var * max_threads + tx.var Aptr = ib.buffer_ptr(A) diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index 8c7e13a..9ad8b62 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod(): def assert_simplified_equal(index_simplified, index_direct): assert tvm.ir_pass.Equal(index_simplified, index_direct),\ "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct) + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod # Test Case1 - index_simplified = A_stride.vload(((k0 % k1) / s, (k0 % k1) % s + (k0 / k1) * k1)) + index_simplified = A_stride.vload( + (idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1)) index_direct = A_stride.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) + # Test Case2 - index_simplified = A.vload(((k0 % (k1 / s)) / n, - (k0 % (k1 / s)) % n + (k0 % k1))) - index_direct = A.vload((0, k0 % k1 + k0 % (k1 / s))) + index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), + idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1))) + index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s)))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 - index_simplified = A.vload((((k0 / (k1 / s)) * (k1 / s)) / n + (k0 % (k1 / s)) / n, - ((k0 / (k1 / s)) * (k1 / s)) % n + (k0 % (k1 / s)) % n)) + index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + + idxdiv(idxmod(k0, idxdiv(k1, s)), n), + idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + + idxmod(idxmod(k0, idxdiv(k1, s)), n))) index_direct = A.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case4 (not able to simplify) - index_simplified = A.vload(((k0 % (k1 / s)) / n, - (k0 % (k1 / n)) % n + (k0 % k1))) - index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1)))) + index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), + idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))) + index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n + + (idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))) assert_simplified_equal(index_simplified, index_direct) @@ -143,14 +150,14 @@ def test_buffer_broadcast(): check() -def test_bbuffer_roadcast_expr(): +def test_buffer_broadcast_expr(): n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x') n1, m1 = tvm.var('n1'), tvm.var('m1') o0, o1 = tvm.var('o0'), tvm.var('o1') A = tvm.placeholder((m0, n0), name='A') B = tvm.placeholder((m1, n1), name='B') - C = tvm.compute((o0, o1/x), lambda i, j: A[i, j] + B[i, j], name='C') + C = tvm.compute((o0, o1//x), lambda i, j: A[i, j] + B[i, j], name='C') Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index 2d30d29..c57f4a1 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -32,10 +32,11 @@ def test_const_fold(): if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): raise ValueError("check error: %s vs %s " % (x, y)) + tmod = tvm.truncmod check(lambda x, y: x + y, 3, 4) check(lambda x, y: x * y, 3, 12) check(lambda x, y: x * y - 10, 3, 12) - check(lambda x, y: x - y % 10, 3, 12) + check(lambda x, y: x - tmod(y, 10), 3, 12) check(lambda x, y: x // y + 10, 100, 12) check(lambda x, y: x & y + 10, 112, 128) check(lambda x, y: x > y, 112, 128) @@ -47,13 +48,15 @@ def test_const_fold(): def test_const_fold2(): x = tvm.var("x") + tmod = tvm.truncmod + tdiv = tvm.truncdiv assert (x + 0).same_as(x) assert (0 + x).same_as(x) assert (x - 0).same_as(x) - assert (x % 1).value == 0 + assert tmod(x, 1).value == 0 assert (x * 1).same_as(x) assert (1 * x).same_as(x) - assert isinstance((1 / x), tvm.expr.Div) + assert isinstance(tdiv(1, x), tvm.expr.Div) def test_const_fold3(): # Test that using ints with logic operations is forbidden @@ -88,8 +91,9 @@ def test_const_fold3(): def test_const_fold4(): x1 = tvm.const(4, "int32") x2 = x1 + 5 + tdiv = tvm.truncdiv assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 - x3 = x2 / 3 + x3 = tdiv(x2, 3) assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 x4 = x3 + 0.55 assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6 diff --git a/tests/python/unittest/test_lang_tensor_overload_op.py b/tests/python/unittest/test_lang_tensor_overload_op.py index 7571df5..3cb6e65 100644 --- a/tests/python/unittest/test_lang_tensor_overload_op.py +++ b/tests/python/unittest/test_lang_tensor_overload_op.py @@ -72,7 +72,7 @@ def test_combination(): A = tvm.placeholder((n, m), name='A') B = tvm.placeholder((n, m), name='B') C = tvm.placeholder((n, m), name='C') - D = k + A - B * C / x + D = k + A - B * C + x s = tvm.create_schedule(D.op) foo = tvm.build(s, [x, A, B, C, D], "llvm") ctx = tvm.cpu(0) @@ -82,7 +82,7 @@ def test_combination(): c = tvm.nd.array(np.random.uniform(size=(n, m)).astype(C.dtype), ctx) d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) foo(x, a, b, c, d) - tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() / x) + tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() + x) def verify_tensor_scalar_bop(shape, typ="add"): diff --git a/tests/python/unittest/test_pass_basic.py b/tests/python/unittest/test_pass_basic.py index b05d75a..8f35611 100644 --- a/tests/python/unittest/test_pass_basic.py +++ b/tests/python/unittest/test_pass_basic.py @@ -17,13 +17,15 @@ import tvm def test_simplify(): + tdiv = tvm.truncdiv + tmod = tvm.truncmod x = tvm.var('x') e1 = tvm.ir_pass.Simplify(x + 2 + 1) assert(tvm.ir_pass.Equal(e1, x + 3)) e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x) assert(tvm.ir_pass.Equal(e2, x * 8)) - e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) - assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) + e3 = tvm.ir_pass.Simplify(x - tdiv(x, 3) * 3) + assert(tvm.ir_pass.Equal(e3, tmod(x, 3))) def test_verify_ssa(): diff --git a/tests/python/unittest/test_pass_equal.py b/tests/python/unittest/test_pass_equal.py index 7b9fd7f..8bd491b 100644 --- a/tests/python/unittest/test_pass_equal.py +++ b/tests/python/unittest/test_pass_equal.py @@ -24,7 +24,7 @@ def test_equal_expr(): return x + y + 1 def func2(): - return tvm.exp((x + y + 1) * y / 4) + return tvm.exp(tvm.truncdiv((x + y + 1) * y, 4)) assert tvm.ir_pass.Equal(func1(), func1()) assert tvm.ir_pass.Equal(func2(), func2()) diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index bb09966..b6fcfa3 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -162,7 +162,7 @@ def test_condition(): ib = tvm.ir_builder.create() m = tvm.var('m') n = tvm.var('n') - with ib.for_range(0, ((n+3)/4), 'i') as i: + with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i: with ib.for_range(0, 4, 'j') as j: ib.emit(tvm.make.Evaluate( tvm.make.Select(ib.likely(i*4+j 0): diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index f7e5f94..03fa999 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score tid = bx * max_threads + tx with ib.if_scope(tid < batch_size * num_anchors): - i = tid / num_anchors + i = tid // num_anchors j = tid % num_anchors with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(tid == 0): diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index 2382127..45882b7 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt tvm.sum(input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='V') + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod + # batch gemm ci = tvm.reduce_axis((0, CI), name='c') M = tvm.compute((alpha, alpha, CO, P_round), lambda eps, nu, co, p: - tvm.sum(U[eps][nu][co // bna][ci][co % bna] * - V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M') + tvm.sum(U[eps][nu][idxdiv(co, bna)][ci][idxmod(co, bna)] * + V[eps][nu][idxdiv(p, bnb)][ci][idxmod(p, bnb)], axis=ci), name='M') r_a = tvm.reduce_axis((0, alpha), 'r_a') r_b = tvm.reduce_axis((0, alpha), 'r_b') @@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt # unpack output output = tvm.compute((N, CO, H, W), lambda n, co, h, w: - Y[co][n * nH * nW + (h//m) * nW + w//m][h % m][w % m] + Y[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + idxmod(h, m), idxmod(w, m)] # The following hack term is used to make the padding in batch gemm ("M") # effective, otherwise the padding will be eliminated by bound inference. # Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding. diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index c04ff01..2faabf2 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, axis=[ci, dh, dw, b1, b2]) conv = tvm.compute(ovshape, _conv, name='conv_out') + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod - return tvm.compute(oshape, lambda n, co, h, w: - conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], - name='conv_vec', tag='spatial_bitserial_conv_nchw') + return tvm.compute( + oshape, lambda n, co, h, w: + conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv( + w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], + name='conv_vec', tag='spatial_bitserial_conv_nchw') @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct') def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, @@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, conv = tvm.compute(ovshape, _conv, name='conv') - return tvm.compute(oshape, lambda n, h, w, co: - conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], - name='output_unpack', tag='spatial_bitserial_conv_nhwc') + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod + return tvm.compute( + oshape, lambda n, h, w, co: + conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv( + co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], + name='output_unpack', tag='spatial_bitserial_conv_nhwc') @tvm.target.generic_func def bitserial_conv2d_legalize(attrs, inputs, types): diff --git a/topi/python/topi/nn/bitserial_dense.py b/topi/python/topi/nn/bitserial_dense.py index b28b3a4..d77a1b7 100644 --- a/topi/python/topi/nn/bitserial_dense.py +++ b/topi/python/topi/nn/bitserial_dense.py @@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k: weight_packed[xo*VX+vx][wb][k], name='weight_vec') + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod + matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum( - (tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) - - tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype) + (tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) - + tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) + ).astype(out_dtype) << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') matmul = tvm.compute(oshape, lambda i, j: tvm.sum( - tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype) + tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k] + ).astype(out_dtype) << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') # binary ops diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 590600a..904dd54 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod + return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[n, - ic // ic_bn, + idxdiv(ic, ic_bn), oh * HSTR + kh * dilation_h, ow * WSTR + kw * dilation_w, - ic % ic_bn].astype(out_dtype) + idxmod(ic, ic_bn)].astype(out_dtype) * kernel[oc_chunk, - ic // ic_bn, + idxdiv(ic, ic_bn), kh, kw, - ic % ic_bn, + idxmod(ic, ic_bn), oc_block], axis=[ic, kh, kw]), name='conv2d_NCHWc', tag="conv2d_NCHWc") diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py index e703bec..f50e357 100644 --- a/topi/python/topi/nn/depthwise_conv2d.py +++ b/topi/python/topi/nn/depthwise_conv2d.py @@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No pad_after = [0, 0, pad_down, pad_right] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") # depthconv stage + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod di = tvm.reduce_axis((0, filter_height), name='di') dj = tvm.reduce_axis((0, filter_width), name='dj') Output = tvm.compute( (batch, out_channel, out_height, out_width), lambda b, c, i, j: tvm.sum( - (PaddedInput[b, c/channel_multiplier, i*stride_h+di*dilation_h, + (PaddedInput[b, idxdiv(c, channel_multiplier), i*stride_h+di*dilation_h, j*stride_w+dj*dilation_w].astype(out_dtype) * - Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)), + Filter[idxdiv(c, channel_multiplier), + idxmod(c, channel_multiplier), di, dj].astype(out_dtype)), axis=[di, dj]), name='DepthwiseConv2d', tag="depthwise_conv2d_nchw") return Output @@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No pad_after = [0, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") # depthconv stage + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod + di = tvm.reduce_axis((0, filter_height), name='di') dj = tvm.reduce_axis((0, filter_width), name='dj') Output = tvm.compute( (batch, out_height, out_width, out_channel), lambda b, i, j, c: tvm.sum( (PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w, - c/channel_multiplier].astype(out_dtype) * - Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)), + idxdiv(c, channel_multiplier)].astype(out_dtype) * + Filter[di, dj, + idxdiv(c, channel_multiplier), + idxmod(c, channel_multiplier)].astype(out_dtype)), axis=[di, dj]), name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc") return Output @@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid dh = tvm.reduce_axis((0, Out_grad.shape[1].value), name='dh') dw = tvm.reduce_axis((0, Out_grad.shape[2].value), name='dw') db = tvm.reduce_axis((0, batch), name='db') + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod Weight_grad = tvm.compute( (filter_h, filter_w, in_c, channel_multiplier), lambda fh, fw, c, m: tvm.sum( - Out_grad[db, dh, dw, c*channel_multiplier+m%channel_multiplier] * + Out_grad[db, dh, dw, c*channel_multiplier+idxmod(m, channel_multiplier)] * padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]), tag='depthwise_conv2d_backward_weight_nhwc') diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py index 24bf0d3..d952453 100644 --- a/topi/python/topi/nn/dilate.py +++ b/topi/python/topi/nn/dilate.py @@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"): def _dilate(*indices): not_zero = [] index_tuple = [] + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod for i in range(n): if not util.equal_const_int(strides[i], 1): - index_tuple.append(indices[i] / strides[i]) - not_zero.append((indices[i] % strides[i]).equal(0)) + index_tuple.append(idxdiv(indices[i], strides[i])) + not_zero.append(idxmod(indices[i], strides[i]).equal(0)) else: index_tuple.append(indices[i]) if not_zero: diff --git a/topi/python/topi/nn/flatten.py b/topi/python/topi/nn/flatten.py index eb07f5e..dba9b7c 100644 --- a/topi/python/topi/nn/flatten.py +++ b/topi/python/topi/nn/flatten.py @@ -38,12 +38,14 @@ def flatten(data): for i in range(1, len(ishape)): dim = dim * ishape[i] oshape = [ishape[0], dim] + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod def unwrap(idx, shape): index = [] for s in reversed(shape): - index.append(idx % s) - idx = idx / s + index.append(idxmod(idx, s)) + idx = idxdiv(idx, s) return list(reversed(index)) return tvm.compute(oshape, lambda i, j: data(i, *unwrap(j, ishape[1:]))) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 6565de2..e7c57e0 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout ic = tvm.reduce_axis((0, in_channel), name='ic') kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') + idxmod = tvm.indexmod + idxdiv = tvm.indexdiv conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn, + tvm.sum(data_vec[n, idxdiv(ic, ic_bn), oh*HSTR+kh*dilation_h, + idxmod(ic, ic_bn), ow*WSTR+kw*dilation_w].astype(out_dtype) * - kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, + kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kh, kw, + idxmod(ic, ic_bn), oc_block].astype(out_dtype), axis=[ic, kh, kw]), name='conv') unpack = tvm.compute(unpack_shape, - lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn] + lambda n, c, h, w: conv[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)] .astype(out_dtype), name='output_unpack', tag='conv2d_nchw') @@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): cfg = get_config() _create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout) + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod # change shape with the value in config ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], cfg["tile_ow"].size[-1]) - new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn, + new_data_shape = (raw_data_shape[0], idxdiv(raw_data_shape[1], ic_bn), raw_data_shape[2], raw_data_shape[3], ic_bn) data_layout = "NCHW%dc" % ic_bn out_layout = "NCHW%dc" % oc_bn - new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, + new_kernel_shape = (idxdiv(raw_kernel_shape[0], oc_bn), + idxdiv(raw_kernel_shape[1], ic_bn), raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) new_data = tvm.placeholder(new_data_shape, data.dtype) new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) @@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg): _, data, kernel, strides, padding, dilation, layout, dtype = workload batch_size, in_channel, in_height, in_width = data[:-1] out_channel, _, k_height, k_width = kernel[:-1] - out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 - out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 + idxdiv = tvm.indexdiv + + out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1 + out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1 tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) + in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic) in_layout = "NCHW%dc" % tile_ic - out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc) + out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc) out_layout = "NCHW%dc" % tile_oc return ((in_shape, in_layout),), ((out_shape, out_layout),) diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index 9b7624e..ec401bf 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): packw = tvm.compute(packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight") + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod k = tvm.reduce_axis((0, K), name="k") C = tvm.compute((M, N), lambda y, x: tvm.sum( data[y, k].astype(out_dtype) * - packw[x // packw_bn, k, x % packw_bn].astype(out_dtype), + packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), axis=k), tag="dense_pack") if bias is not None: diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index d713a54..8af41da 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation, data_pad = data # depthconv stage + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod + kh = tvm.reduce_axis((0, filter_height), name='kh') kw = tvm.reduce_axis((0, filter_width), name='kw') Output = tvm.compute( (batch, out_channel_chunk, out_height, out_width, out_channel_block), lambda b, oco, oh, ow, oci: tvm.sum( - (data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block, - oh*HSTR+kh, ow*WSTR+kw, - ((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block] + (data_pad[ + b, + idxdiv(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block), + oh*HSTR+kh, ow*WSTR+kw, + idxmod(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block)] .astype(out_dtype) * kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)), axis=[kh, kw]),