#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
+#include <builtin_fp16.h>
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
} else if (type == Float(32)) { \
typedef float DType; \
{__VA_ARGS__} \
+ } else if (type == Float(16)) { \
+ typedef uint16_t DType; \
+ {__VA_ARGS__} \
} else if (type == Int(64)) { \
typedef int64_t DType; \
{__VA_ARGS__} \
inline Constant MakeConstantScalar(DataType dtype, T value) {
runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, {
- *static_cast<DType*>(arr->data) = value;
+ if (dtype == Float(16)) {
+ // convert to float16
+ // storage is uint16_t
+ *static_cast<DType*>(arr->data) =
+ __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
+ } else {
+ *static_cast<DType*>(arr->data) = value;
+ }
})
return ConstantNode::make(arr);
}
Expr moving_mean,
Expr moving_var,
Type tdata) {
+ auto ttype = tdata.as<TensorTypeNode>();
+ CHECK(ttype);
const auto param = attrs.as<BatchNormAttrs>();
- Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
+ Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr var_add_eps = Add(moving_var, epsilon);
Expr sqrt_var = Sqrt(var_add_eps);
- Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var);
+ Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);
if (param->scale) {
scale = Multiply(scale, gamma);
}
int axis = param->axis;
- auto ttype = tdata.as<TensorTypeNode>();
- CHECK(ttype);
auto ndim = ttype->shape.size();
scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
from tvm import relay as rly
from tvm.relay.ir_pass import simplify_inference, alpha_equal
-def test_simplify_batchnorm():
+def test_simplify_batchnorm(dtype='float32'):
def simple_bn(x, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, shape=None):
# expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
- scale = rly.multiply(rly.const(1, 'float32') /
- rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma)
+ scale = rly.multiply(rly.const(1, dtype) /
+ rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma)
shift = rly.add(
rly.multiply(rly.negative(moving_mean), scale), beta)
num_newaxis = len(shape) - (axis + 1)
def check(dim, axis, nstep):
eps = 0.01
- ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32')
- ttype2 = rly.TensorType((10,), 'float32')
+ ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
+ ttype2 = rly.TensorType((10,), dtype)
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
y1, y2 = x, x
for _ in range(nstep):
- y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'),
+ y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype),
gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = rly.nn.dropout(y1)
- y2 = simple_bn(y2 + rly.const(1, 'float32'),
+ y2 = simple_bn(y2 + rly.const(1, dtype),
gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ttype1.shape)
y1 = rly.ir_pass.infer_type(y1)
if __name__ == "__main__":
- test_simplify_batchnorm()
+ test_simplify_batchnorm(dtype='float32')
+ test_simplify_batchnorm(dtype='float16')