* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
} \
}
-#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
- Value IRBuilder::_OpName(Value a, Value b) { \
- CHECK_EQ(a.stype.id, b.stype.id); \
- if (a.stype.type.is_int()) { \
- return MakeValue(spv::OpS ## _Op, a.stype, a, b); \
- } else if (a.stype.type.is_uint()) { \
- return MakeValue(spv::OpU ## _Op, a.stype, a, b); \
- } else { \
- CHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
- } \
+#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ CHECK_EQ(a.stype.id, b.stype.id); \
+ if (a.stype.type.is_int()) { \
+ return MakeValue(spv::OpS##_Op, a.stype, a, b); \
+ } else if (a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpU##_Op, a.stype, a, b); \
+ } else { \
+ CHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpF##_Op, a.stype, a, b); \
+ } \
}
DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add);
}
}
-
-#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
- Value IRBuilder:: _OpName(Value a, Value b) { \
- CHECK_EQ(a.stype.id, b.stype.id); \
- if (t_bool_.id == 0) { \
- t_bool_ = DeclareType(UInt(1)); \
- } \
- if (a.stype.type.is_int()) { \
- return MakeValue(spv::OpS ## _Op, t_bool_, a, b); \
- } else if (a.stype.type.is_uint()) { \
- return MakeValue(spv::OpU ## _Op, t_bool_, a, b); \
- } else { \
- CHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
- } \
+#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ CHECK_EQ(a.stype.id, b.stype.id); \
+ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
+ const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
+ if (a.stype.type.is_int()) { \
+ return MakeValue(spv::OpS##_Op, bool_type, a, b); \
+ } else if (a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpU##_Op, bool_type, a, b); \
+ } else { \
+ CHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
+ } \
}
DEFINE_BUILDER_CMP_OP(LT, LessThan);
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
-#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
- Value IRBuilder:: _OpName(Value a, Value b) { \
- CHECK_EQ(a.stype.id, b.stype.id); \
- if (t_bool_.id == 0) { \
- t_bool_ = DeclareType(UInt(1)); \
- } \
- if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
- return MakeValue(spv::OpI ## _Op, t_bool_, a, b); \
- } else { \
- CHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
- } \
+#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ CHECK_EQ(a.stype.id, b.stype.id); \
+ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
+ const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
+ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpI##_Op, bool_type, a, b); \
+ } else { \
+ CHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
+ } \
}
DEFINE_BUILDER_CMP_UOP(EQ, Equal);
Value IRBuilder::Select(Value cond, Value a, Value b) {
CHECK_EQ(a.stype.id, b.stype.id);
- CHECK_EQ(cond.stype.type, UInt(1));
+ CHECK_EQ(cond.stype.type.element_of(), UInt(1));
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
}
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import re
+
+
+def test_vector_comparison():
+ if not tvm.module.enabled("vulkan"):
+ print("Skipping due to no Vulkan module")
+ return
+
+ target = 'vulkan'
+
+ def check_correct_assembly(dtype):
+ n = (1024,)
+ A = tvm.placeholder(n, dtype=dtype, name='A')
+ B = tvm.compute(
+ A.shape,
+ lambda i: tvm.expr.Select(
+ A[i] >= 0, A[i] + tvm.const(1, dtype),
+ tvm.const(0, dtype)), name='B')
+ s = tvm.create_schedule(B.op)
+
+ (bx, tx) = s[B].split(s[B].op.axis[0], factor=128)
+ (tx, vx) = s[B].split(tx, factor=4)
+ s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
+ s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
+ s[B].vectorize(vx)
+ f = tvm.build(s, [A, B], target)
+
+ # Verify we generate the boolx4 type declaration and the OpSelect
+ # v4{float,half,int} instruction
+ assembly = f.imported_modules[0].get_source()
+ matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly)
+ assert len(matches) == 1
+ matches = re.findall("OpSelect %v4.*", assembly)
+ assert len(matches) == 1
+ check_correct_assembly('float32')
+ check_correct_assembly('int32')
+ check_correct_assembly('float16')
+
+
+if __name__ == "__main__":
+ test_vector_comparison()