[Relay] Handle float16 constants & fix BatchNorm (#3260)
authorBalint Cristian <cristian.balint@gmail.com>
Fri, 31 May 2019 02:12:56 +0000 (05:12 +0300)
committerWuwei Lin <vincentl13x@gmail.com>
Fri, 31 May 2019 02:12:56 +0000 (10:12 +0800)
src/relay/pass/pattern_util.h
src/relay/pass/simplify_inference.cc
tests/python/relay/test_pass_simplify_inference.py

index b44bb68..b709f28 100644 (file)
@@ -27,6 +27,7 @@
 #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
 #define TVM_RELAY_PASS_PATTERN_UTIL_H_
 
+#include <builtin_fp16.h>
 #include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/expr.h>
@@ -49,6 +50,9 @@ namespace relay {
   } else if (type == Float(32)) {                       \
     typedef float DType;                                \
     {__VA_ARGS__}                                       \
+  } else if (type == Float(16)) {                       \
+    typedef uint16_t DType;                             \
+    {__VA_ARGS__}                                       \
   } else if (type == Int(64)) {                         \
     typedef int64_t DType;                              \
     {__VA_ARGS__}                                       \
@@ -204,7 +208,14 @@ template<typename T>
 inline Constant MakeConstantScalar(DataType dtype, T value) {
   runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
   TVM_DTYPE_DISPATCH(dtype, DType, {
-    *static_cast<DType*>(arr->data) = value;
+    if (dtype == Float(16)) {
+      // convert to float16
+      // storage is uint16_t
+      *static_cast<DType*>(arr->data) =
+        __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
+    } else {
+      *static_cast<DType*>(arr->data) = value;
+    }
   })
   return ConstantNode::make(arr);
 }
index cecebc5..8dab0c3 100644 (file)
@@ -36,11 +36,13 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
                             Expr moving_mean,
                             Expr moving_var,
                             Type tdata) {
+  auto ttype = tdata.as<TensorTypeNode>();
+  CHECK(ttype);
   const auto param = attrs.as<BatchNormAttrs>();
-  Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
+  Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
   Expr var_add_eps = Add(moving_var, epsilon);
   Expr sqrt_var = Sqrt(var_add_eps);
-  Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var);
+  Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);
 
   if (param->scale) {
     scale = Multiply(scale, gamma);
@@ -52,8 +54,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
   }
 
   int axis = param->axis;
-  auto ttype = tdata.as<TensorTypeNode>();
-  CHECK(ttype);
   auto ndim = ttype->shape.size();
   scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
   shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
index 1387f27..aad1d9f 100644 (file)
 from tvm import relay as rly
 from tvm.relay.ir_pass import simplify_inference, alpha_equal
 
-def test_simplify_batchnorm():
+def test_simplify_batchnorm(dtype='float32'):
     def simple_bn(x, gamma, beta, moving_mean, moving_var,
                   axis=1, epsilon=1e-5, shape=None):
         # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
-        scale = rly.multiply(rly.const(1, 'float32') /
-                rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma)
+        scale = rly.multiply(rly.const(1, dtype) /
+                rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma)
         shift = rly.add(
             rly.multiply(rly.negative(moving_mean), scale), beta)
         num_newaxis = len(shape) - (axis + 1)
@@ -33,8 +33,8 @@ def test_simplify_batchnorm():
 
     def check(dim, axis, nstep):
         eps = 0.01
-        ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32')
-        ttype2 = rly.TensorType((10,), 'float32')
+        ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
+        ttype2 = rly.TensorType((10,), dtype)
         x = rly.var("x", ttype1)
         beta = rly.var("beta", ttype2)
         gamma = rly.var("gamma", ttype2)
@@ -43,10 +43,10 @@ def test_simplify_batchnorm():
         y1, y2 = x, x
 
         for _ in range(nstep):
-            y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'),
+            y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype),
                 gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
             y1 = rly.nn.dropout(y1)
-            y2 = simple_bn(y2 + rly.const(1, 'float32'),
+            y2 = simple_bn(y2 + rly.const(1, dtype),
                            gamma, beta, moving_mean, moving_var,
                            epsilon=eps, axis=axis, shape=ttype1.shape)
         y1 = rly.ir_pass.infer_type(y1)
@@ -60,4 +60,5 @@ def test_simplify_batchnorm():
 
 
 if __name__ == "__main__":
-    test_simplify_batchnorm()
+    test_simplify_batchnorm(dtype='float32')
+    test_simplify_batchnorm(dtype='float16')