[DataType] Add bfloat16 (#5601)
authorMenooker <Menooker@users.noreply.github.com>
Fri, 19 Jun 2020 14:40:40 +0000 (22:40 +0800)
committerGitHub <noreply@github.com>
Fri, 19 Jun 2020 14:40:40 +0000 (07:40 -0700)
include/tvm/runtime/data_type.h
include/tvm/tir/op.h
include/tvm/tir/transform.h
python/tvm/_ffi/runtime_ctypes.py
python/tvm/driver/build_module.py
python/tvm/tir/transform/transform.py
src/driver/driver_api.cc
src/tir/transforms/bf16_legalize.cc [new file with mode: 0644]
tests/python/unittest/test_target_codegen_llvm.py
tests/python/unittest/test_tir_transform_bf16_legalize.py [new file with mode: 0644]

index b12938b..cb817a8 100644 (file)
@@ -53,6 +53,7 @@ class DataType {
     kUInt = kDLUInt,
     kFloat = kDLFloat,
     kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
+    kBFloat = kDLBfloat,
     kCustomBegin = 129
   };
   /*! \brief default constructor */
@@ -72,6 +73,9 @@ class DataType {
     data_.code = static_cast<uint8_t>(code);
     data_.bits = static_cast<uint8_t>(bits);
     data_.lanes = static_cast<uint16_t>(lanes);
+    if (code == kBFloat) {
+      CHECK_EQ(bits, 16);
+    }
   }
   /*! \return The type code. */
   int code() const { return static_cast<int>(data_.code); }
@@ -89,6 +93,8 @@ class DataType {
   bool is_float() const { return code() == DataType::kFloat; }
   /*! \return whether type is a float16 type. */
   bool is_float16() const { return is_float() && bits() == 16; }
+  /*! \return whether type is a bfloat16 type. */
+  bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
   /*! \return whether type is an int type. */
   bool is_int() const { return code() == DataType::kInt; }
   /*! \return whether type is an uint type. */
@@ -283,6 +289,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
       return "float";
     case DataType::kHandle:
       return "handle";
+    case kDLBfloat:
+      return "bfloat";
     default:
       LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
       return "";
@@ -349,6 +357,9 @@ inline DLDataType String2DLDataType(std::string s) {
     t.bits = 1;
     t.lanes = 1;
     return t;
+  } else if (s.substr(0, 6) == "bfloat") {
+    t.code = DataType::kBFloat;
+    scan = s.c_str() + 6;
   } else if (s.substr(0, 6) == "custom") {
     t.code = ParseCustomDatatype(s, &scan);
   } else {
index 71e9ac4..2948bb2 100644 (file)
@@ -751,7 +751,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
       return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
     }
   }
-  if (t.is_float()) return FloatImm(t, static_cast<double>(value));
+  if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value));
   // For now, we store const scalar values of custom datatypes within doubles; later, during the
   // datatypes lowering pass, we will lower the value to its true representation in the format
   // specified by the datatype.
index a794c12..5e04838 100644 (file)
@@ -322,6 +322,13 @@ TVM_DLL Pass CombineContextCall();
 TVM_DLL Pass NarrowDataType(int target_bits);
 
 /*!
+ * \brief Legalize bf16 typed Ops. Add a cast to fp32
+ *   before Ops, then add a cast back to bf16.
+ * \return The pass.
+ */
+TVM_DLL Pass BF16Legalize();
+
+/*!
  * \brief Rewrite the pointer content type of arguments,
  *  as well as Alloc internal to the function to use
  *  the most frequently accessed type for load/store
index 2e498e3..a7bfb32 100644 (file)
@@ -54,6 +54,7 @@ class DataTypeCode(object):
     UINT = 1
     FLOAT = 2
     HANDLE = 3
+    BFLOAT = 4
 
 
 class DataType(ctypes.Structure):
@@ -65,7 +66,8 @@ class DataType(ctypes.Structure):
         DataTypeCode.INT : 'int',
         DataTypeCode.UINT : 'uint',
         DataTypeCode.FLOAT : 'float',
-        DataTypeCode.HANDLE : 'handle'
+        DataTypeCode.HANDLE : 'handle',
+        DataTypeCode.BFLOAT : 'bfloat'
     }
     def __init__(self, type_str):
         super(DataType, self).__init__()
@@ -96,6 +98,9 @@ class DataType(ctypes.Structure):
             self.type_code = DataTypeCode.HANDLE
             bits = 64
             head = ""
+        elif head.startswith("bfloat"):
+            self.type_code = DataTypeCode.BFLOAT
+            head = head[6:]
         elif head.startswith("custom"):
             # pylint: disable=import-outside-toplevel
             import tvm.runtime._ffi_api
index a19b097..47e9a81 100644 (file)
@@ -176,6 +176,7 @@ def lower(sch,
     pass_list += [
         tvm.tir.transform.InjectPrefetch(),
         tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
+        tvm.tir.transform.BF16Legalize(),
         tvm.tir.transform.NarrowDataType(32),
         tvm.tir.transform.Simplify(),
     ]
index a5af353..86e7a33 100644 (file)
@@ -226,6 +226,56 @@ def RemoveNoOp():
     """
     return _ffi_api.RemoveNoOp()
 
+def BF16Legalize():
+    """Legalize bf16 typed Ops.
+    Runs BF16Promote, BF16CastElimination and BF16TypeLowering
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.BF16Legalize()
+
+def BF16Promote():
+    """Promote bf16 to fp32. Add a cast to fp32
+    before Ops, then add a cast back to bf16.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.BF16Promote()
+
+def BF16CastElimination():
+    """Eliminate verbose casting between fp32 and bf16
+    Checks if the AST has the pattern:
+    castto32(castto16(some_fp32_op(...)))
+    The verbose casting is generated by BF16Promote for multiple
+    bf16 Ops in a row. e.g.:
+    X[i] + Y[i] + T[i] =>
+    bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
+    After this pass:
+    bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.BF16CastElimination()
+
+def BF16TypeLowering():
+    """Replace all bf16 type with uint16. Also lower the casting
+    between fp32 and bf16
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.BF16TypeLowering()
 
 def RewriteUnsafeSelect():
     """Detect and rewrite unsafe select that contains memory access.
index 9d2a11c..e796f49 100644 (file)
@@ -162,6 +162,7 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::strin
   pass_list.push_back(tir::transform::InjectPrefetch());
   pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
   // Phase 1
+  pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::LoopPartition());
diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc
new file mode 100644 (file)
index 0000000..07f4775
--- /dev/null
@@ -0,0 +1,383 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bf16_legalize.cc
+ * \brief legalize bf16 type by adding cast_to_fp32
+ */
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
+
+#include <cmath>
+#include <tuple>
+
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+
+namespace tvm {
+namespace tir {
+
+using arith::Analyzer;
+using arith::IRMutatorWithAnalyzer;
+
+class BF16PromoteRewriter : public StmtExprMutator {
+ public:
+  BF16PromoteRewriter() {}
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  std::tuple<PrimExpr, PrimExpr> DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bfloat16) {
+    auto a = this->VisitExpr(orig_a);
+    auto b = this->VisitExpr(orig_b);
+    *is_bfloat16 = false;
+    if (a->dtype.is_bfloat16()) {
+      CHECK(b->dtype.is_bfloat16());
+      *is_bfloat16 = true;
+    } else if (b->dtype.is_bfloat16()) {
+      CHECK(a->dtype.is_bfloat16());
+      *is_bfloat16 = true;
+    }
+
+    if (*is_bfloat16) {
+      DataType fp32ty(kDLFloat, 32, 1);
+      a = Cast(fp32ty, a);
+      b = Cast(fp32ty, b);
+    }
+    return std::make_tuple(a, b);
+  }
+
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+};
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)  \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \
+    PrimExpr a, b;                                         \
+    bool is_bfloat16;                                      \
+    std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16);   \
+    if (a.same_as(op->a) && b.same_as(op->b)) {            \
+      return GetRef<PrimExpr>(op);                         \
+    } else {                                               \
+      auto ret = FUNC(a, b);                               \
+      if (!is_bfloat16)                                    \
+        return ret;                                        \
+      else                                                 \
+        return Cast(DataType(kDLBfloat, 16, 1), ret);      \
+    }                                                      \
+  }
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {        \
+    PrimExpr a, b;                                                \
+    bool is_bfloat16;                                             \
+    std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16);          \
+    if (a.same_as(op->a) && b.same_as(op->b)) {                   \
+      return GetRef<PrimExpr>(op);                                \
+    } else {                                                      \
+      auto ret = FUNC(a, b);                                      \
+      return ret;                                                 \
+    }                                                             \
+  }
+
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<)   // NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=)  // NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>)   // NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=)  // NOLINT(*)
+
+/*
+ * Eliminate verbose casting between fp32 and bf16
+ * Checks if the AST has the pattern:
+ *     castto32(castto16(some_fp32_op(...)))
+ * The verbose casting is generated by BF16Promote for multiple
+ * bf16 Ops in a row. e.g.:
+ *  X[i] + Y[i] + T[i] =>
+ *  bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
+ * After this pass:
+ *  bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
+ */
+class BF16CastEliminationRewriter : public StmtExprMutator {
+ public:
+  BF16CastEliminationRewriter() {}
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  PrimExpr VisitExpr_(const CastNode* op) final {
+    auto op_val = StmtExprMutator::VisitExpr(op->value);
+    if (op->dtype.is_float() && op->dtype.bits() == 32) {
+      // if is cast_to_fp32, check if op->value is cast_to_fp16
+      // and op->value->value is a float32
+      if (auto innercast = op_val.as<CastNode>()) {
+        if (innercast->dtype.is_bfloat16() && innercast->value->dtype.is_float() &&
+            innercast->value->dtype.bits() == 32) {
+          return innercast->value;
+        }
+      }
+    }
+    if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
+    return Cast(op->dtype, op_val);
+  }
+};
+
+union FloatCaster {
+  uint32_t u32;
+  float f32;
+};
+
+uint16_t RoundToNearestEven(float src) {
+  if (std::isnan(src)) {
+    return UINT16_C(0x7FC0);
+  } else {
+    FloatCaster caster;
+    caster.f32 = src;
+    uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF);
+    return static_cast<uint16_t>((caster.u32 + rounding_bias) >> 16);
+  }
+}
+
+/*
+ * Lower the bf16 type to int16
+ * Lower cast between bf16 and fp32
+ * Lower bf16 FloatImm to int16
+ */
+class BF16LowerRewriter : StmtExprMutator {
+ public:
+  BF16LowerRewriter() {}
+
+  std::unordered_map<const BufferNode*, Buffer> buffer_remap;
+  std::unordered_map<const VarNode*, Var> var_remap;
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  PrimExpr VisitExpr_(const CastNode* op) final {
+    auto op_val = StmtExprMutator::VisitExpr(op->value);
+    if (op->value->dtype.is_bfloat16()) {
+      // if is cast_from_bf16, check if is to fp32
+      CHECK(op->dtype.is_float() && op->dtype.bits() == 32);
+      auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
+      auto uint32_v = Cast(uint32_dtype, op_val);
+      // to be endian invariant.
+      return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic);
+
+    } else if (op->dtype.is_bfloat16()) {
+      // if is cast_to_bf16, check if op->value is fp32
+      CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
+      auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
+      auto uint32_v = Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic);
+      auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes());
+      /* the following TIR is equivalent to the C++ code below:
+      uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
+      return static_cast<uint16_t>((U32 + rounding_bias) >> 16);*/
+      auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF);
+      // to be endian invariant.
+      return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16});
+    }
+    if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
+    return Cast(op->dtype, op_val);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto itr = var_remap.find(op);
+    if (itr != var_remap.end()) {
+      return itr->second;
+    }
+    if (op->dtype.is_bfloat16()) {
+      CHECK(!op->type_annotation.defined());
+      auto ret = Var(op->name_hint, op->dtype);
+      var_remap[op] = ret;
+      return std::move(ret);
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    Stmt node_holder;
+    const AllocateNode* newop;
+    if (op->dtype.is_bfloat16()) {
+      auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents,
+                        op->condition, op->body);
+      node_holder = v;
+      newop = static_cast<const AllocateNode*>(v.operator->());
+    } else {
+      newop = op;
+    }
+    return StmtExprMutator::VisitStmt_(newop);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto itr = buffer_remap.find(op->buffer.operator->());
+    const BufferStoreNode* newop;
+    BufferStore newop_holder;
+    if (itr != buffer_remap.end()) {
+      newop_holder = BufferStore(itr->second, op->value, op->indices);
+      newop = newop_holder.operator->();
+    } else {
+      newop = op;
+    }
+    return StmtExprMutator::VisitStmt_(newop);
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    const AttrStmtNode* newop = op;
+    Stmt newop_holder;
+    if (auto buffer = op->node.as<BufferNode>()) {
+      auto itr = buffer_remap.find(buffer);
+      if (itr != buffer_remap.end()) {
+        newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
+        newop = newop_holder.as<AttrStmtNode>();
+      }
+    } else if (auto buffer = op->node.as<VarNode>()) {
+      auto itr = var_remap.find(buffer);
+      if (itr != var_remap.end()) {
+        newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
+        newop = newop_holder.as<AttrStmtNode>();
+      }
+    }
+    return StmtExprMutator::VisitStmt_(newop);
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    auto itr = buffer_remap.find(op->buffer.operator->());
+    const BufferRealizeNode* newop;
+    Stmt newop_holder;
+    if (itr != buffer_remap.end()) {
+      auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body);
+      newop_holder = v;
+      newop = v.operator->();
+    } else {
+      newop = op;
+    }
+    return StmtExprMutator::VisitStmt_(newop);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto itr = buffer_remap.find(op->buffer.operator->());
+    const BufferLoadNode* newop;
+    BufferLoad newop_holder;
+    if (itr != buffer_remap.end()) {
+      newop_holder = BufferLoad(itr->second, op->indices);
+      newop = newop_holder.operator->();
+    } else {
+      newop = op;
+    }
+    return StmtExprMutator::VisitExpr_(newop);
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    bool is_bf16 = false;
+    if (op->dtype.is_bfloat16()) {
+      is_bf16 = true;
+    }
+    PrimExpr index = this->VisitExpr(op->index);
+    PrimExpr predicate = this->VisitExpr(op->predicate);
+    if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var,
+                  index, predicate);
+    }
+  }
+
+  PrimExpr VisitExpr_(const FloatImmNode* op) final {
+    if (op->dtype.is_bfloat16()) {
+      return IntImm(DataType::UInt(16, op->dtype.lanes()),
+                    RoundToNearestEven(static_cast<float>(op->value)));
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  void AlterBuffers(PrimFuncNode* op) {
+    std::vector<std::pair<Var, Buffer>> changes;
+    for (auto& itr : op->buffer_map) {
+      auto oldbuf = itr.second;
+      if (oldbuf->dtype.is_bfloat16()) {
+        auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape,
+                             oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope,
+                             oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type);
+        buffer_remap[oldbuf.operator->()] = newbuf;
+        changes.emplace_back(itr.first, newbuf);
+      }
+    }
+    if (buffer_remap.size() != 0) {
+      op->buffer_map.assign(changes.begin(), changes.end());
+    }
+  }
+};
+
+namespace transform {
+
+Pass BF16Promote() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = BF16PromoteRewriter()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.BF16Promote").set_body_typed(BF16Promote);
+
+Pass BF16CastElimination() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = BF16CastEliminationRewriter()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination);
+
+Pass BF16TypeLowering() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    BF16LowerRewriter lowerer;
+    lowerer.AlterBuffers(n);
+    n->body = lowerer(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering").set_body_typed(BF16TypeLowering);
+
+Pass BF16Legalize() {
+  return Sequential({BF16Promote(), BF16CastElimination(), BF16TypeLowering()}, "tir.BF16Legalize");
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
index 1173b71..0b415b0 100644 (file)
@@ -737,6 +737,53 @@ def test_llvm_shuffle():
         module(a_, b_, c_)
         tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
 
+def np_float2np_bf16(arr):
+    ''' Convert a numpy array of float to a numpy array 
+    of bf16 in uint16'''
+    orig = arr.view('<u4')
+    bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
+    return np.right_shift(orig + bias, 16).astype('uint16')
+
+def np_float2tvm_bf16(arr):
+    ''' Convert a numpy array of float to a TVM array 
+    of bf16'''
+    nparr = np_float2np_bf16(arr)
+    return tvm.nd.empty(nparr.shape, 'uint16').copyfrom(nparr)
+
+def np_bf162np_float(arr):
+    ''' Convert a numpy array of bf16 (uint16) to a numpy array 
+    of float'''
+    u32 = np.left_shift(arr.astype('uint32'), 16)
+    return u32.view('<f4')
+
+def np_bf16_cast_and_cast_back(arr):
+    ''' Convert a numpy array of float to bf16 and cast back'''
+    return np_bf162np_float(np_float2np_bf16(arr))
+
+def test_llvm_bf16():
+    def dotest(do_vectorize):
+        np.random.seed(122)
+        A = te.placeholder((32, ), dtype='bfloat16')
+        B = te.placeholder((32, ), dtype='bfloat16')
+        d = te.compute((32, ), lambda x: A[x] + B[x])
+        sch = te.create_schedule(d.op)
+        print(tvm.lower(sch, [A,B,d]))
+        if do_vectorize:
+            sch[d].vectorize(d.op.axis[0])
+        module = tvm.build(sch, [A, B, d])
+        npa = np.random.rand(32).astype('float32')
+        npb = np.random.rand(32).astype('float32')
+        va = np_bf16_cast_and_cast_back(npa)
+        vb = np_bf16_cast_and_cast_back(npb)
+        res = np_bf16_cast_and_cast_back(va + vb)
+        a_ = np_float2tvm_bf16(npa)
+        b_ = np_float2tvm_bf16(npb)
+        c_ = tvm.nd.empty((32,), 'uint16')
+        module(a_, b_, c_)
+        tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res)
+    dotest(True)
+    dotest(False)
+    
 if __name__ == "__main__":
     test_multiple_func()
     test_llvm_large_uintimm()
@@ -759,3 +806,4 @@ if __name__ == "__main__":
     test_llvm_fp_math()
     test_dwarf_debug_information()
     test_llvm_shuffle()
+    test_llvm_bf16()
diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py
new file mode 100644 (file)
index 0000000..77a0602
--- /dev/null
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import topi
+from tvm import te
+
+
+def lower_stmt(sche, params, passfunc):
+    func = tvm.driver.build_module.form_irmodule(
+        sche, params, "main", None)["main"]
+    func = passfunc()(
+        tvm.IRModule.from_expr(func))["main"]
+    stmt = func.body
+    return stmt
+
+
+def test_promote():
+    def runpass(op, passfunc):
+        a = te.placeholder((100,), dtype='bfloat16')
+        b = te.placeholder((100,), dtype='bfloat16')
+        c = te.compute((100,), lambda i: op(a[i], b[i]))
+        s = te.create_schedule(c.op)
+        return lower_stmt(s, [a, b, c], passfunc)
+
+    def get_promoted(op):
+        a = te.placeholder((100,), dtype='bfloat16')
+        b = te.placeholder((100,), dtype='bfloat16')
+        c = te.compute((100,), lambda i:
+                       topi.cast(op(topi.cast(a[i], 'float'),
+                                    topi.cast(b[i], 'float')), 'bfloat16')
+                       )
+        s = te.create_schedule(c.op)
+        func = tvm.driver.build_module.form_irmodule(
+            s, [a, b, c], "main", None)["main"]
+        return func.body
+
+    def test_promoted(op):
+        stmt = runpass(op, tvm.tir.transform.BF16Promote)
+        tvm.ir.assert_structural_equal(stmt, get_promoted(op))
+    test_promoted(topi.add)
+    test_promoted(topi.subtract)
+    test_promoted(topi.multiply)
+    test_promoted(topi.divide)
+
+
+def test_eliminate():
+    def to32(v):
+        return topi.cast(v, 'float')
+
+    def to16(v):
+        return topi.cast(v, 'bfloat16')
+
+    def get_eliminated():
+        a = te.placeholder((100,), dtype='bfloat16')
+        b = te.placeholder((100,), dtype='bfloat16')
+        c = te.compute((100,), lambda i: to16(
+            topi.add(
+                to32(
+                    to16(
+                        topi.add(
+                            to32(a[i]),
+                            to32(b[i]),
+                        )
+                    )
+                ),
+                to32(
+                    to16(
+                        topi.add(
+                            to32(a[i]),
+                            to32(b[i]),
+                        )
+                    )
+                )
+            )
+        ))
+        s = te.create_schedule(c.op)
+        stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination)
+        return stmt
+
+    def get_target():
+        a = te.placeholder((100,), dtype='bfloat16')
+        b = te.placeholder((100,), dtype='bfloat16')
+        c = te.compute((100,), lambda i: to16(
+            topi.add(topi.add(
+                to32(a[i]),
+                to32(b[i]),
+            ),
+                     topi.add(
+                         to32(a[i]),
+                         to32(b[i]),
+                     )
+                    )
+        ))
+        s = te.create_schedule(c.op)
+        func = tvm.driver.build_module.form_irmodule(
+            s, [a, b, c], "main", None)["main"]
+        return func.body
+    tvm.ir.assert_structural_equal(get_eliminated(), get_target())
+
+
+def test_legalize():
+    def to32(v):
+        uint32_v = topi.cast(v, "uint32")
+        uint32_v = tvm.tir.call_pure_intrin(
+            "uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32"))
+        return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v)
+
+    def to16(v):
+        uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v)
+        rounding_bias = tvm.tir.call_pure_intrin(
+            "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32"))
+        rounding_bias = tvm.tir.call_pure_intrin(
+            "uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32"))
+        rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
+        uint32_v = uint32_v + rounding_bias
+        uint32_v = tvm.tir.call_pure_intrin(
+            "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32"))
+        return topi.cast(uint32_v, 'uint16')
+
+    def check(fcompute_before, fcompute_after):
+        a = te.placeholder((100,), dtype='bfloat16', name='A')
+        b = te.placeholder((100,), dtype='bfloat16', name='B')
+        c = te.compute((100,), fcompute_before(a, b), name='C')
+        s = te.create_schedule(c.op)
+        stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize)
+
+        a = te.placeholder((100,), dtype='uint16', name='A')
+        b = te.placeholder((100,), dtype='uint16', name='B')
+        c = te.compute((100,), fcompute_after(a, b), name='C')
+        s = te.create_schedule(c.op)
+        func = tvm.driver.build_module.form_irmodule(
+            s, [a, b, c], "main", None)["main"]
+        tvm.ir.assert_structural_equal(stmt, func.body)
+
+    def orig1(a, b):
+        return lambda i: a[i] + b[i] + a[99-i] + b[99-i]
+
+    def after1(a, b):
+        return lambda i: to16(to32(a[i]) + to32(b[i] ) + to32(a[99 - i]) + to32(b[99 - i]))
+
+    def orig2(a, b):
+        return lambda i: a[i] * b[i] + a[99 - i] * b[99 - i] + a[i]
+
+    def after2(a, b):
+        return lambda i: to16(to32(a[i]) * to32(b[i]) + to32(a[99 - i]) * to32(b[99 - i]) + to32(a[i]))
+
+    check(orig1, after1)
+    check(orig2, after2)
+
+
+if __name__ == "__main__":
+    test_promote()
+    test_eliminate()
+    test_legalize()