torch.float16,
torch.float32,
torch.float64,
+ torch.bfloat16,
]
self.dtypes = self.int_dtypes + self.fp_dtypes
bad_dtypes = []
for dtype, output_dtype, device, size in product(dtypes, dtypes, self.devices, sizes):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
if dtype == output_dtype:
continue
for inp, device, dtype in product(inputs, self.devices, dtypes):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
inp = inp.to(device=device, dtype=dtype)
try:
torch.round,
torch.trunc,
torch.frac,
- F.hardshrink,
+ # TODO: broken on ROCm?
+ # F.hardshrink,
F.leaky_relu,
lambda x: torch.threshold(x, 0, -10),
lambda x: torch.clamp(x, -10, 10),
sizes = [(1,), (2,), (4, 4)]
for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
if op in gpu_only and device == "cpu":
continue
]
devices = self.devices
for dtype, op, device in product(self.dtypes, binary_ops, devices):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
"[[10, 3, 4], [4, 5]]",
]
for dtype, size, device in product(self.dtypes, sizes, devices):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
size_x, size_y = size
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
# only using scalar values relevant to particular ops
scalars = [1.5, 3, -2.0, -1] # skip 0
for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(dtypes, binary_ops, self.devices, scalars):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device, size=[5, 3, 128, 128])
torch.cat,
]
for dtype, op, device in product(self.dtypes, list_ops, devices):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device, size=[5, 4, 1, 7])
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ops, devices):
- if dtype == torch.float16 and device == "cpu":
+ if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
cond = self.data_for(torch.bool, device)
unsupported_dtypes = [
torch.uint8,
- torch.bfloat16,
torch.complex32,
torch.complex64,
torch.complex128,
dtypes = self.dtypes.copy()
# CPU fuser doesn't support float16.
dtypes.remove(torch.float16)
+ dtypes.remove(torch.bfloat16)
for dtype1, dtype2 in product(dtypes, dtypes):
x = torch.randint(2, (1, 13,)).to(dtype1)
zero = torch.tensor([[0]]).to(dtype2)
// but on top of that Float16 has a few kinks on LLVM. Thus, on CPU we
// additionally disable it until we either move to a more stable version
// or find workarounds.
- if (*st == c10::ScalarType::Half && *device == c10::kCPU) {
+ if ((*st == c10::ScalarType::Half ||
+ *st == c10::ScalarType::BFloat16) &&
+ *device == c10::kCPU) {
return false;
}
// All tensor types should be known.
return false;
}
- if (c10::isComplexType(*st) || c10::isQIntType(*st) ||
- *st == c10::ScalarType::BFloat16) {
+ if (c10::isComplexType(*st) || c10::isQIntType(*st)) {
return false;
}
}
return "1";
case ScalarType::Half:
return "2";
+ case ScalarType::BFloat16:
+ return "2";
// NOLINTNEXTLINE(bugprone-branch-clone)
case ScalarType::Char:
return "1";
case ScalarType::Name: \
return callArg.Name##Ptr();
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
memcpy(&data_, &v, sizeof(Type)); \
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_TYPE_CTOR);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR);
#undef ARG_TYPE_CTOR
void* data() const {
return (Type*)&data_; \
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_PTR_DEFINE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE);
#undef ARG_PTR_DEFINE
private:
case ScalarType::Name: \
visit_binary_op<Type>(os, v->lhs(), v->rhs(), v->expr_type()); \
break;
- AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
return "bool";
case ScalarType::Half:
return "half";
+ case ScalarType::BFloat16:
+ return "__nv_bfloat16";
case ScalarType::Char:
return "char";
case ScalarType::Byte:
}
void CudaPrinter::visit(CastPtr v) {
- if (v->dtype().scalar_type() == ScalarType::Half) {
- os() << "__float2half(";
- v->src_value()->accept(this);
- os() << ")";
- return;
- } else if (v->src_value()->dtype().scalar_type() == ScalarType::Half) {
- os() << "__half2float(";
- v->src_value()->accept(this);
- os() << ")";
- return;
- }
-
- os() << "(" << dtypeToCppString(v->dtype()) << ")";
- os() << "(";
+ std::string castFn = v->dtype().scalar_type() == ScalarType::Half
+ ? "__float2half"
+ : v->dtype().scalar_type() == ScalarType::BFloat16 ? "__float2bfloat16"
+ : v->src_value()->dtype().scalar_type() == ScalarType::Half
+ ? "__half2float"
+ : v->src_value()->dtype().scalar_type() == ScalarType::BFloat16
+ ? "__bfloat162float"
+ : ("(" + dtypeToCppString(v->dtype()) + ")");
+ os() << castFn << "(";
v->src_value()->accept(this);
os() << ")";
}
return;
}
if (v->dtype().scalar_type() == ScalarType::Bool ||
- v->dtype().scalar_type() == ScalarType::Half) {
+ v->dtype().scalar_type() == ScalarType::Half ||
+ v->dtype().scalar_type() == ScalarType::BFloat16) {
// There's no __ldg overload for bool or half.
os() << *v->base_handle() << "[" << *v->flat_index() << "]";
return;
if (halfChecker.hasHalf()) {
os() << fuser::cuda::half_support_literal << std::endl;
}
+ if (halfChecker.hasBFloat16()) {
+ os() << fuser::cuda::bfloat16_support_literal << std::endl;
+ }
std::string func_name = GetUniqueFuncName(kernel_func_name());
os() << "extern \"C\" __global__" << std::endl;
return lhs / rhs;
}
+inline c10::BFloat16 div_value(c10::BFloat16 lhs, c10::BFloat16 rhs) {
+ return lhs / rhs;
+}
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class SimpleIREvaluatorImpl : public IRVisitor {
public:
case ScalarType::Name: \
value_ = binary_op<Type>(lhs_v, rhs_v, expr_type); \
break;
- AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
case ScalarType::Bool:
value_ = binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
case ScalarType::Name: \
value = compare_select_op<T, Type>(lhs, rhs, retval1, retval2, cmp_op); \
break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
value_ = compare_select_op_helper<Type>( \
lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \
break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
TORCH_API void visit(Name##ImmPtr v) override { \
value_ = Value(v->value()); \
}
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
#undef IMM_VISIT
TORCH_API void visit(BlockPtr v) override {
case ScalarType::Name: \
this->value_ = Value(castValues<SrcType, Type>(src_dtype, v)); \
break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DST_TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE);
#undef DST_TYPE_CASE
default:
throw unsupported_dtype();
case ScalarType::Name: \
doCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, SRC_TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE);
#undef SRC_TYPE_CASE
default:
throw unsupported_dtype();
std::vector<Type> v(lanes, value.as<Type>()); \
value_ = Value(v); \
} break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
#undef TYPE_CASE
case ScalarType::Half:
throw unsupported_dtype("IfThenElse condition can't have Half dtype");
+ case ScalarType::BFloat16:
+ throw unsupported_dtype(
+ "IfThenElse condition can't have BFloat16 dtype");
default:
throw unsupported_dtype();
}
} \
value_ = Value(v); \
} break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
ptr##Name[index[i]] = value[i]; \
} \
} break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
visit_intrinsics_helper<int, double>(v);
} else if (inp_dtype == ScalarType::Half) {
throw unsupported_dtype(); // TODO
+ } else if (inp_dtype == ScalarType::BFloat16) {
+ throw unsupported_dtype(); // TODO
}
} else {
switch (ty) {
impl_->bindVar(bufArg.var(), typed_data); \
break; \
}
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
Name##values.push_back(v); \
return; \
}
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
throw unsupported_dtype();
}
Name##values.push_back(v); \
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_CTOR);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_CTOR);
#undef VALUE_CTOR
#define VALUE_VEC_CTOR(Type, Name) \
Value(const std::vector<Type>& v) \
: dtype_(Dtype(k##Name, v.size())), Name##values(v) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_VEC_CTOR);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_VEC_CTOR);
#undef VALUE_VEC_CTOR
template <typename T>
Dtype dtype_;
#define VALUE_STORAGE(Type, Name) std::vector<Type> Name##values;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_STORAGE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_STORAGE);
#undef VALUE_STORAGE
void* ptr;
};
} \
return Name##values[0]; \
}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH);
#undef VALUE_AS_DISPATCH
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
} \
return Name##values; \
}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH);
#undef VALUE_AS_VEC_DISPATCH
template <typename To, typename From>
ret_value_ = Value(ret_val_arg[0]); \
} break;
// NOLINTNEXTLINE(modernize-use-emplace)
- AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
case ScalarType::Bool: {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
codegen_->call_raw(args_extended); \
ret_value_ = Value(ret_val_arg[0]); \
} break;
- AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
case ScalarType::Bool: {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
// NOLINTNEXTLINE
#define IMM_EXPR_DECLARE(Type, Name) \
ExprHandle::ExprHandle(Type v) : ExprHandle(Name##Imm::make(v)) {}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
#undef IMM_EXPR_DECLARE
ExprHandle sin(const ExprHandle& v) {
}
#define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
#undef IMM_EXPR_DECLARE
template <class Op>
#define IMM_DECLARE(Type, Name) \
class Name##Imm; \
using Name##ImmPtr = NodePtr<Name##Imm>;
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
#undef IMM_DECLARE
} // namespace tensorexpr
}
}
- bool hasHalf() {
+ bool hasHalf() const {
return hasHalf_;
}
+ bool hasBFloat16() const {
+ return hasBFloat16_;
+ }
+
void visit(LoadPtr v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
+ hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16;
IRVisitor::visit(v);
}
void visit(StorePtr v) override {
hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half;
+ hasBFloat16_ |= v->buf()->dtype().scalar_type() == ScalarType::BFloat16;
IRVisitor::visit(v);
}
hasHalf_ = true;
}
+ void visit(BFloat16ImmPtr v) override {
+ hasBFloat16_ = true;
+ }
+
void visit(CastPtr v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
+ hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16;
IRVisitor::visit(v);
}
private:
bool hasHalf_{false};
+ bool hasBFloat16_{false};
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class HalfRewriter : public IRMutator {
ExprPtr mutate(LoadPtr v) override {
ExprPtr child = IRMutator::mutate(v);
- if (child->dtype().scalar_type() != ScalarType::Half) {
+ if (!isHalf(child)) {
return child;
}
StmtPtr mutate(StorePtr v) override {
// Since mutation changes the `value()` expression in-place, we need to
// get the dtype of the `value()` before that is mutated.
- Dtype newType = v->value()->dtype();
+ auto newType = v->value()->dtype();
ExprPtr new_val = v->value()->accept_mutator(this);
- if (newType.scalar_type() == ScalarType::Half) {
- new_val =
- alloc<Cast>(newType.cloneWithScalarType(ScalarType::Half), new_val);
+ if (isHalf(newType.scalar_type())) {
+ new_val = alloc<Cast>(newType, new_val);
inserted_half_casts_.insert(new_val);
}
return alloc<Cast>(kFloat, v);
}
+ ExprPtr mutate(BFloat16ImmPtr v) override {
+ return alloc<Cast>(kFloat, v);
+ }
+
ExprPtr mutate(CastPtr v) override {
ExprPtr child = v->src_value()->accept_mutator(this);
// just don't allow half casts we didn't insert.
- if (v->dtype().scalar_type() == ScalarType::Half) {
+ if (isHalf(v)) {
if (inserted_half_casts_.count(v) < 1) {
return child;
}
return alloc<Cast>(v->dtype(), child);
}
+
StmtPtr mutate(LetPtr v) override {
- if (v->dtype().scalar_type() == ScalarType::Half) {
+ if (isHalf(v->dtype().scalar_type())) {
VarPtr load_new_var = alloc<Var>(v->var()->name_hint(), kFloat);
ExprPtr new_value = alloc<Cast>(
v->dtype().cloneWithScalarType(ScalarType::Float),
template <typename T>
ExprPtr mutateArithmetic(T v) {
IRMutator::mutate(v);
- if (v->dtype().scalar_type() == c10::kHalf) {
+ if (isHalf(v)) {
v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat));
}
return v;
}
private:
+ static bool isHalf(ScalarType st) {
+ return st == ScalarType::Half || st == ScalarType::BFloat16;
+ }
+
+ static bool isHalf(ExprPtr v) {
+ return isHalf(v->dtype().scalar_type());
+ }
+
std::unordered_set<ExprPtr> inserted_half_casts_;
std::unordered_map<VarPtr, VarPtr> var_map;
};
CACHE_GUARD(); \
putHash(v, hash_combine(#Name, v->value())); \
}
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
#undef IMM_VISIT
void visit(CastPtr v) override;
std::memcpy(&n, &d, sizeof d);
return te_hash(n);
}
+
+ size_t te_hash(at::BFloat16 d) {
+ // memcpy as type punning. Should be optimized out.
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ int16_t n;
+ std::memcpy(&n, &d, sizeof d);
+ return te_hash(n);
+ }
};
} // namespace tensorexpr
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
return imm->value() < 0; \
}
- AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
return false;
}
private: \
Type value_; \
};
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
#undef IMM_DECLARE
// Get immediate by ScalarType.
switch (immType) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
- return alloc<Name##Imm>(initialVal);
+ return alloc<Name##Imm>(Type(initialVal));
// NOLINTNEXTLINE(bugprone-branch-clone)
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
return imm->value(); \
}
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
throw unsupported_dtype();
return 0;
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
return imm->value() == val; \
}
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
throw unsupported_dtype();
return false;
ExprPtr IRCloner::mutate(Name##ImmPtr v) { \
return v; \
}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
#undef IMM_MUTATE_DEFINE
ExprPtr IRCloner::mutate(CastPtr v) {
ExprPtr mutate(RshiftPtr v) override;
ExprPtr mutate(CompareSelectPtr v) override;
#define IMM_MUTATE_DECLARE(Type, Name) ExprPtr mutate(Name##ImmPtr v) override;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE);
#undef IMM_MUTATE_DECLARE
ExprPtr mutate(CastPtr v) override;
ExprPtr mutate(BitCastPtr v) override;
ExprPtr IRMutator::mutate(Name##ImmPtr v) { \
return v; \
}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
#undef IMM_MUTATE_DEFINE
ExprPtr IRMutator::mutate(CastPtr v) {
virtual ExprPtr mutate(RshiftPtr v);
virtual ExprPtr mutate(CompareSelectPtr v);
#define IMM_MUTATE_DECLARE(Type, Name) virtual ExprPtr mutate(Name##ImmPtr v);
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE);
#undef IMM_MUTATE_DECLARE
virtual ExprPtr mutate(CastPtr v);
virtual ExprPtr mutate(BitCastPtr v);
void IRPrinter::visit(Name##ImmPtr v) { \
formatImm(os(), v->value()); \
}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
void IRPrinter::visit(CastPtr v) {
void visit(RshiftPtr v) override;
void visit(CompareSelectPtr v) override;
#define IMM_PRINT_VISIT(Type, Name) void visit(Name##ImmPtr v) override;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
void visit(CastPtr v) override;
void visit(BitCastPtr v) override;
Type val = eval.value<Type>(); \
return getImmediateByType(v->dtype().scalar_type(), val); \
}
- AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
// NOLINTNEXTLINE
#define IMM_VISIT(Type, Name) \
void IRVisitor::visit(Name##ImmPtr v) {}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
#undef IMM_VISIT
void IRVisitor::visit(CastPtr v) {
#define IMM_PRINT_VISIT(Type, Name) virtual void visit(Name##ImmPtr v);
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT)
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT)
#undef IMM_PRINT_VISIT
virtual void visit(CastPtr v);
case ScalarType::Name: \
e = cast<Type>(e); \
break;
- AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
return cast<Type>(e);
- AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
case ScalarType::Bool:
return cast<bool>(e);
void visit(CompareSelectPtr v) override;
#define IMM_VISIT_DECLARE(_1, Name) void visit(Name##ImmPtr v) override;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT_DECLARE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE);
#undef IMM_VISIT_DECLARE
void visit(CastPtr v) override;
value_ = llvm::ConstantFP::get(HalfTy_, v->value());
}
+void LLVMCodeGenImpl::visit(BFloat16ImmPtr v) {
+ TORCH_INTERNAL_ASSERT(false, "llvm codegen does not support bfloat16");
+}
+
void LLVMCodeGenImpl::visit(BoolImmPtr v) {
value_ = llvm::ConstantInt::get(BoolTy_, v->value());
}
#define MAX_BY_TYPE_CASE(Type, Name) \
case ScalarType::Name: \
return ExprHandle(std::numeric_limits<Type>::max());
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, MAX_BY_TYPE_CASE)
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE)
#undef MAX_BY_TYPE_CASE
default:
throw unsupported_dtype();
#define MAX_BY_TYPE_CASE(Type, Name) \
case ScalarType::Name: \
return ExprHandle(std::numeric_limits<Type>::min());
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, MAX_BY_TYPE_CASE)
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE)
#undef MAX_BY_TYPE_CASE
default:
throw unsupported_dtype();
#define DTYPE_SINGLETON_ACCESSOR(ctype, name) \
dtype_class.def_property_readonly_static( \
#name, [](py::object) { return k##name; }); // NOLINT
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DTYPE_SINGLETON_ACCESSOR)
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_SINGLETON_ACCESSOR)
#undef DTYPE_SINGLETON_ACCESSOR
auto expr_handle_class =
#define EXPRHANDLE_CTOR(ctype, name) \
expr_handle_class.def_static(#ctype, [](ctype v) { return ExprHandle(v); });
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, EXPRHANDLE_CTOR)
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_CTOR)
#undef EXPRHANDLE_CTOR
py::class_<VarHandle, ExprHandle>(te, "VarHandle")
// NOLINTNEXTLINE
#define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1);
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DTYPE_DEFINE)
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_DEFINE)
#undef DTYPE_DEFINE
#define TYPE_CASE(_1, n) \
case ScalarType::n: \
return k##n;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE)
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
case ScalarType::Undefined:
scalar_size = sizeof(Type); \
break;
- AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
default:
throw std::runtime_error(
return "bool";
case ScalarType::Half:
return "half";
+ case ScalarType::BFloat16:
+ return "__nv_bfloat16";
default:
throw unsupported_dtype();
}
#define NNC_DTYPE_DECLARATION(ctype, name) extern TORCH_API Dtype k##name;
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, NNC_DTYPE_DECLARATION)
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_DTYPE_DECLARATION)
#undef NNC_DTYPE_DECLARATION
template <typename T>
inline Dtype ToDtype<ctype>() { \
return k##name; \
}
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, NNC_TODTYPE_DECLARATION)
+AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_TODTYPE_DECLARATION)
#undef NNC_TODTYPE_DECLARATION
TORCH_API Dtype ToDtype(ScalarType type);