Release 18.08
[platform/upstream/armnn.git] / src / armnn / layers / NormalizationLayer.cpp
index cacd348..261b16a 100644 (file)
@@ -31,14 +31,16 @@ NormalizationLayer* NormalizationLayer::Clone(Graph& graph) const
 
 void NormalizationLayer::ValidateTensorShapesFromInputs()
 {
-    ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr,
-                                               "NormalizationLayer: Input slot must be connected.");
+    VerifyLayerConnections(1, CHECK_LOCATION());
+
+    auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+
+    BOOST_ASSERT(inferredShapes.size() == 1);
 
-    const TensorShape& outShape = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape();
     ConditionalThrowIfNotEqual<LayerValidationException>(
         "NormalizationLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
         GetOutputSlot(0).GetTensorInfo().GetShape(),
-        outShape);
+        inferredShapes[0]);
 }
 
 } // namespace armnn