Release 18.08
[platform/upstream/armnn.git] / src / armnn / layers / MultiplicationLayer.cpp
index af40a23..ed7683d 100644 (file)
@@ -31,41 +31,51 @@ MultiplicationLayer* MultiplicationLayer::Clone(Graph& graph) const
     return CloneBase<MultiplicationLayer>(graph, GetName());
 }
 
-void MultiplicationLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> MultiplicationLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
 {
-    auto& input0 = GetInputSlot(0).GetConnection()->GetTensorInfo();
-    auto& input1 = GetInputSlot(1).GetConnection()->GetTensorInfo();
+    BOOST_ASSERT(inputShapes.size() == 2);
+    auto& input0 = inputShapes[0];
+    auto& input1 = inputShapes[1];
 
-    // Get the max of the inputs
+    // Get the max of the inputs.
     BOOST_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
     unsigned int numDims = input0.GetNumDimensions();
     std::vector<unsigned int> dims(numDims);
 
-    // validate inputs are broadcast compatible
-#if !NDEBUG
     for (unsigned int i = 0; i < numDims; i++)
     {
-        unsigned int dim0 = input0.GetShape()[i];
-        unsigned int dim1 = input1.GetShape()[i];
+        unsigned int dim0 = input0[i];
+        unsigned int dim1 = input1[i];
+
+    // Validates inputs are broadcast compatible.
+#if !NDEBUG
         if (dim0 != dim1)
         {
             BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1.");
         }
-    }
 #endif
 
-    for (unsigned int i = 0; i < numDims; i++)
-    {
-        unsigned int dim0 = input0.GetShape()[i];
-        unsigned int dim1 = input1.GetShape()[i];
         dims[i] = std::max(dim0, dim1);
     }
 
-    TensorShape outShape(numDims, dims.data());
+    return std::vector<TensorShape>({ TensorShape(numDims, dims.data()) });
+}
+
+void MultiplicationLayer::ValidateTensorShapesFromInputs()
+{
+    VerifyLayerConnections(2, CHECK_LOCATION());
+
+    auto inferredShapes = InferOutputShapes({
+        GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+        GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
+    });
+
+    BOOST_ASSERT(inferredShapes.size() == 1);
+
     ConditionalThrowIfNotEqual<LayerValidationException>(
         "MultiplicationLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
         GetOutputSlot(0).GetTensorInfo().GetShape(),
-        outShape);
+        inferredShapes[0]);
 }
 
 } // namespace armnn