*/
#ifdef TVM_LLVM_VERSION
// Part of the code are adapted from Halide's CodeGen_LLVM
-
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
+
+#include <algorithm>
+
#include "codegen_llvm.h"
#include "codegen_cpu.h"
#include "../../pass/ir_util.h"
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (extent == num_elems && begin == 0) return vec;
- CHECK_LE(begin + extent, num_elems);
- std::vector<unsigned> indices;
+ std::vector<llvm::Constant*> indices;
+ indices.reserve(extent);
for (int i = 0; i < extent; ++i) {
- indices.push_back(begin + i);
+ if (begin + i >= 0 && begin + i < num_elems) {
+ indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
+ } else {
+ indices.push_back(llvm::UndefValue::get(t_int32_));
+ }
}
- return builder_->CreateShuffleVector(vec, vec, indices);
+ return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
v->getType()->getVectorNumElements());
}
while (vecs.size() > 1) {
- for (size_t i = 0; i < vecs.size(); i+=2) {
- if (i + 1 >= vecs.size()) {
- vecs[i / 2] = vecs[i]; continue;
- }
+ std::vector<llvm::Value*> new_vecs;
+ for (size_t i = 0; i < vecs.size() - 1; i += 2) {
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
- int lanes = static_cast<int>(std::max(
- lhs->getType()->getVectorNumElements(),
- rhs->getType()->getVectorNumElements()));
- lhs = CreateVecPad(lhs, lanes);
- rhs = CreateVecPad(lhs, lanes);
+ const size_t lhs_lanes = lhs->getType()->getVectorNumElements();
+ const size_t rhs_lanes = rhs->getType()->getVectorNumElements();
+ if (lhs_lanes < rhs_lanes) {
+ lhs = CreateVecPad(lhs, rhs_lanes);
+ } else if (rhs_lanes < lhs_lanes) {
+ rhs = CreateVecPad(rhs, lhs_lanes);
+ }
+ const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
std::vector<unsigned> mask;
- for (int i = 0; i < lanes * 2; ++i) {
+ for (size_t i = 0; i < lhs_lanes; ++i) {
mask.push_back(i);
}
- vecs[i / 2] = builder_->CreateShuffleVector(lhs, rhs, mask);
+ for (size_t i = 0; i < rhs_lanes; ++i) {
+ mask.push_back(shared_lanes + i);
+ }
+ new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
+ }
+ if (vecs.size() % 2 != 0) {
+ new_vecs.push_back(vecs.back());
}
- vecs.resize((vecs.size() + 1) / 2);
+ vecs.swap(new_vecs);
}
return CreateVecSlice(vecs[0], 0, total_lanes);
}
--- /dev/null
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file codegen_x86_64.cc
+ * \brief X86-64 specific code generator
+ */
+#ifdef TVM_LLVM_VERSION
+#include "codegen_cpu.h"
+
+#include "llvm/MC/MCSubtargetInfo.h"
+
+namespace tvm {
+namespace codegen {
+
+namespace {
+bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) {
+ // MCSubTargetInfo::checkFeatures was added in LLVM 6.0
+#if TVM_LLVM_VERSION >= 60
+ const auto* MCInfo = tm.getMCSubtargetInfo();
+ return MCInfo->checkFeatures(std::string("+") + feature);
+#else
+ return false;
+ // TODO(tulloch) - enable this block, need to figure out how to reimplement
+ // this given visibility constraints, similar to
+ // https://github.com/rust-lang/rust/pull/31709
+
+ // Copied from
+ // https://github.com/llvm-mirror/llvm/blob/5136df4/lib/MC/MCSubtargetInfo.cpp#L78-L88.
+
+ // auto checkFeatures = [&](const std::string FS) {
+ // llvm::SubtargetFeatures T(FS);
+ // llvm::FeatureBitset Set, All;
+ // for (std::string F : T.getFeatures()) {
+ // llvm::SubtargetFeatures::ApplyFeatureFlag(Set, F, MCInfo->ProcFeatures);
+ // if (F[0] == '-') {
+ // F[0] = '+';
+ // }
+ // llvm::SubtargetFeatures::ApplyFeatureFlag(All, F, MCInfo->ProcFeatures);
+ // }
+ // return (MCInfo->getFeatureBits() & All) == Set;
+ // };
+ // return checkFeatures(MCInfo, std::string("+") + feature);
+#endif
+}
+} // namespace
+
+class CodeGenX86_64 final : public CodeGenCPU {
+ public:
+ llvm::Value* VisitExpr_(const Cast* op) override;
+
+ private:
+ llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
+ const std::vector<llvm::Value*>& args);
+};
+
+llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
+ // LLVM does not automatically generate the correct instruction sequences for
+ // half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
+ // vcvtph2ps), so we explicitly generate them ourselves.
+ const auto from = op->value.type();
+ const auto to = op->type;
+ if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
+ CHECK_EQ(from.lanes(), to.lanes());
+ CHECK_NOTNULL(target_machine_);
+
+ const auto has_f16c = TargetHasFeature(*target_machine_, "f16c");
+ const auto has_avx512 = TargetHasFeature(*target_machine_, "avx512f");
+
+ if (from.lanes() >= 16 && has_avx512) {
+ return CallVectorIntrin(
+ ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, LLVMType(Float(32, from.lanes())),
+ {
+ MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
+ ir::Call::PureIntrinsic)),
+ MakeValue(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), from.lanes())),
+ /*mask=*/MakeValue(ir::IntImm::make(Int(16), -1)),
+ /*rounding-mode=*/MakeValue(ir::IntImm::make(Int(32), 4)),
+ });
+ }
+
+ if (from.lanes() >= 8 && has_f16c) {
+ return CallVectorIntrin(
+ ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(Float(32, from.lanes())),
+ {MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
+ ir::Call::PureIntrinsic))});
+ }
+ }
+
+ return CodeGenCPU::VisitExpr_(op);
+}
+
+llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes,
+ llvm::Type* result_ty,
+
+ const std::vector<llvm::Value*>& args) {
+ llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
+ if (intrin_lanes == result_ty->getVectorNumElements()) {
+ return builder_->CreateCall(f, args);
+ }
+
+ // Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary),
+ // compute each result, and then concatenate the vectors (slicing the result if necessary).
+ CHECK_LT(intrin_lanes, result_ty->getVectorNumElements());
+ std::vector<llvm::Value*> split_results;
+ for (size_t i = 0;
+ i < static_cast<size_t>(result_ty->getVectorNumElements());
+ i += intrin_lanes) {
+ std::vector<llvm::Value*> split_args;
+ for (const auto& v : args) {
+ if (v->getType()->isVectorTy()) {
+ CHECK_EQ(v->getType()->getVectorNumElements(), result_ty->getVectorNumElements());
+ split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
+ } else {
+ split_args.push_back(v);
+ }
+ }
+ split_results.push_back(CallVectorIntrin(
+ id, intrin_lanes, llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes),
+ split_args));
+ }
+ return CreateVecSlice(CreateVecConcat(split_results), 0, result_ty->getVectorNumElements());
+}
+
+TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
+.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+ CodeGenLLVM* cg = new CodeGenX86_64();
+ *rv = static_cast<void*>(cg);
+ });
+
+} // namespace codegen
+} // namespace tvm
+#endif // TVM_LLVM_VERSION
--- /dev/null
+import tvm
+import re
+
+
+def test_fp16_to_fp32():
+ if tvm.codegen.llvm_version_major() < 6:
+ print("Skipping due to LLVM version being {} < 6".format(
+ tvm.codegen.llvm_version_major()))
+ return
+
+ def fp16_to_fp32(target, width, match=None, not_match=None):
+ elements = 64
+ n = tvm.convert(elements)
+ A = tvm.placeholder((n, width), dtype="float16", name='A')
+ B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
+ s = tvm.create_schedule(B.op)
+ s[B].vectorize(s[B].op.axis[1])
+ f = tvm.build(s, [A, B], target)
+
+ assembly = f.get_source('asm').splitlines()
+ if match:
+ matches = [l for l in assembly if re.search(match, l)]
+ assert matches
+ if not_match:
+ not_matches = [l for l in assembly if re.search(not_match, l)]
+ assert not not_matches
+
+
+ fp16_to_fp32(
+ 'llvm -mcpu=skylake-avx512', 15,
+ match="vcvtph2ps.*ymm", not_match="vcvtph2ps.*zmm")
+ fp16_to_fp32(
+ 'llvm -mcpu=skylake-avx512', 16,
+ match="vcvtph2ps.*zmm")
+ fp16_to_fp32(
+ 'llvm -mcpu=skylake-avx512', 17,
+ match="vcvtph2ps.*zmm")
+ fp16_to_fp32(
+ 'llvm -mcpu=skylake-avx512', 49,
+ match="vcvtph2ps.*zmm")
+ fp16_to_fp32(
+ 'llvm -mcpu=skylake-avx512 -mattr=-avx512f', 49,
+ match="vcvtph2ps.*ymm",
+ not_match="vcvtph2ps.*zmm")
+ fp16_to_fp32(
+ 'llvm -mcpu=skylake-avx512 -mattr=-f16c,-avx512f', 49,
+ not_match="vcvtph2ps")
+ fp16_to_fp32(
+ 'llvm -mcpu=core-avx2', 8,
+ match="vcvtph2ps.*ymm")
+ fp16_to_fp32(
+ 'llvm -mcpu=core-avx2', 9,
+ match="vcvtph2ps.*ymm")
+ fp16_to_fp32(
+ 'llvm', 9,
+ not_match="vcvtph2ps")
+
+
+if __name__ == "__main__":
+ test_fp16_to_fp32()