Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / MatrixBandPart.h
index 9f49c8f..5674ff3 100644 (file)
@@ -32,10 +32,10 @@ void MatrixBandPart(const T num_lower_diags, const T num_upper_diags, const Shap
 {
   auto last_dim = input_shape.DimensionsCount() - 1;
 
-  T batch_num = 0;
-  for (int dim = 0; dim < last_dim - 2; dim++)
+  T batch_num = 1;
+  for (int dim = 0; dim < input_shape.DimensionsCount() - 2; dim++)
   {
-    batch_num += input_shape.Dims(dim);
+    batch_num *= input_shape.Dims(dim);
   }
 
   const T row_num = input_shape.Dims(last_dim - 1);