From: Dhruva Ray Date: Thu, 4 Jun 2020 17:54:56 +0000 (+0530) Subject: [TOPI,RELAY][TFLITE] Sparse to dense operator (#5447) X-Git-Tag: upstream/0.7.0~611 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c2e248f07e8263c282a216748e35399a197ea08e;p=platform%2Fupstream%2Ftvm.git [TOPI,RELAY][TFLITE] Sparse to dense operator (#5447) * [Relay][Frontend][TFLite] Add parser support for shape and range Signed-off-by: Dhruva Ray * [TOPI,RELAY][TFLITE] Sparse to dense operator Signed-off-by: Dhruva Ray * use param name in documentation Signed-off-by: Dhruva Ray * sphinx doc errors fixed Signed-off-by: Dhruva Ray * incorporated review comments Signed-off-by: Dhruva Ray * Missing a blank line... Signed-off-by: Dhruva Ray * use get_tensor_expr Signed-off-by: Dhruva Ray * Accidently removed this function in the rebase... Signed-off-by: Dhruva Ray * support default value for default_value Signed-off-by: Dhruva Ray * clang format fixes Signed-off-by: Dhruva Ray * topi pylint fixes Signed-off-by: Dhruva Ray --- diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index cef2999..f93f82f 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -50,6 +50,7 @@ List of operators topi.expand_dims topi.reshape topi.unravel_index + topi.sparse_to_dense topi.squeeze topi.concatenate topi.split @@ -154,6 +155,7 @@ topi .. autofunction:: topi.expand_dims .. autofunction:: topi.reshape .. autofunction:: topi.unravel_index +.. autofunction:: topi.sparse_to_dense .. autofunction:: topi.squeeze .. autofunction:: topi.concatenate .. autofunction:: topi.split diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 798d440..7b20b3d 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -130,6 +130,7 @@ This level enables additional math and transform operators. tvm.relay.tile tvm.relay.reverse tvm.relay.unravel_index + tvm.relay.sparse_to_dense **Level 4: Broadcast and Reductions** diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index ccf8e54..03605ee 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -274,6 +274,15 @@ struct SequenceMaskAttrs : public tvm::AttrsNode { } }; // struct SequenceMaskAttrs. +/*! \brief Attributes used in sparse_to_dense operator */ +struct SparseToDenseAttrs : public tvm::AttrsNode { + Array output_shape; + + TVM_DECLARE_ATTRS(SparseToDenseAttrs, "relay.attrs.SparseToDenseAttrs") { + TVM_ATTR_FIELD(output_shape).describe("Shape of the dense output tensor"); + } +}; // struct SparseToDenseAttrs + /*! \brief Attributes for ndarray_size operator */ struct NdarraySizeAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 9414314..15d0253 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -33,6 +33,7 @@ from .common import ExprTable from .common import infer_shape as _infer_shape from .tflite_flexbuffer import FlexBufferDecoder + __all__ = ['from_tflite'] class TensorWrapper(object): @@ -130,6 +131,7 @@ class OperatorConverter(object): 'SOFTMAX': self.convert_softmax, 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, 'SPACE_TO_DEPTH': self.convert_space_to_depth, + 'SPARSE_TO_DENSE': self.convert_sparse_to_dense, 'SPLIT': self.convert_split, 'SPLIT_V': self.convert_split_v, 'SQRT': self.convert_sqrt, @@ -2267,6 +2269,36 @@ class OperatorConverter(object): return out + def convert_sparse_to_dense(self, op): + """Convert TFLite SPARSE_TO_DENSE""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 4, "input tensors length should be 4" + + indices, values = input_tensors[0], input_tensors[2] + default_value = input_tensors[3] + output_shape = input_tensors[1] + + for t in input_tensors: + assert not t.qnn_params, "Quantized input is not expected." + + for t in [indices, output_shape]: + t_type = t.tensor.Type() + assert t_type in (TensorType.INT32, TensorType.INT64) + + out = _op.sparse_to_dense( + self.get_tensor_expr(indices), + list(self.get_tensor_value(output_shape)), + self.get_tensor_expr(values), + self.get_tensor_expr(default_value) + ) + + return out + def convert_prelu(self, op): """Convert TFLite PReLU""" input_tensors = self.get_input_tensors(op) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e1c2bd7..c99be3c 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -55,6 +55,7 @@ _reg.register_injective_schedule("sequence_mask") _reg.register_injective_schedule("one_hot") _reg.register_reduce_schedule("collapse_sum_like") _reg.register_injective_schedule("unravel_index") +_reg.register_injective_schedule("sparse_to_dense") # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 1da58ae..a3c6517 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -890,3 +890,34 @@ def unravel_index(indices, shape): """ return _make.unravel_index(indices, shape) + +def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0): + """Converts a sparse representation into a dense tensor. + + Example:: + - sparse_to_dense([[0, 0], [1, 1]], [2, 2], [3, 3], 0) = [[3, 0], [0, 3]] + + Parameters + ---------- + sparse_indices : relay.Expr + A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values. + + output_shape : relay.Expr + A list of integers. Shape of the dense output tensor. + + sparse_values : relay.Expr + A 0-D or 1-D tensor containing the sparse values for the sparse indices. + + default_value : relay.Expr + A 0-D tensor containing the default value for the remaining locations. + Defaults to 0. + + Returns + ------- + result : relay.Expr + Dense tensor of shape output_shape. Has the same type as sparse_values. + """ + + if default_value == 0: + default_value = const(0) + return _make.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7282ac7..a80bb31 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2524,5 +2524,79 @@ Example:: .set_attr("FTVMCompute", UnRavelIndexCompute) .set_attr("TOpPattern", kInjective); +// sparse_to_dense +TVM_REGISTER_NODE_TYPE(SparseToDenseAttrs); + +bool SparseToDenseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 3); + auto sparse_indices = types[0].as(); + auto sparse_values = types[1].as(); + auto default_value = types[2].as(); + CHECK(sparse_indices != nullptr && sparse_values != nullptr && default_value != nullptr); + + CHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers"; + + CHECK_LE(sparse_indices->shape.size(), 3) + << "sparse_indices must be a tensor of either 0D, 1D or 2D"; + + CHECK_LE(sparse_values->shape.size(), 2) << "sparse_values must be a tensor of either 0D, 1D"; + + CHECK_EQ(default_value->shape.size(), 0) << "default_value should be a scalar"; + + const auto* param = attrs.as(); + CHECK(param != nullptr); + + Array oshape; + for (auto i : param->output_shape) { + oshape.push_back(i); + } + reporter->Assign(types[3], TensorType(oshape, sparse_values->dtype)); + return true; +} + +Array SparseToDenseCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + CHECK_EQ(inputs.size(), 3); + const auto* param = attrs.as(); + CHECK(param != nullptr); + return {topi::sparse_to_dense(inputs[0], param->output_shape, inputs[1], inputs[2]())}; +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_to_dense") + .set_body_typed([](Expr indices, Array output_shape, Expr values, Expr default_value) { + auto attrs = make_object(); + attrs->output_shape = std::move(output_shape); + static const Op& op = Op::Get("sparse_to_dense"); + return Call(op, {indices, values, default_value}, Attrs(attrs)); + }); + +RELAY_REGISTER_OP("sparse_to_dense") + .describe(R"code(A dense tensor from a sparse representation. + + - **sparse_indices**: A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values + + - **output_shape**: A list of integers. Shape of the dense output tensor. + + - **sparse_values**: A 0-D or 1-D tensor containing the sparse values for the sparse indices. + + - **default_value**: A 0-D tensor containing the default value for the remaining locations. Defaults to 0. + + Example:: + - sparse_to_dense([0, 0], [1, 2]], [3, 4], [1, 2], 0) = [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]] + + )code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .set_support_level(3) + .set_attrs_type() + .add_argument("sparse_indices", "Tensor", "Contains sparse indices.") + .add_argument("sparse_values", "Tensor", "Contains values for sparse indices.") + .add_argument("default_value", "Tensor", "Value to set for non-sparse indices. Defaults to 0.") + .add_type_rel("SparseToDense", SparseToDenseRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", SparseToDenseCompute); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 24b82c6..d5dafd8 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -859,7 +859,6 @@ def test_all_resize(): if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()): _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False) - ####################################################################### # Concatenation # ------------- @@ -1863,6 +1862,80 @@ def test_forward_spacetodepth(): _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4) ####################################################################### +# Sparse To Dense +# --------------- +def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + with tf.Graph().as_default(): + indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices") + values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values") + oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype)) + + if default_value == None: + output = tf.sparse_to_dense(indices, oshape, values) + compare_tflite_with_tvm( + [sparse_indices, sparse_values], + ["indices", "values"], + [indices, values], + [output] + ) + else: + dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value") + output = tf.sparse_to_dense(indices, oshape, values, dv) + compare_tflite_with_tvm( + [sparse_indices, sparse_values, default_value], + ["indices", "values", "default_value"], + [indices, values, dv], + [output] + ) + +def test_forward_sparse_to_dense(): + ''' + Works in tvm/topi/tensorflow. But tflite converter breaks this test case + _test_sparse_to_dense( + np.int32(1), + np.int32(3), + np.int32(0), + np.array([5]).astype("int32") + ) + ''' + # vector + _test_sparse_to_dense( + np.array([0, 1, 4]).astype("int32"), + np.array([3, 3, 3]).astype("int32"), + np.int32(0), + np.array([5]).astype("int32") + ) + # vector nXd + _test_sparse_to_dense( + np.array([[0, 0], [1, 2]]).astype("int32"), + np.array([1, 2]).astype("int32"), + np.int32(0), + np.array([3, 4]).astype("int32") + ) + _test_sparse_to_dense( + np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"), + np.array([1, 2]).astype("int32"), + np.int32(4), + np.array([2, 3, 4]).astype("int32") + ) + # floats + _test_sparse_to_dense( + np.array([0, 1, 4]).astype("int32"), + np.array([3.1, 3.1, 3.1]).astype("float32"), + np.float32(3.5), + np.array([5]).astype("int32") + ) + # default value not specified + _test_sparse_to_dense( + np.array([0, 1, 4]).astype("int32"), + np.array([3.1, 3.1, 3.1]).astype("float32"), + None, + np.array([5]).astype("int32") + ) + +####################################################################### # Fully Connected # --------------- @@ -2305,6 +2378,7 @@ if __name__ == '__main__': test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_sparse_to_dense() test_forward_select() test_forward_quantize_dequantize() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 4deed42..52ff45b 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -747,6 +747,58 @@ def test_unravel_index(): # output which is inline with Tensorflow # verify_unravel_index([0, 1, 2, 5], [2, 2], dtype) +def test_sparse_to_dense(): + def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected): + sparse_indices_data = np.array(sparse_indices) + sparse_values_data = np.array(sparse_values) + default_value_data = np.array(default_value) + + a = relay.var("a", relay.TensorType(sparse_indices_data.shape, str(sparse_indices_data.dtype))) + b = relay.var("b", relay.TensorType(sparse_values_data.shape, str(sparse_values_data.dtype))) + if default_value is None: + args = [a, b] + d = relay.sparse_to_dense(a, output_shape, b) + else: + c = relay.var("c", relay.TensorType(default_value_data.shape, str(default_value_data.dtype))) + args = [a, b, c] + d = relay.sparse_to_dense(a, output_shape, b, c) + + zz = run_infer_type(d) + assert zz.checked_type == relay.ty.TensorType(output_shape, str(sparse_values_data.dtype)) + + func = relay.Function(args, d) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + if default_value is None: + op_res = intrp.evaluate(func)(sparse_indices_data, sparse_values_data) + else: + op_res = intrp.evaluate(func)( + sparse_indices_data, sparse_values_data, default_value_data + ) + tvm.testing.assert_allclose(op_res.asnumpy(), xpected, rtol=1e-5) + + + verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) # scalar + verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) # vector + verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) # nXd + verify_sparse_to_dense( + [[0, 0, 0], [1, 2, 3]], + [1, 2], + 4, + [2, 3, 4], + [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]] + ) # nXd + verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) # floats + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + + #negative test cases + #sparse indices should be ints + #verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_values should be 0d or 1d only + #verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_indices should not be > 2d tensor + #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) if __name__ == "__main__": test_arange() @@ -780,4 +832,5 @@ if __name__ == "__main__": test_gather_nd() test_isfinite() test_isinf() - test_unravel_index() \ No newline at end of file + test_unravel_index() + test_sparse_to_dense() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 400cd1e..813d7d7 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1312,5 +1312,53 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim name, tag); } +/*! + * \brief Get a dense tensor. + * \param sparse_indices sparse_indices[i] contains sparse_values[i] will be placed. + * \param output_shape is the shape of the dense output tensor . + * \param sparse_values is a 0-D or 1-D tensor. Values for each row of sparse_indices. + * \param default_value is a 0-D tensor. Defaults to zero. + * \param name output tensor name. + * \param tag output tensor tag. + * \return Tensor of output_shape. + */ +inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array& output_shape, + const Tensor& sparse_values, const PrimExpr& default_value, + const std::string name = "T_sparse_to_dense", + const std::string tag = kInjective) { + CHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; + CHECK_LE(sparse_indices->shape.size(), 3) << "sparse_indices tensor should be 0D, 1D, or 2D only"; + CHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only"; + + const auto rank_sparse_indices = static_cast(sparse_indices->shape.size()); + Array oshape; + for (auto l : output_shape) { + oshape.push_back(l); + } + return compute( + oshape, + [&](const Array& indices) { + PrimExpr ret = default_value; + if (0 == rank_sparse_indices) { + ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret); + } else if (1 == rank_sparse_indices) { + for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { + ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret); + } + } else { + for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { + PrimExpr aggregate_condition; + for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) { + PrimExpr comparision = indices[k] == sparse_indices[j][k]; + aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision; + } + ret = if_then_else(aggregate_condition, sparse_values[j], ret); + } + } + return ret; + }, + name, tag); +} + } // namespace topi #endif // TOPI_TRANSFORM_H_ diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index ef54560..e0f5c59 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -676,3 +676,32 @@ def unravel_index(indices, shape): """ return cpp.unravel_index(indices, shape) + +def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0): + """Converts a sparse representation into a dense tensor. + + Example:: + - sparse_to_dense([[0, 0], [1, 1]], [2, 2], [3, 3], 0) = [[3, 0], [0, 3]] + + Parameters + ---------- + sparse_indices : tvm.te.Tensor + A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values. + + output_shape : A list of integers + Shape of the dense output tensor. + + sparse_values : tvm.te.Tensor + A 0-D or 1-D tensor containing the sparse values for the sparse indices. + + default_value : tvm.te.Tensor + A 0-D tensor containing the default value for the remaining locations. + Defaults to 0. + + Returns + ------- + result : tvm.te.Tensor + Dense tensor of shape output_shape. Has the same type as sparse_values. + """ + + return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) diff --git a/topi/src/transform.cc b/topi/src/transform.cc index fa27b99..4af491e 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -120,6 +120,10 @@ TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* *rv = unravel_index(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.sparse_to_dense").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = sparse_to_dense(args[0], args[1], args[2], args[3]); +}); + TVM_REGISTER_GLOBAL("topi.matmul").set_body([](TVMArgs args, TVMRetValue* rv) { switch (args.size()) { case 2: diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index b98ce09..47ea8d7 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -595,6 +595,47 @@ def verify_unravel_index(indices, shape, dtype): for device in get_all_backend(): check_device(device) +def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected): + sparse_indices_data = np.array(sparse_indices) + sparse_values_data = np.array(sparse_values) + output_shape_data = np.array(output_shape) + default_value_data = np.array(default_value) + + A = te.placeholder(shape=sparse_indices_data.shape, name="sparse_indices", dtype=str(sparse_indices_data.dtype)) + B = te.placeholder(shape=sparse_values_data.shape, name="sparse_values", dtype=str(sparse_values_data.dtype)) + if default_value is None: + args = [A, B] + D = topi.sparse_to_dense(A, output_shape, B) + else: + C = te.placeholder(shape=(), name="default_value", dtype=str(default_value_data.dtype)) + args = [A, B, C] + D = topi.sparse_to_dense(A, output_shape, B, C) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(D) + + foo = tvm.build(s, args + [D], device, name="sparse_to_dense") + + sparse_indices_nd = tvm.nd.array(sparse_indices_data, ctx) + sparse_values_nd = tvm.nd.array(sparse_values_data, ctx) + out_nd = tvm.nd.empty(output_shape_data, ctx=ctx, dtype=B.dtype) + + if default_value is None: + foo(sparse_indices_nd, sparse_values_nd, out_nd) + else: + default_value_nd = tvm.nd.array(default_value_data, ctx) + foo(sparse_indices_nd, sparse_values_nd, default_value_nd, out_nd) + + tvm.testing.assert_allclose(out_nd.asnumpy(), np.array(xpected)) + + for device in get_all_backend(): + check_device(device) def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) @@ -924,6 +965,27 @@ def test_unravel_index(): verify_unravel_index(144, [5, 5, 5, 2], dtype) verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype) +def test_sparse_to_dense(): + verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) #scalar + verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) #vector + verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0],[0, 0, 2, 0],[0, 0, 0, 0]]) #nXd + verify_sparse_to_dense( + [[0, 0, 0], [1, 2, 3]], + [1, 2], + 4, + [2, 3, 4], + [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]] + ) #nXd + verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) #floats + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + + #negative test cases + #sparse indices should be ints + #verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_values should be 0d or 1d only + #verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_indices should not be > 2d tensor + #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) if __name__ == "__main__": test_strided_slice() @@ -949,3 +1011,4 @@ if __name__ == "__main__": test_where_fusion() test_one_hot() test_unravel_index() + test_sparse_to_dense()