codegen_spirv support Call::reinterpret (#3795)
authorAndrew Tulloch <andrew@tullo.ch>
Fri, 30 Aug 2019 00:25:07 +0000 (17:25 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 30 Aug 2019 00:25:07 +0000 (08:25 +0800)
src/codegen/spirv/codegen_spirv.cc
src/codegen/spirv/ir_builder.cc
tests/python/unittest/test_codegen_vulkan.py [new file with mode: 0644]

index 7686250..7caf3a2 100644 (file)
@@ -283,6 +283,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
     } else {
       return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
     }
+  } else if (op->is_intrinsic(Call::reinterpret)) {
+    return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->type),
+                               MakeValue(op->args[0]));
   } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
     return this->CreateStorageSync(op);
   } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
index d6ba9e4..6afd311 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -522,17 +522,17 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
     }                                                             \
   }
 
-#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op)               \
-  Value IRBuilder::_OpName(Value a, Value b) {                    \
-    CHECK_EQ(a.stype.id, b.stype.id);                             \
-    if (a.stype.type.is_int()) {                                   \
-      return MakeValue(spv::OpS ## _Op, a.stype, a, b);            \
-    } else if (a.stype.type.is_uint()) {                           \
-      return MakeValue(spv::OpU ## _Op, a.stype, a, b);            \
-    } else {                                                       \
-      CHECK(a.stype.type.is_float());                              \
-      return MakeValue(spv::OpF ## _Op, a.stype, a, b);            \
-    }                                                              \
+#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op)   \
+  Value IRBuilder::_OpName(Value a, Value b) {        \
+    CHECK_EQ(a.stype.id, b.stype.id);                 \
+    if (a.stype.type.is_int()) {                      \
+      return MakeValue(spv::OpS##_Op, a.stype, a, b); \
+    } else if (a.stype.type.is_uint()) {              \
+      return MakeValue(spv::OpU##_Op, a.stype, a, b); \
+    } else {                                          \
+      CHECK(a.stype.type.is_float());                 \
+      return MakeValue(spv::OpF##_Op, a.stype, a, b); \
+    }                                                 \
   }
 
 DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add);
@@ -552,21 +552,19 @@ Value IRBuilder::Mod(Value a, Value b) {
   }
 }
 
-
-#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                        \
-  Value IRBuilder:: _OpName(Value a, Value b) {                    \
-    CHECK_EQ(a.stype.id, b.stype.id);                              \
-    if (t_bool_.id == 0) {                                         \
-      t_bool_ = DeclareType(UInt(1));                              \
-    }                                                              \
-    if (a.stype.type.is_int()) {                                   \
-      return MakeValue(spv::OpS ## _Op, t_bool_, a, b);            \
-    } else if (a.stype.type.is_uint()) {                           \
-      return MakeValue(spv::OpU ## _Op, t_bool_, a, b);            \
-    } else {                                                       \
-      CHECK(a.stype.type.is_float());                              \
-      return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b);         \
-    }                                                              \
+#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                                           \
+  Value IRBuilder::_OpName(Value a, Value b) {                                        \
+    CHECK_EQ(a.stype.id, b.stype.id);                                                 \
+    CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                             \
+    const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int()) {                                                      \
+      return MakeValue(spv::OpS##_Op, bool_type, a, b);                               \
+    } else if (a.stype.type.is_uint()) {                                              \
+      return MakeValue(spv::OpU##_Op, bool_type, a, b);                               \
+    } else {                                                                          \
+      CHECK(a.stype.type.is_float());                                                 \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                            \
+    }                                                                                 \
   }
 
 DEFINE_BUILDER_CMP_OP(LT, LessThan);
@@ -574,18 +572,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
 DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
 DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
 
-#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                       \
-  Value IRBuilder:: _OpName(Value a, Value b) {                    \
-    CHECK_EQ(a.stype.id, b.stype.id);                              \
-    if (t_bool_.id == 0) {                                         \
-      t_bool_ = DeclareType(UInt(1));                              \
-    }                                                              \
-    if (a.stype.type.is_int() || a.stype.type.is_uint()) {         \
-      return MakeValue(spv::OpI ## _Op, t_bool_, a, b);            \
-    } else {                                                       \
-      CHECK(a.stype.type.is_float());                              \
-      return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b);         \
-    }                                                              \
+#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                                          \
+  Value IRBuilder::_OpName(Value a, Value b) {                                        \
+    CHECK_EQ(a.stype.id, b.stype.id);                                                 \
+    CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                             \
+    const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int() || a.stype.type.is_uint()) {                            \
+      return MakeValue(spv::OpI##_Op, bool_type, a, b);                               \
+    } else {                                                                          \
+      CHECK(a.stype.type.is_float());                                                 \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                            \
+    }                                                                                 \
   }
 
 DEFINE_BUILDER_CMP_UOP(EQ, Equal);
@@ -593,7 +590,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
 
 Value IRBuilder::Select(Value cond, Value a, Value b) {
   CHECK_EQ(a.stype.id, b.stype.id);
-  CHECK_EQ(cond.stype.type, UInt(1));
+  CHECK_EQ(cond.stype.type.element_of(), UInt(1));
   return MakeValue(spv::OpSelect, a.stype, cond, a, b);
 }
 
diff --git a/tests/python/unittest/test_codegen_vulkan.py b/tests/python/unittest/test_codegen_vulkan.py
new file mode 100644 (file)
index 0000000..2d7edff
--- /dev/null
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import re
+
+
+def test_vector_comparison():
+    if not tvm.module.enabled("vulkan"):
+        print("Skipping due to no Vulkan module")
+        return
+
+    target = 'vulkan'
+
+    def check_correct_assembly(dtype):
+        n = (1024,)
+        A = tvm.placeholder(n, dtype=dtype, name='A')
+        B = tvm.compute(
+            A.shape,
+            lambda i: tvm.expr.Select(
+                A[i] >= 0, A[i] + tvm.const(1, dtype),
+                tvm.const(0, dtype)), name='B')
+        s = tvm.create_schedule(B.op)
+
+        (bx, tx) = s[B].split(s[B].op.axis[0], factor=128)
+        (tx, vx) = s[B].split(tx, factor=4)
+        s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
+        s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
+        s[B].vectorize(vx)
+        f = tvm.build(s, [A, B], target)
+
+        # Verify we generate the boolx4 type declaration and the OpSelect
+        # v4{float,half,int} instruction
+        assembly = f.imported_modules[0].get_source()
+        matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly)
+        assert len(matches) == 1
+        matches = re.findall("OpSelect %v4.*", assembly)
+        assert len(matches) == 1
+    check_correct_assembly('float32')
+    check_correct_assembly('int32')
+    check_correct_assembly('float16')
+
+
+if __name__ == "__main__":
+    test_vector_comparison()