return _op.nn.dense(data0, data1_t)
return _impl
+
def _expand():
def _impl(inputs, input_types):
data_in = inputs[0]
if isinstance(data_in, _expr.Expr):
- shape = _infer_shape(data_in)
+ shape = list(_infer_shape(data_in))
ndims = len(shape)
sizes = _infer_shape(inputs[1])
out = inputs[0]
+ out_dims = len(sizes)
+ if ndims < out_dims:
+ num_newaxis = out_dims - ndims
+ out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis)
+ shape = [1] * num_newaxis + shape
+
for i in range(ndims):
- if sizes[i] in {-1, shape[i]}:
+ if sizes[i] == -1 or sizes[i] == shape[i]:
continue
data = list()
for temp in range(sizes[i]):
data.append(out)
- call = _op.tensor.concatenate(data, i)
- return call
+ out = _op.tensor.concatenate(data, i)
+
+ return out
return _impl
+
def _int():
def _impl(inputs, input_types):
if isinstance(inputs[0], _expr.Expr):
def test_forward_expand():
torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
class Expand1(Module):
def forward(self, *args):
return args[0].expand((3, -1, -1, -1))
+ input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Expand1().float().eval(), input_data=input_data)
+ class Expand2(Module):
+ def forward(self, *args):
+ return args[0].expand((3, 3, 3, 1))
+
+ input_shape = [3, 1]
+ input_data = torch.rand(input_shape).float()
+ verify_model(Expand2().float().eval(), input_data=input_data)
+
+
def test_forward_pow():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]