* Lower cast between bf16 and fp32
* Lower bf16 FloatImm to int16
*/
-class BF16LowerRewriter : StmtExprMutator {
+class BF16LowerRewriter : public StmtExprMutator {
public:
BF16LowerRewriter() {}
- std::unordered_map<const BufferNode*, Buffer> buffer_remap;
- std::unordered_map<const VarNode*, Var> var_remap;
-
- Stmt operator()(Stmt s) { return VisitStmt(s); }
+ using StmtExprMutator::operator();
PrimExpr VisitExpr_(const CastNode* op) final {
auto op_val = StmtExprMutator::VisitExpr(op->value);
auto uint32_v = Cast(uint32_dtype, op_val);
// to be endian invariant.
return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16});
-
} else if (op->dtype.is_bfloat16()) {
// if is cast_to_bf16, check if op->value is fp32
CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
}
PrimExpr VisitExpr_(const VarNode* op) final {
- auto itr = var_remap.find(op);
- if (itr != var_remap.end()) {
+ Var var = GetRef<Var>(op);
+
+ auto itr = var_remap_.find(var);
+ if (itr != var_remap_.end()) {
return itr->second;
+ } else {
+ return std::move(var);
}
- if (op->dtype.is_bfloat16()) {
- CHECK(!op->type_annotation.defined());
- auto ret = Var(op->name_hint, op->dtype);
- var_remap[op] = ret;
- return std::move(ret);
- }
- return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const AllocateNode* op) final {
- Stmt node_holder;
- const AllocateNode* newop;
if (op->dtype.is_bfloat16()) {
- auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents,
- op->condition, op->body);
- node_holder = v;
- newop = static_cast<const AllocateNode*>(v.operator->());
+ DataType dtype = DataType::UInt(16, op->dtype.lanes());
+ Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
+ var_remap_[op->buffer_var] = buffer_var;
+ return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body));
} else {
- newop = op;
+ return StmtExprMutator::VisitStmt_(op);
}
- return StmtExprMutator::VisitStmt_(newop);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
- auto itr = buffer_remap.find(op->buffer.operator->());
- const BufferStoreNode* newop;
- BufferStore newop_holder;
- if (itr != buffer_remap.end()) {
- newop_holder = BufferStore(itr->second, op->value, op->indices);
- newop = newop_holder.operator->();
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<BufferStoreNode>();
+
+ auto it = buffer_remap_.find(op->buffer);
+ if (it != buffer_remap_.end()) {
+ return BufferStore(it->second, op->value, op->indices);
} else {
- newop = op;
+ return ret;
}
- return StmtExprMutator::VisitStmt_(newop);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- const AttrStmtNode* newop = op;
- Stmt newop_holder;
- if (auto buffer = op->node.as<BufferNode>()) {
- auto itr = buffer_remap.find(buffer);
- if (itr != buffer_remap.end()) {
- newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
- newop = newop_holder.as<AttrStmtNode>();
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<AttrStmtNode>();
+
+ if (auto* buffer = op->node.as<BufferNode>()) {
+ auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
+ if (it != buffer_remap_.end()) {
+ return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
- } else if (auto buffer = op->node.as<VarNode>()) {
- auto itr = var_remap.find(buffer);
- if (itr != var_remap.end()) {
- newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
- newop = newop_holder.as<AttrStmtNode>();
+ } else if (auto* var = op->node.as<VarNode>()) {
+ auto it = var_remap_.find(GetRef<Var>(var));
+ if (it != var_remap_.end()) {
+ return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
}
- return StmtExprMutator::VisitStmt_(newop);
+ return ret;
}
Stmt VisitStmt_(const BufferRealizeNode* op) final {
- auto itr = buffer_remap.find(op->buffer.operator->());
- const BufferRealizeNode* newop;
- Stmt newop_holder;
- if (itr != buffer_remap.end()) {
- auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body);
- newop_holder = v;
- newop = v.operator->();
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<BufferRealizeNode>();
+
+ auto it = buffer_remap_.find(op->buffer);
+ if (it != buffer_remap_.end()) {
+ return BufferRealize(it->second, op->bounds, op->condition, op->body);
} else {
- newop = op;
+ return ret;
+ }
+ }
+
+ Stmt VisitStmt_(const StoreNode* op) final {
+ // NOTE: we do not explicit recursivly mutate op->buffer_var
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<StoreNode>();
+
+ auto it = var_remap_.find(op->buffer_var);
+ if (it != var_remap_.end()) {
+ return Store(it->second, op->value, op->index, op->predicate);
+ } else {
+ return ret;
}
- return StmtExprMutator::VisitStmt_(newop);
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- auto itr = buffer_remap.find(op->buffer.operator->());
- const BufferLoadNode* newop;
- BufferLoad newop_holder;
- if (itr != buffer_remap.end()) {
- newop_holder = BufferLoad(itr->second, op->indices);
- newop = newop_holder.operator->();
+ PrimExpr ret = StmtExprMutator::VisitExpr_(op);
+ op = ret.as<BufferLoadNode>();
+
+ auto it = buffer_remap_.find(op->buffer);
+ if (it != buffer_remap_.end()) {
+ return BufferLoad(it->second, op->indices);
} else {
- newop = op;
+ return ret;
}
- return StmtExprMutator::VisitExpr_(newop);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
- bool is_bf16 = false;
+ PrimExpr ret = StmtExprMutator::VisitExpr_(op);
+ op = ret.as<LoadNode>();
+
if (op->dtype.is_bfloat16()) {
- is_bf16 = true;
- }
- PrimExpr index = this->VisitExpr(op->index);
- PrimExpr predicate = this->VisitExpr(op->predicate);
- if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) {
- return GetRef<PrimExpr>(op);
+ auto it = var_remap_.find(op->buffer_var);
+ CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped";
+ return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate);
} else {
- return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var,
- index, predicate);
+ return ret;
}
}
void AlterBuffers(PrimFuncNode* op) {
std::vector<std::pair<Var, Buffer>> changes;
+
for (auto& itr : op->buffer_map) {
auto oldbuf = itr.second;
if (oldbuf->dtype.is_bfloat16()) {
- auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape,
- oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope,
- oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type);
- buffer_remap[oldbuf.operator->()] = newbuf;
+ DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
+ Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
+ auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
+ oldbuf->name, oldbuf->scope, oldbuf->data_alignment,
+ oldbuf->offset_factor, oldbuf->buffer_type);
+ buffer_remap_[oldbuf] = newbuf;
+ var_remap_[oldbuf->data] = buffer_var;
changes.emplace_back(itr.first, newbuf);
+ } else {
+ changes.emplace_back(itr);
}
}
- if (buffer_remap.size() != 0) {
+
+ if (buffer_remap_.size() != 0) {
op->buffer_map = Map<Var, Buffer>(changes.begin(), changes.end());
}
}
+
+ private:
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
+ std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
};
namespace transform {