unsigned int split_dimension = std::get<props::SplitDimension>(split_props);
+ const TensorDim &in_dim = context.getInputDimensions()[0];
+
if (std::get<props::SplitNumber>(split_props).empty()) {
std::get<props::SplitNumber>(split_props)
- .set(context.getNumRequestedOutputs());
+ .set(in_dim.getTensorDim(split_dimension));
}
unsigned int split_number = std::get<props::SplitNumber>(split_props);
/**
* The split is only done along the split_dimension dimension.
+ * (Assumes input data is continous)
* For example, consider input dimension [b,c,h,w], split_number = n
- * 1. axis = 1, output_dim = [b,n,h,w], num_outputs = c//n
- * 2. axis = 2, output_dim = [b,c,n,w], num_outputs = h//n
- * 3. axis = 3, output_dim = [b,c,h,n], num_outputs = w//n
+ * 1. axis = 1, output_dim = [b,c//n,h,w], num_outputs = n
+ * 2. axis = 2, output_dim = [b,c,h//n,w], num_outputs = n
+ * 3. axis = 3, output_dim = [b,c,h,w//n], num_outputs = n
*/
- const TensorDim &in_dim = context.getInputDimensions()[0];
NNTR_THROW_IF(split_number != context.getNumRequestedOutputs(),
std::invalid_argument)
<< "Given split number does not match with number of outputs";
from recorder_v2 import record_v2, inspect_file
import torch
+class Split(torch.nn.Module):
+ def __init__(self, axis, split_number, channel):
+ super().__init__()
+ self.axis = axis
+ self.split_number = split_number
+ self.conv = torch.nn.Conv2d(channel, channel, 1)
+ self.loss = torch.nn.MSELoss()
+
+ def forward(self, inputs, labels):
+ outs = self.conv(inputs[0])
+ split_size = outs.size(self.axis) // self.split_number
+ *outs, = torch.split(outs, split_size, self.axis)
+ out = torch.clone(outs[0])
+ for i in range(1, len(outs)):
+ out += outs[i]
+
+ loss = self.loss(out, labels[0])
+ return out, loss
class SplitAndJoin(torch.nn.Module):
def __init__(self):
if __name__ == "__main__":
record_v2(
+ Split(3, 5, 3),
+ iteration=2,
+ input_dims=[(2, 3, 4, 5)],
+ label_dims=[(2, 3, 4, 1)],
+ name="split_axis3_split_number5"
+ )
+
+ record_v2(
+ Split(2, 4, 3),
+ iteration=2,
+ input_dims=[(2, 3, 4, 5)],
+ label_dims=[(2, 3, 1, 5)],
+ name="split_axis2_split_number4"
+ )
+
+ record_v2(
+ Split(2, 2, 3),
+ iteration=2,
+ input_dims=[(2, 3, 4, 5)],
+ label_dims=[(2, 3, 2, 5)],
+ name="split_axis2_split_number2"
+ )
+
+ record_v2(
SplitAndJoin(),
iteration=2,
input_dims=[(5, 3)],
using namespace nntrainer;
+static std::unique_ptr<NeuralNetwork> split_axis3_split_number5() {
+ std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+ nn->setProperty({"batch_size=2"});
+
+ auto graph = makeGraph({
+ {"conv2d",
+ {"name=conv", "input_shape=3:4:5", "filters=3", "kernel_size=1,1"}},
+ {"split", {"name=split", "input_layers=conv", "axis=3", "split_number=5"}},
+ {"addition",
+ {"name=add", "input_layers=split(0),split(1),split(2),split(3),split(4)"}},
+ {"mse", {"name=loss", "input_layers=add"}},
+ });
+ for (auto &node : graph) {
+ nn->addLayer(node);
+ }
+
+ nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+ return nn;
+}
+
+static std::unique_ptr<NeuralNetwork> split_axis2_split_number4() {
+ std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+ nn->setProperty({"batch_size=2"});
+
+ auto graph = makeGraph({
+ {"conv2d",
+ {"name=conv", "input_shape=3:4:5", "filters=3", "kernel_size=1,1"}},
+ {"split", {"name=split", "input_layers=conv", "axis=2", "split_number=4"}},
+ {"addition",
+ {"name=add", "input_layers=split(0),split(1),split(2),split(3)"}},
+ {"mse", {"name=loss", "input_layers=add"}},
+ });
+ for (auto &node : graph) {
+ nn->addLayer(node);
+ }
+
+ nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+ return nn;
+}
+
+static std::unique_ptr<NeuralNetwork> split_axis2_split_number2() {
+ std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+ nn->setProperty({"batch_size=2"});
+
+ auto graph = makeGraph({
+ {"conv2d",
+ {"name=conv", "input_shape=3:4:5", "filters=3", "kernel_size=1,1"}},
+ {"split", {"name=split", "input_layers=conv", "axis=2", "split_number=2"}},
+ {"addition", {"name=add", "input_layers=split(0),split(1)"}},
+ {"mse", {"name=loss", "input_layers=add"}},
+ });
+ for (auto &node : graph) {
+ nn->addLayer(node);
+ }
+
+ nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+ return nn;
+}
+
/// A has two output tensor a1, a2 and B, C takes it
/// A
/// (a0, a1)
GTEST_PARAMETER_TEST(
multiInoutModels, nntrainerModelTest,
::testing::ValuesIn({
+ mkModelTc_V2(split_axis3_split_number5, "split_axis3_split_number5",
+ ModelTestOption::ALL_V2),
+ mkModelTc_V2(split_axis2_split_number4, "split_axis2_split_number4",
+ ModelTestOption::ALL_V2),
+ mkModelTc_V2(split_axis2_split_number2, "split_axis2_split_number2",
+ ModelTestOption::ALL_V2),
mkModelTc_V2(split_and_join, "split_and_join", ModelTestOption::ALL_V2),
mkModelTc_V2(one_to_one, "one_to_one", ModelTestOption::ALL_V2),
mkModelTc_V2(one_to_one_reversed, "one_to_one__reversed",