Fix bytes_written and bytes_read (#64040)
authorTanvir Zaman <motanv@fb.com>
Mon, 30 Aug 2021 19:56:15 +0000 (12:56 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 19:57:31 +0000 (12:57 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64040

In operator cost inference functions, in many places we are using sizeof(x.data_type()). Since data_type() returns a 32 bit integer from [this enum](https://www.internalfb.com/code/fbsource/[15e7ffe4073cf08c61077c7c24a4839504b964a2]/fbcode/caffe2/caffe2/proto/caffe2.proto?lines=20), we are basically always getting 4 for sizeof(x.data_type()) no matter what actual data type x has. Big thanks to Jack Langman for specifically pointing to this bug.

We would instead use the size in bytes based on actual data type.

Test Plan:
Added unit tests BatchMatMulMemCostTest:

buck test //caffe2/caffe2/fb/fbgemm:batch_matmul_op_test -- BatchMatMulMemCostTest

Extended existing unit test test_columnwise_concat for different data types:

buck test //caffe2/caffe2/python/operator_test:concat_op_cost_test -- test_columnwise_concat

Differential Revision: D30561459

fbshipit-source-id: 976fa5167097a35af548498480001aafd7851d93

caffe2/core/operator_schema.h
caffe2/operators/batch_matmul_op.cc
caffe2/operators/concat_split_op.cc
caffe2/operators/conv_pool_op_base.h
caffe2/operators/distance_op.cc
caffe2/operators/fc_inference.cc
caffe2/operators/one_hot_ops.cc
caffe2/operators/utility_ops.cc
caffe2/python/operator_test/concat_op_cost_test.py
caffe2/python/workspace_test.py
caffe2/sgd/adagrad_op.cc

index 64f5ef3..0d048eb 100644 (file)
@@ -6,12 +6,13 @@
 #include <initializer_list>
 #include <ostream>
 #include <set>
-#include <vector>
 #include <unordered_map>
+#include <vector>
 
 #include "c10/util/Registry.h"
 #include "caffe2/core/common.h"
 #include "caffe2/core/logging.h"
+#include "caffe2/core/types.h"
 #include "caffe2/proto/caffe2_pb.h"
 #include "caffe2/utils/filler.h"
 #include "caffe2/utils/proto_utils.h"
@@ -273,8 +274,8 @@ class TORCH_API OpSchema {
   OpSchema&
   Arg(const char* name, const char* description, bool required = false);
 
-#define DECLARE_STANDARD_ARG(name, str)     \
-  static const char* Arg_##name; \
+#define DECLARE_STANDARD_ARG(name, str) \
+  static const char* Arg_##name;        \
   OpSchema& Arg##name(const char* description);
 
   DECLARE_STANDARD_ARG(IsTest, is_test)
@@ -339,7 +340,9 @@ class TORCH_API OpSchema {
     return inplace_enforced_(x, y);
   }
 
-  TORCH_API friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);
+  TORCH_API friend std::ostream& operator<<(
+      std::ostream& out,
+      const OpSchema& schema);
 
   const std::vector<Argument>& args() const {
     return args_;
@@ -562,8 +565,10 @@ OpSchema::Cost PointwiseCostInference(
   }
 
   c.flops = nElemX * OpsPerPoint;
-  c.bytes_read = nElemRead * sizeof(X.data_type());
-  c.bytes_written = nElemX * sizeof(X.data_type());
+  auto const& X_element_size_byte =
+      DataTypeToTypeMeta(X.data_type()).itemsize();
+  c.bytes_read = nElemRead * X_element_size_byte;
+  c.bytes_written = nElemX * X_element_size_byte;
   return c;
 }
 
index 32799ce..205acf7 100644 (file)
@@ -1,6 +1,7 @@
 #include "caffe2/operators/batch_matmul_op.h"
 
 #include "caffe2/core/operator_schema.h"
+#include "caffe2/core/types.h"
 
 namespace caffe2 {
 
@@ -116,9 +117,13 @@ OpSchema::Cost CostInferenceForBatchMatMul(
     K = in[0].dims(ndims_A - 1);
   }
 
+  auto const& A_element_size_byte =
+      DataTypeToTypeMeta(A.data_type()).itemsize();
+  auto const& Y_element_size_byte =
+      DataTypeToTypeMeta(Y.data_type()).itemsize();
   c.flops = 2 * nElemY * K;
-  c.bytes_read = (nElemA + nElemB) * sizeof(A.data_type());
-  c.bytes_written = nElemY * sizeof(Y.data_type());
+  c.bytes_read = (nElemA + nElemB) * A_element_size_byte;
+  c.bytes_written = nElemY * Y_element_size_byte;
   c.params_bytes = 0;
   return c;
 }
@@ -180,72 +185,76 @@ class GetBatchMatMulGradient : public GradientMakerBase {
     auto no_trans_arg = vector<Argument>();
     auto trans_a_arg = vector<Argument>{MakeArgument<int>("trans_a", 1)};
     auto trans_b_arg = vector<Argument>{MakeArgument<int>("trans_b", 1)};
-    auto trans_both_arg = vector<Argument>{MakeArgument<int>("trans_a", 1),
-                                           MakeArgument<int>("trans_b", 1)};
+    auto trans_both_arg = vector<Argument>{
+        MakeArgument<int>("trans_a", 1), MakeArgument<int>("trans_b", 1)};
 
     if (trans_a) {
       if (trans_b) {
         // A'B':
         // dA = B'G', dB = G'A'
-        return vector<OperatorDef>{CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{I(1), GO(0)},
-                                       vector<string>{GI(0)},
-                                       trans_both_arg),
-                                   CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{GO(0), I(0)},
-                                       vector<string>{GI(1)},
-                                       trans_both_arg)};
+        return vector<OperatorDef>{
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{I(1), GO(0)},
+                vector<string>{GI(0)},
+                trans_both_arg),
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{GO(0), I(0)},
+                vector<string>{GI(1)},
+                trans_both_arg)};
       } else {
         // A'B:
         // dA = BG', dB = AG
-        return vector<OperatorDef>{CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{I(1), GO(0)},
-                                       vector<string>{GI(0)},
-                                       trans_b_arg),
-                                   CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{I(0), GO(0)},
-                                       vector<string>{GI(1)},
-                                       no_trans_arg)};
+        return vector<OperatorDef>{
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{I(1), GO(0)},
+                vector<string>{GI(0)},
+                trans_b_arg),
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{I(0), GO(0)},
+                vector<string>{GI(1)},
+                no_trans_arg)};
       }
     } else {
       if (trans_b) {
         // AB':
         // dA = GB, dB = G'A
-        return vector<OperatorDef>{CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{GO(0), I(1)},
-                                       vector<string>{GI(0)},
-                                       no_trans_arg),
-                                   CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{GO(0), I(0)},
-                                       vector<string>{GI(1)},
-                                       trans_a_arg)};
+        return vector<OperatorDef>{
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{GO(0), I(1)},
+                vector<string>{GI(0)},
+                no_trans_arg),
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{GO(0), I(0)},
+                vector<string>{GI(1)},
+                trans_a_arg)};
       } else {
         // AB:
         // dA = GB', dB = A'G
-        return vector<OperatorDef>{CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{GO(0), I(1)},
-                                       vector<string>{GI(0)},
-                                       trans_b_arg),
-                                   CreateOperatorDef(
-                                       "BatchMatMul",
-                                       "",
-                                       vector<string>{I(0), GO(0)},
-                                       vector<string>{GI(1)},
-                                       trans_a_arg)};
+        return vector<OperatorDef>{
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{GO(0), I(1)},
+                vector<string>{GI(0)},
+                trans_b_arg),
+            CreateOperatorDef(
+                "BatchMatMul",
+                "",
+                vector<string>{I(0), GO(0)},
+                vector<string>{GI(1)},
+                trans_a_arg)};
       }
     }
   }
index 8eceb5a..8aa9e28 100644 (file)
@@ -101,9 +101,12 @@ OpSchema::Cost CostInferenceForSplit(
   CAFFE_ENFORCE_GT(in.size(), 0);
   struct OpSchema::Cost cost;
   cost.flops = 0;
-  auto input_bytes_count = nElemFromDim(in[0]) * sizeof(in[0].data_type());
-  auto split_bytes_count =
-      (in.size() == 1) ? 0 : nElemFromDim(in[1]) * sizeof(in[1].data_type());
+  auto const& input_0_element_size_byte =
+      DataTypeToTypeMeta(in[0].data_type()).itemsize();
+  auto const& input_1_element_size_byte =
+      (in.size() > 1) ? DataTypeToTypeMeta(in[1].data_type()).itemsize() : 0;
+  auto input_bytes_count = nElemFromDim(in[0]) * input_0_element_size_byte;
+  auto split_bytes_count = nElemFromDim(in[1]) * input_1_element_size_byte;
   // There can be two input blobs:
   // (1) actual tensor to be split
   // (2) lengths of outputs along split axis
@@ -329,11 +332,13 @@ OpSchema::Cost CostInferenceForConcat(
   }
   auto split_info_bytes_count = in.size() * sizeof(int);
 
+  auto const& input_0_element_size_byte =
+      DataTypeToTypeMeta(in[0].data_type()).itemsize();
   struct OpSchema::Cost cost;
   cost.flops = 0;
-  cost.bytes_read = nElemRead * sizeof(in[0].data_type());
+  cost.bytes_read = nElemRead * input_0_element_size_byte;
   cost.bytes_written =
-      size * sizeof(in[0].data_type()) + split_info_bytes_count;
+      size * input_0_element_size_byte + split_info_bytes_count;
   cost.params_bytes = 0;
   return cost;
 }
index 25bd99a..b356ef9 100644 (file)
@@ -7,6 +7,7 @@
 #include "caffe2/core/context.h"
 #include "caffe2/core/logging.h"
 #include "caffe2/core/operator.h"
+#include "caffe2/core/types.h"
 #include "caffe2/proto/caffe2_legacy.pb.h"
 #include "caffe2/utils/math.h"
 
@@ -519,14 +520,20 @@ class ConvPoolOpBase : public Operator<Context> {
     uint64_t nElemW = nElemFromDim(W);
     uint64_t nElemBias = inputs.size() > 2 ? nElemFromDim(inputs[2]) : 0;
 
+    auto const& X_elemenet_size_byte =
+        DataTypeToTypeMeta(X.data_type()).itemsize();
+    auto const& Y_element_size_byte =
+        DataTypeToTypeMeta(Y.data_type()).itemsize();
+    auto const& W_element_size_byte =
+        DataTypeToTypeMeta(W.data_type()).itemsize();
+
     // grouping is NOT properly handled yet
     c.flops = N * Y_t * Y_h * Y_w * kernel_t * kernel_w * kernel_h *
         in_channels * out_channels * 2;
-    c.bytes_read = (nElemX + nElemW + nElemBias) * sizeof(X.data_type());
-    c.bytes_written =
-        N * out_channels * Y_t * Y_h * Y_w * sizeof(Y.data_type());
+    c.bytes_read = (nElemX + nElemW + nElemBias) * X_elemenet_size_byte;
+    c.bytes_written = N * out_channels * Y_t * Y_h * Y_w * Y_element_size_byte;
     c.params_bytes = out_channels * in_channels * kernel_t * kernel_h *
-        kernel_w * sizeof(W.data_type());
+        kernel_w * W_element_size_byte;
     return c;
   }
 
index 1529534..9ea8eea 100644 (file)
@@ -1,4 +1,5 @@
 #include "caffe2/operators/distance_op.h"
+#include "caffe2/core/types.h"
 #include "caffe2/utils/eigen_utils.h"
 #ifdef CAFFE2_USE_MKLDNN
 #include <caffe2/ideep/operators/operator_fallback_ideep.h>
@@ -7,7 +8,7 @@
 
 namespace caffe2 {
 
-template<>
+template <>
 bool SquaredL2DistanceOp<float, CPUContext>::RunOnDevice() {
   auto& X = Input(0);
   auto& Y = Input(1);
@@ -257,7 +258,9 @@ OpSchema::Cost CostInferenceForDotProduct(
   CAFFE_ENFORCE_EQ(out[0].dims().size(), 1);
 
   struct OpSchema::Cost c = PointwiseCostInference<2>(def, in);
-  c.bytes_written = out[0].dims(0) * sizeof(out[0].data_type());
+  auto const& out_0_element_size_byte =
+      DataTypeToTypeMeta(out[0].data_type()).itemsize();
+  c.bytes_written = out[0].dims(0) * out_0_element_size_byte;
   c.params_bytes = 0;
   return c;
 }
@@ -379,10 +382,12 @@ bool DotProductWithPaddingOp<float, CPUContext>::RunOnDevice() {
 }
 
 // L2
-REGISTER_CPU_OPERATOR(SquaredL2Distance,
-                      SquaredL2DistanceOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(SquaredL2DistanceGradient,
-                      SquaredL2DistanceGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(
+    SquaredL2Distance,
+    SquaredL2DistanceOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(
+    SquaredL2DistanceGradient,
+    SquaredL2DistanceGradientOp<float, CPUContext>);
 
 OPERATOR_SCHEMA(SquaredL2Distance)
     .NumInputs(2)
@@ -402,7 +407,8 @@ class GetSquaredL2DistanceGradient : public GradientMakerBase {
   using GradientMakerBase::GradientMakerBase;
   vector<OperatorDef> GetGradientDefs() override {
     return SingleGradientDef(
-        "SquaredL2DistanceGradient", "",
+        "SquaredL2DistanceGradient",
+        "",
         vector<string>{I(0), I(1), GO(0)},
         vector<string>{GI(0), GI(1)});
   }
@@ -762,9 +768,9 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase {
       replicate = GetArgument(Def(), "replicate").i();
     }
 
-    const auto dot_arg =
-        vector<Argument>{MakeArgument<float>("pad_value", pad_value),
-                         MakeArgument<bool>("replicate", replicate)};
+    const auto dot_arg = vector<Argument>{
+        MakeArgument<float>("pad_value", pad_value),
+        MakeArgument<bool>("replicate", replicate)};
 
     return SingleGradientDef(
         "DotProductWithPaddingGradient",
@@ -775,4 +781,4 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase {
   }
 };
 REGISTER_GRADIENT(DotProductWithPadding, GetDotProductWithPaddingGradient);
-}  // namespace caffe2
+} // namespace caffe2
index a44c230..ba1b712 100644 (file)
@@ -1,4 +1,5 @@
 #include "caffe2/operators/fc_inference.h"
+#include "caffe2/core/types.h"
 
 namespace caffe2 {
 std::vector<TensorShape> FCShapeInference(
@@ -51,11 +52,12 @@ OpSchema::Cost CostInferenceForFC(
       ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
       : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
 
-  const auto& X = in[0];
+  auto const& X_element_size_byte =
+      DataTypeToTypeMeta(in[0].data_type()).itemsize();
   c.flops = M * N * (2 * K + 1);
-  c.bytes_read = (K * (M + N) + N) * sizeof(X.data_type());
-  c.bytes_written = M * N * sizeof(X.data_type());
-  c.params_bytes = (K * N + N) * sizeof(X.data_type());
+  c.bytes_read = (K * (M + N) + N) * X_element_size_byte;
+  c.bytes_written = M * N * X_element_size_byte;
+  c.params_bytes = (K * N + N) * X_element_size_byte;
   return c;
 }
 
@@ -94,7 +96,11 @@ OpSchema::Cost CostInferenceForFCGradient(
 
   CAFFE_ENFORCE_LT(0, out.size());
   const TensorShape dW = out[0];
+  auto const& dW_element_size_byte =
+      DataTypeToTypeMeta(dW.data_type()).itemsize();
   const TensorShape db = out[1];
+  auto const& db_element_size_byte =
+      DataTypeToTypeMeta(db.data_type()).itemsize();
 
   auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
   const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
@@ -111,15 +117,17 @@ OpSchema::Cost CostInferenceForFCGradient(
   uint64_t size_db = nElemFromDim(db);
 
   c.flops = M * N * (2 * K + 1);
-  c.bytes_written = (size_dW + size_db) * sizeof(float);
+  c.bytes_written =
+      size_dW * dW_element_size_byte + size_db * db_element_size_byte;
   c.params_bytes = (K * N + N) * sizeof(float);
 
   if (out.size() == 3) {
     const TensorShape dX = out[2];
     uint64_t size_dX = nElemFromDim(dX);
-
+    auto const& dX_element_size_byte =
+        DataTypeToTypeMeta(dX.data_type()).itemsize();
     c.flops += 2 * M * N * K;
-    c.bytes_written += size_dX * sizeof(float);
+    c.bytes_written += size_dX * dX_element_size_byte;
   }
   return c;
 }
index c3eaf05..55c73a5 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "caffe2/core/operator.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/core/types.h"
 
 namespace caffe2 {
 
@@ -78,12 +79,21 @@ OpSchema::Cost CostInferenceForBatchOneHot(
   const auto& length = in[1];
   const auto& values = in[2];
 
-  uint64_t nBytesData = nElemFromDim(data) * sizeof(data.data_type());
-  uint64_t nBytesLength = nElemFromDim(length) * sizeof(length.data_type());
-  uint64_t nBytesValues = nElemFromDim(values) * sizeof(values.data_type());
+  auto const& data_element_size_byte =
+      DataTypeToTypeMeta(data.data_type()).itemsize();
+  auto const& length_element_size_byte =
+      DataTypeToTypeMeta(length.data_type()).itemsize();
+  auto const& values_element_size_byte =
+      DataTypeToTypeMeta(values.data_type()).itemsize();
+  auto const& output_element_size_byte =
+      DataTypeToTypeMeta(output.data_type()).itemsize();
+
+  uint64_t nBytesData = nElemFromDim(data) * data_element_size_byte;
+  uint64_t nBytesLength = nElemFromDim(length) * length_element_size_byte;
+  uint64_t nBytesValues = nElemFromDim(values) * values_element_size_byte;
   c.flops = 0;
   c.bytes_read = nBytesData + nBytesLength + nBytesValues;
-  c.bytes_written = nElemFromDim(output) * sizeof(output.data_type());
+  c.bytes_written = nElemFromDim(output) * output_element_size_byte;
   c.params_bytes = 0;
   return c;
 }
@@ -145,15 +155,15 @@ bool BatchBucketOneHotOp<CPUContext>::RunOnDevice() {
     for (int64_t j = 0; j < D; j++) {
       // here we assume the boundary values for each feature are sorted
       int64_t lower_bucket_idx = std::lower_bound(
-                                    boundaries_offset,
-                                    boundaries_offset + lens_data[j],
-                                    input_data[pos]) -
+                                     boundaries_offset,
+                                     boundaries_offset + lens_data[j],
+                                     input_data[pos]) -
           boundaries_offset;
 
       int64_t upper_bucket_idx = std::upper_bound(
-                                    boundaries_offset,
-                                    boundaries_offset + lens_data[j],
-                                    input_data[pos]) -
+                                     boundaries_offset,
+                                     boundaries_offset + lens_data[j],
+                                     input_data[pos]) -
           boundaries_offset;
 
       int64_t bucket_idx = (lower_bucket_idx + upper_bucket_idx) / 2;
index 8b5e116..561da91 100644 (file)
@@ -1,6 +1,7 @@
 #include "caffe2/operators/utility_ops.h"
 #include <cmath>
 #include <iostream>
+#include "caffe2/core/types.h"
 #include "caffe2/utils/eigen_utils.h"
 
 namespace caffe2 {
@@ -34,9 +35,11 @@ OpSchema::Cost CostInferenceForWeightedSum(
   const auto& nElem = nElemFromDim(X0);
   const auto& nInputs = in.size();
   c.flops = (nInputs - 1) * nElem;
-  c.bytes_read = (nInputs / 2) * (nElem + 1) * sizeof(X0.data_type());
-  c.bytes_written = nElem * sizeof(X0.data_type());
-  c.params_bytes = (nInputs / 2) * sizeof(X0.data_type());
+  auto const& X0_element_size_byte =
+      DataTypeToTypeMeta(X0.data_type()).itemsize();
+  c.bytes_read = (nInputs / 2) * (nElem + 1) * X0_element_size_byte;
+  c.bytes_written = nElem * X0_element_size_byte;
+  c.params_bytes = (nInputs / 2) * X0_element_size_byte;
   return c;
 }
 
@@ -48,9 +51,7 @@ REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>);
 REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>);
 REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<CPUContext>);
 REGISTER_CPU_OPERATOR(WeightedSumGradient, WeightedSumGradientOp<CPUContext>);
-REGISTER_CPU_OPERATOR(
-    ScatterWeightedSum,
-    ScatterWeightedSumOp<CPUContext>);
+REGISTER_CPU_OPERATOR(ScatterWeightedSum, ScatterWeightedSumOp<CPUContext>);
 REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
 REGISTER_CPU_OPERATOR(Scatter, ScatterOp<CPUContext>);
 
index 996b330..7dab4d6 100644 (file)
@@ -7,33 +7,39 @@ from caffe2.python.test_util import TestCase
 
 class TestConcatOpCost(TestCase):
     def test_columnwise_concat(self):
-        workspace.ResetWorkspace()
-        workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
-        workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=np.int32))
-        concat_op = core.CreateOperator(
-            "Concat",
-            ["input_1", "input_2"],
-            ["output", "split_info"],
-        )
-        workspace.RunOperatorOnce(concat_op)
+        def _test_columnwise_concat_for_type(dtype):
+            workspace.ResetWorkspace()
+            workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
+            workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=dtype))
+            concat_op = core.CreateOperator(
+                "Concat",
+                ["input_1", "input_2"],
+                ["output", "split_info"],
+            )
+            workspace.RunOperatorOnce(concat_op)
 
-        output = workspace.FetchBlob("output")
-        self.assertTupleEqual(output.shape, (2, 4))
-        np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]])
+            output = workspace.FetchBlob("output")
+            self.assertTupleEqual(output.shape, (2, 4))
+            np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]])
 
-        flops, bytes_written, bytes_read = workspace.GetOperatorCost(
-            concat_op, concat_op.input
-        )
+            flops, bytes_written, bytes_read = workspace.GetOperatorCost(
+                concat_op, concat_op.input
+            )
 
-        self.assertEqual(flops, 0)
-        self.assertEqual(
-            bytes_read,
-            sum(workspace.FetchBlob(b).nbytes for b in concat_op.input),
-        )
-        self.assertEqual(
-            bytes_written,
-            sum(workspace.FetchBlob(b).nbytes for b in concat_op.output),
-        )
+            self.assertEqual(flops, 0)
+            self.assertEqual(
+                bytes_read,
+                sum(workspace.FetchBlob(b).nbytes for b in concat_op.input),
+            )
+            self.assertEqual(
+                bytes_written,
+                sum(workspace.FetchBlob(b).nbytes for b in concat_op.output),
+            )
+
+        [
+            _test_columnwise_concat_for_type(t)
+            for t in [np.int64, np.float, np.half, np.int8]
+        ]
 
     def test_split_then_concat(self):
         workspace.ResetWorkspace()
index afb2065..1bf7b60 100644 (file)
@@ -60,7 +60,7 @@ class TestWorkspace(unittest.TestCase):
         self.assertTupleEqual(
             op_cost,
             namedtuple("Cost", ["flops", "bytes_written", "bytes_read"])(
-                1152, 256, 2084
+                1152, 256, 4168
             ),
         )
 
index 0de50f0..0b6f604 100644 (file)
@@ -1,4 +1,5 @@
 #include "adagrad_op.h"
+#include "caffe2/core/types.h"
 
 namespace caffe2 {
 
@@ -23,22 +24,30 @@ static OpSchema::Cost CostInferenceForAdagrad(
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
   c.flops = grad_size * 10;
 
+  auto const& moment_element_size_byte =
+      DataTypeToTypeMeta(moment.data_type()).itemsize();
+  auto const& param_element_size_byte =
+      DataTypeToTypeMeta(param.data_type()).itemsize();
+  auto const& grad_element_size_byte =
+      DataTypeToTypeMeta(grad.data_type()).itemsize();
+  auto const& lr_element_size_byte =
+      DataTypeToTypeMeta(lr.data_type()).itemsize();
   uint64_t bytes_written =
-      grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
+      grad_size * param_element_size_byte + moment_element_size_byte;
 
   if (output_size == 3) {
     // also need to output effective learning rate in this case
     // assume it's the same data type as lr
-    bytes_written += grad_size * sizeof(lr.data_type());
+    bytes_written += grad_size * lr_element_size_byte;
   } else if (output_size == 4) {
     // also need to output effective learning rate and updates in this case
     // assume update is the same data type as param
     bytes_written +=
-        grad_size * (sizeof(lr.data_type()) + sizeof(param.data_type()));
+        grad_size * (lr_element_size_byte + param_element_size_byte);
   }
   c.bytes_written = bytes_written;
   c.bytes_read = c.bytes_written +
-      grad_size * (sizeof(grad.data_type()) + sizeof(lr.data_type()));
+      grad_size * (grad_element_size_byte + lr_element_size_byte);
 
   return c;
 }
@@ -102,10 +111,18 @@ static OpSchema::Cost CostInferenceForSparseAdagrad(
   // (optimistically count sqrt as one flop).
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
   c.flops = grad_size * 7;
+  auto const& param_element_size_byte =
+      DataTypeToTypeMeta(param.data_type()).itemsize();
+  auto const& moment_element_size_byte =
+      DataTypeToTypeMeta(moment.data_type()).itemsize();
   c.bytes_written =
-      grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
-  c.bytes_read = c.bytes_written + grad_size * sizeof(grad.data_type()) +
-      n * sizeof(indices.data_type());
+      grad_size * (param_element_size_byte + moment_element_size_byte);
+  auto const& grad_element_size_byte =
+      DataTypeToTypeMeta(grad.data_type()).itemsize();
+  auto const& indices_element_size_byte =
+      DataTypeToTypeMeta(indices.data_type()).itemsize();
+  c.bytes_read = c.bytes_written + grad_size * grad_element_size_byte +
+      n * indices_element_size_byte;
 
   return c;
 }
@@ -153,6 +170,16 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad(
   OpSchema::Cost c;
 
   if (n > 0) {
+    auto const& param_element_size_byte =
+        DataTypeToTypeMeta(param.data_type()).itemsize();
+    auto const& moment_element_size_byte =
+        DataTypeToTypeMeta(moment.data_type()).itemsize();
+    auto const& grad_element_size_byte =
+        DataTypeToTypeMeta(grad.data_type()).itemsize();
+    auto const& indices_element_size_byte =
+        DataTypeToTypeMeta(indices.data_type()).itemsize();
+    auto const& lr_element_size_byte =
+        DataTypeToTypeMeta(lr.data_type()).itemsize();
     auto block_size = grad_size / n;
     if (block_size == 1) {
       // +2: applying weight decay and add to grads
@@ -161,22 +188,22 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad(
       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
       c.flops = n * 9;
       c.bytes_written =
-          n * (sizeof(param.data_type()) + sizeof(moment.data_type()));
+          n * (param_element_size_byte + moment_element_size_byte);
       c.bytes_read = c.bytes_written +
           n *
-              (sizeof(grad.data_type()) + sizeof(indices.data_type()) +
-               sizeof(lr.data_type()));
+              (grad_element_size_byte + indices_element_size_byte +
+               lr_element_size_byte);
     } else {
       // 5 per block (not counting index transforms)
       // 8 for each value of a block
       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
       c.flops = n * (5 + (block_size * 8));
-      c.bytes_written =
-          n * sizeof(moment.data_type()) + n * block_size * (param.data_type());
+      c.bytes_written = n * moment_element_size_byte +
+          n * block_size * param_element_size_byte;
 
-      c.bytes_read = c.bytes_written + n * (sizeof(lr.data_type())) +
+      c.bytes_read = c.bytes_written + n * lr_element_size_byte +
           2 * n * block_size *
-              (sizeof(grad.data_type()) + sizeof(param.data_type()));
+              (grad_element_size_byte + param_element_size_byte);
     }
   }
   return c;