# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=no-else-return,invalid-name,len-as-condition
+# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
"""
A pass for manifesting explicit memory allocations.
"""
new_args = [self.visit(arg) for arg in call.args]
ins = expr.Tuple(new_args)
ret_type = call.checked_type
+ view = LinearizeRetType(ret_type)
+ out_types = view.unpack()
is_dynamic = ret_type.is_dynamic()
# TODO(@jroesch): restore this code, more complex then it seems
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
if is_dynamic:
- assert isinstance(ret_type, ty.TensorType)
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(call.op, self.target_host)
input_states = cfunc.shape_func_param_states
is_inputs = []
+ input_pos = 0
for i, (arg, state) in enumerate(zip(new_args, input_states)):
state = int(state)
# Pass Shapes
if state == 2:
- sh_of = self.visit(self.shape_of(arg))
- shape_func_ins.append(
- scope.let("in_shape_{0}".format(i), sh_of))
+ if isinstance(arg.type_annotation, ty.TupleType):
+ for j in range(len(arg.type_annotation.fields)):
+ let_in_arg = scope.let("in_arg_{0}".format(input_pos + j),
+ expr.TupleGetItem(arg, j))
+ sh_of = self.visit(self.shape_of(let_in_arg))
+ shape_func_ins.append(
+ scope.let("in_shape_{0}".format(input_pos + j), sh_of))
+ input_pos += len(arg.type_annotation.fields)
+ else:
+ sh_of = self.visit(self.shape_of(arg))
+ shape_func_ins.append(
+ scope.let("in_shape_{0}".format(input_pos), sh_of))
+ input_pos += 1
is_inputs.append(0)
# Pass Inputs
elif state == 1:
new_arg = self.visit(arg)
shape_func_ins.append(
- scope.let("in_shape_{0}".format(i), new_arg))
+ scope.let("in_shape_{0}".format(input_pos), new_arg))
+ input_pos += 1
is_inputs.append(1)
# TODO(@jroesch): handle 3rd case
else:
scope.let("shape_func", shape_call)
- out_types = []
- out_types.append(call.checked_type)
-
storages = []
for out_shape, out_type in zip(out_shapes, out_types):
size = self.compute_storage_in_relay(
alloc = scope.let("out_{i}".format(i=i), alloc)
outs.append(alloc)
- invoke = self.invoke_tvm(call.op, ins, expr.Tuple(outs))
+ tuple_outs = expr.Tuple(outs)
+ invoke = self.invoke_tvm(call.op, ins, tuple_outs)
scope.let("", invoke)
- return outs[0]
+ return outs[0] if len(outs) == 1 else tuple_outs
else:
- view = LinearizeRetType(ret_type)
- out_tys = view.unpack()
-
outs = []
- for i, out_ty in enumerate(out_tys):
+ for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i)
outs.append(out)
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
+from topi.util import get_const_tuple
from .op import register_compute, register_schedule, register_pattern, register_shape_func
from .op import schedule_injective, OpPattern
from ...hybrid import script
+from ...api import convert
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]
-# shape func
@script
-def _full_shape_func(x):
- out_ndim = len(x)
+def _full_shape_func(shape):
+ out_ndim = len(shape)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
- out[i] = x[i]
+ out[i] = int64(shape[i])
return out
def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros, zeros_like, ones, ones_like.
"""
- return [_full_shape_func(*inputs)]
+ shape = get_const_tuple(attrs.shape)
+ return [_full_shape_func(convert(shape))]
@script
def _broadcast_shape_func(x, y, ndim):
register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func)
-register_shape_func("zeros_like", False, full_shape_func)
+register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, full_shape_func)
-register_shape_func("ones_like", False, full_shape_func)
+register_shape_func("ones_like", False, elemwise_shape_func)
+register_shape_func("full", False, full_shape_func)
+register_shape_func("full_like", False, elemwise_shape_func)
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
+register_shape_func("maximum", False, broadcast_shape_func)
+register_shape_func("minimum", False, broadcast_shape_func)
register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
+register_shape_func("exp", False, elemwise_shape_func)
@script
def _squeeze_shape_func(data_shape, keep_axes):
out = output_tensor((len(keep_axes),), "int64")
- if len(keep_axes) == 0:
- out_size = 0
- for i in const_range(data_shape.shape[0]):
- if data_shape[i] != 1:
- out_size += 1
-
- if out_size == 0:
- out_size = 1
- out = output_tensor((out_size,), "int64")
- out[0] = int64(1)
- pos = 0
- for i in const_range(data_shape.shape[0]):
- if data_shape[i] != 1:
- out[pos] = data_shape[i]
- pos += 1
- else:
- for i in const_range(len(keep_axes)):
- out[i] = data_shape[keep_axes[i]]
+ for i in const_range(len(keep_axes)):
+ out[i] = data_shape[keep_axes[i]]
return out
if i not in axis:
keep_axes.append(i)
- return [_squeeze_shape_func(inputs[0], convert(keep_axes))]
+ # Due to current relay type system, it is possible even
+ # a static kernel function needs shape function. To handle
+ # this case, we allow axis to be None in squeeze shape func
+ # for now.
+ # TODO(kevinthesun): Enhance relay type system to avoid this.
+ if keep_axes:
+ out = _squeeze_shape_func(inputs[0], convert(keep_axes))
+ else:
+ out = tvm.compute((), lambda *indices: 0)
+ return [out]
@script
def _reshape_like_shape_func(target_shape):
@_reg.register_shape_func("tile", False)
def tile_shape_func(attrs, inputs, _):
+ """
+ Shape function for tile op.
+ """
reps = get_const_tuple(attrs.reps)
ndim = inputs[0].shape[0].value
rndim = len(reps)
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
convert(tndim), convert(rndim))]
+
+@script
+def _split_shape_func(data_shape, index, indices_or_sections, axis):
+ out = output_tensor((data_shape.shape[0],), "int64")
+ if len(indices_or_sections) == 1:
+ for i in const_range(data_shape.shape[0]):
+ if i == axis:
+ out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
+ else:
+ out[i] = data_shape[i]
+ else:
+ start = int64(0)
+ if index > 0:
+ start = int64(indices_or_sections[index - 1])
+ end = data_shape[axis]
+ if index < len(indices_or_sections):
+ end = int64(indices_or_sections[index])
+ for i in const_range(data_shape.shape[0]):
+ if i == axis:
+ out[i] = end - start
+ else:
+ out[i] = data_shape[i]
+ return out
+
+@_reg.register_shape_func("split", False)
+def split_shape_func(attrs, inputs, _):
+ """
+ Shape function for split op.
+ """
+ if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)):
+ indices_or_sections = get_const_int(attrs.indices_or_sections)
+ else:
+ indices_or_sections = get_const_tuple(attrs.indices_or_sections)
+
+ axis = get_const_int(attrs.axis)
+
+ num_out = indices_or_sections if isinstance(indices_or_sections, int) \
+ else len(indices_or_sections) + 1
+ if isinstance(indices_or_sections, int):
+ indices_or_sections = [indices_or_sections]
+ return [_split_shape_func(inputs[0],
+ convert(i),
+ convert(indices_or_sections),
+ convert(axis)) for i in range(num_out)]
verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
+def verify_any_elemwise(x_shape, x_np_shape, op, np_op):
+ dtype = 'float32'
+ x = relay.var('x', shape=x_shape, dtype=dtype)
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x], op(x))
+ x_np = np.random.uniform(size=x_np_shape).astype(dtype)
+ res_np = np_op(x_np)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(x_np)
+ tvm.testing.assert_allclose(result.asnumpy(), res_np)
+
+def test_any_elemwise():
+ verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt)
+ verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative)
+ verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp)
def test_any_broadcast_fail():
# Test broadcast with incompatible values at runtime
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
- z = relay.op.concatenate([x, y], axis=0)
+ xx = x - relay.expr.const(3.0)
+ yy = y * relay.expr.const(5.0)
+ z = relay.op.concatenate([xx, yy], axis=0)
mod = relay.module.Module()
mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
- ref = np.concatenate([x_np, y_np], axis=0)
+ ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
"NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4))
+def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.split(data, indices_or_sections, axis)
+ mod["main"] = relay.Function([data], y.astuple())
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ for ret, ref_ret in zip(result, ref_out_shape):
+ assert ret.asnumpy().shape == ref_ret, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_ret), str(ret.asnumpy().shape))
+
+def test_any_split():
+ verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)])
+ verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
+
def test_any_batch_flatten():
mod = relay.Module()
dtype = "float32"
if __name__ == "__main__":
test_any_full()
test_any_broadcast()
+ test_any_elemwise()
test_any_broadcast_fail()
test_any_concat()
test_any_reshape()
test_any_take()
test_any_tile()
+ test_any_split()
test_any_shape_of()
test_any_reduce()
test_any_layout_transform()