From 8541e2554392cbfab8eeae7bbe40c10f4c730088 Mon Sep 17 00:00:00 2001 From: YixinBao Date: Mon, 16 Dec 2019 13:46:21 +0800 Subject: [PATCH] add bfloat16 typeflag support (#4525) --- nnvm/include/nnvm/top/tensor.h | 15 ++++++++++----- nnvm/src/pass/plan_memory.cc | 1 + 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 8ecdf0c..51cb6e7 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -100,10 +100,14 @@ enum TypeFlag { kInt32 = 4, kInt8 = 5, kInt64 = 6, - kInt16 = 7, - kUint16 = 8, - kUint32 = 9, - kUint64 = 10, + // kBool = 7, + // 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in + // https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L314 + kInt16 = 8, + kUint16 = 9, + kUint32 = 10, + kUint64 = 11, + kBfloat16 = 12, }; enum IndicatorRuleFlag { @@ -125,7 +129,8 @@ enum IndicatorRuleFlag { .add_enum("int8", kInt8) \ .add_enum("int16", kInt16) \ .add_enum("int32", kInt32) \ - .add_enum("int64", kInt64) + .add_enum("int64", kInt64) \ + .add_enum("bfloat16", kBfloat16) struct CastParam : public dmlc::Parameter { int dtype; diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index fa48bdd..de8bc94 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -40,6 +40,7 @@ static int GetDTypeSize(int type_flag) { case kInt8: return 1; case kFloat16: + case kBfloat16: case kInt16: case kUint16: return 2; -- 2.7.4