Handle vectorize for LE statement (#3137)
authorWei Chen <ipondering.weic@gmail.com>
Wed, 8 May 2019 03:52:24 +0000 (20:52 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 8 May 2019 03:52:24 +0000 (23:52 -0400)
* Handle vectorize for LE statement

Fix a new cases introduced by commit 7afbca5691fdb599cd90b043d5a5036e55cae2d6

* Add test

src/pass/vectorize_loop.cc
tests/python/unittest/test_pass_vectorize.py

index bd0a91c..f87e80c 100644 (file)
@@ -166,6 +166,9 @@ class Vectorizer : public IRMutator {
   Expr Mutate_(const LT* op, const Expr &e) final {
     return BinaryVec(op, e);
   }
+  Expr Mutate_(const LE* op, const Expr &e) final {
+    return BinaryVec(op, e);
+  }
   Expr Mutate_(const GT* op, const Expr &e) final {
     return BinaryVec(op, e);
   }
index 0351687..fca22a1 100644 (file)
@@ -69,6 +69,28 @@ def test_vectorize_with_if():
     assert stmt.then_case.value.dtype == "float32x4"
     assert isinstance(stmt.else_case, tvm.stmt.For)
 
+def test_vectorize_with_le_cond():
+    n = tvm.var('n')
+    ib = tvm.ir_builder.create()
+    A = ib.pointer("float32", name="A")
+    with ib.for_range(0, 4, for_type="vectorize") as i:
+        with ib.if_scope(i <= n):
+            A[i] = A[i] + 1
+    stmt = ib.get()
+    stmt = tvm.ir_pass.VectorizeLoop(stmt)
+    assert isinstance(stmt, tvm.stmt.For)
+
+def test_vectorize_with_ge_cond():
+    n = tvm.var('n')
+    ib = tvm.ir_builder.create()
+    A = ib.pointer("float32", name="A")
+    with ib.for_range(0, 4, for_type="vectorize") as i:
+        with ib.if_scope(i >= n):
+            A[i] = A[i] + 1
+    stmt = ib.get()
+    stmt = tvm.ir_pass.VectorizeLoop(stmt)
+    assert isinstance(stmt, tvm.stmt.For)
+
 def test_vectorize_if_then_else():
     n = tvm.var('n')
     x = tvm.var('x')
@@ -102,3 +124,5 @@ if __name__ == "__main__":
     test_vectorize_with_if()
     test_vectorize_loop()
     test_vectorize_if_then_else()
+    test_vectorize_with_le_cond()
+    test_vectorize_with_ge_cond()