From aa035f4650926f5e714b02cbab6d974f0a17352f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 1 Jul 2020 19:14:33 -0700 Subject: [PATCH] [Relay/TOPI][OP] Add meshgrid op in Relay, TOPI, Pytorch frontend (#5961) * Add meshgrid op with pytorch importer * Fix c++ lint * Fix pylint * Meshgrid: add scalar test for pytorch, add topi python wrapper * Add indexing mode attr. * Add MeshgridAttrs python binding * c++ lint --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 1 + include/tvm/relay/attrs/transform.h | 13 ++++ python/tvm/relay/frontend/pytorch.py | 7 +++ python/tvm/relay/op/_transform.py | 1 + python/tvm/relay/op/op_attrs.py | 5 ++ python/tvm/relay/op/transform.py | 41 ++++++++++++- src/relay/op/tensor/transform.cc | 87 +++++++++++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 20 ++++++ tests/python/relay/test_op_level3.py | 35 ++++++++++- topi/include/topi/transform.h | 32 ++++++++++ topi/python/topi/transform.py | 19 ++++++ topi/src/transform.cc | 4 ++ 13 files changed, 265 insertions(+), 2 deletions(-) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 53f2f3c..09c3318 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -104,6 +104,7 @@ List of operators topi.logical_not topi.logical_xor topi.arange + topi.meshgrid topi.stack topi.repeat topi.tile @@ -187,6 +188,7 @@ topi .. autofunction:: topi.greater .. autofunction:: topi.less .. autofunction:: topi.arange +.. autofunction:: topi.meshgrid .. autofunction:: topi.stack .. autofunction:: topi.repeat .. autofunction:: topi.tile diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 86e0c0d..febe542 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -128,6 +128,7 @@ This level enables additional math and transform operators. tvm.relay.reinterpret tvm.relay.split tvm.relay.arange + tvm.relay.meshgrid tvm.relay.stack tvm.relay.repeat tvm.relay.tile diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 8af9f63..b0c8108 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -153,6 +153,19 @@ struct ArangeAttrs : public tvm::AttrsNode { } }; // struct ArangeAttrs +/*! \brief Attributes used in meshgrid operators */ +struct MeshgridAttrs : public tvm::AttrsNode { + std::string indexing; + + TVM_DECLARE_ATTRS(MeshgridAttrs, "relay.attrs.MeshgridAttrs") { + TVM_ATTR_FIELD(indexing) + .describe( + "Indexing mode, either \"ij\" for matrix or \"xy\" for cartesian in which first two" + "dimensions are swapped.") + .set_default("ij"); + } +}; // struct MeshgridAttrs + /*! \brief Attributes used in stack operators */ struct StackAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 84b0907..f0ad87f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1727,6 +1727,12 @@ def _one_hot(): return _impl +def _meshgrid(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.meshgrid(data, indexing="ij") + return _impl + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -1869,6 +1875,7 @@ def _get_convert_map(prelude): "aten::mul_" : _elemwise("multiply"), "aten::pow" : _elemwise("power"), "aten::arange" : _arange(), + "aten::meshgrid" : _meshgrid(), "aten::div" : _elemwise("divide"), "aten::div_" : _elemwise("divide"), "aten::floor_divide" : _elemwise("floor_divide"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a3f2e08..878b82a 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -40,6 +40,7 @@ _reg.register_injective_schedule("reshape_like") _reg.register_injective_schedule("full") _reg.register_injective_schedule("full_like") _reg.register_injective_schedule("arange") +_reg.register_injective_schedule("meshgrid") _reg.register_injective_schedule("reverse") _reg.register_injective_schedule("reverse_sequence") _reg.register_injective_schedule("cast") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 486d63c..32540a5 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -208,6 +208,11 @@ class ArangeAttrs(Attrs): """Attributes used in arange operators""" +@tvm._ffi.register_object("relay.attrs.MeshgridAttrs") +class MeshgridAttrs(Attrs): + """Attributes used in arange operators""" + + @tvm._ffi.register_object("relay.attrs.StackAttrs") class StackAttrs(Attrs): """Attributes used in stack operators""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 9dc96f5..188cd5c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -20,7 +20,7 @@ from . import _make from .dyn import _make as _dyn_make -from ..expr import TupleWrapper, const, Expr +from ..expr import TupleWrapper, const, Expr, Tuple from ...tir import expr as _expr @@ -418,6 +418,45 @@ def arange(start, stop=None, step=None, dtype="float32"): return _make.arange(start, stop, step, dtype) +def meshgrid(data, indexing="ij"): + """Create coordinate matrices from coordinate vectors. + + .. note:: + Similar to ``numpy.meshgrid``. + + Parameters + ---------- + data : Union(List[relay.Expr], Tuple[relay.Expr]) + A list of tensors, which must be either scalars or 1-D vectors. + + indexing : str + Indexing mode, either "ij" for matrix indexing or "xy" for Cartesian indexing. + + Returns + ------- + ret : relay.Tuple([relay.Expr, relay.Expr]) + The computed result. + + Examples + -------- + .. code-block:: python + + x = [1, 2, 3] + y = [4, 5] + + gx, gy = relay.meshgrid([x, y]) + + gx = [[1., 1.], + [2., 2.], + [3., 3.]] + + gy = [[4., 5.], + [4., 5.], + [4., 5.]] + """ + data = list(data) + ret_size = len(data) + return TupleWrapper(_make.meshgrid(Tuple(data), indexing), ret_size) def repeat(data, repeats, axis): """Repeats elements of an array. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b44ddf4..b1c2d8b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1247,6 +1247,93 @@ RELAY_REGISTER_OP("repeat") .set_attr("FTVMCompute", RepeatCompute) .set_attr("TOpPattern", kBroadcast); +// meshgrid operator +TVM_REGISTER_NODE_TYPE(MeshgridAttrs); + +bool MeshgridRel(const Array& types, int num_inputs, const Attrs& raw_attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + const MeshgridAttrs* attrs = raw_attrs.as(); + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + throw Error( + ErrorBuilder() << "meshgrid requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); + } else if (types[0].as() != nullptr) { + return false; + } + const int data_length = static_cast(tensor_tuple->fields.size()); + + // Get first dtype. + const auto& first = Downcast(tensor_tuple->fields[0]); + const DataType dtype = first->dtype; + + // Get size of output grid. + std::vector grid_shape; + grid_shape.reserve(data_length); + for (const Type& ele : tensor_tuple->fields) { + if (ele.as()) { + return false; + } + const auto& e = Downcast(ele); + int e_ndim = static_cast(e->shape.size()); + const DataType& e_dtype = e->dtype; + if (e_dtype != dtype) { + throw Error("relay.meshgrid requires all tensors have the same dtype"); + } + if (e_ndim == 0) { + grid_shape.emplace_back(1); + } else if (e_ndim == 1) { + grid_shape.emplace_back(e->shape[0]); + } else { + throw Error("relay.meshgrid requires all tensors be either scalars or 1-D vectors."); + } + } + + // "xy" mode swaps first two dimensions + if (attrs->indexing == "xy" && grid_shape.size() >= 2) { + std::swap(grid_shape[0], grid_shape[1]); + } + + // There is one output grid for each input, all with same shape. + std::vector grids; + grids.reserve(data_length); + for (int i = 0; i < data_length; i++) { + grids.emplace_back(TensorType(grid_shape, dtype)); + } + reporter->Assign(types[1], TupleType(Array(grids))); + return true; +} + +Array MeshgridCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const MeshgridAttrs* param = attrs.as(); + CHECK(param != nullptr); + return {topi::meshgrid(inputs, param->indexing)}; +} + +Expr MakeMeshgrid(Expr data, String indexing) { + auto attrs = make_object(); + attrs->indexing = std::move(indexing); + static const Op& op = Op::Get("meshgrid"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.meshgrid").set_body_typed(MakeMeshgrid); + +RELAY_REGISTER_OP("meshgrid") + .describe(R"code(Create coordinate matrices from coordinate vectors. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input list of tensors.") + .set_support_level(3) + .add_type_rel("Meshgrid", MeshgridRel) + .set_attr("FTVMCompute", MeshgridCompute) + .set_attr("TOpPattern", kInjective); + // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0694fa5..6b731f4 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -466,6 +466,25 @@ def test_forward_arange(): verify_model(Arange11().float().eval()) verify_model(Arange12().float().eval()) +def test_forward_mesh_grid(): + torch.set_grad_enabled(False) + + class MeshGrid1(Module): + def forward(self, *args): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6]) + grid_x, grid_y = torch.meshgrid([x, y]) + return grid_x, grid_y + + class MeshGrid2(Module): + def forward(self, *args): + x = torch.tensor([1, 2, 3], dtype=torch.float32) + y = torch.add(torch.tensor(5, dtype=torch.float32), 1) + grid_x, grid_y = torch.meshgrid([x, y]) + return grid_x, grid_y + + verify_model(MeshGrid1().float().eval()) + verify_model(MeshGrid2().float().eval()) def test_forward_abs(): torch.set_grad_enabled(False) @@ -2677,6 +2696,7 @@ if __name__ == "__main__": test_forward_full_like() test_forward_linspace() test_forward_arange() + test_forward_mesh_grid() test_forward_chunk() test_forward_split() test_upsample() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f3e28db..115900f 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -588,6 +588,39 @@ def test_arange(): # arange doesnt' support floating point right now, see type relation # verify_arange(20, 1, -1.5) +def test_meshgrid(): + def verify_meshgrid(lengths, indexing="ij"): + input_vars = [] + input_data = [] + for i, length in enumerate(lengths): + input_name = "x_{}".format(i) + if length == 0: + # Scalar + input_vars.append(relay.var(input_name, relay.scalar_type("float32"))) + input_data.append(np.array(1, "float32")) + else: + input_vars.append(relay.var(input_name, relay.TensorType((length,), "float32"))) + input_data.append(np.arange(length).astype("float32")) + + z = relay.meshgrid(input_vars, indexing=indexing).astuple() + func = relay.Function(input_vars, z) + # Get ref + ref_res = np.meshgrid(*input_data, indexing=indexing) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(*input_data) + assert len(op_res) == len(ref_res) + for i in range(len(op_res)): + tvm.testing.assert_allclose(op_res[i].asnumpy(), ref_res[i], rtol=1e-5) + verify_meshgrid([3, 5]) + verify_meshgrid([4, 2], indexing="xy") + verify_meshgrid([3, 5, 2]) + verify_meshgrid([3, 1, 5], indexing="xy") + # Length 0 signifies scalar. + verify_meshgrid([3, 5, 0]) + def test_tile(): def verify_tile(dshape, reps): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -968,7 +1001,6 @@ def test_sparse_to_dense(): #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) if __name__ == "__main__": - test_arange() test_cast() test_zeros_ones() test_unary_identity() @@ -992,6 +1024,7 @@ if __name__ == "__main__": test_squeeze_bad_axes_infer_type() test_split_infer_type() test_arange() + test_meshgrid() test_reverse() test_stack() test_tile() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e0e4556..0da1d71 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1258,6 +1258,38 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr } /*! + * \brief Produce grids by expanding input over dimensions defined by other inputs + * + * \param inputs The input tensors + * \param indexing The indexing mode, either "xy" or "ij" + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the meshgrid operation + */ +inline Array meshgrid(const Array& inputs, const std::string& indexing, + std::string name = "T_meshgrid", std::string tag = kInjective) { + const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2; + Array out_shape; + for (size_t i = 0; i < inputs.size(); ++i) { + const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; + out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]); + } + Array result; + for (size_t i = 0; i < inputs.size(); ++i) { + result.push_back(compute( + out_shape, + [&](const Array& indices) { + const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; + Array real_indices = {indices[src_index]}; + return inputs[i](real_indices); + }, + name, tag)); + } + return result; +} + +/*! * \brief Transform the layout according to \p src_layout and \p dst_layout * \param src the source input. * \param src_layout the source layout. diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index a8c8b14..159412f 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -524,6 +524,25 @@ def arange(start, stop=None, step=1, dtype="float32"): return cpp.arange(start, stop, step, dtype) +def meshgrid(a_tuple, indexing): + """Create coordinate matrices from coordinate vectors. + + Parameters + ---------- + a_tuple : tuple of tvm.te.Tensor + The coordinate vectors or scalars. + + indexing : str + Indexing mode, either "ij" or "xy". + + Returns + ------- + result : tuple of tvm.te.Tensor + The resulting grids for each axis. + """ + return cpp.meshgrid(a_tuple, indexing) + + def repeat(a, repeats, axis): """Repeats elements of an array. diff --git a/topi/src/transform.cc b/topi/src/transform.cc index 4308784..ab39a5e 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -109,6 +109,10 @@ TVM_REGISTER_GLOBAL("topi.arange").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = arange(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.meshgrid").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = meshgrid(args[0], args[1]); +}); + TVM_REGISTER_GLOBAL("topi.repeat").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = repeat(args[0], args[1], args[2]); }); -- 2.7.4